Need help: How to make functions with different parameters achieve the effect of mutual tail recursive calls?

const std = @import("std");

pub fn main() !void {
    var i: i128 = 0;
    foo(&i);
}

pub fn foo(i: *i128) void {
    i.* += 1;
    std.debug.print("i: {d}\n", .{i.*});
    bar(i);
}

pub fn bar(i: *i128 ) void {
    i.* += 1;
    foo(i);
}

This code will work, but will eventually overflow the stack if it runs for a long time.

If I change it to the tail-recursive version below, it will run for a long time without any stack overflow.

const std = @import("std");

pub fn main() !void {
    var i: i128 = 0;
    foo(&i);
}

pub fn foo(i: *i128) void {
    i.* += 1;
    std.debug.print("i: {d}\n", .{i.*});
    @call(.always_tail, bar, .{ i });
}

pub fn bar(i: *i128 ) void {
    i.* += 1;
    @call(.always_tail, foo, .{i});
}

However, tail recursion requires the same parameters. For example, the parameters of functions foo and bar in the above code are both *128 .

If their parameters are different, the compiler will issue an error, such as the following code.

const std = @import("std");

pub fn main() !void {
    var i: i128 = 0;
    foo(&i);
}

pub fn foo(i: *i128) void {
    i.* += 1;
    std.debug.print("i: {d}\n", .{i.*});
    @call(.always_tail, bar, .{ i, 10 });
}

pub fn bar(i: *i128, _: i32) void {
    i.* += 1;
    @call(.always_tail, foo, .{i});
}

The compiler will produce the following error:

src/main.zig:11:5: error: unable to perform tail call: type of function being called 'fn (*i128, i32) void' does not match type of calling function 'fn (*i128) void'
    @call(.always_tail, bar, .{ i, 10 });
    ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
src/main.zig:16:5: error: unable to perform tail call: type of function being called 'fn (*i128) void' does not match type of calling function 'fn (*i128, i32) void'
    @call(.always_tail, foo, .{i});
    ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

I understand why the compiler does this: taking x86 as an example, the return address is next to the parameter when the function is called. If their parameters are different, the value of the return address may be overwritten.

Due to the requirements of my project, I need to achieve a similar effect as above: with different parameters, and tail recursion calls each other.

How should this be implemented in zig? Do you have any suggestions?

You need labeled switch for this sort of thing:

const std = @import("std");

pub fn main() !void {
    var i: i128 = 0;
    call( .{ .foo = .{ .i = &i} });
}

const F = union(enum) {
    foo: struct { i: *i128 },
    bar: struct { i: *i128, j: i32},
};

pub fn call(f: F) void {
    recur: switch (f) {
        .foo => |p| {
            p.i.* += 1;
            std.debug.print("i: {d}\n", .{p.i.*});
            continue :recur .{ .bar = .{ .i = p.i, .j = 10} };
        },
        .bar => |p| {
            p.i.* += 1;
            continue :recur .{ .foo = .{ .i = p.i} };
        }
    }
}
3 Likes

Having written this, I am wondering, does the compiler ensure that enum discriminant is comptime-known for code like this? In AIR, do we encode “tailcalls” in the above example as the dispatch switch and rely on LLVM to notice direct jumps, or is there a direct jump already in the AIR?

The rel notes say that if the argument is comptime-known, the jump is direct, but here only the discriminant part of the argument is comptime known!

Thanks for your reply! This is one solution, but in my project I need to scatter these tail-recursive functions everywhere, instead of concentrating them in a switch block as you showed.

Here I show an interesting example of this approach:

const std = @import("std");

pub fn main() !void {
    const wa = ESt.EWitness(.a){};
    wa.handler()();
}
pub fn Witness(T: type, end: T, start: T) type {
    switch (@typeInfo(T)) {
        .@"enum" => |tenum| {
            const i: usize = @intFromEnum(start);
            const ename = tenum.fields[i].name;
            const stru = @field(T, ename ++ "ST");
            return struct {
                pub const witness_spec_type = T;
                pub const witness_spec_start = start;
                pub const witness_spec_end = end;

                pub fn handler(_: @This()) @TypeOf(stru.handler) {
                    return stru.handler;
                }
            };
        },
        else => @compileError("The type not support, it must be enum"),
    }
}

