Pytorch-lightning: How do you use the hidden state from the previous training step in the next step?

Created on 2 Jan 2020  路  5Comments  路  Source: PyTorchLightning/pytorch-lightning

I want to train an LSTM Language Model using Lightning, but I need to pass the previous step's hidden state to the next training step. The raw training loop looks something like this:

model.train()
hidden = model.init_hidden(batch_size)
for data in train_iter:
  text, target = data.text, data.target
  model.zero_grad()
  output, hidden = model(text, hidden)
  hidden = detach_tensors(hidden)  # for truncated bptt
  loss = criterion(output.view(-1, vocab_size), target.view(-1))
  loss.backward()

  optimizer.step()

How can I achieve this with Lightning? Thanks!

question

All 5 comments

I think you can define a variable in the model initialization for saving the hidden state of the previous step. For example:
class model(LightningModule):
def __init__(self, args):
super().__init__()
self.previous_hidden = None # cache
def training_step():
output, hidden = model(text, self.previous_hidden)
self.previous_hidden = detach_tensors(hidden)

Correct!

Would it be fair to say that this advice has been superseded by the truncated_bptt_steps argument in Trainer?

Because this issue pops up as one of the first results for "pytorch lightning bptt" on Google, so a pointer to the updated method may be helpful to anyone stumbling upon this result like I did.

yes! 0.9.0 makes this very clear. check out the latest master docs

Thanks for the clarity!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

Vichoko picture Vichoko  路  3Comments

remisphere picture remisphere  路  3Comments

maxime-louis picture maxime-louis  路  3Comments

srush picture srush  路  3Comments

polars05 picture polars05  路  3Comments