I have been interested in the same idea and prototyped it last year but there was limited benefit until I could improve my metaprogramming for comptime graph analysis. I am still improving and plan to revisit this in the context of Zigrad. If you remain interested let me know!
Related post from old account, perhaps useful for you: Comptime memory references and side effects
As far as I understand the Accelerate silicon library used by Zigrad is not exactly using a regular CPU. If my memory serves me early benchmarking showed Accelerate approached GPU-level throughput, not CPU (especially considering the CPU specs of the test machine and manual GEMM implementations). I am not knowledgable enough about Apple hw-sw co-design to say more but hopefully someone on this forum can chime in. Apple does officially refer to Accelerate as a CPU library, though.
PyTorch is not really a second-class citizen on macos, Meta has direct support from Apple engineers:
https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/
“In collaboration with the Metal engineering team at Apple, we are excited to announce support
for GPU-accelerated PyTorch training on Mac”
Torch links against native Apple APIs for optimal performance, so it is not a second-class citizen in my opinion (and the benchmarks reflect this as torch is quite fast on silicon).
Using the Torch MPS backend, I did not see performance improvements (actually, it is slow and not sure why). I would be more motivated in supporting MPS if Torch MPS showed significant benefit. Using MLX is also slow. That being said, my MLX code was a translation of torch code for the sake of responding to your comment and I suspect it could be optimized. Also, the MLX trained model was quite inaccurate and again I am not sure why. In general, I do not have a particular interest in optimizing Zigrad for Apple hardware unless there are motivating use cases and demand (happy to field feature requests) but it would be wise to optimize this script for benchmarking (and potentially write a swift script which would accurately represent the performance ceiling).
Optimal performance requires platform-specialization. I would highly recommend using official Apple SDKs, such as MLX, for optimal performance. Despite the results of the below benchmark, no one will be able to beat Apple code on Apple hardware.
Aside: Same comments go for NVIDIA as well.
Microbenchmark results using this script. Please do not take results too seriously and, as mentioned above, this is a translation of torch code and not optimized however MLX expects to be written. I have suspicions as to why this is slower, but this post is already quite long. Suffice to say: “Python.”
--- Zigrad Summary ---
Metadata: platform: darwin-23.6.0, python: 3.12.1
Avg loss: 0.191304
Std loss: 0.198011
Avg ms/sample: 0.002119
Std ms/sample: 0.000190
--- MLX Summary ---
Metadata: batch_size: 64, epochs: 3, lr: 0.1, batches: 938, platform: darwin-23.6.0, python: 3.12.1
Avg loss: 1.738327
Std loss: 0.258141
Avg ms/sample: 0.009428
Std ms/sample: 0.001928
Speedup
count 2813.000000
mean 4.471025
std 0.887820
min 1.721319
25% 3.810588
50% 4.278214
75% 5.197537
max 14.615357