@memset beaten

This beats @memset in 0.16.0.
Another vectorization regression thingy?

fn crazy_memset(dst: []u8, value: u8) void {
    const UNROLL = std.simd.suggestVectorLength(u8) orelse @sizeOf(usize);

    var i: usize = 0;
    while (i + UNROLL <= dst.len) : (i += UNROLL) {
        const p: *align(1) @Vector(UNROLL, u8) = @ptrCast(dst[i..]);
        p.* = @splat(value);
    }
    while (i < dst.len) : (i += 1) {
        dst[i] = value;
    }
}
9 Likes

#32091

2 Likes

Maybe I’ll create a PR this this in some time.

1 Like

Looks good to me.
I’ve been trying for the past few hours to beat yours, but failed. My plan was to improve the second part (the byte per byte part), by using intermediary sizes, but the overhead makes it not profitable.
How about a PR?

I wonder if making the simd part memory aligned to a power of two greater or equal to UNROLL would speed it up?
e.g. Another scalar loop at the start, using the minimum of (alignForward of dst.ptr) and dst.len as the bound, then doing the vectorized loop from that aligned address all the way to the last unrollable chunk, then a scalar loop for the remaining part

(this is kinda inspired by what musl does)

One implementation I currently have is this. It is also partly inspired by what musl does. Probably there is still some optimization potential.

For x86, I believe this would be a net loss. There is ample evidence that unaligned loads and stores don’t matter, and doing it this way would result in spending less time inside the most efficient loop, both at the beginning and end. Consider an array that has size n * UNROLL, but is unaligned. Doing it like you suggest, you would need a scalar loop at the beggining and another at end, while doing like @ericlang has shown, the scalar loop is completely avoided.

With function signature fixed:

const std = @import("std");

export fn improved_memset(dest: ?[*]u8, c: u8, len: usize) callconv(.c) ?[*]u8 {
    @setRuntimeSafety(false);

    const n = std.simd.suggestVectorLength(u8) orelse @sizeOf(usize);

    var i: usize = 0;
    while (i + n <= len) : (i += n) {
        const p: *align(1) @Vector(n, u8) = @ptrCast(dest.?[i..]);
        p.* = @splat(c);
    }
    while (i < len) : (i += 1) {
        dest.?[i] = c;
    }

    return dest;
}
0000000000000000 <improved_memset>:
   0: 55                    push   rbp
   1: 48 89 e5              mov    rbp,rsp
   4: 53                    push   rbx
   5: 50                    push   rax
   6: 48 89 fb              mov    rbx,rdi
   9: 48 83 fa 40           cmp    rdx,0x40
   d: 73 04                 jae    13 <improved_memset+0x13>
   f: 31 ff                 xor    edi,edi
  11: eb 24                 jmp    37 <improved_memset+0x37>
  13: 62 f2 7d 48 7a c6     vpbroadcastb zmm0,esi
  19: 31 c0                 xor    eax,eax
  1b: 0f 1f 44 00 00        nop    DWORD PTR [rax+rax*1+0x0]
  20: 62 f1 fe 48 7f 04 03  vmovdqu64 ZMMWORD PTR [rbx+rax*1],zmm0
  27: 48 8d 78 40           lea    rdi,[rax+0x40]
  2b: 48 83 e8 80           sub    rax,0xffffffffffffff80
  2f: 48 39 d0              cmp    rax,rdx
  32: 48 89 f8              mov    rax,rdi
  35: 76 e9                 jbe    20 <improved_memset+0x20>
  37: 48 29 fa              sub    rdx,rdi
  3a: 76 0f                 jbe    4b <improved_memset+0x4b>
  3c: 48 01 df              add    rdi,rbx
  3f: 40 0f b6 f6           movzx  esi,sil
  43: c5 f8 77              vzeroupper
  46: e8 00 00 00 00        call   4b <improved_memset+0x4b>
  4b: 48 89 d8              mov    rax,rbx
  4e: 48 83 c4 08           add    rsp,0x8
  52: 5b                    pop    rbx
  53: 5d                    pop    rbp
  54: c5 f8 77              vzeroupper
  57: c3                    ret

