How to enforce Function Call Ordering

I’m writing a DBUS library for zig. DBUS allows you to call methods by sending messages over a socket. These methods encode the parameter types in a string signature, i.e. “uus” would be two u32’s and a string.

One goal with this library is to allow the caller to maintain control as its reading data. I don’t want to force the caller to have to allocate memory and read the message into memory. We can achieve this with an API like this:

const a = try msg.readU32(reader);
const b = try msg.readU32(reader);
const string_size = try msg.readStringSize(reader);
// caller can read the string however they like, i.e.
try reader.stream(writer, string_size);
// let the library know we read the string
msg.notifyConsumed(.string);
try msg.finish(); // enforce that we've read the entire message

This kind of API has a disadvantage over one that just reads an entire message into memory because a mistake in the caller’s code won’t be caught until runtime. However, I found a technique to catch this mistake at comptime, check it out:

const sig = "uus";
comptime var sig_index: usize = 0;
try msg.enforceSignature(sig);

const a = try msg.readU32(sig, &sig_index)(reader);
const b = try msg.readU32(sig, &sig_index)(reader);
const string_size = try msg.readStringSize(sig, &sig_index)(reader);
// stream from reader same as before
try reader.stream(writer, string_size);
// let library know we read the string just as before
msg.notifyConsumed(.string);
// enforce we've read the entire message again, but now we'll get a
// comptime error if our code doesn't agree with the signature!
try msg.finish(sig, &sig_index)();

We’ve introduced a signature check and at the same time, enforced that our code agrees with the signature. If you remove one of the reads or add a new one, you’ll get a compile error. Note the lines that are reading parameters now have two sets of parameters:

const a = try msg.readU32(sig, &sig_index)(reader);
                                         ^
                                         two sets of parameters

readU32 takes two comptime args which track where we our in the signature. If they disagree then we hit a @compileError, otherwise, we return a function that will take the runtime-known reader and read the value.

I tried just making readU32 take the reader as well but for some reason Zig doesn’t like it when you mix comptime var pointers with runtime-known parameters?

Here’s a full code example to experiment with yourself:

fn ReadFn(comptime signature: []const u8, comptime index: usize) type {
    if (index >= signature.len) @compileError("signature has no more types");
    return switch (signature[index]) {
        'u' => fn (*Reader) error{ ReadFailed, EndOfStream }!u32,
        's' => fn (*Reader) error{ ReadFailed, EndOfStream }![]const u8,
        else => @compileError("unknown signature char: '" ++ signature[index .. index + 1] ++ "'"),
    };
}

// NOTE: this function returns a function that reads the value rather than just reading
//       the value itself because Zig doesn't support comptime pointers when they are
//       mixed with runtime values.
fn nextReadFn(
    comptime signature: []const u8,
    comptime signature_index: *usize,
) ReadFn(signature, signature_index.*) {
    const start = signature_index.*;
    signature_index.* = signature_index.* + 1;
    return comptime switch (signature[start]) {
        'u' => readU,
        's' => readS,
        else => unreachable,
    };
}

fn finish(comptime signature: []const u8, comptime index: usize) void {
    if (index != signature.len) @compileError("the remaining signature has not been read: " ++ signature[index..]);
}

fn readU(r: *Reader) error{ ReadFailed, EndOfStream }!u32 {
    return r.takeInt(u32, .big);
}
fn readS(r: *Reader) error{ ReadFailed, EndOfStream }![]const u8 {
    return r.take(11);
}

