Hi,
I've been trying to add Gumbel noise as in here, but no success.
Looks like several modules are still missing (e.g., basic variable functions such as nn.Uniform()), or am I wrong? How would you implement in pytorch e.g. the following rows:
``` -- Create noise 蔚 sample module
local noiseModule = nn.Sequential()
noiseModule:add(nn.Uniform(0, 1)) -- Sample from U(0, 1)
-- Transform uniform sample to Gumbel sample
noiseModule:add(nn.AddConstant(1e-9, true)) -- Improve numerical stability
noiseModule:add(nn.Log())
noiseModule:add(nn.MulConstant(-1, true))
noiseModule:add(nn.AddConstant(1e-9, true)) -- Improve numerical stability
noiseModule:add(nn.Log())
noiseModule:add(nn.MulConstant(-1, true))
-- Create sampler q(z) = G(z) = softmax((log(蟺) + 蔚)/蟿) (reparametrization trick)
local sampler = nn.Sequential()
local samplerInternal = nn.ConcatTable()
samplerInternal:add(nn.Identity()) -- Unnormalised log probabilities log(蟺)
samplerInternal:add(noiseModule) -- Create noise 蔚
sampler:add(samplerInternal)
sampler:add(nn.CAddTable())
self.temperature = nn.MulConstant(1 / self.tau, true) -- Temperature 蟿 for softmax
sampler:add(self.temperature)
sampler:add(nn.View(-1, self.k)) -- Resize to work over k
sampler:add(nn.SoftMax())
sampler:add(nn.View(-1, self.N * self.k)) -- Resize back
```
Here you go. Much more readable and no modules required:
import torch.nn.functional as F
from torch.autograd import Variable
def sampler(input, tau, temperature):
noise = torch.rand(input.size())
noise.add_(1e-9).log_().neg_()
noise.add_(1e-9).log_().neg_()
noise = Variable(noise)
x = (input + noise) / tau + temperature
x = F.softmax(x.view(input.size(0), -1))
return x.view_as(input)
We're using GitHub for bug reports only, if you have questions please post the on our forums.
Thanks for the prompt answer!, and sure, will write on the forum next time.
Thanks for the answer!
Most helpful comment
Here you go. Much more readable and no modules required:
We're using GitHub for bug reports only, if you have questions please post the on our forums.