Simple ECS implementation inspired by Zig's MultiArrayList

Thought I’d share a little ECS implementation I wrote for a game I’m working on. It’s inspired by Zig’s MultiArrayList although it’s not based on using struct fields but rather just a list of components. Component access is conditioned on a “valid entity” bit and “entity type mask” so that you can a) keep a list of allocated entities in the world and b) that you can check what component are set for an entity.

I’m sharing this because I “Godbolted” the implementation and I’m really happy how clean and fast looking machine code is generated from the Zig code.

The API is basically:

const max_entities = 1000;
// Create an ECS state container, allocate memory for up to 1000 entities
var ecs = try Ecs(&[_]type{ Position, Velocity }).init(allocator, max_entities);
defer ecs.deinit();
// Set "Position" component for entity id = 0
ecs.set(0, Position{ .pos = .{ 1, 3 } });
// Get "Position" component for entity id = 0
var p0 = a.get(Position, 0);
try std.testing.expect(p0.?.pos[0] == 1 and p0.?.pos[1] == 3);

ecs.removeEntity(0);
try std.testing.expect(ecs.get(Position, 0) == null);

Compiling to ReleaseFast produces pretty clean & fast assembly for the setters and getters:

const Position = struct { pos: @Vector(2, f32) };
const Age = packed struct { age: u16 };
const AgePosEcs = Ecs(&[_]type{ Age, Position });

export fn setTest(ecs: *AgePosEcs, idx: usize, p: @Vector(2, f32)) void {
    ecs.set(idx, Position{ .pos = p });
    ecs.set(idx, Age{ .age = 15 });
}

export fn getTest(ecs: *AgePosEcs, idx: usize) ?*Age {
    return ecs.get(Age, idx);
}

export fn iterateEntities(ecs: *AgePosEcs, out: [*]u32) void {
    var it = ecs.queryEntities(&[_]type{Position});
    var idx: usize = 0;
    while (it.next()) |c| {
        out[idx] = c;
    }
}