const ESt = enum {
    exit,
    a,
    b,
    yesOrNo,

    pub fn EWitness(s: ESt) type {
        return Witness(ESt, .exit, s);
    }
    pub const exitST = union(enum) {
        pub fn handler() void {}
    };

    pub const yesOrNoST = union(enum) {

        // wit: EWitness(.yesOrNo) = .{},
        // yes: struct { wit: EWitness(.b) = .{}, v: i32 },
        // no: EWitness(.a) = .{},

        pub fn handler(yes: anytype, no: anytype) void {
            if (genMsg()) {
                const ty = @TypeOf(yes);
                if (@hasField(ty, "wit")) {
                    yes.wit.handler()(yes.v);
                } else {
                    yes.handler()();
                }
            } else {
                const ty = @TypeOf(no);
                if (@hasField(ty, "wit")) {
                    no.wit.handler()(no.v);
                } else {
                    no.handler()();
                }
            }
        }

        var buf: [30]u8 = @splat(0);

        const stdin = std.io.getStdIn().reader();
        fn genMsg() bool {
            while (true) {
                const st = stdin.readUntilDelimiter(&buf, '\n') catch unreachable;
                if (std.mem.eql(u8, st, "yes")) {
                    return true;
                } else if (std.mem.eql(u8, st, "no")) {
                    return false;
                }
                std.debug.print("input error, retry\n", .{});
            }
        }
    };

    var counter: i32 = 0;

    pub const aST = union(enum) {
        GoB: struct {
            wit: EWitness(.yesOrNo) = .{},
            yes: struct { wit: EWitness(.b) = .{}, v: i32 },
            no: EWitness(.a) = .{},
        },
        Exit: EWitness(.exit),

        pub fn handler() void {
            switch (genMsg()) {
                .GoB => |v| v.wit.handler()(v.yes, v.no),
                .Exit => |wit| wit.handler()(),
            }
        }

        fn genMsg() @This() {
            std.debug.print("counter: {d}\n", .{counter});
            if (counter == 4) return .Exit;
            return .{ .GoB = .{ .yes = .{ .v = counter + 1 } } };
        }
    };

    pub const bST = union(enum) {
        GoA: EWitness(.a),

        pub fn handler(i: i32) void {
            counter = i;
            switch (genMsg()) {
                .GoA => |wit| wit.handler()(),
            }
        }

        fn genMsg() @This() {
            return .GoA;
        }
    };
};

There are four states here: exit, a, b, yesOrNo

The union corresponding to each state is: exitST, aST, bST, yesOrNoST,

Take aST as an example:

    pub const aST = union(enum) {
        GoB: struct {
            wit: EWitness(.yesOrNo) = .{},
            yes: struct { wit: EWitness(.b) = .{}, v: i32 },
            no: EWitness(.a) = .{},
        },
        Exit: EWitness(.exit),

        pub fn handler() void {
            switch (genMsg()) {
                .GoB => |v| v.wit.handler()(v.yes, v.no),
                .Exit => |wit| wit.handler()(),
            }
        }

        fn genMsg() @This() {
            std.debug.print("counter: {d}\n", .{counter});
            if (counter == 4) return .Exit;
            return .{ .GoB = .{ .yes = .{ .v = counter + 1 } } };
        }
    };

GoB, Exit is a message, genMsg is the front-end (interaction with the user) generating the message, and handler is the back-end (business logic) processing the message.

The meaning of this demo is: state a generates a new number (counter + 1), and sends this value to state b.
In state b, the value of counetr is updated (counter = i), and then jumps back to state a. What is special here is: the process of jumping from a to b (goB) must pass through the yesOrNo state. In this state, the user is required to enter yes or no. If yes is entered, it jumps to state b. If no is entered, it jumps back to state a.

The most interesting thing here is the implementation of yesOrNo. It is a common state that can be reused very conveniently. For example, you can modify the Exit code above like this:

        ...........
        Exit: struct {
            wit: EWitness(.yesOrNo) = .{},
            yes: EWitness(.exit) = .{},
            no: EWitness(.a) = .{},
        },

        pub fn handler() void {
            switch (genMsg()) {
                .GoB => |v| v.wit.handler()(v.yes, v.no),
                .Exit => |v| v.wit.handler()(v.yes, v.no),
            }
        }

