Zigrad: Deep learning faster than PyTorch

Zigrad is a deep learning framework built on a tensor valued autograd engine, written in Zig (of course), 2.5x faster than PyTorch on Apple Silicon and 1.5x faster on x86. Detailed benchmarks and some getting started instructions are available in the readme.

Zigrad has been extensively benchmarked throughout development, you can actually train real AI models with Zigrad, faster than PyTorch, Tensorflow, and Tinygrad.

There are a similar efforts out there but a few characteristics I think make Zigrad different are its focus on performance without sacrificing extensibility and usability as well as its small binaries, speed, and memory utilization patterns.

Distributed training, graph compilation, and CUDA support are a few of the more exciting upcoming features, among many others being planned.

If you want to get started with Zig, neural networks, GPU programming, HPC, etc, please let me know! There are many low hanging tasks for newcomers to get their feet wet, as well as more complex system architecture level decisions that deserve discussion for more experienced engineers.

If you have a project that you think has overlap (e.g. @AndrewCodeDev’s Metaphor or a dataloader package) drop a link to your project and lets work together.

Thank you to @nurpax for zigrograd which was an early inspiration.

14 Likes

Hey @Marco, this is awesome!

As can be seen from my git history, I’ve simply run out of time to work on Metaphor between work and other projects. I’d love to collaborate and happily be a contributor instead of an owner of a project. If you’re interested in working together, send me a message!

6 Likes

Not sure if I have time to colab on this project. But I want to see this project grow. Not sure what it would take me to convince the pointy hats to move our existing PyTorch stuff to zig…

2 Likes

Great to hear, would love to work together! How would you like to connect? Github discussions or discord could work (do not think there is a way to DM here).

My job is entirely in pytorch as well. Zigrad is intended to serve a different purpose (e.g. edge deployed online learning) and is not supposed to be a torch competitor. If you wanted to work in python I can release bindings which I wrote at one point.

What torch modules do you use most often? (e.g. Conv2d, BatchNorm2d, Dropout, etc) I can prioritize accordingly.

1 Like

I’ll send you a DM here. You should see a notification on your icon up at the top.

1 Like

Nice project! I wanted to eventually do something similar but with comptime graphs - it might be interesting idea for later :slight_smile:

You mention Apple Silicon, which probably means just CPU, right? But how does it compare to mlx? pytorch is second-class citizen on macos but mlx seems to be very active and usable (I just did a simple PoC today)

1 Like

Mostly image based stuff, so, yes, Conv2d, Dropout, BatchNormalize.

1 Like

I have been interested in the same idea and prototyped it last year but there was limited benefit until I could improve my metaprogramming for comptime graph analysis. I am still improving and plan to revisit this in the context of Zigrad. If you remain interested let me know!

Related post from old account, perhaps useful for you: Comptime memory references and side effects

As far as I understand the Accelerate silicon library used by Zigrad is not exactly using a regular CPU. If my memory serves me early benchmarking showed Accelerate approached GPU-level throughput, not CPU (especially considering the CPU specs of the test machine and manual GEMM implementations). I am not knowledgable enough about Apple hw-sw co-design to say more but hopefully someone on this forum can chime in. Apple does officially refer to Accelerate as a CPU library, though.

PyTorch is not really a second-class citizen on macos, Meta has direct support from Apple engineers:

https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/
“In collaboration with the Metal engineering team at Apple, we are excited to announce support
for GPU-accelerated PyTorch training on Mac”

Torch links against native Apple APIs for optimal performance, so it is not a second-class citizen in my opinion (and the benchmarks reflect this as torch is quite fast on silicon).

Using the Torch MPS backend, I did not see performance improvements (actually, it is slow and not sure why). I would be more motivated in supporting MPS if Torch MPS showed significant benefit. Using MLX is also slow. That being said, my MLX code was a translation of torch code for the sake of responding to your comment and I suspect it could be optimized. Also, the MLX trained model was quite inaccurate and again I am not sure why. In general, I do not have a particular interest in optimizing Zigrad for Apple hardware unless there are motivating use cases and demand (happy to field feature requests) but it would be wise to optimize this script for benchmarking (and potentially write a swift script which would accurately represent the performance ceiling).

Optimal performance requires platform-specialization. I would highly recommend using official Apple SDKs, such as MLX, for optimal performance. Despite the results of the below benchmark, no one will be able to beat Apple code on Apple hardware.

Aside: Same comments go for NVIDIA as well.

Microbenchmark results using this script. Please do not take results too seriously and, as mentioned above, this is a translation of torch code and not optimized however MLX expects to be written. I have suspicions as to why this is slower, but this post is already quite long. Suffice to say: “Python.”

--- Zigrad Summary ---
Metadata: platform: darwin-23.6.0, python: 3.12.1
Avg loss: 0.191304
Std loss: 0.198011
Avg ms/sample: 0.002119
Std ms/sample: 0.000190

--- MLX Summary ---
Metadata: batch_size: 64, epochs: 3, lr: 0.1, batches: 938, platform: darwin-23.6.0, python: 3.12.1
Avg loss: 1.738327
Std loss: 0.258141
Avg ms/sample: 0.009428
Std ms/sample: 0.001928
Speedup
count    2813.000000
mean        4.471025
std         0.887820
min         1.721319
25%         3.810588
50%         4.278214
75%         5.197537
max        14.615357
1 Like

It might have changed since the last time I’ve checked (~1y), but torch MPS was basically unusable for any kind of exploratory work (homebrew architectures, ie. RWKV)

It’s not just about performance, some things were unimplemented even for CPU. And then there was a strange middle ground where some parts of the graph were accelerated only to be handed into CPU-only operations, and then it could continue, so you can imagine it was very slow - sometimes even like cpu-only.

MLX is macos-native and it looks like apple ML researchers are using it for doing their own work.

About accelerate, yes, I think some of it might be run on GPU, but if you want max perf, you need metal kernels (llama.cpp has few of those so you could get inspired)

Sorry, I can’t run any benchmark right now but maybe in future I can try.

About comptime yes, I am interested, but I have a lot on my plate for at least EOY so don’t count with me now.

1 Like

Good point about being poor for experimentation, I also got this impression. From what I can recall, it looked like a static graph API (or at least had that feel) which is a paradigm that seemingly only engineers appreciate but makes AI research challenging.

Do you do any deep learning on Apple hardware? Curious what spurred the interest.

I have M3 Max with 128G of memory so naturally, I wanted to give it a try. So far I’m just going through my list of stupid ideas and trying some of those locally to see if it leads anywhere. All I can say is that it’s both fast & flexible enough for such experiments, but I’ve not yet run it for longer period of time.

What type of hardware are you typically targeting?

You mean for bigger things? Whatever I can find on lambda labs for a reasonable price, typically single-machine, and I don’t have big budget so A100 40G. I was also hoping for battlemage which should have >24G so I could buy something home but I have not heard any new info for a long time.

Depends of project: from desktop/server with GPU to edge computing nodes with tiny Arm CPUs.

I was asking @daredemo but thats interesting thanks for sharing. Trying to prioritize platform support, NVIDIA is atop the list of course. Judging by the GPU requirements it seems like you are interested in LLMs, is that accurate?

Cool, thanks! I will assume NVIDIA GPUs, ARM support is already in but have not added support for freestanding builds and the image-related ops are entirely unoptimized.

Sorry, I mis-replied :slight_smile:
LLMs, yes.

1 Like