Can I get your opinion on this approach for generic matrices?

Hello, I want to build a small matrix type that can operator on arbitrary types. However, one problem that I have is that different types may not provide the same methods. For example, primitive types provide access to the + operator while a custom type may instead provide an add method.

I don’t want to bloat my matrix implementation with a bunch of comptime check to see which methods are available for a given type (if that is even possible).

So here is my solution for that. My Matrix type requires to specify both the item type and a second type which I call Scalar that provide a table of the required operations.

What do you think of this approach? Do you know a better way to achieve this? Also, do you see a runtime drawback of my approach?

Here is a small example

const std = @import("std");

// Provide the Scalar interface for primitive types like i32, f32, etc
pub fn PrimitiveScalar(comptime T: type) type {
    return struct {
        pub const zero: T = 0;
        pub const one: T = 1;

        pub fn isZero(value: T) bool {
            return value == 0;
        }

        pub fn isOne(value: T) bool {
            return value == 1;
        }

        pub fn add(lhs: T, rhs: T) T {
            return lhs + rhs;
        }
    };
}

pub const Bit = enum { zero, one };

// Implement the Scalar interface for the custom type Bit.
pub const BitScalar = struct {
    pub const zero = Bit.zero;
    pub const one = Bit.one;

    pub fn isZero(bit: Bit) bool {
        return std.meta.eql(bit, .zero);
    }

    pub fn isOne(bit: Bit) bool {
        return std.meta.eql(bit, .one);
    }

    pub fn add(lhs: Bit, rhs: Bit) Bit {
        return @as(Bit, @as(u2, lhs) ^ @as(u2, rhs));
    }
};

pub fn SquareMatrix(comptime T: type, comptime S: type, comptime size: usize) type {
    return struct {
        const Self = @This();
        const Items = [size * size]T;

        items: Items,

        pub fn initIdentity() Self {
            var items: Items = undefined;
            for (&items) |*item| item.* = S.zero;
            var i: usize = 0;
            while (i < size) : (i += 1) {
                items[i + i * size] = S.one;
            }
            return .{ .items = items };
        }

        pub fn isIdentity(self: Self) bool {
            var i: usize = 0;
            while (i < size) : (i += 1) {
                var j: usize = 0;
                while (j < size) : (j += 1) {
                    const item = self.items[i * size + j];
                    if (i == j and !S.isOne(item)) {
                        return false;
                    }
                    if (i != j and !S.isZero(item)) {
                        return false;
                    }
                }
            }
            return true;
        }

        pub fn add(lhs: Self, rhs: Self) Self {
            var sum: Items = undefined;
            for (&sum, lhs.items, rhs.items) |*s, l, r| {
                s.* = S.add(l, r);
            }
            return .{ .items = sum };
        }
    };
}

test "SquareMatrix(i32).initIdentity" {
    const matrix = SquareMatrix(i32, PrimitiveScalar(i32), 2).initIdentity();
    const expected = [4]i32{ 1, 0, 0, 1 };
    try std.testing.expectEqualSlices(i32, &expected, &matrix.items);
}

test "SquareMatrix(f32).add" {
    const S = PrimitiveScalar(f32);
    const matrix1 = SquareMatrix(f32, S, 2).initIdentity();
    const matrix2 = SquareMatrix(f32, S, 2).initIdentity();
    const expected = [4]f32{ 2, 0, 0, 2 };
    try std.testing.expectEqualSlices(f32, &expected, &matrix1.add(matrix2).items);
}

test "SquareMatrix(Bit).isIdentity" {
    const matrix = SquareMatrix(Bit, BitScalar, 2).initIdentity();
    try std.testing.expect(matrix.isIdentity());
}
1 Like

Overall I think this is a good approach. There are a few things that I would change:

First of all there is no reason why Bit and BitScalar need to be separate types. I personally prefer it when functions are bundled together with the type and this also allows to call them as member functions:

const Bit = enum {
    zero,
    one,

    pub fn add(lhs: Bit, rhs: Bit) Bit {...} // You can put function declarations into enums
    ...
};
...
var bit1, var bit2 = ...;
bit1 = bit1.add(bit2); // Then you can call them as member functions
...
SquareMatrix(Bit, Bit, 2); // Now creating the matrix gets a bit redundant though

Secondly I would add a bit of comptime logic to avoid the redundant parameter:

pub fn isPrimitive(comptime T: type) bool {
    return switch (@typeInfo(T)) {
        .Struct, .Enum, .Union, .Opaque => false,
        else => true,
    };
}

pub fn SquareMatrix(comptime T: type, comptime size: usize) type {
    const S = if(isPrimitive(T)) PrimitiveScalar(T) else T;
    return struct {...}; // The implementation remains unchanged.
}

SIMD optimization can get severely restricted by doing this. You can see some SIMD examples in a library I started a while ago and haven’t done much with as of late… ZEIN/src/tensor_ops.zig at main · andrewCodeDev/ZEIN · GitHub … just search for the word “SIMD”.

For fused matrix multiplication that leverages SIMD, check out this example from @cgbur on Llama2.zig: Inference Llama2 in one file of pure Zig and a reflection on my first Zig project

The point being… if you want to go fast with anything tensor/vector/matrix based, you want to be friendly with SIMD or write custom kernels. Try not to invalidate those options because they are your best friend for fast data processing.

3 Likes

This is cool, it reminds me of my own project I’m working on currently (mostly unuseable for now). You’re basically returning a namespace with your functions, which is very convenient and readable.

I can’t not tell how I would do (as I probably will at some point) implement this in my project:

// my module
const interfacil = @import("interfacil");

// not "Scalar" because of naming conventions, but it's the same thing
pub fn Scalable(
    // the contractor is the type responsible for providing the functions
    comptime Contractor: type,
    // the clauses is an anonymous struct that'll contain the provided functions
    comptime clauses: anytype,
) type {
    // the contract is just a convenient interface for retrieving the functions from the clauses
    const contract = interfacil.contracts.Contract(Contractor, clauses);

    // this returns a type, but it's effectively a namespace
    return struct {
        // This will lets us use convenient functions for comparing the instances of `Contractor`
        pub usingnamespace interfacil.comparison.Equivalent(Contractor, .{});

        // the contract  can require fields from the clauses
        pub const zero = contract.require(.zero, Contractor);
        pub const one = contract.require(.one, Contractor);

        // but it can also use defaults
        pub const isZero = contract.default(.isZero, defaultIsZero);
        fn defaultIsZero(self: Contractor) bool {
            // this is from the `Equivalent` interface ;)
            return eq(self, zero);
        }

        pub const add = contract.require(.add, fn (Contractor, Contractor) Contractor);
        pub const sub = contract.require(.sub, fn (Contractor, Contractor) Contractor);
        pub const mul = contract.require(.mul, fn (Contractor, Contractor) Contractor);
        pub const div = contract.require(.div, fn (Contractor, Contractor) ?Contractor);

        // You can also generate other functions now
        pub fn addAll(self: Contractor, others: []const Contractor) Contractor {
            var result = self;
            for (others) |other| result = add(self, other);
            return result;
        }

        /// and so on ...
    };
}

// Now when you define a type,  you can make it a scalar by using this namespace
const MyScalar = struct {
    ...
    pub usingnamespace Scalable(MyScalar .{
        // the required clauses must be given
        .zero = ...,
        .one = ...,
        // but for the others do what you want
        ....
     });
};

// You can now use them like methods
const two = MyScalar.one.add(MyScalar.one);
const sixteen = two.mulAll(&[_]MyScalar{two, two, two});