Hello,
I'm currently trying to implement a trellis network, an architecture that requires lots of weight sharing. I'm a little confused on how to implement this with Jax besides writing my own layer definition (I know #430 somewhat addresses this, but I'm wondering if a workaround exists). Thanks!
Thanks for raising this! We don't currently have an idiomatic way to do weight sharing in jax.experimental.stax, but as you might have already seen in #430 people are exploring how to do it. At the moment you can either adopt one of those solutions, or else just ignore jax.experimental.stax and figure out another solution, perhaps starting with the question, "How would I implement this on top of raw NumPy?"
Suggestions and advice welcome!
Thanks for the reply! Yeah, I closed the issue because I wanted to come back after trying to write a layer library first. 馃槃
Most helpful comment
Thanks for the reply! Yeah, I closed the issue because I wanted to come back after trying to write a layer library first. 馃槃