Comptime tensor

while I was learning about neural networks and the ways to represent weights and biases, matrices are usually used, but one word kept popping up “tensor”.
I looked it up, and it’s an object that generalizes over scalars, vectors, and matrices into higher dimentions. I noticed that the “shape” of a tensor, in how it’s used, is fixed and known a perfect candidate for Zig’s comptime, so I spent the next few hours making this

const std = @import("std");
const Allocator = std.mem.Allocator;

const view = struct {
    fn Type(T: type, comptime d: usize) type {
        return @Pointer(switch (d) {
            0 => .one,
            else => .slice,
        }, .{}, switch (d) {
            0, 1 => T,
            else => Type(T, d - 1),
        }, null);
    }

    fn build(
        gpa: Allocator,
        comptime shape: []const usize,
        T: type,
        elem: type,
        comptime idx: usize,
        data: []elem,
    ) !T {
        if (idx + 1 == shape.len) return data;
        const info = @typeInfo(T).pointer;
        const child_len = data.len / shape[idx];
        const res = try gpa.alloc(info.child, shape[idx]);
        errdefer gpa.free(res);

        for (res, 0..shape[idx]) |*r, i| r.* = try build(
            gpa,
            shape,
            info.child,
            elem,
            idx + 1,
            data[i * child_len .. (i + 1) * child_len],
        );

        return res;
    }

    fn free(gpa: Allocator, comptime dims: usize, comptime idx: usize, T: type, data: T) void {
        const info = @typeInfo(T);

        switch (info) {
            .pointer => |p| if (p.size == .slice) for (data) |d| free(gpa, dims, idx + 1, p.child, d),
            else => return,
        }

        if (idx + 1 < dims) gpa.free(data);
    }
};

fn Tensor(comptime shape: []const usize, elem: type) type {
    return struct {
        const dims = shape.len;
        const ViewType = view.Type(elem, dims);
        const len: usize = blk: {
            var prod = 1;
            for (shape) |s| prod *= s;
            break :blk prod;
        };

        _data: []elem,
        view: ViewType,

        fn init(gpa: Allocator) !@This() {
            const data = try gpa.alloc(elem, len);
            const v = if (dims == 0)
                &data[0]
            else
                try view.build(gpa, shape, ViewType, elem, 0, data);

            return .{
                ._data = data,
                .view = v,
            };
        }

        fn deinit(self: @This(), gpa: Allocator) void {
            view.free(gpa, dims, 0, ViewType, self.view);
            gpa.free(self._data);
        }
    };
}

I did it mostly cuz I wanted to deepen my understanding of how zig’s comptime and memory works.

if it isn’t clear i’m new to zig (it’s my second after C) and programming in general, so i would appreciate any comments on the quality of the code.

when I finished, imagine my surprise when I found a much easier/more efficient way using strides.

4 Likes