I want to train an LSTM Language Model using Lightning, but I need to pass the previous step's hidden state to the next training step. The raw training loop looks something like this:
model.train()
hidden = model.init_hidden(batch_size)
for data in train_iter:
text, target = data.text, data.target
model.zero_grad()
output, hidden = model(text, hidden)
hidden = detach_tensors(hidden) # for truncated bptt
loss = criterion(output.view(-1, vocab_size), target.view(-1))
loss.backward()
optimizer.step()
How can I achieve this with Lightning? Thanks!
I think you can define a variable in the model initialization for saving the hidden state of the previous step. For example:
class model(LightningModule):
def __init__(self, args):
super().__init__()
self.previous_hidden = None # cache
def training_step():
output, hidden = model(text, self.previous_hidden)
self.previous_hidden = detach_tensors(hidden)
Correct!
Would it be fair to say that this advice has been superseded by the truncated_bptt_steps argument in Trainer?
Because this issue pops up as one of the first results for "pytorch lightning bptt" on Google, so a pointer to the updated method may be helpful to anyone stumbling upon this result like I did.
yes! 0.9.0 makes this very clear. check out the latest master docs
Thanks for the clarity!