Support associate type for dynamic polymorphism in Zig

Normally, when people write a vtable, they cannot support function like compare, because compare is *const fn (*const anyopaque, *const anyopque) bool, there are two anyopaque and normal vtable can only support one anyopaque.

I have studied many different language to support such usage, and I found Zig’s opaque type is very suitable for that.

Here is an example way to use two or more anyopaque.

const std = @import("std");

const Comparable = struct {
    const Value = opaque {};

    value: *const Value,
    compare: *const fn (*const Value, *const Value) bool,

    fn makeCompareAdapter(
        comptime U: type,
        comptime Cmp: *const fn (*const U, *const U) bool,
    ) *const fn (*const Value, *const Value) bool {
        return &struct {
            fn call(a: *const Value, b: *const Value) bool {
                const pa: *const U = @ptrCast(@alignCast(a));
                const pb: *const U = @ptrCast(@alignCast(b));
                return Cmp(pa, pb);
            }
        }.call;
    }

    pub fn init(
        comptime U: type,
        value_ptr: *const U,
        comptime Cmp: *const fn (*const U, *const U) bool,
    ) Comparable {
        return .{
            .value = @ptrCast(value_ptr),
            .compare = makeCompareAdapter(U, Cmp),
        };
    }

    pub fn callFunctor(self: *const Comparable, functor: anytype) bool {
        return functor.useComparable(Value, self.value, self.compare);
    }
};

const SelfCompareFunctor = struct {
    pub fn useComparable(
        self: SelfCompareFunctor,
        comptime T: type,
        value: *const T,
        compare: *const fn (*const T, *const T) bool,
    ) bool {
        _ = self;
        return compare(value, value);
    }
};

fn compareInt(a: *const i32, b: *const i32) bool {
    return a.* < b.*;
}

fn compareDouble(a: *const f64, b: *const f64) bool {
    return a.* < b.*;
}

fn compareStr(a: *const []const u8, b: *const []const u8) bool {
    return std.mem.lessThan(u8, a.*, b.*);
}

pub fn main() !void {
    var i: i32 = 1;
    var d: f64 = 3.14;
    const s: []const u8 = "abc";

    const comp_int = Comparable.init(i32, &i, compareInt);
    const comp_double = Comparable.init(f64, &d, compareDouble);
    const comp_string = Comparable.init([]const u8, &s, compareStr);

    const comps = [_]Comparable{ comp_int, comp_double, comp_string };

    for (comps, 0..) |c, idx| {
        const result = c.callFunctor(SelfCompareFunctor{});
        std.debug.print("comps[{d}]: compare(self, self) = {}\n", .{ idx, result });
    }
}

In the example, if we only create Comparable by init and use Comparable by callFunctor , the type is always safety. (we can ensure that by set .value and .compare is private, but to make the example simple, we don’t do that.) And it’s really dynamic dispatch because we create comps which contains Comparable with different type. Also, here is a more complexity case:


const std = @import("std");

const IntArray = struct {
    data: []i32,
};

fn int_array_get_size(arr: *const IntArray) usize {
    return arr.data.len;
}

fn int_array_get_elem(arr: *IntArray, index: usize) *i32 {
    return &arr.data[index];
}

fn int_swap(a: *i32, b: *i32) void {
    const tmp = a.*;
    a.* = b.*;
    b.* = tmp;
}

fn int_compare_asc(a: *const i32, b: *const i32) bool {
    return a.* > b.*;
}

fn int_print(e: *const i32) void {
    std.debug.print("{d}", .{ e.* });
}

