Polystate: Composable Finite State Machines

Complete Example

const std = @import("std");

pub fn main() !void {
    const StartingState = Witness(Example.Exit, Example.AOrB(Example.State1, Example.State2));

    var gst: GlobalState = .init;

    sw: switch (StartingState.transition(&gst)) {
        .Current => |function| {
            continue :sw function(&gst);
        },
        .Exit => {},
    }
}

pub const GlobalState = struct {
    counter: u64,

    prng: std.Random.DefaultPrng,

    pub const init: GlobalState = .{
        .counter = 0,
        .prng = .init(123),
    };
};

pub fn Witness(End: type, Current: type) type {
    if (Current == End) {
        return struct {
            pub const name = Current.name;
            pub fn transition(gst: *GlobalState) TransitionResult {
                _ = gst;
                std.debug.print("end: {s} \n", .{Current.name});
                return .Exit;
            }
        };
    } else {
        return struct {
            pub const name = Current.name;
            pub fn transition(gst: *GlobalState) TransitionResult {
                switch (Current.transitionInt(gst)) {
                    inline else => |wit, tag| {
                        _ = tag;
                        std.debug.print("{s} -> {s}\n", .{ Current.name, @TypeOf(wit).name });
                        return .{ .Current = @TypeOf(wit).transition };
                    },
                }
            }
        };
    }
}

pub const Example = struct {
    pub const Exit = union(enum) {
        pub const name = getName(Example, Exit);
    };

    pub const State1 = union(enum) {
        exit: Witness(Exit, Exit),
        toAOrB: Witness(Exit, AOrB(State1, State2)),

        pub const name = getName(Example, State1);

        pub fn transitionInt(gst: *GlobalState) @This() {
            gst.counter += 1;
            if (gst.counter >= 10) {
                return .exit;
            }

            return .toAOrB;
        }
    };

    pub const State2 = union(enum) {
        exit: Witness(Exit, Exit),
        toAOrB: Witness(Exit, AOrB(State1, State2)),

        pub const name = getName(Example, State2);

        pub fn transitionInt(gst: *GlobalState) @This() {
            gst.counter += 1;
            if (gst.counter >= 10) {
                return .exit;
            }

            return .toAOrB;
        }
    };

    pub fn AOrB(comptime A: type, comptime B: type) type {
        return union(enum) {
            toA: Witness(Exit, A),
            toB: Witness(Exit, B),
            toSelf: Witness(Exit, AOrB(A, B)),

            pub const name = getComposedName(Example, AOrB, .{ A, B });

            const Self = @This();

            pub fn transitionInt(gst: *GlobalState) @This() {
                const random = gst.prng.random();

                if (random.intRangeLessThan(usize, 0, 3) != 0) {
                    return .toSelf;
                }

                if (random.boolean()) {
                    return .toA;
                }

                return .toB;
            }
        };
    }
};

// Utilities:

pub fn getName(comptime Container: type, comptime decl: anytype) []const u8 {
    comptime {
        for (std.meta.declarations(Container)) |possible_decl| {
            const name = possible_decl.name;
            const decl_value = @field(Container, name);

            if (@TypeOf(decl_value) == @TypeOf(decl) and decl_value == decl) {
                return name;
            }
        } else unreachable;
    }
}

pub const TransitionResult = union(enum) {
    Exit: void,
    Current: *const fn (gst: *GlobalState) TransitionResult,
};

pub fn getComposedName(comptime Container: type, comptime decl: anytype, comptime args: anytype) []const u8 {
    comptime {
        var composed_name: []const u8 = "";

        composed_name = composed_name ++ getName(Container, decl) ++ "(";

        for (&args, 0..) |arg, i| {
            if (i > 0) {
                composed_name = composed_name ++ ", ";
            }
            composed_name = composed_name ++ arg.name;
        }

        composed_name = composed_name ++ ")";

        return composed_name;
    }
}

I feel like I’ve reached the end.
@milogreg Maybe we can talk about specific function naming!

Kill the Example shell! We get first-class composability, and have an Exit that is common to all state machines. We can use any state without any registration.

const std = @import("std");

pub fn main() !void {
    const StartingState = Witness(AOrB(State1, State2));

    var gst: GlobalState = .init;

    sw: switch (StartingState.transition(&gst)) {
        .Current => |function| {
            continue :sw function(&gst);
        },
        .Exit => {},
    }
}

pub const Exit = union(enum) {};

pub fn Witness(Current: type) type {
    if (Current == Exit) {
        return struct {
            pub const CST = Current;
            pub const name = Current.name;
            pub fn transition(gst: *GlobalState) TransitionResult {
                _ = gst;
                return .Exit;
            }
        };
    } else {
        return struct {
            pub const CST = Current;
            pub fn transition(gst: *GlobalState) TransitionResult {
                switch (Current.transitionInt(gst)) {
                    inline else => |wit, tag| {
                        _ = tag;
                        std.debug.print("{s} -> {s}\n", .{ @typeName(Current), @typeName(@TypeOf(wit).CST) });
                        return .{ .Current = @TypeOf(wit).transition };
                    },
                }
            }
        };
    }
}

