Stream: show and tell

Topic: autograd in Roc


view this post on Zulip Ayaz Hafiz (Dec 31 2024 at 07:32):

Had some good fun implementing autograd in Roc:
https://gist.github.com/ayazhafiz/88aec4c50b6403a9f343b13998a41886
The TLDR behind autograd for those not already in the know is, you build up some computational expression (e.g. any function y depending on any number of variables), and then you can run two passes - a "forward" pass that computes the value of the function y with respect to the values of the variables, and a "backward" pass that computes the derivative of y with respect to each variable. If y is a minimizing function (a loss function), you can then update the variables to incrementally try to get the value of y closer and closer to zero. This is the basis for a lot of neural networks (there's a simple multi-layer NN in the example above that tries to solve y = x^2 + z^2 with a neural net)

view this post on Zulip Ayaz Hafiz (Dec 31 2024 at 07:33):

This is quite a simple implementation, it's not optimized, probably has bugs and doesn't support tensors (matrices) like a real library would

view this post on Zulip Ayaz Hafiz (Dec 31 2024 at 07:41):

A few thoughts:

view this post on Zulip Sam Mohr (Dec 31 2024 at 07:43):

Seemingly all positive!

view this post on Zulip Ayaz Hafiz (Dec 31 2024 at 07:50):

A few things I struggled with:

view this post on Zulip Sam Mohr (Dec 31 2024 at 07:58):

On the bad type comparisons, I think there was a discussion of doing a git diff-like delta, e.g.

It seems like you have a type mismatch:

  List {
>     item: Str,
<     item: U64,
  }

view this post on Zulip Sam Mohr (Dec 31 2024 at 07:59):

Which might not be directly related to the [..] is not equal to [..] issue, but a rework of that code in a diffing feature push should fix that issue

view this post on Zulip Sam Mohr (Dec 31 2024 at 08:00):

The LSP definitely feels ad hoc at the moment. The main thing I use it for is displaying warnings/errors, but I also need to keep a "sixth sense" as to when it has crashed and for what reason, and remember to restart in those cases

view this post on Zulip Ayaz Hafiz (Dec 31 2024 at 08:01):

I'd love folks' feedback on how to make this faster. In general autograd benefits heavily from in-place mutation - you do many rounds of training, and both within a training round and between them you'd like to re-use the same memory and mutate in-place. One particular problem with my implementation is stuff like this:

            Mul m n ->
                # x = m*n, dy/dm = dy/dx * dx/dm, dx/dm = n
                mval = forward m xs # ouch
                nval = forward n xs # ouch
                grads1 = go m (dydx * nval) grads
                grads2 = go n (dydx * mval) grads1
                grads2

Basically, when computing the gradient for x, y for a function z = x*y, we know that dz/dx = y and dz/dy = x. That means we need to compute the value of x and y, which is why forward is called. But this is wasteful, because forward is always called before backward, so there's no new information gained by doing this. Especially if the computation graph/equation is very large before the particular node (pretend for example we are doing x * y where x and y are themselves equations of a large # of variables), this can get expensive really quickly.

The alternative I was thinking of is to have a representation like

Graph : [
    Const F64 F64, # forward pass value, constant value
    Var F64 Str, # forward pass value, constant name
    ... # etc
]

and change the signature of forward to Graph, Vars -> Graph, i.e. directly compute the value of the forward pass and save it in the computation graph, that way the value is pre-cached when you go to back-propogate the gradients. The issue with this is twofold. The first issue is trivial, which is that I think you would want to lift the forward-pass value out and have a record of Node = {forward_pass: F64, op: Op}, but that runs into the record recursion mentioned before - but that seems solvable.

The second is that reconstruction of the entire AST needs to re-use memory, or else it will be quite expensive to do this over and over (imagine on the lower side a 100K parameter graph with 100K+ training rounds). And I'm not sure if there's a way to guarantee that. This gets more problematic once you introduce higher-dimensional tensors (i.e. actual matrices, and not just the scalars used here).

Anyway, curious if anyone has thoughts or ideas

view this post on Zulip Ayaz Hafiz (Dec 31 2024 at 08:01):

Sam Mohr said:

Seemingly all positive!

Yep! it was great how well this worked

view this post on Zulip Eli Dowling (Dec 31 2024 at 18:35):

Yeah any time code with recursive or nested types starts coming into play the language server is pretty well useless.

In my experience, this is because the compiler has a lot of bugs that causes hangs in that domain, and so the language server constantly ends up stuck with a hanging compiler.

We should be able to terminate the hanging process and It is setup to do that, but I think I don't know enough about Async rust and I'm not yielding correctly or something so it can't ever kill the hanging process

view this post on Zulip Ayaz Hafiz (Dec 31 2024 at 19:04):

ah okay, one issue i was running into with the LS was #compiler development > bug: Outstanding references to the derived module @ 💬 - removing the main.roc fixed it

view this post on Zulip Brendan Hansknecht (Dec 31 2024 at 19:24):

Basically, when computing the gradient for x, y for a function z = x*y, we know that dz/dx = y and dz/dy = x. That means we need to compute the value of x and y, which is why forward is called.

Couldn't you make forward take an optional grad dictionary (or a different variant of forward that takes grad)? Then just set the values during the forward pass?

view this post on Zulip Brendan Hansknecht (Dec 31 2024 at 19:25):

Also, when you have a Dict.get followed by a Dict.insert, you can get some extra efficiency by using Dict.update. Avoids looking up the key twice.

view this post on Zulip Brendan Hansknecht (Dec 31 2024 at 19:45):

Or, I guess it would be a value cache instead of a gradient cache

view this post on Zulip Ayaz Hafiz (Dec 31 2024 at 19:52):

Couldn't you make forward take an optional grad dictionary (or a different variant of forward that takes grad)? Then just set the values during the forward pass?

That would work, but I believe unique IDs would have to be created for each node in the graph then somehow. because for example if i'm at x = a*b where b = e^c (but being just at x, I don't know that the second operand is b=e^c), i need the result of b. but b is a resultant value, not a variable index