pub fn main() !void {
    var r: Reader = .fixed("\x12\x34\x56\x78" ++ "\x9a\xbc\xde\xf0" ++ "hello there");
    try example(&r);
}
fn example(r: *Reader) !void {
    // This is an example signature that represents two u32 values and a string.
    // This API enforces that's always what's read in that order at compile time.
    const signature = "uus";
    comptime var signature_index: usize = 0;

    // If you comment out this read (or the other reads below), then you'll get
    // a compile error.
    const first_u32: u32 = try nextReadFn(signature, &signature_index)(r);
    std.debug.assert(first_u32 == 0x12345678);

    const second_u32: u32 = try nextReadFn(signature, &signature_index)(r);
    std.debug.assert(second_u32 == 0x9abcdef0);

    const string: []const u8 = try nextReadFn(signature, &signature_index)(r);
    std.debug.assert(std.mem.eql(u8, string, "hello there"));

    finish(signature, signature_index);
}
const std = @import("std");
const Reader = std.Io.Reader;

13 Likes

&sig_index is brilliant! Implementing type state by using literal state is :pinched_fingers:

2 Likes

Thanks for this. I have just started looking at implementing a dbus library, and this pops up.

1 Like

This is awesome! It’s a great use of Zig’s comptime to enhance compilation check.

You can extend this to a pre-condition/post-condition system. Have a comptime CallCtx struct to carry the calling history and various flags. Each receiving function can call CallCtx.pre(..some requirement..) on entry. The function updates CallCtx for its role. The function can call other functions and call CallCtx.post(...) to check for post condition. All these happen in comptime and throw @compileError during compilation.

1 Like

I changed it towards a comptime var struct with comptime methods, I think otherwise it should be similar to the original, but I find its ergonomics are a bit simpler on the use site:

fn readU(r: *Reader) error{ ReadFailed, EndOfStream }!u32 {
    return r.takeInt(u32, .big);
}
fn readS(r: *Reader) error{ ReadFailed, EndOfStream }![]const u8 {
    return r.take(11);
}

pub const SignatureReaderEnumerator = struct {
    signature: []const u8,
    index: usize,

    pub fn init(comptime signature: []const u8) SignatureReaderEnumerator {
        return .{ .signature = signature, .index = 0 };
    }

    pub fn next(comptime self: *SignatureReaderEnumerator) self.Impl() {
        defer self.index += 1;
        return self.impl();
    }

    pub fn finish(comptime self: SignatureReaderEnumerator) void {
        if (self.index != self.signature.len) @compileError("the remaining signature has not been read: " ++ self.signature[self.index..]);
    }

    fn Impl(comptime self: SignatureReaderEnumerator) type {
        if (self.index >= self.signature.len) @compileError("signature has no more types");
        return switch (self.signature[self.index]) {
            'u' => fn (*Reader) error{ ReadFailed, EndOfStream }!u32,
            's' => fn (*Reader) error{ ReadFailed, EndOfStream }![]const u8,
            else => @compileError("unknown signature char: '" ++ self.signature[self.index .. self.index + 1] ++ "'"),
        };
    }

    fn impl(comptime self: SignatureReaderEnumerator) self.Impl() {
        return switch (self.signature[self.index]) {
            'u' => readU,
            's' => readS,
            else => unreachable,
        };
    }
};

pub fn main() !void {
    var r: Reader = .fixed("\x12\x34\x56\x78" ++ "\x9a\xbc\xde\xf0" ++ "hello there");
    try example(&r);
}
fn example(r: *Reader) !void {
    // This is an example signature that represents two u32 values and a string.
    // This API enforces that's always what's read in that order at compile time.
    comptime var signature_reader: SignatureReaderEnumerator = .init("uus");

    // If you comment out this read (or the other reads below), then you'll get
    // a compile error.
    const first_u32: u32 = try signature_reader.next()(r);
    std.debug.assert(first_u32 == 0x12345678);

    const second_u32: u32 = try signature_reader.next()(r);
    std.debug.assert(second_u32 == 0x9abcdef0);

    const string: []const u8 = try signature_reader.next()(r);
    std.debug.assert(std.mem.eql(u8, string, "hello there"));

    signature_reader.finish();
}
const std = @import("std");
const Reader = std.Io.Reader;

