Comptime commandline builder

I am writing a painting program, during development I decided I wanted a command line for running a subset of functions in my program. Stuff like setting the brush size, the color, loading and saving. Kind of like the vim commands.

In order to avoid repeating myself, I built this compile time reflection thing. So far it all works but I would like to know If there is a better way to do this sort of thing.

I have included a toy version of it below.

The code
const std = @import("std");

const State = struct {
    color: [4]f32 = .{ 0, 0, 0, 0 },
    size: f32 = 1,
    origin: [2]f32 = .{ 0, 0 },
};

const NamedColor = enum {
    red,
    green,
    blue,
    yellow,
    purple,
    orange,

    fn fcolor(self: @This()) [4]f32 {
        return switch (self) {
            .red => .{ 1, 0, 0, 1 },
            .green => .{ 0, 1, 0, 1 },
            .blue => .{ 0, 0, 1, 1 },
            .yellow => .{ 1, 1, 0, 1 },
            .purple => .{ 1, 0, 1, 1 },
            .orange => .{ 1, 1, 0, 1 },
        };
    }
};

pub const CommandFuncs = struct {
    pub fn set_color_rgba(state: *State, r: f32, g: f32, b: f32, a: f32) void {
        state.color = .{
            @max(0.0, @min(1.0, r)),
            @max(0.0, @min(1.0, g)),
            @max(0.0, @min(1.0, b)),
            @max(0.0, @min(1.0, a)),
        };
    }

    pub fn set_size(state: *State, s: f32) void {
        state.size = s;
    }

    pub fn set_color(state: *State, color: []const u8) void {
        const res = std.meta.stringToEnum(NamedColor, color);

        if (res) |col| state.color = col.fcolor() else print("! not a color\n");
    }

    pub fn move(state: *State, x: f32, y: f32) void {
        state.origin = .{
            x + state.origin[0],
            y + state.origin[1],
        };
    }
};

const Command = struct {
    name: []const u8,
    func: *const fn (*State, []const []const u8) void,
    n_args: usize,
};

const commandcount = @typeInfo(CommandFuncs).@"struct".decls.len;

pub const commands: [commandcount]Command = blk: {
    var c: [commandcount]Command = undefined;

    for (@typeInfo(CommandFuncs).@"struct".decls, 0..) |decl, i| {
        const func = @field(CommandFuncs, decl.name);

        const func_type = @TypeOf(func);
        const params = @typeInfo(func_type).@"fn".params;

        if (params.len == 0) @panic("Command needs at least contoller param");

        const Wrapper = struct {
            pub fn wrapper_func(state: *State, args: []const []const u8) void {
                var func_args: std.meta.ArgsTuple(func_type) = undefined;
                func_args[0] = state;
                inline for (params[1..], 0..) |p, j| {
                    if (p.type) |p_type| switch (p_type) {
                        u32 => func_args[j + 1] = std.fmt.parseInt(u32, args[j], 10) catch {
                            print("! bad u32\n");
                            return;
                        },
                        f32 => func_args[j + 1] = std.fmt.parseFloat(f32, args[j]) catch {
                            print("! bad f32\n");
                            return;
                        },
                        []const u8 => func_args[j + 1] = args[j],
                        else => @compileError("type not supported\n"),
                    };
                }

                @call(.auto, func, func_args);
            }
        };

        c[i] = .{
            .name = decl.name,
            .func = Wrapper.wrapper_func,
            .n_args = params.len - 1,
        };
    }

    break :blk c;
};

pub fn main() !void {
    const allocator = std.heap.page_allocator;

    var state = State{};

    CommandFuncs.set_size(&state, 2);

    while (true) {
        print(">>> ");
        const read_line = try std.io.getStdIn().reader().readUntilDelimiterAlloc(allocator, '\n', 256);

        var tokens = std.mem.tokenizeAny(u8, read_line, " ");

        if (tokens.next()) |first_token| {
            var found = false;
            for (commands) |command| {
                if (std.mem.eql(u8, first_token, command.name)) {
                    found = true;
                    var arg_list = std.ArrayListUnmanaged([]const u8).empty;
                    while (tokens.next()) |tok| try arg_list.append(allocator, tok);
                    if (arg_list.items.len == command.n_args) {
                        command.func(&state, arg_list.items);
                        print_state(&state);
                    } else {
                        print("! wrong number of args\n");
                    }
                    break;
                }
            }
            if (!found) print("! command not found\n");
        } else print("! no input\n");
    }
}