With the illegal @ptrCast fixed:

const std = @import("std");

export fn improved_memset(dest: ?[*]u8, c: u8, len: usize) callconv(.c) ?[*]u8 {
    @setRuntimeSafety(false);

    const n = std.simd.suggestVectorLength(u8) orelse @sizeOf(usize);
    const splatted: @Vector(n, u8) = @splat(c);

    var i: usize = 0;
    while (i + n <= len) : (i += n) {
        dest.?[i..][0..n].* = splatted;
    }
    while (i < len) : (i += 1) {
        dest.?[i] = c;
    }

    return dest;
}

(same machine code, except it won’t miscompile under various conditions, and it won’t become a compile error when we make language changes to vectors)

7 Likes

Duff’s Device:

const std = @import("std");
const assert = std.debug.assert;

export fn improved_memset(dest: ?[*]u8, c: u8, len: usize) callconv(.c) ?[*]u8 {
    @setRuntimeSafety(false);

    const n = std.simd.suggestVectorLength(u8) orelse @sizeOf(usize);
    const splatted: @Vector(n, u8) = @splat(c);

    var i: usize = 0;
    sw: switch (len % n) {
        inline 1...n - 1 => |remainder| {
            dest.?[i] = c;
            i += 1;
            continue :sw remainder - 1;
        },
        0 => {
            assert(i <= len);
            if (i == len) return dest;
            dest.?[i..][0..n].* = splatted;
            i += n;
            continue :sw 0;
        },
        else => unreachable,
    }

    return dest;
}
0000000000000000 <improved_memset>:
   0:	55                   	push   rbp
   1:	48 89 e5             	mov    rbp,rsp
   4:	89 d1                	mov    ecx,edx
   6:	83 e1 3f             	and    ecx,0x3f
   9:	48 89 f8             	mov    rax,rdi
   c:	ff 24 cd 00 00 00 00 	jmp    QWORD PTR [rcx*8+0x0]
  13:	31 c9                	xor    ecx,ecx
  15:	e9 d7 02 00 00       	jmp    2f1 <improved_memset+0x2f1>
  1a:	31 c9                	xor    ecx,ecx
  1c:	e9 52 01 00 00       	jmp    173 <improved_memset+0x173>
  21:	31 c9                	xor    ecx,ecx
  23:	e9 7c 01 00 00       	jmp    1a4 <improved_memset+0x1a4>
  28:	31 c9                	xor    ecx,ecx
  2a:	e9 9f 01 00 00       	jmp    1ce <improved_memset+0x1ce>
  2f:	31 c9                	xor    ecx,ecx
  31:	e9 28 01 00 00       	jmp    15e <improved_memset+0x15e>
  36:	31 c9                	xor    ecx,ecx
  38:	e9 4b 01 00 00       	jmp    188 <improved_memset+0x188>
  3d:	31 c9                	xor    ecx,ecx
  3f:	e9 05 01 00 00       	jmp    149 <improved_memset+0x149>
  44:	31 c9                	xor    ecx,ecx
  46:	e9 91 01 00 00       	jmp    1dc <improved_memset+0x1dc>
  4b:	31 c9                	xor    ecx,ecx
  4d:	e9 83 01 00 00       	jmp    1d5 <improved_memset+0x1d5>
  52:	31 c9                	xor    ecx,ecx
  54:	e9 bb 01 00 00       	jmp    214 <improved_memset+0x214>
  59:	31 c9                	xor    ecx,ecx
  5b:	e9 83 01 00 00       	jmp    1e3 <improved_memset+0x1e3>
  60:	31 c9                	xor    ecx,ecx
  62:	e9 fa 01 00 00       	jmp    261 <improved_memset+0x261>
  67:	31 c9                	xor    ecx,ecx
  69:	e9 2f 01 00 00       	jmp    19d <improved_memset+0x19d>
  6e:	31 c9                	xor    ecx,ecx
  70:	e9 fa 01 00 00       	jmp    26f <improved_memset+0x26f>
  75:	31 c9                	xor    ecx,ecx
  77:	e9 6e 01 00 00       	jmp    1ea <improved_memset+0x1ea>
  7c:	31 c9                	xor    ecx,ecx
  7e:	e9 fe 00 00 00       	jmp    181 <improved_memset+0x181>
  83:	31 c9                	xor    ecx,ecx
  85:	e9 b1 00 00 00       	jmp    13b <improved_memset+0x13b>
  8a:	31 c9                	xor    ecx,ecx
  8c:	e9 f3 01 00 00       	jmp    284 <improved_memset+0x284>
  91:	31 c9                	xor    ecx,ecx
  93:	e9 e2 00 00 00       	jmp    17a <improved_memset+0x17a>
  98:	31 c9                	xor    ecx,ecx
  9a:	e9 16 02 00 00       	jmp    2b5 <improved_memset+0x2b5>
  9f:	31 c9                	xor    ecx,ecx
  a1:	e9 13 01 00 00       	jmp    1b9 <improved_memset+0x1b9>
  a6:	31 c9                	xor    ecx,ecx
  a8:	e9 80 00 00 00       	jmp    12d <improved_memset+0x12d>
  ad:	31 c9                	xor    ecx,ecx
  af:	e9 c9 01 00 00       	jmp    27d <improved_memset+0x27d>
  b4:	31 c9                	xor    ecx,ecx
  b6:	e9 d7 01 00 00       	jmp    292 <improved_memset+0x292>
  bb:	31 c9                	xor    ecx,ecx
  bd:	e9 60 01 00 00       	jmp    222 <improved_memset+0x222>
  c2:	31 c9                	xor    ecx,ecx
  c4:	e9 60 01 00 00       	jmp    229 <improved_memset+0x229>
  c9:	31 c9                	xor    ecx,ecx
  cb:	e9 f3 01 00 00       	jmp    2c3 <improved_memset+0x2c3>
  d0:	31 c9                	xor    ecx,ecx
  d2:	e9 91 01 00 00       	jmp    268 <improved_memset+0x268>
  d7:	40 88 30             	mov    BYTE PTR [rax],sil
  da:	b9 01 00 00 00       	mov    ecx,0x1
  df:	eb 3e                	jmp    11f <improved_memset+0x11f>
  e1:	31 c9                	xor    ecx,ecx
  e3:	e9 d4 01 00 00       	jmp    2bc <improved_memset+0x2bc>
  e8:	31 c9                	xor    ecx,ecx
  ea:	e9 b1 01 00 00       	jmp    2a0 <improved_memset+0x2a0>
  ef:	31 c9                	xor    ecx,ecx
  f1:	e9 d4 01 00 00       	jmp    2ca <improved_memset+0x2ca>
  f6:	31 c9                	xor    ecx,ecx
  f8:	e9 fb 00 00 00       	jmp    1f8 <improved_memset+0x1f8>
  fd:	31 c9                	xor    ecx,ecx
  ff:	e9 56 01 00 00       	jmp    25a <improved_memset+0x25a>
 104:	31 c9                	xor    ecx,ecx
 106:	e9 a7 00 00 00       	jmp    1b2 <improved_memset+0x1b2>
 10b:	31 c9                	xor    ecx,ecx
 10d:	e9 33 01 00 00       	jmp    245 <improved_memset+0x245>
 112:	31 c9                	xor    ecx,ecx
 114:	eb 10                	jmp    126 <improved_memset+0x126>
 116:	31 c9                	xor    ecx,ecx
 118:	e9 91 01 00 00       	jmp    2ae <improved_memset+0x2ae>
 11d:	31 c9                	xor    ecx,ecx
 11f:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 123:	48 ff c1             	inc    rcx
 126:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 12a:	48 ff c1             	inc    rcx
 12d:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 131:	48 ff c1             	inc    rcx
 134:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 138:	48 ff c1             	inc    rcx
 13b:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 13f:	48 ff c1             	inc    rcx
 142:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 146:	48 ff c1             	inc    rcx
 149:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 14d:	48 ff c1             	inc    rcx
 150:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 154:	48 ff c1             	inc    rcx
 157:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 15b:	48 ff c1             	inc    rcx
 15e:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 162:	48 ff c1             	inc    rcx
 165:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 169:	48 ff c1             	inc    rcx
 16c:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 170:	48 ff c1             	inc    rcx
 173:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 177:	48 ff c1             	inc    rcx
 17a:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 17e:	48 ff c1             	inc    rcx
 181:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 185:	48 ff c1             	inc    rcx
 188:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 18c:	48 ff c1             	inc    rcx
 18f:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 193:	48 ff c1             	inc    rcx
 196:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 19a:	48 ff c1             	inc    rcx
 19d:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 1a1:	48 ff c1             	inc    rcx
 1a4:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 1a8:	48 ff c1             	inc    rcx
 1ab:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 1af:	48 ff c1             	inc    rcx
 1b2:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 1b6:	48 ff c1             	inc    rcx
 1b9:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 1bd:	48 ff c1             	inc    rcx
 1c0:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 1c4:	48 ff c1             	inc    rcx
 1c7:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 1cb:	48 ff c1             	inc    rcx
 1ce:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 1d2:	48 ff c1             	inc    rcx
 1d5:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 1d9:	48 ff c1             	inc    rcx
 1dc:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 1e0:	48 ff c1             	inc    rcx
 1e3:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 1e7:	48 ff c1             	inc    rcx
 1ea:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 1ee:	48 ff c1             	inc    rcx
 1f1:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 1f5:	48 ff c1             	inc    rcx
 1f8:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 1fc:	48 ff c1             	inc    rcx
 1ff:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 203:	48 ff c1             	inc    rcx
 206:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 20a:	48 ff c1             	inc    rcx
 20d:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 211:	48 ff c1             	inc    rcx
 214:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 218:	48 ff c1             	inc    rcx
 21b:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 21f:	48 ff c1             	inc    rcx
 222:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 226:	48 ff c1             	inc    rcx
 229:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 22d:	48 ff c1             	inc    rcx
 230:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 234:	48 ff c1             	inc    rcx
 237:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 23b:	48 ff c1             	inc    rcx
 23e:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 242:	48 ff c1             	inc    rcx
 245:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 249:	48 ff c1             	inc    rcx
 24c:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 250:	48 ff c1             	inc    rcx
 253:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 257:	48 ff c1             	inc    rcx
 25a:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 25e:	48 ff c1             	inc    rcx
 261:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 265:	48 ff c1             	inc    rcx
 268:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 26c:	48 ff c1             	inc    rcx
 26f:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 273:	48 ff c1             	inc    rcx
 276:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 27a:	48 ff c1             	inc    rcx
 27d:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 281:	48 ff c1             	inc    rcx
 284:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 288:	48 ff c1             	inc    rcx
 28b:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 28f:	48 ff c1             	inc    rcx
 292:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 296:	48 ff c1             	inc    rcx
 299:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 29d:	48 ff c1             	inc    rcx
 2a0:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 2a4:	48 ff c1             	inc    rcx
 2a7:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 2ab:	48 ff c1             	inc    rcx
 2ae:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 2b2:	48 ff c1             	inc    rcx
 2b5:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 2b9:	48 ff c1             	inc    rcx
 2bc:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 2c0:	48 ff c1             	inc    rcx
 2c3:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 2c7:	48 ff c1             	inc    rcx
 2ca:	40 88 34 08          	mov    BYTE PTR [rax+rcx*1],sil
 2ce:	48 ff c1             	inc    rcx
 2d1:	eb 1e                	jmp    2f1 <improved_memset+0x2f1>
 2d3:	66 66 66 66 2e 0f 1f 	data16 data16 data16 cs nop WORD PTR [rax+rax*1+0x0]
 2da:	84 00 00 00 00 00 
 2e0:	62 f2 7d 48 7a c6    	vpbroadcastb zmm0,esi
 2e6:	62 f1 fe 48 7f 04 08 	vmovdqu64 ZMMWORD PTR [rax+rcx*1],zmm0
 2ed:	48 83 c1 40          	add    rcx,0x40
 2f1:	48 39 d1             	cmp    rcx,rdx
 2f4:	75 ea                	jne    2e0 <improved_memset+0x2e0>
 2f6:	5d                   	pop    rbp
 2f7:	c5 f8 77             	vzeroupper
 2fa:	c3                   	ret
 2fb:	31 c9                	xor    ecx,ecx
 2fd:	e9 3c ff ff ff       	jmp    23e <improved_memset+0x23e>
 302:	31 c9                	xor    ecx,ecx
 304:	e9 6d ff ff ff       	jmp    276 <improved_memset+0x276>
 309:	31 c9                	xor    ecx,ecx
 30b:	eb 9a                	jmp    2a7 <improved_memset+0x2a7>
 30d:	31 c9                	xor    ecx,ecx
 30f:	e9 3f ff ff ff       	jmp    253 <improved_memset+0x253>
 314:	31 c9                	xor    ecx,ecx
 316:	e9 70 ff ff ff       	jmp    28b <improved_memset+0x28b>
 31b:	31 c9                	xor    ecx,ecx
 31d:	e9 77 ff ff ff       	jmp    299 <improved_memset+0x299>
 322:	31 c9                	xor    ecx,ecx
 324:	e9 d6 fe ff ff       	jmp    1ff <improved_memset+0x1ff>
 329:	31 c9                	xor    ecx,ecx
 32b:	e9 c1 fe ff ff       	jmp    1f1 <improved_memset+0x1f1>
 330:	31 c9                	xor    ecx,ecx
 332:	e9 5f fe ff ff       	jmp    196 <improved_memset+0x196>
 337:	31 c9                	xor    ecx,ecx
 339:	e9 f9 fe ff ff       	jmp    237 <improved_memset+0x237>
 33e:	31 c9                	xor    ecx,ecx
 340:	e9 7b fe ff ff       	jmp    1c0 <improved_memset+0x1c0>
 345:	31 c9                	xor    ecx,ecx
 347:	e9 00 ff ff ff       	jmp    24c <improved_memset+0x24c>
 34c:	31 c9                	xor    ecx,ecx
 34e:	e9 ef fd ff ff       	jmp    142 <improved_memset+0x142>
 353:	31 c9                	xor    ecx,ecx
 355:	e9 51 fe ff ff       	jmp    1ab <improved_memset+0x1ab>
 35a:	31 c9                	xor    ecx,ecx
 35c:	e9 d3 fd ff ff       	jmp    134 <improved_memset+0x134>
 361:	31 c9                	xor    ecx,ecx
 363:	e9 c8 fe ff ff       	jmp    230 <improved_memset+0x230>
 368:	31 c9                	xor    ecx,ecx
 36a:	e9 9e fe ff ff       	jmp    20d <improved_memset+0x20d>
 36f:	31 c9                	xor    ecx,ecx
 371:	e9 90 fe ff ff       	jmp    206 <improved_memset+0x206>
 376:	31 c9                	xor    ecx,ecx
 378:	e9 4a fe ff ff       	jmp    1c7 <improved_memset+0x1c7>
 37d:	31 c9                	xor    ecx,ecx
 37f:	e9 97 fe ff ff       	jmp    21b <improved_memset+0x21b>
 384:	31 c9                	xor    ecx,ecx
 386:	e9 cc fd ff ff       	jmp    157 <improved_memset+0x157>
 38b:	31 c9                	xor    ecx,ecx
 38d:	e9 be fd ff ff       	jmp    150 <improved_memset+0x150>
 392:	31 c9                	xor    ecx,ecx
 394:	e9 cc fd ff ff       	jmp    165 <improved_memset+0x165>
 399:	31 c9                	xor    ecx,ecx
 39b:	e9 ef fd ff ff       	jmp    18f <improved_memset+0x18f>
 3a0:	31 c9                	xor    ecx,ecx
 3a2:	e9 c5 fd ff ff       	jmp    16c <improved_memset+0x16c>

