Had some good fun implementing autograd in Roc:
https://gist.github.com/ayazhafiz/88aec4c50b6403a9f343b13998a41886
The TLDR behind autograd for those not already in the know is, you build up some computational expression (e.g. any function y
depending on any number of variables), and then you can run two passes - a "forward" pass that computes the value of the function y
with respect to the values of the variables, and a "backward" pass that computes the derivative of y with respect to each variable. If y
is a minimizing function (a loss function), you can then update the variables to incrementally try to get the value of y
closer and closer to zero. This is the basis for a lot of neural networks (there's a simple multi-layer NN in the example above that tries to solve y = x^2 + z^2
with a neural net)
This is quite a simple implementation, it's not optimized, probably has bugs and doesn't support tensors (matrices) like a real library would
A few thoughts:
ctx2
, which feels a bit silly - and when it causes an error, the error is viral and the actual issue is somewhat opaqueSeemingly all positive!
A few things I struggled with:
roc check
, and would frequently need a restart. I'm not sure why this is, maybe because it's a standalone module? but idk, that's likely a red herring. Anyway, it was awesome when it worked and formatted the code automatically, but otherwise I couldn't find it very useful[..] is not equal to [..]
, but I know that's already a known issue.Graph = { value: F64, op: ... }
where op
is the current definition of Graph
. This doesn't really work because Graph
and Op
are mutually recursive, and then it runs into that bug about records not being able to be recursive without passing through a tag union (even though it would in this case)On the bad type comparisons, I think there was a discussion of doing a git diff-like delta, e.g.
It seems like you have a type mismatch:
List {
> item: Str,
< item: U64,
}
Which might not be directly related to the [..] is not equal to [..]
issue, but a rework of that code in a diffing feature push should fix that issue
The LSP definitely feels ad hoc at the moment. The main thing I use it for is displaying warnings/errors, but I also need to keep a "sixth sense" as to when it has crashed and for what reason, and remember to restart in those cases
I'd love folks' feedback on how to make this faster. In general autograd benefits heavily from in-place mutation - you do many rounds of training, and both within a training round and between them you'd like to re-use the same memory and mutate in-place. One particular problem with my implementation is stuff like this:
Mul m n ->
# x = m*n, dy/dm = dy/dx * dx/dm, dx/dm = n
mval = forward m xs # ouch
nval = forward n xs # ouch
grads1 = go m (dydx * nval) grads
grads2 = go n (dydx * mval) grads1
grads2
Basically, when computing the gradient for x, y for a function z = x*y
, we know that dz/dx = y
and dz/dy = x
. That means we need to compute the value of x
and y
, which is why forward
is called. But this is wasteful, because forward
is always called before backward
, so there's no new information gained by doing this. Especially if the computation graph/equation is very large before the particular node (pretend for example we are doing x * y
where x and y are themselves equations of a large # of variables), this can get expensive really quickly.
The alternative I was thinking of is to have a representation like
Graph : [
Const F64 F64, # forward pass value, constant value
Var F64 Str, # forward pass value, constant name
... # etc
]
and change the signature of forward
to Graph, Vars -> Graph
, i.e. directly compute the value of the forward pass and save it in the computation graph, that way the value is pre-cached when you go to back-propogate the gradients. The issue with this is twofold. The first issue is trivial, which is that I think you would want to lift the forward-pass value out and have a record of Node = {forward_pass: F64, op: Op}
, but that runs into the record recursion mentioned before - but that seems solvable.
The second is that reconstruction of the entire AST needs to re-use memory, or else it will be quite expensive to do this over and over (imagine on the lower side a 100K parameter graph with 100K+ training rounds). And I'm not sure if there's a way to guarantee that. This gets more problematic once you introduce higher-dimensional tensors (i.e. actual matrices, and not just the scalars used here).
Anyway, curious if anyone has thoughts or ideas
Sam Mohr said:
Seemingly all positive!
Yep! it was great how well this worked
Yeah any time code with recursive or nested types starts coming into play the language server is pretty well useless.
In my experience, this is because the compiler has a lot of bugs that causes hangs in that domain, and so the language server constantly ends up stuck with a hanging compiler.
We should be able to terminate the hanging process and It is setup to do that, but I think I don't know enough about Async rust and I'm not yielding correctly or something so it can't ever kill the hanging process
ah okay, one issue i was running into with the LS was
- removing the main.roc fixed itBasically, when computing the gradient for x, y for a function
z = x*y
, we know thatdz/dx = y
anddz/dy = x
. That means we need to compute the value ofx
andy
, which is whyforward
is called.
Couldn't you make forward take an optional grad dictionary (or a different variant of forward that takes grad)? Then just set the values during the forward pass?
Also, when you have a Dict.get
followed by a Dict.insert
, you can get some extra efficiency by using Dict.update
. Avoids looking up the key twice.
Or, I guess it would be a value cache instead of a gradient cache
Couldn't you make forward take an optional grad dictionary (or a different variant of forward that takes grad)? Then just set the values during the forward pass?
That would work, but I believe unique IDs would have to be created for each node in the graph then somehow. because for example if i'm at x = a*b
where b = e^c
(but being just at x
, I don't know that the second operand is b=e^c
), i need the result of b
. but b
is a resultant value, not a variable index
https://gist.github.com/bhansconnect/effa61cb21e879e28b6cc816fbb2850e
This just makes a new graph for going backwards
Solid perf gains. Definitely could be made cleaner.
That said, making new nodes instead of mutating in place definitely is not as nice as it could be.
Using your last expect and increasing to 1000 rounds of training, I see:
Summary
./grad-new ran
1.53 ± 0.04 times faster than ./grad
yeah that def works. my only concern is that it creates a new tree to each forward pass. but yeah def better
Yeah, not sure the best way to map this into roc.
another interesting thing is that the program spends much of its time in Dict operations. for small graphs (small # of vars) it looks like an association list is faster
Makes sense. No hashing.
Also denser memory.
Not to mention, our dict is two loads due to being a index map (hash -> index -> list of kv).
Also, equality can fail fast, hashing can not.
yeah
oh another thing that did come up was i was thinking of how to support arbitrary differentiable operations, like if someone wants to write a custom Sigmoid op or something. I think the best API would be an interface (e.g. abilities). But you can't hold on to opaque values of abilities (i.e. hidden behind a pointer, the concrete type is never materialized), so this doesn't work
@Ayaz Hafiz do you think we should support that?
idk. i think it might make sense at some point, but probably not right now
seems like it would require doing something conceptually similar to lambda sets, where we make a tag union behind the scenes of all the different instantiations that come up in practice
yeah, or you could compile it to the boxed representation
i also feel like there must be some generalization of reset-reuse to support in-place mutation of trees, e.g. when doing something like
Node : [
Const U64,
Add Node Node,
Mul Node Node,
]
somewalk : Node -> Node
somewalk = \node ->
when node is
Const x -> Const (x + 1)
Add m n -> Add (somewalk m) (somewalk n)
Mul m n -> Mul (somewalk m) (somewalk n)
which would make it a non-issue to create and teardown trees between passes, since they're almost always unique (and this is a common operation in anything graph shaped, be it ML or compilers or whatever)
however, i'm struggling to find an intuition that holds up.
oh wait, im silly
this works fine under reset-reuse, so it is updated in place
Ayaz Hafiz said:
- recursion of types - i'll get into this more later but at first I wanted a shape like
Graph = { value: F64, op: ... }
whereop
is the current definition ofGraph
. This doesn't really work becauseGraph
andOp
are mutually recursive, and then it runs into that bug about records not being able to be recursive without passing through a tag union (even though it would in this case)
I have a branch that solves this in the dev backend that got silently left behind while implementing the llvm backend. I'll revive it today as I should have done quiet some time ago.
:thinking: the bug occurs during typechecking though, i believe
Oh, yes... I thought you meant you couldn't write tail recursive functions using tag unions where the recursive data (op
) was inside a struct, not directly in the payload of a tag union. Realized we're talking about a completely different problem.
I haven't really used enzyme, but I wonder what the tradeoffs are between doing autograd on the roc source vs doing it on the llvm ir (or similar low-level IR)?
Last updated: Jul 06 2025 at 12:14 UTC