view this post on Zulip Brendan Hansknecht (Dec 31 2024 at 20:12):

https://gist.github.com/bhansconnect/effa61cb21e879e28b6cc816fbb2850e

view this post on Zulip Brendan Hansknecht (Dec 31 2024 at 20:12):

This just makes a new graph for going backwards

view this post on Zulip Brendan Hansknecht (Dec 31 2024 at 20:12):

Solid perf gains. Definitely could be made cleaner.

view this post on Zulip Brendan Hansknecht (Dec 31 2024 at 20:16):

That said, making new nodes instead of mutating in place definitely is not as nice as it could be.

view this post on Zulip Brendan Hansknecht (Dec 31 2024 at 20:19):

Using your last expect and increasing to 1000 rounds of training, I see:

Summary
  ./grad-new ran
    1.53 ± 0.04 times faster than ./grad

view this post on Zulip Ayaz Hafiz (Dec 31 2024 at 20:43):

yeah that def works. my only concern is that it creates a new tree to each forward pass. but yeah def better

view this post on Zulip Brendan Hansknecht (Dec 31 2024 at 21:34):

Yeah, not sure the best way to map this into roc.

view this post on Zulip Ayaz Hafiz (Jan 01 2025 at 22:11):

another interesting thing is that the program spends much of its time in Dict operations. for small graphs (small # of vars) it looks like an association list is faster

view this post on Zulip Brendan Hansknecht (Jan 01 2025 at 22:23):

Makes sense. No hashing.

view this post on Zulip Brendan Hansknecht (Jan 01 2025 at 22:23):

Also denser memory.

view this post on Zulip Brendan Hansknecht (Jan 01 2025 at 22:23):

Not to mention, our dict is two loads due to being a index map (hash -> index -> list of kv).

view this post on Zulip Brendan Hansknecht (Jan 01 2025 at 22:24):

Also, equality can fail fast, hashing can not.

view this post on Zulip Ayaz Hafiz (Jan 01 2025 at 22:24):

yeah

view this post on Zulip Ayaz Hafiz (Jan 01 2025 at 22:43):

oh another thing that did come up was i was thinking of how to support arbitrary differentiable operations, like if someone wants to write a custom Sigmoid op or something. I think the best API would be an interface (e.g. abilities). But you can't hold on to opaque values of abilities (i.e. hidden behind a pointer, the concrete type is never materialized), so this doesn't work

view this post on Zulip Richard Feldman (Jan 01 2025 at 22:52):

@Ayaz Hafiz do you think we should support that?

view this post on Zulip Ayaz Hafiz (Jan 01 2025 at 22:53):

idk. i think it might make sense at some point, but probably not right now

view this post on Zulip Richard Feldman (Jan 01 2025 at 22:53):

seems like it would require doing something conceptually similar to lambda sets, where we make a tag union behind the scenes of all the different instantiations that come up in practice

view this post on Zulip Ayaz Hafiz (Jan 01 2025 at 22:53):

yeah, or you could compile it to the boxed representation

view this post on Zulip Ayaz Hafiz (Jan 01 2025 at 22:55):

i also feel like there must be some generalization of reset-reuse to support in-place mutation of trees, e.g. when doing something like

Node : [
    Const U64,
    Add Node Node,
    Mul Node Node,
]

somewalk : Node -> Node
somewalk = \node ->
    when node is
        Const x -> Const (x + 1)
        Add m n -> Add (somewalk m) (somewalk n)
        Mul m n -> Mul (somewalk m) (somewalk n)

which would make it a non-issue to create and teardown trees between passes, since they're almost always unique (and this is a common operation in anything graph shaped, be it ML or compilers or whatever)

view this post on Zulip Ayaz Hafiz (Jan 01 2025 at 22:56):

however, i'm struggling to find an intuition that holds up.

view this post on Zulip Ayaz Hafiz (Jan 01 2025 at 22:59):

oh wait, im silly

view this post on Zulip Ayaz Hafiz (Jan 01 2025 at 23:00):

this works fine under reset-reuse, so it is updated in place

view this post on Zulip Norbert Hajagos (Jan 04 2025 at 09:13):

Ayaz Hafiz said:

I have a branch that solves this in the dev backend that got silently left behind while implementing the llvm backend. I'll revive it today as I should have done quiet some time ago.

view this post on Zulip Ayaz Hafiz (Jan 04 2025 at 17:01):

:thinking: the bug occurs during typechecking though, i believe

view this post on Zulip Norbert Hajagos (Jan 05 2025 at 13:48):

Oh, yes... I thought you meant you couldn't write tail recursive functions using tag unions where the recursive data (op) was inside a struct, not directly in the payload of a tag union. Realized we're talking about a completely different problem.

view this post on Zulip shua (Jan 17 2025 at 18:28):

I haven't really used enzyme, but I wonder what the tradeoffs are between doing autograd on the roc source vs doing it on the llvm ir (or similar low-level IR)?

https://enzyme.mit.edu/


Last updated: Jul 06 2025 at 12:14 UTC