I'd like to use Lightning to do the training of a PyTorch transformer model. So I wrap the transformer model in a LightningModule. Before training, the model is initialized from a pre-trained model by use of hparams. I use default saving with Lightning. However, the transformer itself has a very specific save/load method. Ideally we would like the transformer model to save using its own method instead of relying on Lightning. My questions are:
import transformers
class Transformer(LightningModule):
def __init__(self, hparams):
...
# Initialize the pytorch model (dependent on an external pre-trained model)
self.transformer = transformers.from_pretrained(params.transformer_name)
# note: self.transformer has a method save_pretrained to save it in a directory so ideally we would like it to be saved with its own method instead of default one provided by Lightning
model = Transformer(hparams)
trainer = Trainer(..)
trainer.fit(model)
I tried the default saving with Lightning but not sure if this is correct.
Hey I think this is the only way I found out too.
https://pytorch-lightning.readthedocs.io/en/latest/weights_loading.html
I would use save/load just weights, or any other suggestion @PyTorchLightning/core-contributors?
@Borda , thanks for your suggestion. I am aware of this save/load weights. The issue is that the model inside the LightningModule has a customized save/load function: from_pretranied and save_pretrained that takes care of save/load of weights. Ideally, I am looking for something like:
When Lightning is auto save LightningModule to a checkpoint location:
When Lightning is initialize the model from a checkpoint location
In this way, the save/load is passed on from LightningModule(self) to its member (self.module).
Could core develops chime in and offer recommendation? Thanks a lot!
call self.model.save_pretrained(the checkpoint location)
call self.model.from_pretrained(the checkpoint location)
why do you need these methods if you have trained and saved the model and want to load the saved checkpoint?
@rohitgr7 The reason I need to use these methods above is that the LightningModule (self) is a wrapper of the pretrained transformer model(self.model). In addition, the self.model has to be loaded from an external location using its method from_pretrained and saved using its own method save_pretrained. I have copied some code above for your reference:
import transformers
class Transformer(LightningModule):
def __init__(self, hparams):
...
# Initialize the pytorch model (dependent on an external pre-trained model)
self.transformer = transformers.from_pretrained(hparams.transformer_name)
# note: self.transformer has a method save_pretrained to save it in a directory so ideally we would like it to be saved with its own method instead of default one provided by Lightning
Thanks and looking forward to your comments/suggestions!
Ok, from_pretrained is ok since you want to load pretrained weights but save_pretrained doesn't do anything special apart from saving weights which lightning also does.
@rohitgr7 thanks for your comments.
I reloaded the Lightning model saved automatically and found the weighs inside the self.model changed, which led me to think there might be something inside save_pretrained that might be different from just saving the weights. Therefore I was hoping that I could customize the saving by myself to respect he save_pretrained.
This could also be an issue of from_pretrained inside the __init__ as in the first run, the weights were loaded from the external data. When there is a need to reload the model from the Lightning, I was not sure which one happens first, the loading from checkpoint file or the loading from pre_trained?
Really appreciate your insight!!!
save_pretrained is not doing anything different (source code). In your current use-case, when you load the checkpoint using it first loads from from_pretrained and then loads the checkpoint weights. So in-short weights loaded from from_pretrained is basically getting replaced by the weights from the checkpoint. When you reloaded the LightningModel checkpoint they will definitely be different from the pre-trained weights if you trained the backbone or model in an un-freezed state.
@rohitgr7 @awaelchli Thanks for looking into the source code and you are right in this case, there is nothing fancy about the from_pretrained in terms of saving the weights. However, for some Transformer models, for example, Albert Transformer, the weights are shared across many layers, and the load/saving functions from Albert Transformer takes advantage of that and the model saved is much smaller than using general save from Pytorch or the checkpoint from Lightning (roughly 14MB vs 140MB).
This goes back to my point, we might want to respect the customized saving/loading function a model inside a Lightning and provide a way for the model to call its own saving/loading function when Lightning checkpoints.
I'd like to hear your thoughts. Thanks.
@junwen-austin which Albert version are you referring here?? Also which optimizer are you using??
If you are using Adam or some other adaptive optimizer that stores the state of gradient or squared-gradient for adaptive optimization then ModelCheckpoint will save that optimizer state too by default that's why the checkpoint size is bigger than the original model size. You can use the save_weights_only parameter in ModelCheckpoint to disable it and save only weights.
@rohitgr7 the base of Albert, which is about 14MB
I am using AdamW for the optimizer. Is it possible that the additional saving Lightning does (optimizer state and others) becomes as large as more than 100MB? Thanks.
I checked Albert albert-base-v1 and it's around 45mb.
becomes as large as more than 100MB
yes, the state in Adam optimizer saves both grad and squared-grad for all the parameters, so yeah it can be.
@rohitgr7 You are right. Sorry about that. For this case, then it is fine we have 100MB of additional states.
I really do not want to bug you further but do you have a general suggestion as to how to save a LightningModule that wraps a Pytorch model that has some customized saving function so that the Pytorch model has a chance to call it? Thanks.
not sure why you want that since both are exactly doing the same thing. But still one solution I can suggest is to override Trainer.save_checkpoint.
class SubTrainer(Trainer):
def save_checkpoint(self, filepath, weights_only=False):
if self.is_global_zero:
dirpath = os.path.split(filepath)[0]
lightningmodel = self.get_model()
lightningmodel.transformer.save_pretrained(dirpath)
model = Transformer()
trainer = SubTrainer(**trainer_params) # with checkpoint callback
trainer.fit(model)
if you don't want to use ModelCheckpoint and want to save it like after every epoch or something can you can do the same in training_epoch_end or validation_epoch_end as per your requirement.
@junwen-austin you're trying to save the best model at the end of the training using the method from Hugging Face transformers correct? I encountered a similar problem at work, and my suggestion is to avoid mixing PL's serialization and Hugging Face's serialization. You're just asking for more trouble. More concretely, I suggest you do something like the following:
import transformers
class Transformer(LightningModule):
def __init__(self, hparams):
...
# Initialize the pytorch model (dependent on an external pre-trained model)
self.transformer = transformers.from_pretrained(params.transformer_name)
# note: self.transformer has a method save_pretrained to save it in a directory so ideally we would like it to be saved with its own method instead of default one provided by Lightning
model = Transformer(hparams)
trainer = Trainer(..)
trainer.fit(model)
for i, (path, _) in enumerate(trainer.checkpoint_callback.best_k_models.items()):
m = Transformer.load_from_checkpoint(path)
m.transformer.save_pretrained(f'{i}th_best.pt')
Yes, you're not calling save_pretrained() during training automatically, but you can always load your lightning module from a vanilla PL checkpoint, then call save_pretrained() directly. This is more explicit and easier to understand rather than trying to be smart with callbacks.
@rohitgr7 @yukw777 Thanks for the suggestions! Very helpful!
Yes, I was to save the best model at the end of the training. What you suggested is also a very good way of extracting the transformer model from Lightning checkpoints.
Most helpful comment
@Borda , thanks for your suggestion. I am aware of this save/load weights. The issue is that the model inside the LightningModule has a customized save/load function: from_pretranied and save_pretrained that takes care of save/load of weights. Ideally, I am looking for something like:
When Lightning is auto save LightningModule to a checkpoint location:
When Lightning is initialize the model from a checkpoint location
In this way, the save/load is passed on from LightningModule(self) to its member (self.module).
Could core develops chime in and offer recommendation? Thanks a lot!