Jax: Idiomatic Weight Sharing in Jax

Created on 6 May 2019  路  2Comments  路  Source: google/jax

Hello,

I'm currently trying to implement a trellis network, an architecture that requires lots of weight sharing. I'm a little confused on how to implement this with Jax besides writing my own layer definition (I know #430 somewhat addresses this, but I'm wondering if a workaround exists). Thanks!

Most helpful comment

Thanks for the reply! Yeah, I closed the issue because I wanted to come back after trying to write a layer library first. 馃槃

All 2 comments

Thanks for raising this! We don't currently have an idiomatic way to do weight sharing in jax.experimental.stax, but as you might have already seen in #430 people are exploring how to do it. At the moment you can either adopt one of those solutions, or else just ignore jax.experimental.stax and figure out another solution, perhaps starting with the question, "How would I implement this on top of raw NumPy?"

Suggestions and advice welcome!

Thanks for the reply! Yeah, I closed the issue because I wanted to come back after trying to write a layer library first. 馃槃

Was this page helpful?
0 / 5 - 0 ratings

Related issues

rdaems picture rdaems  路  3Comments

alexbw picture alexbw  路  3Comments

RobertTLange picture RobertTLange  路  3Comments

lonelykid picture lonelykid  路  3Comments

harshit-2115 picture harshit-2115  路  3Comments