Unfortunately it was not able to jump directly from the switch condition into the inc/mov section. That would have eliminated all those pairs of jmp/xor.

Tracking issue:

https://codeberg.org/ziglang/zig/issues/35590

5 Likes

Couldn’t help myself looking at that assembly

6 Likes

I couldn’t get rid of the indirect jump (it really insists on that for a labelled switch), but I got rid of the incs:

export fn improved_memset(dest: ?[*]u8, c: u8, len: usize) callconv(.c) ?[*]u8 {
    @setRuntimeSafety(false);

    const n = std.simd.suggestVectorLength(u8) orelse @sizeOf(usize);
    const splatted: @Vector(n, u8) = @splat(c);

    sw: switch (len % n) {
        inline 1...n - 1 => |rem| {
            dest.?[rem - 1] = c;
            continue :sw rem - 1;
        },
        0 => {},
        else => unreachable,
    }

    var i: usize = len % n;
    while (i + n <= len) : (i += n) {
        dest.?[i..][0..n].* = splatted;
    }

    return dest;
}

asm: Compiler Explorer

Edit again:

I disassembled it locally, and the indirect jump looks gone!
zig build-lib memset2.zig -OReleaseFast && objdump -d -M intel libmemset2.a:

0000000000000000 <improved_memset_nn>:
   0:	55                   	push   rbp
   1:	48 89 e5             	mov    rbp,rsp
   4:	48 89 f8             	mov    rax,rdi
   7:	89 d1                	mov    ecx,edx
   9:	83 e1 1f             	and    ecx,0x1f
   c:	89 cf                	mov    edi,ecx
   e:	ff 24 fd 00 00 00 00 	jmp    QWORD PTR [rdi*8+0x0]
  15:	40 88 30             	mov    BYTE PTR [rax],sil
  18:	40 88 74 08 e2       	mov    BYTE PTR [rax+rcx*1-0x1e],sil
  1d:	40 88 74 08 e3       	mov    BYTE PTR [rax+rcx*1-0x1d],sil
  22:	40 88 74 08 e4       	mov    BYTE PTR [rax+rcx*1-0x1c],sil
  27:	40 88 74 08 e5       	mov    BYTE PTR [rax+rcx*1-0x1b],sil
  2c:	40 88 74 08 e6       	mov    BYTE PTR [rax+rcx*1-0x1a],sil
  31:	40 88 74 08 e7       	mov    BYTE PTR [rax+rcx*1-0x19],sil
  36:	40 88 74 08 e8       	mov    BYTE PTR [rax+rcx*1-0x18],sil
  3b:	40 88 74 08 e9       	mov    BYTE PTR [rax+rcx*1-0x17],sil
  40:	40 88 74 08 ea       	mov    BYTE PTR [rax+rcx*1-0x16],sil
  45:	40 88 74 08 eb       	mov    BYTE PTR [rax+rcx*1-0x15],sil
  4a:	40 88 74 08 ec       	mov    BYTE PTR [rax+rcx*1-0x14],sil
  4f:	40 88 74 08 ed       	mov    BYTE PTR [rax+rcx*1-0x13],sil
  54:	40 88 74 08 ee       	mov    BYTE PTR [rax+rcx*1-0x12],sil
  59:	40 88 74 08 ef       	mov    BYTE PTR [rax+rcx*1-0x11],sil
  5e:	40 88 74 08 f0       	mov    BYTE PTR [rax+rcx*1-0x10],sil
  63:	40 88 74 08 f1       	mov    BYTE PTR [rax+rcx*1-0xf],sil
  68:	40 88 74 08 f2       	mov    BYTE PTR [rax+rcx*1-0xe],sil
  6d:	40 88 74 08 f3       	mov    BYTE PTR [rax+rcx*1-0xd],sil
  72:	40 88 74 08 f4       	mov    BYTE PTR [rax+rcx*1-0xc],sil
  77:	40 88 74 08 f5       	mov    BYTE PTR [rax+rcx*1-0xb],sil
  7c:	40 88 74 08 f6       	mov    BYTE PTR [rax+rcx*1-0xa],sil
  81:	40 88 74 08 f7       	mov    BYTE PTR [rax+rcx*1-0x9],sil
  86:	40 88 74 08 f8       	mov    BYTE PTR [rax+rcx*1-0x8],sil
  8b:	40 88 74 08 f9       	mov    BYTE PTR [rax+rcx*1-0x7],sil
  90:	40 88 74 08 fa       	mov    BYTE PTR [rax+rcx*1-0x6],sil
  95:	40 88 74 08 fb       	mov    BYTE PTR [rax+rcx*1-0x5],sil
  9a:	40 88 74 08 fc       	mov    BYTE PTR [rax+rcx*1-0x4],sil
  9f:	40 88 74 08 fd       	mov    BYTE PTR [rax+rcx*1-0x3],sil
  a4:	40 88 74 08 fe       	mov    BYTE PTR [rax+rcx*1-0x2],sil
  a9:	40 88 74 08 ff       	mov    BYTE PTR [rax+rcx*1-0x1],sil
  ae:	48 83 c9 20          	or     rcx,0x20
  b2:	48 39 d1             	cmp    rcx,rdx
  b5:	77 18                	ja     cf <improved_memset_nn+0xcf>
  b7:	c5 f9 6e c6          	vmovd  xmm0,esi
  bb:	c4 e2 7d 78 c0       	vpbroadcastb ymm0,xmm0
  c0:	c5 fe 7f 44 08 e0    	vmovdqu YMMWORD PTR [rax+rcx*1-0x20],ymm0
  c6:	48 83 c1 20          	add    rcx,0x20
  ca:	48 39 d1             	cmp    rcx,rdx
  cd:	76 f1                	jbe    c0 <improved_memset_nn+0xc0>
  cf:	5d                   	pop    rbp
  d0:	c5 f8 77             	vzeroupper
  d3:	c3                   	ret
