Transformers: Gradient checkpointing with GPT2DoubleHeadsModel

Created on 24 Jan 2020  ยท  4Comments  ยท  Source: huggingface/transformers

โ“ Questions & Help

I've been trying to fine-tune GPT2DoubleHeadsModel using gpt2-large and gpt2-xl on the Topical-Chat dataset.

I'm finding that loading even a single example into memory is difficult with the larger versions of GPT-2. I found this Medium post by @thomwolf which suggests that gradient checkpointing would be effective at handling this situation.

Is there a gradient-checkpointed version of the code in GPT2DoubleHeadsModel or the underlying GPT2Model that could be used as-is? I'm trying to do this myself by editing modeling_gpt2.py, but I'm facing issues.

https://github.com/huggingface/transformers/blob/babd41e7fa07bdd764f8fe91c33469046ab7dbd1/src/transformers/modeling_gpt2.py#L478-L480

Specifically, I added a checkpoint in the above line like this:

outputs = checkpoint(block, hidden_states, layer_past, attention_mask, head_mask[i])

NOTE: I had to remove the key names since it looks like checkpoint does not support key-value arguments, only positional. This might lead to compatibility issues, I'd love to know thoughts on this as well.

This is using the official PyTorch checkpoint. I'm also considering trying this other implementation for checkpoint since I read somewhere that it is supposed to be faster than the official implementation.

With the official PyTorch implementation, I'm getting the following error:
CheckpointFunctionBackward.forward: expected Variable (got list) for return value 0.

This thread on the PyTorch forums seems to suggest that this error arises when attempting to use torch.utils.checkpoint with modules that return a variable number of tensors, which is the case with Block within GPT2Model.

Could @thomwolf, @LysandreJik or anyone else in the Hugging Face team please help with this? Thanks!

wontfix

Most helpful comment

I am using GPT2Model and would also find this very useful.

All 4 comments

I think I figured this out, it looks like I'll have to change the outputs returned by Block to be tuples instead of lists:

https://github.com/huggingface/transformers/blob/babd41e7fa07bdd764f8fe91c33469046ab7dbd1/src/transformers/modeling_gpt2.py#L238

i.e., change the above to return tuple(outputs) for checkpointing of the blocks inside GPT2Model to work.

@thomwolf @LysandreJik Would this explicit type-casting of the outputs to tuple lead to any unexpected, downstream effects? If not, I think this update should be reflected in the repo as well, given that the README says that every model's forward() method always outputs a tuple.

I am also finding that checkpointing the blocks doesn't seem to help fit a single example into memory with gpt2-xl. A check-pointed version of these classes would be really helpful!

Bumping this, I'm training a TensorFlow ALBERT model and with long sequence lengths (512) it's tough to get a large enough batch size - currently I'm constrained to 8 or 16 per GPU. Adding automatic gradient checkpointing support for tf.recompute_grad() would be a godsend :)

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

I am using GPT2Model and would also find this very useful.

Was this page helpful?
0 / 5 - 0 ratings