Pytorch-lightning: How do you use the hidden state from the previous training step in the next step?

Created on 2 Jan 2020  ·  5Comments  ·  Source: PyTorchLightning/pytorch-lightning

I want to train an LSTM Language Model using Lightning, but I need to pass the previous step's hidden state to the next training step. The raw training loop looks something like this:

model.train()
hidden = model.init_hidden(batch_size)
for data in train_iter:
  text, target = data.text, data.target
  model.zero_grad()
  output, hidden = model(text, hidden)
  hidden = detach_tensors(hidden)  # for truncated bptt
  loss = criterion(output.view(-1, vocab_size), target.view(-1))
  loss.backward()

  optimizer.step()

How can I achieve this with Lightning? Thanks!

question

All 5 comments

I think you can define a variable in the model initialization for saving the hidden state of the previous step. For example:
class model(LightningModule):
def __init__(self, args):
super().__init__()
self.previous_hidden = None # cache
def training_step():
output, hidden = model(text, self.previous_hidden)
self.previous_hidden = detach_tensors(hidden)

Correct!

Would it be fair to say that this advice has been superseded by the truncated_bptt_steps argument in Trainer?

Because this issue pops up as one of the first results for "pytorch lightning bptt" on Google, so a pointer to the updated method may be helpful to anyone stumbling upon this result like I did.

yes! 0.9.0 makes this very clear. check out the latest master docs

Thanks for the clarity!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

iakremnev picture iakremnev  ·  3Comments

baeseongsu picture baeseongsu  ·  3Comments

DavidRuhe picture DavidRuhe  ·  3Comments

as754770178 picture as754770178  ·  3Comments

monney picture monney  ·  3Comments