I tried to use comptime defer signature_reader.finish(); but it seems comptime defer isn’t a valid language construct (which makes we wonder whether it could/should be?).

6 Likes

9 Likes

I’m starting to wonder if I shouldn’t have shared that subscriber link on hn…

Is there a strong reason to roll your own signature format instead of using a struct type? The following lets you read a struct one field at a time:

const std = @import("std");

fn StructReader(comptime Struct: type) type {
    return struct {
        pub const FieldName = std.meta.FieldEnum(Struct);
        nextFieldIndex: comptime_int = 0,

        pub fn field(self: *@This(), comptime name: FieldName) type {
            if (@intFromEnum(name) != self.nextFieldIndex) {
                const expected: FieldName = @enumFromInt(self.nextFieldIndex);
                @compileError("." ++ @tagName(expected) ++ " is expected, but received ." ++ @tagName(name));
            }
            self.nextFieldIndex += 1;
            const FT = FieldType(name);
            return struct {
                pub fn read(ptr: [*]const u8) FT {
                    const offset = @offsetOf(Struct, @tagName(name));
                    const field_value_ptr: *align(1) const FT = @ptrCast(&ptr[offset]);
                    return field_value_ptr.*;
                }
            };
        }

        fn FieldType(comptime name: FieldName) type {
            return @FieldType(Struct, @tagName(name));
        }
    };
}

pub fn main() void {
    comptime var reader: StructReader(extern struct {
        a: u32,
        b: u32,
        size: usize,
    }) = .{};
    const buffer: [16]u8 = .{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 };
    const ptr: [*]const u8 = &buffer;
    const result_a = reader.field(.a).read(ptr);
    const result_b = reader.field(.b).read(ptr);
    std.debug.print("a = {x}, b = {x}\n", .{ result_a, result_b });
}
a = 4030201, b = 8070605

If you change .b in the line for const result_b to .size, you’d get the compile error “.b is expected, but received .size”.

comptime_int is used for the index so that the compiler would tell you reader needs to be comptime when you forget.

1 Like

Is there a strong reason to roll your own signature format instead of using a struct type?

This is actually isn’t my own signature format, it’s the DBUS type signature format. All DBUS messages include this type signature when they are sent over the wire and we actually verify the signature is what we expect at runtime. So, the signature is a given, the only question then is, how do we ensure our code remains in agreement with the signature? My example above shows one way you can do this.

This question also brings up an interesting aspect. Note that I ended up using a string literal to represent this “tree” of composable types rather than an actual tree of type nodes, something like this:

const Type = union(enum) {
    u32,
    string,
    array: *Type,
};
const string: Type = .string;
const string_array: Type = .{ .array = &string_type };
const array_of_string_arrays: Type = .{ .array = &string_array };

It’s a bit harder to follow because now we need indirection in order to compose types. With string literals you can compose all you want without needing indirection, the last 3 lines of the example above is equivalent to a(as). You could make an argument that it might not be worth it to create custom syntax to represent a tree of types but in this case, DBUS already did that for us :slight_smile:

2 Likes

You want defer comptime ... instead

1 Like

Try it, it doesn’t work, it seems to run too early, before the .next() calls are processed and thus you get the compile error about the signature not being done.

I think defer comptime expr basically means immediately evaluate expr at comptime and then defer the resulting expr, but for this example to work you need to defer the comptime expr evaluation itself to the end of the function. I guess you could turn it into a runtime error, but that would be sad when you can get a compile error by manually writing it at the end.

2 Likes

Ah I understood what you were after now. Yeah I don’t think that’s possible in zig currently.

1 Like

I think defer comptime expr basically means immediately evaluate expr at comptime and then defer the resulting expr, but for this example to work you need to defer the comptime expr evaluation itself to the end of the function.

The actual problem is that the deferred expression gets emitted at the function’s end, as well as on each of the try expressions (because try is a conditional return in disguise). If you remove the error handling, it compiles:

