Okay, update time!
I’m sticking with the file generation approach because so far, it’s working out great.
There is a 3 step process for generating files to get convenient usage/linkage to zig. Here’s the overview of what I’m working with:
Step 1 - generate overloads from cuda source
My source files are marked with stand-in replacable types that allows me to write direct cuda with full lsp support and then generate overloads for compiling the library. Here’s how that step happens…
In the file_gen.zig
file, there’s a list of replacement types and their size precedence that’s used to create valid type combinations - the declaration looks like this:
// level relates to the validity of a cast
// higher levels cannot result in lower levels
const Replacer = struct {
symbol: []const u8,
level: usize,
};
pub const ReplacerSet = struct {
indicator: []const u8,
replacers: []const Replacer,
};
const replacer_sets = [_]ReplacerSet {
ReplacerSet { // real number replacers
.indicator = "RScalar",
.replacers = &.{
Replacer{ .symbol = "r16", .level = MIN_LEVEL + 0 },
Replacer{ .symbol = "r32", .level = MIN_LEVEL + 1 },
Replacer{ .symbol = "r64", .level = MIN_LEVEL + 2 },
}
},
ReplacerSet { // complex number replacers
.indicator = "CScalar",
.replacers = &.{
Replacer{ .symbol = "c16", .level = MIN_LEVEL + 0 },
Replacer{ .symbol = "c32", .level = MIN_LEVEL + 1 },
Replacer{ .symbol = "c64", .level = MIN_LEVEL + 2 },
}
},
ReplacerSet { // real tensor replacers
.indicator = "RTensor",
.replacers = &.{
Replacer{ .symbol = "RTensor16", .level = MIN_LEVEL + 0 },
Replacer{ .symbol = "RTensor32", .level = MIN_LEVEL + 1 },
Replacer{ .symbol = "RTensor64", .level = MIN_LEVEL + 2 },
}
},
ReplacerSet { // complex tensor replacers
.indicator = "CTensor",
.replacers = &.{
Replacer{ .symbol = "CTensor16", .level = MIN_LEVEL + 0 },
Replacer{ .symbol = "CTensor32", .level = MIN_LEVEL + 1 },
Replacer{ .symbol = "CTensor64", .level = MIN_LEVEL + 2 },
}
},
};
Here’s an example of a cuda kernel:
__global__ void __kernel_addition_RScalar(
const RScalar *dev_a,
const RScalar *dev_b,
RScalar *dev_c,
len_t N
) {
const len_t tid = (blockIdx.x * blockDim.x) + threadIdx.x;
if (tid < N)
dev_c[tid] = dev_a[tid] + dev_b[tid];
}
extern "C" void launch_addition_RScalar(
const RScalar* a,
const RScalar* b,
RScalar* c,
len_t N
) {
__kernel_addition_RScalar<<<GRID_1D(N), 32>>>(a, b, c, N);
}
In a header file, I have RScalar
, CScalar
, RTensor
, and CTensor
defined to their 32 bit types types (aka RScalar
is float
). This enables me to write cuda code with full assistance from the lsp’s that will then be replaced by the file generator to their final types.
Step 2 - generate C-stype declarations
This step is short, but important - we gather each of the extern "C"
declarations and push them into a header during the generation process. The file looks like this:
/* GENERATED FILE */
#include "../tensor_types.h"
#if defined(__cplusplus)
#define EXTERN_C extern "C"
#else
#define EXTERN_C extern
#endif
EXTERN_C void launch_hadamard_reverse_r16(
r16 *grads_a,
const r16 *value_b,
const r16 *grads_c,
len_t N
);
EXTERN_C void launch_hadamard_reverse_c16(
c16 *grads_a,
const c16 *value_b,
const c16 *grads_c,
len_t N
);
EXTERN_C void launch_hadamard_reverse_r32(
r32 *grads_a,
const r32 *value_b,
const r32 *grads_c,
len_t N
);
EXTERN_C void launch_hadamard_reverse_c32(
c32 *grads_a,
const c32 *value_b,
const c32 *grads_c,
len_t N
);
EXTERN_C void launch_hadamard_reverse_r64(
r64 *grads_a,
const r64 *value_b,
const r64 *grads_c,
len_t N
);
EXTERN_C void launch_hadamard_reverse_c64(
c64 *grads_a,
const c64 *value_b,
const c64 *grads_c,
len_t N
);
Step 3: generate overload sets
Here’s where this all comes together - a while ago I started a thread about OverloadSets: Making Overloaded Function Sets Using Comptime
I continued to workshop that idea and with @Sze’s input, and we were able to build an OverloadSet that does best matching on const
pointer parameters. I now can automatically generate function overloads from the C/Cuda back end as part of the kernel generation process. The the include path is also automatically generated as well:
const OverloadSet = @import("overloadset.zig").OverloadSet;
const decls = @cImport(
@cInclude("/home/andrew/ZigCode/Metaphor/src/nvcc_target/kernel_decls.h"),
);
pub const kernel_hadamard_reverse = OverloadSet(.{
decls.launch_hadamard_reverse_r16,
decls.launch_hadamard_reverse_c16,
decls.launch_hadamard_reverse_r32,
decls.launch_hadamard_reverse_c32,
decls.launch_hadamard_reverse_r64,
decls.launch_hadamard_reverse_c64,
});
pub const kernel_subtraction = OverloadSet(.{
decls.launch_subtraction_r16,
decls.launch_subtraction_c16,
decls.launch_subtraction_r32,
decls.launch_subtraction_c32,
decls.launch_subtraction_r64,
decls.launch_subtraction_c64,
});
pub const kernel_fill = OverloadSet(.{
decls.launch_fill_r16,
decls.launch_fill_c16,
decls.launch_fill_r32,
decls.launch_fill_c32,
decls.launch_fill_r64,
decls.launch_fill_c64,
});
pub const kernel_permutate = OverloadSet(.{
decls.launch_perumutate_r16,
decls.launch_permutate_c16,
decls.launch_perumutate_r32,
decls.launch_permutate_c32,
decls.launch_perumutate_r64,
decls.launch_permutate_c64,
});
pub const kernel_addition = OverloadSet(.{
decls.launch_addition_r16,
decls.launch_addition_c16,
decls.launch_addition_r32,
decls.launch_addition_c32,
decls.launch_addition_r64,
decls.launch_addition_c64,
});
pub const kernel_addition_reverse = OverloadSet(.{
decls.launch_addition_reverse_r16,
decls.launch_addition_reverse_c16,
decls.launch_addition_reverse_r32,
decls.launch_addition_reverse_c32,
decls.launch_addition_reverse_r64,
decls.launch_addition_reverse_c64,
});
pub const kernel_subtraction_reverse = OverloadSet(.{
decls.launch_subtraction_reverse_r16,
decls.launch_subtraction_reverse_c16,
decls.launch_subtraction_reverse_r32,
decls.launch_subtraction_reverse_c32,
decls.launch_subtraction_reverse_r64,
decls.launch_subtraction_reverse_c64,
});
pub const kernel_hadamard = OverloadSet(.{
decls.launch_hadamard_r16,
decls.launch_hadamard_c16,
decls.launch_hadamard_r32,
decls.launch_hadamard_c32,
decls.launch_hadamard_r64,
decls.launch_hadamard_c64,
});
How it’s used
In my operations file for my torch-style library, I can now just do the following:
pub fn additionForward(x: anytype, y: anytype, z: anytype) void {
const x_values = x.values();
const y_values = y.values();
const z_values = z.values();
overloads.kernel_addition.call(.{
x_values.ptr, y_values.ptr, z_values.ptr, z_values.len
});
}
pub fn additionReverseArg0(X: anytype, _: anytype, Z: anytype) void {
const x_grads = UT.assertGrads(X);
const z_grads = UT.assertGrads(Z);
overloads.kernel_addition_reverse.call(.{
x_grads.ptr, z_grads.ptr, z_grads.len
});
}
pub fn additionReverseArg1(_: anytype, Y: anytype, Z: anytype) void {
const y_grads = UT.assertGrads(Y);
const z_grads = UT.assertGrads(Z);
overloads.kernel_addition_reverse.call(.{
y_grads.ptr, z_grads.ptr, z_grads.len
});
}
pub const AddImpl = CallbackBuilder(
additionForward, .{
.{ additionReverseArg0, 0 },
.{ additionReverseArg1, 1 }
}, NoCleanup
);
Which creates a callback via CallbackBuilder
that can be used by the computation graph to call forwards and reverse for gradient back-prop.
So far it all works - I’m not supporting mixed precision/category operations yet but the generator can create those overloads. Basically, you just have to write kernels now and the rest gets automatically parsed and built for you and Zig beautifully picks up all the generated declarations and overloads them. It’s really simple, actually.
Anyhow, that’s what I’m going with… I’ll post a version of my library hopefully soon… getting there quickly now!