When using classic SGD optimizer with momentum with sparse embeddings the memory keeps garbage collecting / allocating leading to slow down and out of memory error eventually. Here is a minimal exemple to reproduce the issue
The issue dissapears when momentum is not used
or when embeddings are not sparse
I'm using the last pytorch version on conda: '0.2.0_4'
I tried out your script with momentum 0.1 on master, it takes roughly 10800mb gpu memory max. This is caused by using sparse buffer. I'm sending out a PR for this.
Most helpful comment
I tried out your script with momentum 0.1 on master, it takes roughly 10800mb gpu memory max. This is caused by using sparse buffer. I'm sending out a PR for this.