fn print(s: []const u8) void {
    std.io.getStdOut().writeAll(s) catch @panic("failed to write to stdout");
}

fn print_state(state: *State) void {
    const writer = std.io.getStdOut().writer();
    writer.print(".color = {d}\n.size = {d}\n", .{ state.color, state.size }) catch @panic("failed to write to stdout");
}

1 Like

It looks pretty solid for testing purposes. It doesn’t handle \r\n inputs, meaning your program may not work on windows as-is, but if it’s just a personal utility then that’s fine.

I took a shot at trying to implement it in a slightly more robust/reusable way:

const std = @import("std");

const State = struct {
    color: [4]f32 = .{ 0, 0, 0, 0 },
    size: f32 = 1,
    origin: [2]f32 = .{ 0, 0 },
};

const NamedColor = enum {
    red,
    green,
    blue,
    yellow,
    purple,
    orange,

    fn fcolor(self: @This()) [4]f32 {
        return switch (self) {
            .red => .{ 1, 0, 0, 1 },
            .green => .{ 0, 1, 0, 1 },
            .blue => .{ 0, 0, 1, 1 },
            .yellow => .{ 1, 1, 0, 1 },
            .purple => .{ 1, 0, 1, 1 },
            .orange => .{ 1, 1, 0, 1 },
        };
    }
};

pub const CommandFuncs = struct {
    pub fn set_color_rgba(state: *State, r: f32, g: f32, b: f32, a: f32) void {
        state.color = .{
            @max(0.0, @min(1.0, r)),
            @max(0.0, @min(1.0, g)),
            @max(0.0, @min(1.0, b)),
            @max(0.0, @min(1.0, a)),
        };
    }

    pub fn set_size(state: *State, s: f32) void {
        state.size = s;
    }

    pub fn set_color(state: *State, color: []const u8) void {
        const res = std.meta.stringToEnum(NamedColor, color);

        if (res) |col| state.color = col.fcolor() else print("! not a color\n");
    }

    pub fn move(state: *State, x: f32, y: f32) void {
        state.origin = .{
            x + state.origin[0],
            y + state.origin[1],
        };
    }
};

fn FunctionDeclEnum(comptime T: type) type {
    comptime {
        var res_info = @typeInfo(std.meta.DeclEnum(T));
        var new_fields: []const std.builtin.Type.EnumField = &.{};

        for (res_info.@"enum".fields) |field| {
            if (std.meta.hasFn(T, field.name)) {
                new_fields = new_fields ++ &[_]std.builtin.Type.EnumField{field};
            }
        }
        res_info.@"enum".fields = new_fields;

        return @Type(res_info);
    }
}

fn functionEnumFromString(comptime Container: type, function_name: []const u8) ?FunctionDeclEnum(Container) {
    const FunctionEnum = FunctionDeclEnum(Container);

    const function_map: std.StaticStringMap(FunctionEnum) = comptime .initComptime(blk: {
        const command_decls = std.meta.declarations(Container);

        const KeyValuePair = struct { []const u8, FunctionEnum };
        var kv_pairs: [command_decls.len]KeyValuePair = undefined;

        for (&kv_pairs, command_decls) |*kv_pair, decl| {
            const name = decl.name;
            kv_pair.* = .{ name, @field(FunctionEnum, name) };
        }

        break :blk kv_pairs;
    });

    return function_map.get(function_name);
}

