This is more a philosophical design suggestion than a feature request.
I think that the presentation of LightningModule as a torch.Module-plus-features encourages early experiment designs that don't refactor nicely as the projects using it grow.
I also think that calling self.forward directly is a torch anti-pattern, and should not be encouraged.
I'd like the official docs to suggest using a self.my_model = MyModel(...) in __init__ and y = self.my_model(x) in training_step etc.
I think that most non-research uses of lightning are going to require that the environment the model is trained in be separable from the model itself. This is most obvious when considering the infrastructure needed to load training data vs. production inference data; you're not going to want to drag along all of the libraries needed to connect to a database, decompress data, etc. in the production environment.
To do so, I'd need to be able to from some.other.package import MyModel and then self.my_model = MyModel(...) in __init__. As long as some.other.package doesn't have extra dependencies, I can ship my production model and weights to production without needing everything else that lightning, etc. depends on.
By suggesting that users have the lightning subclass _be_ the model, the set of packages that need to be present in production goes up quite a bit (speaking from experience, the pip version management becomes painful).
Another thing that this makes unclear, then, is what is actually happening when training_step gets called. The suggestion "Normally you'd call self.forward() from your training_step() method." implies that self.training_step is happening inside of a self.__call__ since torch.nn.Module.forward isn't supposed to be called directly (since it's __call__ that handles hooks, etc.), but that doesn't actually seem to be the case. Unless I'm missing something, this really feels like misuse of the torch API.
By making it clear that your LightningModule subclass should have an instance of your model as an attribute, not _be_ the model, all of the above gets cleared up quite a bit.
I think it's a lot cleaner and clearer to say "Normally you'd call y = self.my_model(x) from your training_step() method." and remove any suggestion of overriding self.forward() from the documentation (and I'd in fact make the default implementation of forward raise a YouAreDoingItWrongException).
As I said earlier, I think that projects that mix training and model code in the same class are going to have a difficult time refactoring things later on, and I think that the perceived simplicity early on is a mastery trap. Anyone familiar with PyTorch isn't going to have a problem defining a separate model class.
Note that I don't think there's anything preventing me from implementing models the way I think is proper right now, but I'm currently doing an investigation into if we can use lightning for more projects in the organization, and I'd really rather not having to try and educate users to ignore the docs and do it the self.my_model() way instead.
At the very least, changing the documentation to say "Normally you'd call y = self(x) from your training_step() method." makes sure that hooks, etc. get called as expected.
Now, I will fully admit that I haven't dug into lightning a ton yet, so it's possible that I'm missing something that will change my understanding/perception of things. If that's the case, I think it should be articulated more clearly.
Thanks for reading.
I recently started using Lightning for a project I have been working on and I needed to import the model from a seperate module like you stated @elistevens.
In my Lightning __init__ i just instantiate my external model and override the forward to return mymode(x). This works fine, however, I agree that it might be better to have the model as an attribute as opposed to Lightning being the model.
This would also help with more complicated research projects that involve multiple models (autoencoders or GANs, for example) and make things a lot more flexible and pythonic, "pytorchic."
@darwinkim I agree with @elistevens that it will be useful to be able to "extract" a more lightweight to ship to production. However, can you provide an example as to how this would help with more complicated research projects that involve multiple models?
There's a GAN example here which shows how you can cleanly incorporate multiple models.
I wonder if there's a way we could expose a jit_export functionality which JIT compiles the model and extracts out only that which is necessary for serving inference. The LightningModel would contain everything needed for training and it could export a minimal model for inference as a post-training artifact.
@jeremyjordan
https://arxiv.org/abs/1703.00848
Involves training six networks as two autoencoders and two GANs on two datasets
just a note: PR #1211 promotes the use of self(...) instead of self.forward in examples and docs.
@williamFalcon @PyTorchLightning/core-contributors ^^
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
@williamFalcon asked me to revisit this, so I'm adding some more thoughts. PR #1211 fixed the issue of suggesting that users call .forward() directly, but there's another layer to what I'm trying to suggest.
Essentially, I'm wanting to clearly and cleanly separate concerns, and have that clean separation be suggested by the documentation. From a OOP perspective, the documentation suggests that the training loop object and the model object _be the same object_ and that mixes two separate concepts.
Put another way, if you were going to be using a stock model from torchvision, you wouldn't have class MyModel(pl.LightningModule, ResNet): you'd have class MyModel(pl.LightningModule): def __init__(self): self.model = ResNet(). The training loop and the model would be separate python objects, and you could do things with the model such that you'd never know it was trained with lightning. For example, save/load the weights, or export it to onnx, etc.
How it's suggested now, it becomes much harder to pull out my model and use it in some other context (like a different training loop). I typically try to avoid libraries with that kind of lock-in.
Thanks for adding more details!
I use lightning a lot the way you describe. What gives you the impression that you can’t use it this way?
Is there a better way to show this in the docs or examples?
First, take a look at all the bolts models. Most models in bolts have that pattern.
Second, when you are done you can load the full thing and pull out whatever parts are interesting to you (ie: just the encoder of a GAN), or make the forward use only the encoder.
But yeah, you can always drop a model into a lightningModule and use the lightningmodule purely as training loops for the model.
https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#byol
Finally, we can make some example or write in the docs what you’re more clearly looking for if you’d prefer:
class ClassificationTask(LightningModule):
def __init__(self, model):
self.model = model
def training_step(...):
def validation_step(...):
def test_step(...):
def configure_optimizers(...):
model = Resnet50()
loops = ClassificationTask(model)
trainer = Trainer(...)
trainer.fit(model, train_dataloader, val_dataloader)
In fact, we can add a new section to bolts with these prebuilt loops. Classification loops, fine-tuning loop, etc...
Ok, added the following to the docs to clarify this particular use case of a lightning module.

Can we at least raise a NotImplementedError like PyTorch does? I only just now noticed that in the current version, LightningModule actually implements forward for you to return None. Why is that?
My expectation was that LightningModule behaves like a nn.Module outside the context of PL.
it does but forward is not required...
we want to separate training from inference. in training you use the __step methods.
if your model also happens to do inference, then it should implement forward.
this makes a clean separation between training scripts purely and models.
this removal also enables tasks which weren’t possible before.
All of this is clear. No problem with that. If you don't use forward all is good.
I suggest to raise NotImplementedError if you use self.forward anywhere, instead of just returning None.
class Lightning(LightningModule):
pass
class Torch(nn.Module):
pass
lightning_model = Lightning()
print(lightning_model(torch.rand(2, 2))) # does not raise, returns None, why?
torch_model = Torch()
print(torch_model(torch.rand(2, 2))) # raises NotImplementedError, good!