@chung-leong I’ll try that if we can’t find something else.
@LucasSantos91 Good point. The backend is fairly extensive so I’ll do my best to show what’s happening here succinctly. Let me know if there’s something else I’m missing.
Here’s the scheme - essentially, I’m generating a whole library of C wrappers for a C++/Cuda backend that then make their way into the zig code. All of the functions are automatically tagged with extern "C"
, compiled to a static library, and then the function declarations that Zig sees are just extern
. So far, this has been totally fine. Here’s an example of the generated file declarations that get brought in via @cimport
:
// Top of generated declarations... Zig struggled with size_t
// so len_t is just unsigned long long:
#if defined(__cplusplus)
#define EXTERN_C extern "C"
#else
#define EXTERN_C extern
#endif
EXTERN_C void launch_reduce_ij_i_r16(
const void* src,
void* dst,
double alpha,
len_t m,
len_t n,
StreamContext stream
);
EXTERN_C void launch_reduce_ij_i_r32(
const void* src,
void* dst,
double alpha,
len_t m,
len_t n,
StreamContext stream
);
// ...
These get grouped into function arrays where each function array is generated as follows:
// Top of kernels file:
const decls = @import("cimport.zig").C;
fn dispatch_array(tuple: anytype) [tuple.len]*const @TypeOf(tuple[0]) { return tuple; }
pub const reduce_ij_i = dispatch_array(.{
decls.launch_reduce_ij_i_r16,
decls.launch_reduce_ij_i_r32,
decls.launch_reduce_ij_i_r64,
});
// and so on...
These functions are then called by using runtime-type information at the call site. Here’s the offending example:
const key = core.dkey(y);
core.kernels.permutate_ij_ji[key](
y.stream(), <--- This works, but only in this position
x.data_ptr(),
y.data_ptr(),
1.0, // alpha
xs[0], xs[1],
// y.stream() <--- The rest of the API is in this position
);
There’s a few things I’m going to try. Instead of using the GCC
compiler (which is technically an external discrepancy as CUDA prefers Clang
), I’m going to try using Clang
instead.
This probably has nothing to do with it, but instead of using *const foo
, I may just try doing foo
to see if that has anything to do with it. Edit: those are runtime values apparently, so they need to remain as pointers.
Also, these are compiled with the NVCC
compiler atop of that which is NVIDIA’s cuda compiler. I know that Clang
can compile Cuda as well, I may just try to switch over to that and see if it has better luck. NVCC
dispatches to another compiler to handle the C++ code and then compiles the to a CUBIN file (their binary files) or PTX instructions for the devices from there.
That’s the setup.