In this way, you also need to manually enter yes to confirm when exiting. This is very cool, we have made a type-safe universal component!!!

This example shows the great power of typed-fsm. Unfortunately, this demo currently has a stack overflow problem.

1 Like

I have solved this issue, thanks everyone for the help.

1 Like

Did you gain any new insights, could you describe your solution, in case somebody else has a similar problem in the future?

3 Likes

The core idea is to put the parameters in a large structure, the caller puts the parameters needed by the callee into this structure, and the callee takes the required parameters from this large structure.
The core requirement is to ensure the type safety of the caller and callee parameters. I can’t describe it completely yet, I will write a more detailed description after using this idea in a project.
But here is a demo code:

const std = @import("std");

pub fn main() !void {
    const wa = Example.EWitness(.a){};
    var ist: Example.InternelState = undefined;
    ist.counter = 0;
    wa.handler_normal(&ist);
}

pub fn fieldOffset(comptime name: []const u8, ty: type) usize {
    var offset: usize = 0;
    inline for (@typeInfo(ty).@"struct".fields) |field| {
        if (std.mem.eql(u8, name, field.name)) {
            return offset;
        }
        offset += @sizeOf(field.type);
    }
    @compileError(std.fmt.comptimePrint("No field {s}!", .{name}));
}

pub fn Witness(T: type, start: T, IST: type) type {
    switch (@typeInfo(T)) {
        .@"enum" => |tenum| {
            const i: usize = @intFromEnum(start);
            const ename = tenum.fields[i].name;
            const union_ty = @field(T, ename ++ "ST");
            const callMode = union_ty.callMode;

            switch (callMode) {
                .IndirectCall => {
                    return struct {
                        pub inline fn handler(_: @This()) @TypeOf(union_ty.handler) {
                            return union_ty.handler;
                        }
                    };
                },
                .DirectCallNoParams => {
                    return struct {
                        pub inline fn handler_normal(_: @This(), ist: *IST) void {
                            union_ty.handler(ist);
                        }

                        pub inline fn handler(_: @This(), ist: *IST) void {
                            @call(.always_tail, union_ty.handler, .{ist});
                        }
                    };
                },
                .DirectCallRequireParams => |ty| {
                    return struct {
                        pub inline fn handler_normal(_: @This(), ist: *IST, arg: ty) void {
                            const offset = comptime fieldOffset("_internel_field_" ++ ename, IST);
                            const arr_ptr: [*]u8 = @ptrCast(ist);
                            const ptr: *ty = @ptrCast(@alignCast(arr_ptr + offset));
                            ptr.* = arg;
                            union_ty.handler(ist);
                        }

                        pub inline fn handler(_: @This(), ist: *IST, arg: ty) void {
                            const offset = comptime fieldOffset("_internel_field_" ++ ename, IST);
                            const arr_ptr: [*]u8 = @ptrCast(ist);
                            const ptr: *ty = @ptrCast(@alignCast(arr_ptr + offset));
                            ptr.* = arg;
                            @call(.always_tail, union_ty.handler, .{ist});
                        }
                    };
                },
            }
        },
        else => @compileError("The type not support, it must be enum"),
    }
}

pub fn createInternelState(baseState: type, st: type) type {
    const STArgField = std.builtin.Type.StructField;
    var buf: [200]STArgField = undefined;
    var st_arg_arrayList: std.ArrayListUnmanaged(STArgField) = .initBuffer(&buf);

    switch (@typeInfo(st)) {
        .@"enum" => |tenum| {
            for (tenum.fields) |field| {
                const union_ty = @field(st, field.name ++ "ST");
                const callMode = union_ty.callMode;

                switch (callMode) {
                    .IndirectCall => {},
                    .DirectCallNoParams => {},
                    .DirectCallRequireParams => |ty| {
                        st_arg_arrayList.appendAssumeCapacity(.{
                            .name = "_internel_field_" ++ field.name,
                            .type = ty,
                            .default_value_ptr = null,
                            .is_comptime = false,
                            .alignment = @alignOf(ty),
                        });
                    },
                }
            }
        },
        else => @compileError("The type not support, it must be enum"),
    }

    const bs_info = @typeInfo(baseState).@"struct";
    var tmp_stru: std.builtin.Type.Struct = bs_info;
    tmp_stru.fields = bs_info.fields ++ st_arg_arrayList.items;

    return @Type(.{ .@"struct" = tmp_stru });
}