1 Like

I tried to improve on your duff’s device code by changing splatted from a vector to an array and added a couple more switch cases for n/2 and n/4. It seems to result less generated code.

export fn improved_memset3(dest: ?[*]u8, c: u8, len: usize) callconv(.c) ?[*]u8 {
    @setRuntimeSafety(false);

    const n = std.simd.suggestVectorLength(u8) orelse @sizeOf(usize);
    const splatted: [n]u8 = @splat(c);
    var d = dest.?;
    const end = dest.? + len;
    sw: switch (len % n) {
        0 => {
            if (d == end) return dest;
            assert(@intFromPtr(d + n) <= @intFromPtr(end));
            d[0..n].* = splatted;
            d += n;
            continue :sw 0;
        },
        n / 2 => {
            assert(@intFromPtr(d + n / 2) <= @intFromPtr(end));
            d[0 .. n / 2].* = splatted[0 .. n / 2].*;
            d += n / 2;
            continue :sw 0;
        },
        n / 4 => {
            assert(@intFromPtr(d + n / 4) <= @intFromPtr(end));
            d[0 .. n / 4].* = splatted[0 .. n / 4].*;
            d += n / 4;
            continue :sw 0;
        },
        else => |remainder| {
            d[0] = c;
            d += 1;
            continue :sw remainder - 1;
        },
    }
    return dest;
}
$ zig build-obj -fstrip -OReleaseFast /tmp/tmp.zig -femit-bin=/tmp/tmp.o && objdump -d /tmp/tmp.o

