Generic retry function that re-calls a function if it returns errors I designate as retryable?

Is it possible to write a generic retry function that will re-call a function for me if it returns errors I designate as retryable?

pub fn retry(retries: u8, retryable_errors: errorset ???, target: ???) !target return type{
    for (0..retries + 1) {
        target catch |err| {
            if err in retryable_errors continue
        }
    } else {
        return the last error
    }

}

Eh, there are too many unknowns, the location of whatever is returned by target, for example

This is like std.debug.print levels of magic

I got this far:

pub fn retry(retries: u8, retryable_errors: anyerror, comptime func: type, args: anytype, ) @typeInfo(func).Fn.return_type.? {
    
    for (0..retries + 1) |i| {
        return @call(.auto, func, args); catch {
        }
    }
    
}

I don’t write that much Zig these days so this is probably not very elegant, but it is some of the way there.

const std = @import("std");

pub fn retry(
    func: anytype,
    args: anytype,
    errors: anytype,
    retries: usize
) @typeInfo(@TypeOf(func)).Fn.return_type.? {
    for (0..retries) |_| {
        const rvalue = @call(.auto, func, args);
        std.debug.print(" -> {any}\n", .{ rvalue }); 
        if (rvalue) |valid| {
            return valid;
        } else |rerr| {
            var valid_error: bool = false;
            inline for (errors) |err| {
                if (rerr == err) {
                    valid_error = true;
                }
            }
            if (!valid_error) {
                return rvalue;
            }
        }
    }

    return @call(.auto, func, args);
}

const Error = error{TooHot, TooCold, JustRight};

pub fn mayFail(m: u64) Error!u64 {
    var n: u64 = undefined;
    std.posix.getrandom(std.mem.asBytes(&n)) catch unreachable;
    n = ((m + n) % 10);
    std.debug.print("{}", .{ n });
    if (n <= 3) {
        return Error.TooCold;
    }
    if (n >= 7) {
        return Error.TooHot;
    }
    if (n == 5) {
        return Error.JustRight;
    }
    return n;
}

pub fn main() void {
    const r = retry(mayFail, .{ 11 }, .{ Error.TooHot, Error.TooCold }, 4);
    std.debug.print("\nreturned: {!}\n", .{ r });
}

I think that retry code like this is already simple and short to write by hand though, and probably clearer that way.

1 Like

Something like this:

const std = @import("std");

fn ReturnType(comptime func: anytype, comptime args_type: type) type {
    const args: args_type = undefined;
    return @TypeOf(@call(.auto, func, args));
}

fn is_allowed(comptime allowable: []const anyerror, err: anyerror) bool {
    return for (allowable) |allowed| {
        if (err == allowed) break true;
    } else false;
}

pub fn attempt(
    retries: u8,
    comptime allowable: []const anyerror,
    comptime func: anytype, 
    args: anytype,
) ReturnType(func, @TypeOf(args)) {
    
    for (0..retries) |_| {

        return @call(.auto, func, args) catch |err| {

            if (is_allowed(allowable, err)) 
                continue;

            return err;
        };
    }

    return @call(.auto, func, args);
}

Few notes… notice that I am not using @typeInfo for the return type. this is because generic functions have null return_type. So a function like the following would fail:

pub fn foo(x: anytype) @TypeOf(x) { return x; }

You’ll get this error:

main.zig:20:42: error: unable to unwrap null
) @typeInfo(@TypeOf(func)).Fn.return_type.? {

By doing this pseudo-invocation, we can directly grab the return type. This even works for runtime only arguments too (try passing an allocator to it, for instance).

Anyhow, here’s how you use it:

var global: u32 = 0;

const Error = error{ meh };

pub fn buggy(x: anytype) !@TypeOf(x) {
    if (global == 0) {
        global += 1;
        return Error.meh;
    }
    return 2 * x;
}

pub fn main() !void {

    const x: i32 = 42;

    const result = try attempt(1, &.{ Error.meh }, buggy, .{ x });

    std.debug.print("{}\n", .{ result });
}
5 Likes

Here’s another approach. Given two functions, func and retry, we’ll create a third function that repeatedly calls func until it succeeds or retry returns false:

const std = @import("std");

pub fn hello(a: u32, b: u32) !void {
    std.debug.print("Hello world {d}\n", .{a + b});
    return error.shaka_when_the_walls_fell;
}

pub fn main() !void {
    const Retry = struct {
        const allowed = .{
            error.the_beast_at_tanagra,
            error.shaka_when_the_walls_fell,
        };
        const max = 10;

        attempt: usize = 0,

        fn onError(self: *@This(), err: anyerror) bool {
            if (self.attempt <= max) {
                if (std.mem.indexOfScalar(anyerror, &allowed, err) != null) {
                    self.attempt += 1;
                    return true;
                }
            }
            return false;
        }
    };
    const retryHello = Retriable(hello, Retry.onError);
    try retryHello(123, 456);
}

Doing things this way is more flexible. It gives you the ability to introduce a delay mechanism, for example.

And here’s the code for Retriable():

pub fn Retriable(func: anytype, retry: anytype) @TypeOf(func) {
    const f = @typeInfo(@TypeOf(func)).Fn;
    const r = @typeInfo(@TypeOf(retry)).Fn;
    const RT = f.return_type.?;
    const RCPT = r.params[0].type.?;
    const RCT = @typeInfo(RCPT).Pointer.child;
    const PT = comptime extract: {
        var Types: [f.params.len]type = undefined;
        for (f.params, 0..) |param, index| {
            Types[index] = param.type.?;
        }
        break :extract Types;
    };
    const cc = f.calling_convention;

    const ns = struct {
        fn call(args: anytype) RT {
            var retry_context: ?RCT = null;
            while (true) {
                if (@call(.auto, func, args)) |result| {
                    return result;
                } else |err| {
                    if (retry_context == null) retry_context = .{};
                    if (!retry(&retry_context.?, err)) {
                        return err;
                    }
                }
            }
        }

        fn call0() callconv(cc) RT {
            return call(.{});
        }

        fn call1(a0: PT[0]) callconv(cc) RT {
            return call(.{a0});
        }

        fn call2(a0: PT[0], a1: PT[1]) callconv(cc) RT {
            return call(.{ a0, a1 });
        }

        fn call3(a0: PT[0], a1: PT[1], a2: PT[2]) callconv(cc) RT {
            return call(.{ a0, a1, a2 });
        }

        fn call4(a0: PT[0], a1: PT[1], a2: PT[2], a3: PT[3]) callconv(cc) RT {
            return call(.{ a0, a1, a2, a3 });
        }

        fn call5(a0: PT[0], a1: PT[1], a2: PT[2], a3: PT[3], a4: PT[4]) callconv(cc) RT {
            return call(.{ a0, a1, a2, a3, a4 });
        }

        fn call6(a0: PT[0], a1: PT[1], a2: PT[2], a3: PT[3], a4: PT[4], a5: PT[5]) callconv(cc) RT {
            return call(.{ a0, a1, a2, a3, a4, a5 });
        }

        fn call7(a0: PT[0], a1: PT[1], a2: PT[2], a3: PT[3], a4: PT[4], a5: PT[5], a6: PT[6]) callconv(cc) RT {
            return call(.{ a0, a1, a2, a3, a4, a5, a6 });
        }

        fn call8(a0: PT[0], a1: PT[1], a2: PT[2], a3: PT[3], a4: PT[4], a5: PT[5], a6: PT[6], a7: PT[7]) callconv(cc) RT {
            return call(.{ a0, a1, a2, a3, a4, a5, a6, a7 });
        }
    };
    const caller_name = std.fmt.comptimePrint("call{d}", .{f.params.len});
    if (!@hasDecl(ns, caller_name)) {
        @compileError("Too many arguments");
    }
    return @field(ns, caller_name);
}

The function pyramid at the bottom is sort of goofy but there’s no other way to implement this type of function transform at the moment.

From working on code suggestions for this related topic:

I found it nice to use a while loop with else branch to keep the retry logic within an iterator, while keeping all the error handling and result processing directly within the code you are working on, without needing duplicate calls of the function that can fail, I prefer this because the reusable code (the retry stuff) is abstracted away, while the code that differs from situation to situation is kept un-abstracted and can be changed easily.

Here I simplified the suggestion from the above topic (removing the time based backoff), I think one of the great things about this, is that you could switch between many other iterators and still keep the same code within the while loop, additionally using the else branch for when it failed too many times makes the code more readable.

const std = @import("std");

pub const Retry = struct {
    iter: u64,
    max_tries: u16,

    pub fn init(max_tries: u16) Retry {
        return .{ .max_tries = max_tries, .iter = 0 };
    }

    pub fn next(self: *Retry) ?void {
        if (self.iter < self.max_tries) {
            self.iter += 1;
            return;
        }
        return null;
    }
};

pub fn attempt(random: std.Random) !u32 {
    return switch (random.uintLessThan(u8, 100)) {
        0...80 => error.Retry,
        81 => error.Other,
        82...99 => |x| x,
        else => unreachable,
    };
}

pub fn main() !void {
    // zig 0.13
    var gpa = std.heap.GeneralPurposeAllocator(.{}){};
    defer _ = gpa.deinit();
    const allocator = gpa.allocator();

    const seed = try getSeed(allocator);
    var gen = std.rand.DefaultPrng.init(seed);
    const random = gen.random();

    var tries = Retry.init(10);
    while (tries.next()) |_| {
        std.debug.print("new try\n", .{});
        const result = attempt(random) catch |err| switch (err) {
            error.Retry => continue,
            else => return err,
        };
        std.debug.print("got valid result: {}\n", .{result});
        break;
    } else {
        return error.MaxTriesExhausted;
    }
}

fn getSeed(allocator: std.mem.Allocator) !u64 {
    const args = try std.process.argsAlloc(allocator);
    defer std.process.argsFree(allocator, args);

    const seed_arg = 1;
    const seed: u64 = if (args.len > seed_arg) try parseSeed(args[seed_arg]) else guessSeedFromTime();
    return seed;
}

fn guessSeedFromTime() u64 {
    const seed: u64 = @intCast(std.time.nanoTimestamp());
    std.debug.print("guessed seed: {d}\n", .{seed});
    return seed;
}

fn parseSeed(seed: []const u8) !u64 {
    std.debug.print("got seed: {s}\n", .{seed});
    return std.fmt.parseInt(u64, seed, 10);
}

I also added a random seed that either gets set via a timestamp or supplied as commandline argument, that way you can explore different possibilities and also repeat them by providing the seed.

If the source is stored in a retry.zig file you can call it like this to provide the seed: zig run retry.zig -- <seed>

seed result
1726310939861915769 error: MaxTriesExhausted
1726311882627969760 got valid result: 95
1726312050498686408 error: Other

You also could add an abort function to the iterator, which causes the next next call to return null (and thus exit via the else branch) if you wanted to write some logic within the while loop that considers giving up after every try, or maybe based on specific error codes. It would then call tries.abort(); continue;

2 Likes

Aside - I got a feeling there should be “Statements as expressions” in Docs section.

Feel free to open a doc :+1: