while I was learning about neural networks and the ways to represent weights and biases, matrices are usually used, but one word kept popping up “tensor”.
I looked it up, and it’s an object that generalizes over scalars, vectors, and matrices into higher dimentions. I noticed that the “shape” of a tensor, in how it’s used, is fixed and known a perfect candidate for Zig’s comptime, so I spent the next few hours making this
const std = @import("std");
const Allocator = std.mem.Allocator;
const view = struct {
fn Type(T: type, comptime d: usize) type {
return @Pointer(switch (d) {
0 => .one,
else => .slice,
}, .{}, switch (d) {
0, 1 => T,
else => Type(T, d - 1),
}, null);
}
fn build(
gpa: Allocator,
comptime shape: []const usize,
T: type,
elem: type,
comptime idx: usize,
data: []elem,
) !T {
if (idx + 1 == shape.len) return data;
const info = @typeInfo(T).pointer;
const child_len = data.len / shape[idx];
const res = try gpa.alloc(info.child, shape[idx]);
errdefer gpa.free(res);
for (res, 0..shape[idx]) |*r, i| r.* = try build(
gpa,
shape,
info.child,
elem,
idx + 1,
data[i * child_len .. (i + 1) * child_len],
);
return res;
}
fn free(gpa: Allocator, comptime dims: usize, comptime idx: usize, T: type, data: T) void {
const info = @typeInfo(T);
switch (info) {
.pointer => |p| if (p.size == .slice) for (data) |d| free(gpa, dims, idx + 1, p.child, d),
else => return,
}
if (idx + 1 < dims) gpa.free(data);
}
};
fn Tensor(comptime shape: []const usize, elem: type) type {
return struct {
const dims = shape.len;
const ViewType = view.Type(elem, dims);
const len: usize = blk: {
var prod = 1;
for (shape) |s| prod *= s;
break :blk prod;
};
_data: []elem,
view: ViewType,
fn init(gpa: Allocator) !@This() {
const data = try gpa.alloc(elem, len);
const v = if (dims == 0)
&data[0]
else
try view.build(gpa, shape, ViewType, elem, 0, data);
return .{
._data = data,
.view = v,
};
}
fn deinit(self: @This(), gpa: Allocator) void {
view.free(gpa, dims, 0, ViewType, self.view);
gpa.free(self._data);
}
};
}
I did it mostly cuz I wanted to deepen my understanding of how zig’s comptime and memory works.
if it isn’t clear i’m new to zig (it’s my second after C) and programming in general, so i would appreciate any comments on the quality of the code.
when I finished, imagine my surprise when I found a much easier/more efficient way using strides.