fn runFunctionString(comptime Container: type, function_name: []const u8, base_params: anytype, params_string: []const u8) bool {
    const function_enum = functionEnumFromString(Container, function_name) orelse {
        print("! command not found\n");
        return false;
    };

    switch (function_enum) {
        inline else => |tag| {
            const function = @field(Container, @tagName(tag));

            var params: std.meta.ArgsTuple(@TypeOf(function)) = undefined;

            inline for (base_params, 0..) |base_param, i| {
                params[i] = base_param;
            }

            var params_string_iterator = std.mem.tokenizeScalar(u8, params_string, ' ');
            inline for (base_params.len..params.len) |i| {
                const param_string = params_string_iterator.next() orelse {
                    print("! too few args\n");
                    return false;
                };

                params[i] = switch (@TypeOf(params[i])) {
                    u32 => std.fmt.parseUnsigned(u32, param_string, 10) catch {
                        print("! bad u32\n");
                        return false;
                    },
                    f32 => std.fmt.parseFloat(f32, param_string) catch {
                        print("! bad f32\n");
                        return false;
                    },
                    []const u8 => param_string,
                    else => |T| @compileError("type not supported: '" ++ @typeName(T) ++ "'"),
                };
            }

            if (params_string_iterator.next() != null) {
                print("! too many args\n");
                return false;
            }

            @call(.auto, function, params);
            return true;
        },
    }
}

pub fn main() !void {
    var state: State = .{};

    CommandFuncs.set_size(&state, 2);

    const reader = std.io.getStdIn().reader();

    while (true) {
        print(">>> ");

        var buf: [1024]u8 = undefined;
        const read_line_untrimmed = try reader.readUntilDelimiterOrEof(&buf, '\n') orelse {
            std.debug.print("! no input\n", .{});
            continue;
        };
        const read_line = std.mem.trim(u8, read_line_untrimmed, " \t\r\n");

        var arg_iterator = std.mem.tokenizeAny(u8, read_line, " ");

        const first_arg = arg_iterator.next() orelse {
            std.debug.print("! no input\n", .{});
            continue;
        };

        const success = runFunctionString(CommandFuncs, first_arg, .{&state}, arg_iterator.rest());
        if (success) {
            print_state(&state);
        }
    }
}

fn print(s: []const u8) void {
    std.io.getStdOut().writeAll(s) catch @panic("failed to write to stdout");
}

fn print_state(state: *State) void {
    const writer = std.io.getStdOut().writer();
    writer.print(".color = {d}\n.size = {d}\n", .{ state.color, state.size }) catch @panic("failed to write to stdout");
}

I make use of std.StaticStringMap to look up the functions, which is a little overkill for this use-case, but it’s a nice thing to know about if you need to implement name-based function calls like this in a more performant way.

1 Like

Nice! I’ve been playing about to do something similar, inspired by Python libraries like typer, but I was stumped by trying to get the names of the arguments to the functions.

It seems like putting the functions into a struct is the key? It is so easy to get the types of the arguments, in comparison.

You actually can’t get argument names via reflection in Zig.

Though, I was toying with a similar problem before, and “wrote” some hacky LLM-generated code to extract the arguments by parsing the source, which, if combined with @embed, could technically allow you to know the names at comptime… if it weren’t for the allocation.

Here it is if you are interested:

const std = @import("std");