pub const GlobalState = struct {
    counter: u64,

    prng: std.Random.DefaultPrng,

    pub const init: GlobalState = .{
        .counter = 0,
        .prng = .init(123),
    };
};

pub const State1 = union(enum) {
    exit: Witness(Exit),
    toAOrB: Witness(AOrB(State1, State2)),

    pub fn transitionInt(gst: *GlobalState) @This() {
        gst.counter += 1;
        if (gst.counter >= 10) {
            return .exit;
        }

        return .toAOrB;
    }
};

pub const State2 = union(enum) {
    exit: Witness(Exit),
    toAOrB: Witness(AOrB(State1, State2)),

    pub fn transitionInt(gst: *GlobalState) @This() {
        gst.counter += 1;
        if (gst.counter >= 10) {
            return .exit;
        }

        return .toAOrB;
    }
};

pub fn AOrB(comptime A: type, comptime B: type) type {
    return union(enum) {
        toA: Witness(A),
        toB: Witness(B),
        toSelf: Witness(AOrB(A, B)),

        const Self = @This();

        pub fn transitionInt(gst: *GlobalState) @This() {
            const random = gst.prng.random();

            if (random.intRangeLessThan(usize, 0, 3) != 0) {
                return .toSelf;
            }

            if (random.boolean()) {
                return .toA;
            }

            return .toB;
        }
    };
}

pub const TransitionResult = union(enum) {
    Exit: void,
    Current: *const fn (gst: *GlobalState) TransitionResult,
};

This will produce the following output:

p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.State1
p2.State1 -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.State1
p2.State1 -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.State2
p2.State2 -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.State2
p2.State2 -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.State2
p2.State2 -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.State2
p2.State2 -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.State1
p2.State1 -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.State2
p2.State2 -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.State1
p2.State1 -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.AOrB(p2.State1,p2.State2)
p2.AOrB(p2.State1,p2.State2) -> p2.State2
p2.State2 -> p2.Exit

Nice, I was considering using @typeName as well, but wanted to keep the manual names for edge cases where we want to tune them. Perhaps including a name definition could be optional, and the @typeName is used as a fallback.

In other news, I want to share a new way of structuring things I’ve been working on that should eliminate the need to wrap every transition type in Witness. Instead of using Witness to give the states a transition function everywhere, we simply wrap a state with it when we need to execute the transition function.

To achieve this, I made the result type of the transitionInt function be its own type, rather than the state type itself (@This()). This eliminates the circular dependency issues.

I also did some things to make the Exit state more consistent with other states, made the library functions generic over different GlobalState types, and made some naming changes.

The name changes are:

  • WitnessWithTransition. This more clearly expresses what Witness is doing, as it is basically giving you back your state, but with a transition function.
  • gst: *GlobalStatectx: *Context. This removes the confusion between state machine states and the data state. It’s also a common zig naming convention for what we are doing with GlobalState.
  • transitionIntnextState. This more explicitly describes the difference between transition and transitionInt, with transitionInt being the function that produces the next state and transition being the function that allows you to transition to the next state. The naming of nextState is similar to that of genMsg.

Here’s the code, it now also includes a Runner state to demonstrate composition of function states:

const std = @import("std");

pub fn main() !void {
    const StartingState = Example.AOrB(Example.State1, Example.State2);

    // Run it statically:
    {
        var ctx: ExampleContext = .init;

        staticStateMachine(ExampleContext, StartingState, Example.Exit)(&ctx);
    }

    // Run it with function pointers:
    {
        var ctx: ExampleContext = .init;

        sw: switch (WithTransition(ExampleContext, StartingState, Example.Exit).transition(&ctx)) {
            .current => |function| {
                continue :sw function(&ctx);
            },
            .exit => {},
        }
    }
}

pub const ExampleContext = struct {
    counter: u64,
    prng: std.Random.DefaultPrng,

    pub const init: ExampleContext = .{
        .counter = 0,
        .prng = .init(123),
    };
};