pub fn CallMode(ty: type) type {
    return union(enum) {
        DirectCallNoParams: void,
        DirectCallRequireParams: type,
        IndirectCall: void,

        pub fn getArg(_: @This(), internel_state_ref: anytype) ty.callMode.DirectCallRequireParams {
            const name = @typeName(ty);
            comptime var split = std.mem.splitBackwardsScalar(u8, name, '.');
            const tmp_name = comptime split.first();
            const field_name = "_internel_field_" ++ tmp_name[0 .. tmp_name.len - 2];
            return @field(internel_state_ref, field_name);
        }
    };
}

// st_enum, st_union

const Example = enum {
    exit,
    a,
    b,
    yesOrNo,

    pub const State = struct { counter: i64 };
    pub const InternelState = createInternelState(State, @This());

    pub fn EWitness(s: Example) type {
        return Witness(Example, s, InternelState);
    }
    pub const exitST = union(enum) {
        pub const callMode: CallMode(@This()) = .DirectCallNoParams;

        pub fn handler(ist: *InternelState) void {
            std.debug.print("{any}\n", .{ist.*});
            std.debug.print("........finish!.......\n", .{});
        }
    };

    pub const aST = union(enum) {
        GoB: struct { wit: EWitness(.b) = .{}, v: i64 },
        Exit: struct {
            wit: EWitness(.yesOrNo) = .{},
            yes: EWitness(.exit) = .{},
            no: struct { wit: EWitness(.b) = .{}, v: i64 = 0 } = .{},
            msg: []const u8,
        },

        pub const callMode: CallMode(@This()) = .DirectCallNoParams;

        pub fn handler(ist: *InternelState) void {
            switch (genMsg(ist.counter)) {
                .GoB => |v| v.wit.handler(ist, v.v),
                .Exit => |v| v.wit.handler()(ist, v.yes, v.no, v.msg),
            }
        }

        fn genMsg(counter: i64) @This() {
            if (@mod(counter, 100_000_000) == 0) {
                std.debug.print(" counter: {d}\n", .{counter});
            }
            if (counter == 1_000_000_000) {
                std.debug.print("exit!\n", .{});
                return .{ .Exit = .{ .msg = 
                \\Are you sure exit!
                \\  yes or y: exit!
                \\  no or  n: goto B, set counter to 0
            } };
            }
            return .{ .GoB = .{ .v = counter + 1 } };
        }
    };

    pub const bST = union(enum) {
        GoA: EWitness(.a),

        pub const callMode: CallMode(@This()) = .{ .DirectCallRequireParams = i64 };

        pub fn handler(ist: *InternelState) void {
            const i = callMode.getArg(ist);
            ist.counter = i;
            switch (genMsg()) {
                .GoA => |wit| wit.handler(ist),
            }
        }

        fn genMsg() @This() {
            return .GoA;
        }
    };

    pub const yesOrNoST = union(enum) {
        pub const callMode: CallMode(@This()) = .IndirectCall;

        pub inline fn handler(
            ist: *InternelState,
            yes: anytype,
            no: anytype,
            str: []const u8,
        ) void {
            if (genMsg(str)) {
                const ty = @TypeOf(yes);
                if (@hasField(ty, "wit")) {
                    yes.wit.handler_tail(ist, yes.v);
                } else {
                    yes.handler(ist);
                }
            } else {
                const ty = @TypeOf(no);
                if (@hasField(ty, "wit")) {
                    no.wit.handler(ist, no.v);
                } else {
                    no.handler(ist);
                }
            }
        }

        var buf: [30]u8 = @splat(0);

        const stdin = std.io.getStdIn().reader();

        fn genMsg(str: []const u8) bool {
            std.debug.print("{s}\n", .{str});
            blk: while (true) {
                const st = stdin.readUntilDelimiter(&buf, '\n') catch |err| {
                    std.debug.print("{any}\n", .{err});
                    continue :blk;
                };
                if (std.mem.eql(u8, st, "yes") or std.mem.eql(u8, st, "y")) {
                    return true;
                } else if (std.mem.eql(u8, st, "no") or std.mem.eql(u8, st, "n")) {
                    return false;
                }
                std.debug.print("input error, retry\n", .{});
            }
        }
    };
};

4 Likes