pub fn extractParameterNames(allocator: std.mem.Allocator, library_source: [:0]const u8, func_name: []const u8) ![][]const u8 {
    var ast = std.zig.Ast.parse(allocator, library_source, .zig) catch |err| {
        std.log.warn("Failed to parse AST: {}", .{err});
        return &[_][]const u8{};
    };
    defer ast.deinit(allocator);

    const token_tags = ast.tokens.items(.tag);

    var param_names = std.ArrayList([]const u8).init(allocator);
    defer param_names.deinit();

    // Find top-level function declarations only
    var i: u32 = 0;
    while (i < ast.tokens.len) {
        // Look for top-level declarations
        if (token_tags[i] == .keyword_pub or token_tags[i] == .keyword_fn) {
            var decl_start = i;

            // Skip pub/export modifiers to find the actual declaration type
            while (decl_start < ast.tokens.len and
                (token_tags[decl_start] == .keyword_pub or
                    token_tags[decl_start] == .keyword_export or
                    token_tags[decl_start] == .keyword_extern))
            {
                decl_start += 1;
            }

            if (decl_start < ast.tokens.len) {
                if (token_tags[decl_start] == .keyword_fn) {
                    // This is a function declaration - parse it
                    var name_token = decl_start + 1;
                    while (name_token < ast.tokens.len and token_tags[name_token] != .identifier) {
                        name_token += 1;
                    }

                    if (name_token < ast.tokens.len) {
                        const fn_name = ast.tokenSlice(name_token);

                        if (std.mem.eql(u8, fn_name, func_name)) {
                            // Found our function, parse parameters
                            var paren_token = name_token + 1;
                            while (paren_token < ast.tokens.len and token_tags[paren_token] != .l_paren) {
                                paren_token += 1;
                            }

                            if (paren_token < ast.tokens.len) {
                                var token_idx = paren_token + 1;

                                while (token_idx < ast.tokens.len and token_tags[token_idx] != .r_paren) {
                                    if (token_tags[token_idx] == .identifier) {
                                        if (token_idx + 1 < ast.tokens.len and token_tags[token_idx + 1] == .colon) {
                                            const param_name = ast.tokenSlice(token_idx);
                                            try param_names.append(try allocator.dupe(u8, param_name));

                                            token_idx += 2; // Skip identifier and colon
                                            token_idx = skipExpression(token_tags, token_idx);
                                            continue;
                                        }
                                    }
                                    token_idx += 1;
                                }
                            }
                            break;
                        }
                    }

                    // Skip past this function declaration
                    i = skipTopLevelDeclaration(token_tags, decl_start);
                } else {
                    // This is some other declaration (const, var, struct, etc.) - skip it entirely
                    i = skipTopLevelDeclaration(token_tags, decl_start);
                }
            } else {
                i += 1;
            }
        } else {
            i += 1;
        }
    }

    return param_names.toOwnedSlice();
}

fn skipTopLevelDeclaration(token_tags: []const std.zig.Token.Tag, start_idx: u32) u32 {
    var idx = start_idx;
    var brace_depth: u32 = 0;
    var paren_depth: u32 = 0;
    var found_content = false;

    while (idx < token_tags.len) {
        switch (token_tags[idx]) {
            .l_brace => {
                brace_depth += 1;
                found_content = true;
            },
            .r_brace => {
                if (brace_depth > 0) {
                    brace_depth -= 1;
                    if (brace_depth == 0 and found_content) {
                        return idx + 1; // End of block declaration
                    }
                }
            },
            .l_paren => paren_depth += 1,
            .r_paren => {
                if (paren_depth > 0) paren_depth -= 1;
            },
            .semicolon => {
                if (brace_depth == 0 and paren_depth == 0) {
                    return idx + 1; // End of simple declaration
                }
            },
            else => {},
        }
        idx += 1;
    }

    return idx;
}

fn skipExpression(token_tags: []const std.zig.Token.Tag, start_idx: u32) u32 {
    var idx = start_idx;
    var paren_depth: u32 = 0;
    var brace_depth: u32 = 0;
    var bracket_depth: u32 = 0;

    while (idx < token_tags.len) {
        switch (token_tags[idx]) {
            .l_paren => paren_depth += 1,
            .r_paren => {
                if (paren_depth == 0) return idx; // End of parameter list
                paren_depth -= 1;
            },
            .l_brace => brace_depth += 1,
            .r_brace => {
                if (brace_depth > 0) brace_depth -= 1;
            },
            .l_bracket => bracket_depth += 1,
            .r_bracket => {
                if (bracket_depth > 0) bracket_depth -= 1;
            },
            .comma => {
                if (paren_depth == 0 and brace_depth == 0 and bracket_depth == 0) {
                    return idx + 1; // Move past comma to next parameter
                }
            },
            else => {},
        }
        idx += 1;
    }

    return idx;
}
1 Like

Aah. I had come to that conclusion, but I thought the above code was achieving it. I guess I misunderstood something (discovered zig about a week ago, so more than possible).

It is a shame that it isn’t possible, I think being able to access the argument names at comptime would open up interesting possibilities.

Thanks for writing this out, StaticStringMap is new to me and it hadn’t occured to me I could build the function lookup table inside a function, I supposed I haven’t fully internalized the fact comptime is constant, and it doesn’t matter what the context is.