Hi everyone, I have been learning Zig for the last few weeks, and I would like feedback on some code below. I want to make it as fast as possible.
To give some context:
In my application, I need to populate a BitSet
with i.i.d binary variables, each with probability prob
.
My first approach was something like this:
for (0..bitset.bit_length) |i| {
const int_thresh: u32 = @intFromFloat(prob * std.math.maxInt(u32));
const val = rng.random().int(u32) < int_thresh;
bitset.setValue(i, val);
}
where rng
is a Xoshiro256 instance.
After some initial testing I found that calling the function with this loop was a performance bottleneck. The alternative I came up with was with the following algorithm:
// fun(p: p is of the form 0.b_{n-1}b_{n-2}..b_0):
// acc = false
// for digit d in least significant to most significant:
// if d = true:
// acc = acc | coinFlip
// else:
// acc = acc & coinFlip
// return acc
where coinFlip
gives you an unbiased binary variable.
Finally, in Zig:
/// Generates (supposedly) iid random bernoullis with prob
pub fn randomBitsU64(prob: f32, buf: []u64, rng: anytype) void {
assert(prob >= 0.0 and prob <= 1.0);
if (prob == 1.0) {
@memset(buf, std.math.boolMask(u64, true));
}
const frexp = floatDec(prob);
var bits = frexp.significand;
var exp = frexp.exponent;
bits = bits >> (@ctz(bits) + 1);
for (buf) |*v| {
v.* = rng.next();
}
while (bits > 0) : (bits >>= 1) {
if (bits % 2 == 1) {
for (buf) |*v| {
v.* |= rng.next();
}
} else {
for (buf) |*v| {
v.* &= rng.next();
}
}
}
while (exp < -1) : (exp += 1) {
for (buf) |*v| {
v.* &= rng.next();
}
}
}
I did some quick benchmarking and it seems that the second approach can be 10x faster.
My questions:
- Did I miss anything? Was my initial approach inappropriate?
- Did I do anything ‘non-idiomatic’ in my solution? Is there room for further improvements?
- I did notice that using
std.rand.Random.float
to get a [0, 1] uniform variable was quite slow. Maybe too many layers of indirection? float() → int() → fill() → next()?
Thanks a lot for the help!
PS: For floatDec
I took inspiration in the standard library’s frexp
:
pub fn FloatDec(comptime T: type) type {
return struct {
sign: u1,
exponent: CExpInt,
significand: SigInt,
const bits: comptime_int = @typeInfo(T).Float.bits;
const exp_bits: comptime_int = std.math.floatExponentBits(T);
const frac_bits: comptime_int = std.math.floatFractionalBits(T);
const mant_bits: comptime_int = std.math.floatMantissaBits(T);
const Int: type = std.meta.Int(.unsigned, bits);
const ExpInt: type = std.meta.Int(.unsigned, exp_bits);
const CExpInt: type = std.meta.Int(.signed, exp_bits + 1);
const FracInt: type = std.meta.Int(.unsigned, frac_bits);
const SigInt: type = std.meta.Int(.unsigned, frac_bits + 1);
};
}
//float decomposition
pub fn floatDec(fl: anytype) FloatDec(@TypeOf(fl)) {
const T = @TypeOf(fl);
const FT = FloatDec(T);
const exp_bias: comptime_int = (1 << FT.exp_bits - 1) - 1;
const v: std.meta.Int(.unsigned, FT.bits) = @bitCast(fl);
const sign: u1 = @truncate(v >> FT.bits - 1);
const exponent: FT.ExpInt = @truncate(v >> FT.mant_bits);
const frac: FT.FracInt = @truncate(v);
const imp_bool: FT.SigInt = @as(FT.SigInt, @intFromBool(exponent != 0)) << (FT.frac_bits);
const significand: FT.SigInt = @as(FT.SigInt, frac) + imp_bool;
const cexp: FT.CExpInt = exponent;
return FT{ .sign = sign, .exponent = cexp - exp_bias, .significand = significand };
}