Sub switch pattern

Found interesting pattern to write some common code for some enum values.

switch (instruction) {
    inline .add,
    .subtract,
    .multiply,
    .divide,
    => |op| {
        const b = vm.stack.popOrNull() orelse return error.Runtime;
        const a = vm.stack.popOrNull() orelse return error.Runtime;
        const result = switch (op) {
            .add => a.inner + b.inner,
            .subtract => a.inner - b.inner,
            .multiply => a.inner * b.inner,
            .divide => a.inner / b.inner,
            else => comptime unreachable,
        };
        vm.stack.appendAssumeCapacity(.{ .inner = result });
    },
    //....some other non binary operators
}

If I add new binary instruction like so:

switch (instruction) {
    inline .add,
    .subtract,
    .multiply,
    .divide,
    .power, // new
    => |op| {
        const b = vm.stack.popOrNull() orelse return error.Runtime;
        const a = vm.stack.popOrNull() orelse return error.Runtime;
        const result = switch (op) {
            .add => a.inner + b.inner,
            .subtract => a.inner - b.inner,
            .multiply => a.inner * b.inner,
            .divide => a.inner / b.inner,
            else => comptime unreachable, // reached unreachable
        };
        vm.stack.appendAssumeCapacity(.{ .inner = result });
    },
    //....some other non binary operators
}

It will result in compile error since it will reach comptime unreachable in switch. And its pretty nice

But what if I remove binary operator

switch (instruction) {
    inline .add,
    .subtract,
    .multiply, // how needs division anyway
    => |op| {
        const b = vm.stack.popOrNull() orelse return error.Runtime;
        const a = vm.stack.popOrNull() orelse return error.Runtime;
        const result = switch (op) {
            .add => a.inner + b.inner,
            .subtract => a.inner - b.inner,
            .multiply => a.inner * b.inner,
            .divide => a.inner / b.inner,
            else => comptime unreachable,
        };
        vm.stack.appendAssumeCapacity(.{ .inner = result });
    },
    //....some other non binary operators
}

This will compile :frowning:
Not really a big problem since it’s rare to remove instructions but maybe there is some cleaver solution?

This should only compile if you have an else clause in the outer switch.
I’ve been using this pattern as well, it can be really great to split a single loop that contains branches into many branchless loops, but you have to be careful about code bloat.

I wasn’t aware of this inline switch syntax. At first glance it looks like |op| would be the tagged union payload, but I guess it captures just the enum tag instead? I guess it narrows down the valid choices to those that were used in the parent switch? Pretty neat!

So, I have to admit this nerd-sniped me a little bit, and I’ve come up with a solution that would allow one to write code like:

const Operator = enum {
    pub const Arity = enum { nullary, unary, binary };

    halt,
    neg,
    inc,
    add,
    sub,

    pub fn arity(self: Operator) Arity {
        return switch (self) {
            .halt => .nullary,
            .neg, .inc => .unary,
            .add, .sub => .binary,
        };
    }
};

pub fn main() !void {
    var op: Operator = .neg;
    _ = .{&op};

    switch (groupBy(op, Operator.arity)) {
        .nullary => |o| {
            switch (o) {
                .halt => std.debug.print("halt\n", .{}),
            }
        },
        .unary => |o| {
            switch (o) {
                .neg => std.debug.print("negation\n", .{}),
                .inc => std.debug.print("increment\n", .{}),
            }
        },
        .binary => |o| {
            switch (o) {
                .add => std.debug.print("addition\n", .{}),
                .sub => std.debug.print("subtraction", .{}),
            }
        }
    }
}

The solution is heavily inspired by Tagged Union Subsets. It’s somewhat more generic, but using enums rather than unions.

The groupBy function constructs a union whose tag type is the return type of the passed function (in this case Operator.arity), and whose fields are enums whose valid values are the subsets given by the return values of that function. The value passed to group by is then converted into the corresponding union field.

The return type of groupBy(op, Operator.arity) will look something like this:

union(Operator.Arity) {
    nullary: enum { halt },
    unary: enum { neg, inc },
    binary: enum { sub, add },
}

The code isn’t pretty, and could do with some cleanup (and with some tweaks could be modified to handle union(enum) as well):

const Type = std.builtin.Type;

pub fn SubEnum(T: type, comptime distinguisher: anytype) type {
    const Tags = ReturnType(distinguisher);

    var fields: []const Type.UnionField = &.{};
    inline for (std.meta.fields(Tags)) |f| {
        var innerFields: []const Type.EnumField = &.{};
        inline for (std.meta.fields(T)) |tf| {
            if (std.mem.eql(
                u8,
                @tagName(@call(.auto, distinguisher, .{@field(T, tf.name)})),
                f.name,
            )) {
                innerFields = innerFields ++ .{
                    tf,
                };
            }
        }
        const innerType = @Type(.{ .@"enum" = .{
            .tag_type = @typeInfo(T).@"enum".tag_type,
            .fields = innerFields,
            .decls = &.{},
            .is_exhaustive = true,
        } });
        fields = fields ++ .{Type.UnionField{
            .name = f.name,
            .type = innerType,
            .alignment = @alignOf(innerType),
        }};
    }

    return @Type(.{ .@"union" = .{
        .layout = .auto,
        .tag_type = Tags,
        .fields = fields,
        .decls = &.{},
    } });
}

fn ReturnType(f: anytype) type {
    return @typeInfo(@TypeOf(f)).@"fn".return_type.?;
}

pub fn groupBy(e: anytype, distinguisher: anytype) SubEnum(@TypeOf(e), distinguisher) {
    const SubType = ReturnType(distinguisher);

    inline for (comptime std.meta.tags(SubType)) |tag| {
        if (@call(.auto, distinguisher, .{e}) == tag) {
            return @unionInit(
                SubEnum(@TypeOf(e), distinguisher),
                @tagName(tag),
                @enumFromInt(@intFromEnum(e)),
            );
        }
    }
    unreachable;
}
1 Like

This one will work thanks. I thought about something in that direction but went other route