Hi ! Well the feature is quite simple, it is to have a simple way to code batch level transforms without having to put them directly in training_step.
Motivation comes from the fact that in a project of mine, I want to change the color space of my input images directly on the batch rather than on individual images so that I gain some time. Problem is I didn't find a way to automate that using lightning. When going through code and documentation, I even stumbled upon something that confused me a lot : in the documentation for LightningDataModule, there are numerous examples of places where transforms are passed directly to a pytorch DataLoader constructor. However we can't do that, there is no transform argument in pytorch DataLoader. I looked for an override somewhere in lightning's source code, and tried to find if there is somewhere in the code where the batch could get modified based on transforms but I couldn't find any. So please if I missed something just tell me and I guess this issue can be closed. Else I'm pretty sure this feature could come in handy for others than me.
The idea is to allow the user to specify batch transforms in a LightningDataModule or somewhere in the trainer and have them be called automatically after the dataloader is iterated through. The idea for me is to having it be called right after the batch is transferred to GPU in training_forward. I'm pretty sure further details need to be discussed as the implementation will depend on the coding philosophy of the library.
For now the alternative I use is to manually insert he batch transform in training_step. This could also be automated by adding a hook that allows batch modification right before training_step. It would allow for a broader use and could be what is used behind the scene fore a higher level batch transform feature. Again, the details probably need to be discussed.
I'll add a screenshot of the confusing part of the documentation I was refering to earlier.

I'll also add that I am willing to work on this feature myself once details are set.
I'll add a screenshot of the confusing part of the documentation I was refering to earlier.
I think that is a bug. Transforms are passed to the dataset.
You could pass the transforms to your LightningModule or get the LightningDataModule transforms via self.trainer.datamodule.train_transforms.
This could also be automated by adding a hook that allows batch modification right before training_step. It would allow for a broader use and could be what is used behind the scene fore a higher level batch transform feature.
I do think that Lightning could call prepare_batch(batch) function before calling *_step
def prepare_batch(self, batch):
x, y = do_something_with_batch(batch)
return x, y
def training_step(self, batch):
x, y = self.prepare_batch(batch)
...
def prepare_batch(self, batch):
# super().prepare_batch(batch) is the identity
x, y = do_something_with_batch(batch)
return x, y
def training_step(self, batch):
# lightning has already called prepare_batch()
x, y = batch
...
@nateraw can you fix this documentation error?
The stuff w/ transforms being passed to dataloader is definitely wrong...I'll update docs!
As for batch transforms...I think philosophically that feature is already covered by torch Dataset object, right? Datamodule is meant to extend it. I think a great way to solve your issue is to...
implement a custom pytorch dataset that does the batch transforms
implement a datamodule that initializes that dataset object for each split (passing transforms to your dataset's init that you can tether to the DM itself for reproducibility/modularity purposes)
suggested something here #3399.
@carmocca That is quite what I had in mind, this kind of hook is probably the easiest way to go with it.
@nateraw Thanks for the doc fix ! Also, isn't the dataset purpose to handle items individually and not on the batch level ? The only actual solution I found for batch level transforms using pytorch API only is to use a custome collate function in the dataloader, but I don't find it very handy, therefore this suggestion.
@rohitgr7 Indeed, didn't see your post but this is exactly what I had in mind. Separating transforms before and after being passed to device is also quite a good idea, didn't think of it. This is close to the solution @carmocca proposed. Basically anything along this line is fine to me.
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!