/tmp/tmp.o:     file format elf64-x86-64


Disassembly of section .text:

0000000000000000 <improved_memset3>:
   0:   55                      push   %rbp
   1:   48 89 e5                mov    %rsp,%rbp
   4:   48 8d 0c 17             lea    (%rdi,%rdx,1),%rcx
   8:   83 e2 1f                and    $0x1f,%edx
   b:   48 89 f8                mov    %rdi,%rax
   e:   48 83 fa 11             cmp    $0x11,%rdx
  12:   73 0c                   jae    20 <improved_memset3+0x20>
  14:   ff 24 d5 00 00 00 00    jmp    *0x0(,%rdx,8)
  1b:   0f 1f 44 00 00          nopl   0x0(%rax,%rax,1)
  20:   49 89 d0                mov    %rdx,%r8
  23:   40 88 37                mov    %sil,(%rdi)
  26:   48 ff c7                inc    %rdi
  29:   48 ff ca                dec    %rdx
  2c:   48 83 fa 11             cmp    $0x11,%rdx
  30:   73 ee                   jae    20 <improved_memset3+0x20>
  32:   42 ff 24 c5 00 00 00    jmp    *0x0(,%r8,8)
  39:   00
  3a:   c5 f9 6e c6             vmovd  %esi,%xmm0
  3e:   c4 e2 79 78 c0          vpbroadcastb %xmm0,%xmm0
  43:   c5 fa 7f 07             vmovdqu %xmm0,(%rdi)
  47:   48 83 c7 10             add    $0x10,%rdi
  4b:   eb 34                   jmp    81 <improved_memset3+0x81>
  4d:   40 0f b6 d6             movzbl %sil,%edx
  51:   49 b8 01 01 01 01 01    movabs $0x101010101010101,%r8
  58:   01 01 01
  5b:   4c 0f af c2             imul   %rdx,%r8
  5f:   4c 89 07                mov    %r8,(%rdi)
  62:   48 83 c7 08             add    $0x8,%rdi
  66:   eb 19                   jmp    81 <improved_memset3+0x81>
  68:   0f 1f 84 00 00 00 00    nopl   0x0(%rax,%rax,1)
  6f:   00
  70:   c5 f9 6e c6             vmovd  %esi,%xmm0
  74:   c4 e2 7d 78 c0          vpbroadcastb %xmm0,%ymm0
  79:   c5 fe 7f 07             vmovdqu %ymm0,(%rdi)
  7d:   48 83 c7 20             add    $0x20,%rdi
  81:   48 39 cf                cmp    %rcx,%rdi
  84:   75 ea                   jne    70 <improved_memset3+0x70>
  86:   5d                      pop    %rbp
  87:   c5 f8 77                vzeroupper
  8a:   c3                      ret

I made a little benchmark script which memsets a 10M - rand() % 1000 length buffer and it seems to be very close to the same speed.

Not sure if this is useful for anything. Just thought I’d share since it partially works around the duff’s device codegen bug. benchmark script if anyone wants to see.

1 Like

That’s good in that it adds extra vectorisation, which is nice. But it doesn’t unroll the loop at all, so there is a branch after every single write, even for a single byte.

What Duff’s Device does is unroll the loop, so there is only a single branch instruction for the call, instead of one for each write. This can speed up the CPU as is it better able to pipeline the instructions.

What you’ve written isn’t a Duff’s device. The key difference is the use of inline in the switch statement. That’s what forces the loop to be unrolled by generating a distinct code path for each value of the switch. Without it, it reverts to a conventional loop construct, like while or for.

Whether the extra vector writes make up for the loop not being unrolled would have to be benchmarked, although as the only difference is in the final ~1-30 bytes being copied, benchmarking this accurately is harder than it seems. I would be curious if unrolling is even worth it.

2 Likes

Nice catch and thanks for the explanation. Missing inline was an oversight. My benchmark did seem like the vector writes made up for some of the extra branching. I did try to simplify the benchark script but I feel like its still not great with the arg parsing and prng calls.