fn readU(r: *Reader) u32 {
    return r.takeInt(u32, .big) catch @panic("PANIK");
}
fn readS(r: *Reader) []const u8 {
    return r.take(11) catch @panic("PANIK");
}

pub const SignatureReaderEnumerator = struct {
    signature: []const u8,
    index: usize,

    pub fn init(comptime signature: []const u8) SignatureReaderEnumerator {
        return .{ .signature = signature, .index = 0 };
    }

    pub fn next(comptime self: *SignatureReaderEnumerator) self.Impl() {
        defer self.index += 1;
        return self.impl();
    }

    pub fn finish(comptime self: SignatureReaderEnumerator) void {
        if (self.index != self.signature.len) @compileError("the remaining signature has not been read: " ++ self.signature[self.index..]);
    }

    fn Impl(comptime self: SignatureReaderEnumerator) type {
        if (self.index >= self.signature.len) @compileError("signature has no more types");
        return switch (self.signature[self.index]) {
            'u' => fn (*Reader) u32,
            's' => fn (*Reader) []const u8,
            else => @compileError("unknown signature char: '" ++ self.signature[self.index .. self.index + 1] ++ "'"),
        };
    }

    fn impl(comptime self: SignatureReaderEnumerator) self.Impl() {
        return switch (self.signature[self.index]) {
            'u' => readU,
            's' => readS,
            else => unreachable,
        };
    }
};

pub fn main() !void {
    var r: Reader = .fixed("\x12\x34\x56\x78" ++ "\x9a\xbc\xde\xf0" ++ "hello there");
    try example(&r);
}
fn example(r: *Reader) !void {
    // This is an example signature that represents two u32 values and a string.
    // This API enforces that's always what's read in that order at compile time.
    comptime var signature_reader: SignatureReaderEnumerator = .init("uus");
    defer comptime signature_reader.finish();

    // If you comment out this read (or the other reads below), then you'll get
    // a compile error.
    const first_u32: u32 = signature_reader.next()(r);
    std.debug.assert(first_u32 == 0x12345678);

    const second_u32: u32 = signature_reader.next()(r);
    std.debug.assert(second_u32 == 0x9abcdef0);

    const string: []const u8 = signature_reader.next()(r);
    std.debug.assert(std.mem.eql(u8, string, "hello there"));
}
const std = @import("std");
const Reader = std.Io.Reader;

But you shouldn’t actually do that – when you think about it, putting signature_reader.finish() at the end of the function is the correct logic, because you only care about the signature being handled correctly when the reads were actually successful.

(Edit: typo)

In my mind the hypothetical comptime defer would run when the comptime interpreter is done with the codegen of the body, so it would observe the mutation of the comptime vars, but I haven’t thought deeply about it so there probably would be good arguments against it. There is no way to use try at comptime (as far as I know) so a comptime defer would not be influenced by runtime control flow, because you are deferring within the comptime control flow of the function, but because that could be confusing it is probably one of the reasons not to have it.

1 Like

To me that sounds exactly equivalent to putting the statement at the end of the block. Also, a potential footgun – I can definitely see myself accidentally writing defer comptime instead of comptime defer and/or vice versa.

You can always use the signature to double-check the struct type. My issue with using it to enforce order is that it doesn’t yield meaningful results. So the functions are called in the right order. Does that mean the fields are retrieved in the right order? Nope.

Yes, this looks like Indexed Monad.

The usage here is closer to: Atkey Indexed Monad


class IMonad m where
  ireturn  ::  a -> m i i a
  ibind    ::  m i j a -> (a -> m j k b) -> m i k b

This is more suitable for sequential program structures.

If you want to pass type information in branches, you need a stronger moand: Mcbride Indexed Monad


type a ~> b = forall i. a i -> b i 

class IMonad m where
  ireturn :: a ~> m a
  ibind :: (a ~> m b) -> (m a ~> m b)

4 Likes