Here’s how the assembly looks like for setTest/getTest/iterateEntities looks like (full Compiler Explorer link:

setTest:
        mov     rax, qword ptr [rdi + 72]
        mov     rcx, rsi
        shr     rcx, 6
        mov     edx, 1
        shlx    rdx, rdx, rsi
        or      qword ptr [rax + 8*rcx], rdx
        mov     rax, qword ptr [rdi + 80]
        or      dword ptr [rax + 4*rsi], 2
        mov     rax, qword ptr [rdi + 40]
        vmovlps qword ptr [rax + 8*rsi], xmm0
        mov     rax, qword ptr [rdi + 72]
        or      qword ptr [rax + 8*rcx], rdx
        mov     rax, qword ptr [rdi + 80]
        or      dword ptr [rax + 4*rsi], 1 ; this is actually redundant, it could be just one bitwise-or of 3 to the tags array
        mov     rax, qword ptr [rdi + 32]
        mov     word ptr [rax + 2*rsi], 15
        ret

getTest:
        mov     rcx, qword ptr [rdi + 72]
        mov     rdx, rsi
        shr     rdx, 6
        mov     rcx, qword ptr [rcx + 8*rdx]
        bt      rcx, rsi
        jae     .LBB1_3
        mov     rax, rsi
        mov     rcx, qword ptr [rdi + 80]
        test    byte ptr [rcx + 4*rsi], 1
        jne     .LBB1_4
.LBB1_3: ; return null if there's no entity for this id
        xor     eax, eax
        ret
.LBB1_4:
        add     rax, rax ; return pointer to the "Age" object (e.g., just *u16)
        add     rax, qword ptr [rdi + 32]
        ret

; loop's through all valid entities and returns entities that have all the 
; desired components set
iterateEntities:
        mov     rax, qword ptr [rdi + 64]
        test    rax, rax
        je      .LBB2_1
        mov     rcx, qword ptr [rdi + 72]
        add     rax, 63
        shr     rax, 6
        mov     rdx, qword ptr [rcx]
        add     rcx, 8
        dec     rax
        jmp     .LBB2_3
.LBB2_1:
        movabs  rcx, -6148914691236517206
        xor     eax, eax
        xor     edx, edx
.LBB2_3:
        mov     r8, qword ptr [rdi + 80]
        xor     edi, edi
.LBB2_4:
        test    rdx, rdx
        jne     .LBB2_7
        sub     rax, 1
        jb      .LBB2_9
        mov     rdx, qword ptr [rcx]
        add     rcx, 8
        add     rdi, 64
        jmp     .LBB2_4
.LBB2_7:
        tzcnt   r9, rdx
        add     r9, rdi
        blsr    rdx, rdx
        test    byte ptr [r8 + 4*r9], 2
        je      .LBB2_4
        mov     dword ptr [rsi], r9d
        jmp     .LBB2_4
.LBB2_9:
        ret

Here’s the full implementation. Use it if you like… or suggest improvements if something can be done better (I started writing Zig just a couple of weeks ago) :slight_smile:

const std = @import("std");
const Allocator = std.mem.Allocator;

pub fn Ecs(comptime Components: []const type) type {
    return struct {
        const Self = @This();

        pub const QueryEntitiesIterator = struct {
            iter_idx: usize,
            valid_it: std.bit_set.DynamicBitSet.Iterator(.{}),
            tag_mask: u32,
            tags: []u32,

            pub fn next(self: *@This()) ?u32 {
                const tag_mask = self.tag_mask;
                while (self.valid_it.next()) |idx| {
                    if (self.tags[idx] & tag_mask == tag_mask) {
                        return @intCast(u32, idx);
                    }
                }
                return null;
            }
        };

        // Use the largest align of all Components for all.
        // component arrays
        const max_alignment = blk: {
            var a = 0;
            for (Components) |c| {
                a = std.math.max(a, @alignOf(c));
            }
            break :blk a;
        };

        fn findSlot(comptime C: type) usize {
            inline for (Components, 0..) |c, i| {
                if (c == C) {
                    return i;
                }
            }
            @compileError("unknown component type: " ++ @typeName(C));
        }

        fn compTagBits(comptime C: type) u32 {
            inline for (Components, 0..) |c, i| {
                if (c == C) {
                    return 1 << i;
                }
            }
            @compileError("unhandled component type " ++ @typeName(C));
        }

        allocator: Allocator,
        bytes: [*]align(max_alignment) u8 = undefined,
        len: usize = 0,
        soa_ptrs: [Components.len][*]align(max_alignment) u8,

        // validity and component masks
        valid: std.DynamicBitSet,
        tags: []u32,

        // `sizes.bytes` is an array of @sizeOf each S field. Sorted by alignment, descending.
        // `sizes.comps` is an array mapping from `sizes.bytes` array index to component index.
        const sizes = blk: {
            const Data = struct {
                size: usize,
                size_index: usize,
                alignment: usize,
            };
            var data: [Components.len]Data = undefined;
            for (Components, 0..) |comp, i| {
                data[i] = .{
                    .size = @sizeOf(comp),
                    .size_index = i,
                    .alignment = @alignOf(comp),
                };
            }
            const Sort = struct {
                fn lessThan(context: void, lhs: Data, rhs: Data) bool {
                    _ = context;
                    return lhs.alignment > rhs.alignment;
                }
            };
            std.sort.sort(Data, &data, {}, Sort.lessThan);
            var sizes_bytes: [Components.len]usize = undefined;
            var comp_indexes: [Components.len]usize = undefined;
            for (data, 0..) |elem, i| {
                sizes_bytes[i] = elem.size;
                comp_indexes[i] = elem.size_index;
            }
            break :blk .{
                .bytes = sizes_bytes,
                .comps = comp_indexes,
            };
        };

        fn capacityInBytes(capacity: usize) usize {
            comptime var elem_bytes: usize = 0;
            inline for (sizes.bytes) |size| elem_bytes += size;
            return elem_bytes * capacity;
        }

        fn allocatedBytes(self: Self) []align(max_alignment) u8 {
            return self.bytes[0..capacityInBytes(self.len)];
        }

        pub fn init(allocator: Allocator, len: usize) !Self {
            var mem = try allocator.alignedAlloc(u8, max_alignment, capacityInBytes(len));

            var ptr: [*]u8 = mem.ptr;
            var soa_ptrs: [Components.len][*]align(max_alignment) u8 = undefined;
            for (sizes.bytes, sizes.comps) |comp_size, i| {
                soa_ptrs[i] = @alignCast(max_alignment, ptr);
                ptr += comp_size * len;
            }

            var valid = try std.DynamicBitSet.initEmpty(allocator, len);
            var tags = try allocator.alloc(u32, len);
            std.mem.set(@TypeOf(tags[0]), tags, 0);

            return .{
                .allocator = allocator,
                .bytes = mem.ptr,
                .len = len,
                .soa_ptrs = soa_ptrs,
                .valid = valid,
                .tags = tags,
            };
        }

        fn items(self: *Self, comptime C: type) []C {
            comptime var comp_idx = findSlot(C);
            var ptr = @ptrCast([*]C, self.soa_ptrs[comp_idx]);
            return ptr[0..self.len];
        }

        pub fn newEntity(self: *Self) u32 {
            var unset_it = self.valid.iterator(.{ .kind = .unset });
            const idx = while (unset_it.next()) |idx| {
                break idx;
            } else {
                @panic("out of entities -- shouldn't get here");
            };
            self.valid.set(idx);
            self.tags[idx] = 0;
            return @intCast(u32, idx);
        }

        pub fn removeEntity(self: *Self, id: u32) void {
            self.valid.unset(id);
            self.tags[id] = 0;
        }

        pub fn get(self: *Self, comptime C: type, idx: usize) ?*C {
            const mask = compTagBits(C);
            if (self.valid.isSet(idx) and (self.tags[idx] & mask) != 0) {
                return &self.items(C)[idx];
            }
            return null;
        }

        pub fn set(self: *Self, idx: usize, c: anytype) void {
            self.valid.set(idx);
            self.tags[idx] |= compTagBits(@TypeOf(c));
            self.items(@TypeOf(c))[idx] = c;
        }

        pub fn queryEntities(self: *Self, comptime Comps: []const type) QueryEntitiesIterator {
            var tag_mask: u32 = 0;
            comptime {
                for (Comps) |c| {
                    tag_mask |= compTagBits(c);
                }
            }
            return .{
                .valid_it = self.valid.iterator(.{}),
                .iter_idx = 0,
                .tag_mask = tag_mask,
                .tags = self.tags,
            };
        }

        pub fn deinit(self: *Self) void {
            self.valid.deinit();
            self.allocator.free(self.tags);
            self.allocator.free(self.allocatedBytes());
        }
    };
}

const Position = struct {
    pos: @Vector(2, f32),
};

const Velocity = struct {
    v: @Vector(2, f32),
};

test "ecs simple" {
    var allocator = std.testing.allocator;
    var a = try Ecs(&[_]type{ Position, Velocity }).init(allocator, 10);
    defer a.deinit();

    a.set(0, Position{ .pos = .{ 1, 3 } });
    a.set(0, Velocity{ .v = .{ 0.5, 0 } });
    var p0 = a.get(Position, 0);
    var v0 = a.get(Velocity, 0);
    try std.testing.expect(p0.?.pos[0] == 1 and p0.?.pos[1] == 3);
    try std.testing.expect(v0.?.v[0] == 0.5 and v0.?.v[1] == 0);
}

test "ecs alignment" {
    const Age = packed struct { age: u16 };
    var allocator = std.testing.allocator;

    var a = try Ecs(&[_]type{ Age, Position }).init(allocator, 3);
    defer a.deinit();
    a.set(0, Position{ .pos = .{ 1, 3 } });
    a.set(0, Age{ .age = 13 });
    var p0 = a.get(Position, 0);
    var a0 = a.get(Age, 0);

    try std.testing.expect(p0.?.pos[0] == 1 and p0.?.pos[1] == 3);
    try std.testing.expect(a0.?.age == 13);
    var a1 = a.get(Age, 1);
    try std.testing.expect(a1 == null);
    var p1 = a.get(Position, 1);
    try std.testing.expect(p1 == null);
    a.set(1, Position{ .pos = .{ 1, 1 } });
    var p11 = a.get(Position, 1);
    try std.testing.expect(p11.?.pos[0] == 1);
    try std.testing.expect(a.get(Age, 1) == null);
}
3 Likes