const Sortable = struct {
    const Container = opaque {};
    const Value = opaque {};

    container: *Container,
    get_size: *const fn (*const Container) usize,
    get_elem: *const fn (*Container, usize) *Value,
    swap: *const fn (*Value, *Value) void,
    compare: *const fn (*const Value, *const Value) bool,
    print_value: *const fn (*const Value) void,
    type_name: []const u8,

    fn makeGetSizeAdapter(
        comptime C: type,
        comptime F: *const fn (*const C) usize,
    ) *const fn (*const Container) usize {
        return &struct {
            fn call(c: *const Container) usize {
                const pc: *const C = @ptrCast(@alignCast(c));
                return F(pc);
            }
        }.call;
    }

    fn makeGetElemAdapter(
        comptime C: type,
        comptime E: type,
        comptime F: *const fn (*C, usize) *E,
    ) *const fn (*Container, usize) *Value {
        return &struct {
            fn call(c: *Container, index: usize) *Value {
                const pc: *C = @ptrCast(@alignCast(c));
                const pe: *E = F(pc, index);
                return @ptrCast(pe);
            }
        }.call;
    }

    fn makeSwapAdapter(
        comptime E: type,
        comptime F: *const fn (*E, *E) void,
    ) *const fn (*Value, *Value) void {
        return &struct {
            fn call(a: *Value, b: *Value) void {
                const pa: *E = @ptrCast(@alignCast(a));
                const pb: *E = @ptrCast(@alignCast(b));
                F(pa, pb);
            }
        }.call;
    }

    fn makeCompareAdapter(
        comptime E: type,
        comptime F: *const fn (*const E, *const E) bool,
    ) *const fn (*const Value, *const Value) bool {
        return &struct {
            fn call(a: *const Value, b: *const Value) bool {
                const pa: *const E = @ptrCast(@alignCast(a));
                const pb: *const E = @ptrCast(@alignCast(b));
                return F(pa, pb);
            }
        }.call;
    }

    fn makePrintAdapter(
        comptime E: type,
        comptime F: *const fn (*const E) void,
    ) *const fn (*const Value) void {
        return &struct {
            fn call(e: *const Value) void {
                const pe: *const E = @ptrCast(@alignCast(e));
                F(pe);
            }
        }.call;
    }

    pub fn callFunctor(self: *const Sortable, functor: anytype) void {
        _ = functor.useSortable(
            Container, Value,
            self.container,
            self.get_size,
            self.get_elem,
            self.swap,
            self.compare,
            self.print_value,
            self.type_name,
        );
    }

    pub fn init(
        comptime C: type,
        comptime E: type,
        container_ptr: *C,
        comptime get_size_fn: *const fn (*const C) usize,
        comptime get_elem_fn: *const fn (*C, usize) *E,
        comptime swap_fn: *const fn (*E, *E) void,
        comptime compare_fn: *const fn (*const E, *const E) bool,
        comptime print_fn: *const fn (*const E) void,
        comptime type_name: []const u8,
    ) Sortable {
        return .{
            .container = @ptrCast(container_ptr),
            .get_size = makeGetSizeAdapter(C, get_size_fn),
            .get_elem = makeGetElemAdapter(C, E, get_elem_fn),
            .swap = makeSwapAdapter(E, swap_fn),
            .compare = makeCompareAdapter(E, compare_fn),
            .print_value = makePrintAdapter(E, print_fn),
            .type_name = type_name,
        };
    }
};

const PrintAllFunctor = struct {
    pub fn useSortable(
        self: @This(),
        comptime Cont: type,
        comptime Val: type,
        container: *Cont,
        get_size: *const fn (*const Cont) usize,
        get_elem: *const fn (*Cont, usize) *Val,
        swap: *const fn (*Val, *Val) void,
        compare: *const fn (*const Val, *const Val) bool,
        print_value: *const fn (*const Val) void,
        type_name: []const u8,
    ) void {
        _ = self;
        _ = swap;
        _ = compare;

        const n = get_size(container);
        std.debug.print("[{s}] = [", .{ type_name });
        var i: usize = 0;
        while (i < n) : (i += 1) {
            if (i > 0) std.debug.print(", ", .{});
            const e = get_elem(container, i);
            print_value(e);
        }
        std.debug.print("]\n", .{});
    }
};

const BubbleSortFunctor = struct {
    pub fn useSortable(
        self: @This(),
        comptime Cont: type,
        comptime Val: type,
        container: *Cont,
        get_size: *const fn (*const Cont) usize,
        get_elem: *const fn (*Cont, usize) *Val,
        swap: *const fn (*Val, *Val) void,
        compare: *const fn (*const Val, *const Val) bool,
        print_value: *const fn (*const Val) void,
        type_name: []const u8,
    ) void {
        _ = self;
        _ = print_value;
        _ = type_name;

        const n = get_size(container);
        if (n <= 1) return;

        var i: usize = 0;
        while (i < n) : (i += 1) {
            var j: usize = 0;
            while (j + 1 < n - i) : (j += 1) {
                const a = get_elem(container, j);
                const b = get_elem(container, j + 1);
                if (compare(a, b)) {
                    swap(a, b);
                }
            }
        }
    }
};

pub fn main() !void {
    var int_data = [_]i32{ 5, 2, 8, 1, 9, 3 };
    var int_array = IntArray{ .data = int_data[0..] };

    var s_int = Sortable.init(
        IntArray, i32,
        &int_array,
        int_array_get_size,
        int_array_get_elem,
        int_swap,
        int_compare_asc,
        int_print,
        "int",
    );

    s_int.callFunctor(PrintAllFunctor{});
    s_int.callFunctor(BubbleSortFunctor{});
    s_int.callFunctor(PrintAllFunctor{});
}

1 Like

This pattern is called type erasure. In the C++ community there are a lot of great videos about it, but the idea is really simple. Like the name of that famous book says, algorithms + data structures = programs. In other words, any computation can be performed with some data and a recipe for what to do with that data, which is a function.
Zig uses this everywhere. All those interfaces are type erasures.
The name “type erasure” comes from the fact that you erase all type information. The object becomes just a bundle of data, tipically in the form of a pointer, and a bundle of functions, tipically in the form of a vtable.

1 Like

It’s not just type erasure, it type erasure that support associate type. Just as I mentioned above, c++ cannot support compare as virtual function. When compare is virtual function, it’s not type safe at all.

So does rust, which has object safety (dyn compatible) limitation.

(What’s more, because c++ has virtual function already, even though people can write same code in c++, people never do that.)