Hello, I want to build a small matrix type that can operator on arbitrary types. However, one problem that I have is that different types may not provide the same methods. For example, primitive types provide access to the +
operator while a custom type may instead provide an add
method.
I don’t want to bloat my matrix implementation with a bunch of comptime check to see which methods are available for a given type (if that is even possible).
So here is my solution for that. My Matrix
type requires to specify both the item type and a second type which I call Scalar
that provide a table of the required operations.
What do you think of this approach? Do you know a better way to achieve this? Also, do you see a runtime drawback of my approach?
Here is a small example
const std = @import("std");
// Provide the Scalar interface for primitive types like i32, f32, etc
pub fn PrimitiveScalar(comptime T: type) type {
return struct {
pub const zero: T = 0;
pub const one: T = 1;
pub fn isZero(value: T) bool {
return value == 0;
}
pub fn isOne(value: T) bool {
return value == 1;
}
pub fn add(lhs: T, rhs: T) T {
return lhs + rhs;
}
};
}
pub const Bit = enum { zero, one };
// Implement the Scalar interface for the custom type Bit.
pub const BitScalar = struct {
pub const zero = Bit.zero;
pub const one = Bit.one;
pub fn isZero(bit: Bit) bool {
return std.meta.eql(bit, .zero);
}
pub fn isOne(bit: Bit) bool {
return std.meta.eql(bit, .one);
}
pub fn add(lhs: Bit, rhs: Bit) Bit {
return @as(Bit, @as(u2, lhs) ^ @as(u2, rhs));
}
};
pub fn SquareMatrix(comptime T: type, comptime S: type, comptime size: usize) type {
return struct {
const Self = @This();
const Items = [size * size]T;
items: Items,
pub fn initIdentity() Self {
var items: Items = undefined;
for (&items) |*item| item.* = S.zero;
var i: usize = 0;
while (i < size) : (i += 1) {
items[i + i * size] = S.one;
}
return .{ .items = items };
}
pub fn isIdentity(self: Self) bool {
var i: usize = 0;
while (i < size) : (i += 1) {
var j: usize = 0;
while (j < size) : (j += 1) {
const item = self.items[i * size + j];
if (i == j and !S.isOne(item)) {
return false;
}
if (i != j and !S.isZero(item)) {
return false;
}
}
}
return true;
}
pub fn add(lhs: Self, rhs: Self) Self {
var sum: Items = undefined;
for (&sum, lhs.items, rhs.items) |*s, l, r| {
s.* = S.add(l, r);
}
return .{ .items = sum };
}
};
}
test "SquareMatrix(i32).initIdentity" {
const matrix = SquareMatrix(i32, PrimitiveScalar(i32), 2).initIdentity();
const expected = [4]i32{ 1, 0, 0, 1 };
try std.testing.expectEqualSlices(i32, &expected, &matrix.items);
}
test "SquareMatrix(f32).add" {
const S = PrimitiveScalar(f32);
const matrix1 = SquareMatrix(f32, S, 2).initIdentity();
const matrix2 = SquareMatrix(f32, S, 2).initIdentity();
const expected = [4]f32{ 2, 0, 0, 2 };
try std.testing.expectEqualSlices(f32, &expected, &matrix1.add(matrix2).items);
}
test "SquareMatrix(Bit).isIdentity" {
const matrix = SquareMatrix(Bit, BitScalar, 2).initIdentity();
try std.testing.expect(matrix.isIdentity());
}