pub const Example = struct {
    pub const Exit = struct {
        pub const name = getName(Example, Exit, .{});

        pub fn nextState(_: *ExampleContext) union(enum) {
            to_self: Exit,
        } {
            return .to_self;
        }
    };

    pub const State1 = struct {
        pub const name = getName(Example, State1, .{});

        pub fn nextState(ctx: *ExampleContext) union(enum) {
            exit: Exit,
            to_a_or_b: AOrB(State1, State2),
        } {
            ctx.counter += 1;
            if (ctx.counter >= 10) {
                return .exit;
            }

            return .to_a_or_b;
        }
    };

    pub const State2 = struct {
        pub const name = getName(Example, State2, .{});

        pub fn nextState(ctx: *ExampleContext) union(enum) {
            exit: Exit,
            to_a_or_b: AOrB(State1, State2),
        } {
            ctx.counter += 1;
            if (ctx.counter >= 10) {
                return .exit;
            }

            return .to_a_or_b;
        }
    };

    pub fn AOrB(comptime A: type, comptime B: type) type {
        return struct {
            pub const name = getName(Example, AOrB, .{ A, B });

            pub fn nextState(ctx: *ExampleContext) union(enum) {
                to_a: A,
                to_b: B,
                to_self: AOrB(A, B),
                to_runner: Runner(A, B, AOrB),
            } {
                const random = ctx.prng.random();

                if (random.intRangeLessThan(usize, 0, 5) == 0) {
                    return .to_runner;
                }

                if (random.intRangeLessThan(usize, 0, 3) != 0) {
                    return .to_self;
                }

                if (random.boolean()) {
                    return .to_a;
                }

                return .to_b;
            }
        };
    }

    pub fn Runner(comptime A: type, comptime B: type, comptime Composer: fn (comptime type, comptime type) type) type {
        return struct {
            pub const name = getName(Example, Runner, .{ A, B, Composer });

            pub fn nextState(ctx: *ExampleContext) union(enum) {
                to_composer: Composer(A, B),
                to_self: Runner(A, B, Composer),
            } {
                const random = ctx.prng.random();

                if (random.intRangeLessThan(usize, 0, 3) != 0) {
                    return .to_self;
                }

                return .to_composer;
            }
        };
    }
};

// Utilities:

pub fn WithTransition(comptime Context: type, comptime CurrentState: type, comptime EndingState: type) type {
    return struct {
        pub const name = CurrentState.name;

        pub const transition = if (CurrentState == EndingState) exitTransition else regularTransition;

        fn exitTransition(_: *Context) TransitionResult(Context) {
            return .exit;
        }

        fn regularTransition(ctx: *Context) TransitionResult(Context) {
            switch (CurrentState.nextState(ctx)) {
                inline else => |next_state| {
                    const NextState = WithTransition(Context, @TypeOf(next_state), EndingState);
                    std.debug.print("{s} -> {s}\n", .{ CurrentState.name, NextState.name });
                    return .{ .current = NextState.transition };
                },
            }
        }
    };
}

pub fn staticStateMachine(comptime Context: type, comptime StartingState: type, comptime EndingState: type) fn (ctx: *Context) void {
    comptime {
        const reachable_states = getReachableStates(StartingState);

        const ending_state_int = std.mem.indexOfScalar(type, reachable_states, EndingState) orelse unreachable;

        const S = struct {
            pub fn traverse(ctx: *Context) void {
                const starting_state_int = comptime std.mem.indexOfScalar(type, reachable_states, StartingState) orelse unreachable;
                sw: switch (starting_state_int) {
                    inline 0...reachable_states.len - 1 => |current_state_int| {
                        if (current_state_int == ending_state_int) {
                            return;
                        }

                        const CurrentState = reachable_states[current_state_int];

                        switch (CurrentState.nextState(ctx)) {
                            inline else => |next_state| {
                                const NextState = @TypeOf(next_state);
                                std.debug.print("{s} -> {s}\n", .{ CurrentState.name, NextState.name });
                                continue :sw comptime std.mem.indexOfScalar(type, reachable_states, NextState) orelse unreachable;
                            },
                        }
                    },
                    else => unreachable,
                }
            }
        };

        return S.traverse;
    }
}

pub fn getReachableStates(comptime StartingState: type) []const type {
    comptime {
        return getReachableStatesRecursive(StartingState, &.{});
    }
}

fn getReachableStatesRecursive(comptime CurrentState: type, comptime reachable_states: []const type) []const type {
    comptime {
        for (reachable_states) |VisitedState| {
            if (VisitedState == CurrentState) {
                return reachable_states;
            }
        }

        var new_reachable_states: []const type = reachable_states ++ [_]type{CurrentState};

        const Transition = @typeInfo(@TypeOf(CurrentState.nextState)).@"fn".return_type.?;

        const fields = std.meta.fields(Transition);

        for (fields) |field| {
            const TransitionState = field.type;
            new_reachable_states = getReachableStatesRecursive(TransitionState, new_reachable_states);
        }

        return new_reachable_states;
    }
}

pub fn getName(comptime Container: type, comptime decl: anytype, comptime args: anytype) []const u8 {
    comptime {
        var res: []const u8 =
            for (std.meta.declarations(Container)) |possible_decl| {
                const name = possible_decl.name;
                const decl_value = @field(Container, name);

                if (@TypeOf(decl_value) == @TypeOf(decl) and decl_value == decl) {
                    break name;
                }
            } else unreachable;

        if (args.len > 0) {
            res = res ++ "(";

            for (&args, 0..) |arg, i| {
                if (i > 0) {
                    res = res ++ ", ";
                }
                const arg_name = if (@TypeOf(arg) == type) arg.name else getName(Container, arg, .{});
                res = res ++ arg_name;
            }

            res = res ++ ")";
        }

        return res;
    }
}

pub fn TransitionResult(comptime Context: type) type {
    return union(enum) {
        exit: void,
        current: *const fn (ctx: *Context) TransitionResult(Context),
    };
}

I wrote this before I saw your improvements with the @typeName names and universal Exit state so I didn’t incorporate them yet, but adding them should be simple.

1 Like

I don’t understand this, in my opinion the semantics of Witness is correct. Why must it be removed? From the code above, using Witness is concise and easy to understand.

Is there any reason why you must remove the Witness?

Here is a rough draft of my new Witness implementation, what do you think? Variable naming can be ignored.

Witness does not affect our construction of staticStateMachine.

Looking at this article, you probably don’t want to use Runner like this (of course you may just show a demo).
The main reason for using functions as parameters is that if the state constructed using the symbolic reasoning results is further generalized, it is inevitable to use functions as parameters. If you don’t have such a reasoning process, it is best not to do this.

Well, I’m not removing Witness, I simply renamed it to WithTransition and changed where it is used. Since the actual change produced by wrapping a state in Witness only comes into play when you need to call its transition function, it made more sense to me to just use Witness on the state in the one place where you need to transition, rather than having to manually write out Witness everywhere.

Yeah, Runner is just a demo, I mainly made it to make sure my getName function would work when a state had function arguments.

Comparing these two ways of writing, which one is the better way?

I prefer my version, which does not have a struct in the outer layer.

If you want, you can also let more people try the selection.

The semantics of the Witness here are perfect. If you know Haskell, you will know that this is very familiar with Haskell.

data Witness a where
  IntWitness :: Witness Int
  BoolWitness :: Witness Bool
  CharWitness :: Witness Char

dynamicCast IntWitness IntWitness pa = Just pa
dynamicCast BoolWitness BoolWitness pa = Just pa
dynamicCast CharWitness CharWitness pa = Just pa
dynamicCast _ _ _ = Nothing

I understand. This should not be called Witness, but Example!!!!!!!!

pub fn Example(Current: type) type {
    return Witness(GlobalState, print_st, Current);
}

pub const State1 = union(enum) {
    exit: Example(Exit),
    toAOrB: Example(AOrB(State1, State2)),

    pub fn handler(gst: *GlobalState) @This() {
        gst.counter += 1;
        if (gst.counter >= 10) {
            return .exit;
        }

        return .toAOrB;
    }

};

This way we can give each state machine a unique number.

This ensures that the state I designed for Example cannot be used by other state machines!

I like where this is heading. The next area that I think needs refining is the divide between the handler and conthandler functions, and the general logic of how we transfer control flow. I think we may be able to consolidate handler and conthandler into one concept, and implement the transition result logic (current, next, etc.) at a higher level. The goal of this would not only be to simplify the core logic of the library, but also to allow its users more flexibility in their program’s control flow.

The following is my general idea for how things should be structured. This is mostly off the top of my head so I’ll need to think about it more, but it should give you an idea of what I’m thinking:

  • The user constructs a container called StateMachine using their context type and their starting state type. This container stores their context, and an integer that serves as a runtime identifier for the state machine’s current state.
  • StateMachine will provide comptime functions to convert a state to an integer identifier and vice versa.
  • StateMachine will provide methods to run the state machine. There could be several of these methods, each with a different suspend condition. A suspend condition is a condition under which the state machine’s current state will be stored and control will be given back to the caller. This suspend condition could come in many forms, such as a callback function that is ran before or after each transition (fn (ctx: *Context) bool), or as a set of states that will trigger a suspend if they are transitioned to. So, the point at which the state machine suspends could be determined by logic in the state transitions, by logic on the caller’s side, or by both. There could also be config options in these methods that allow you to choose things like whether to provide debug logs and whether to execute the machine statically or using function pointers.
  • Calling a method that runs the state machine will execute state transitions until the suspend condition is met, starting at StateMachine’s stored state.
  • StateMachine will provide a method that tells you whether the current state is the exit state.
  • All of this provides the same functionality as the (current, next, etc.) results, but puts the burden of implementation on to the user, allowing the core conventions enforced by the library to be simpler (users don’t need to think about handler vs conthandler if they don’t want to).
  • Having the all-encompassing StateMachine container allows us to provide methods that facilitate both simple and complicated strategies, which a user can choose depending on their use case.

Nice! I like this strategy more than the plain Witness.