Support sync batch norm as a plugin so we don't duplicate code across accelerators
It'll look very similar to this: https://github.com/PyTorchLightning/pytorch-lightning/pull/4285
import torch
from pytorch_lightning.core.lightning import LightningModule
class SyncBatchNormPlugin(object):
"""
Plugin to link a custom sync batchnorm implementation to any arbitrary accelerator.
Example::
class MySyncBN(SyncBatchNormPlugin):
def configure_sync_batchnorm(self, model):
model = MySyncBNWrapper(model)
return model
my_sync_bn = MySyncBN()
trainer = Trainer(sync_batchnorm=True, plugins=[my_sync_bn])
"""
def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
"""
By default: adds global batchnorm for a model spread across multiple GPUs and nodes.
Override this to synchronize batchnorm between specific process groups instead
of the whole world or use a different sync_bn like `apex`'s version.
Args:
model: pointer to current :class:`LightningModule`.
Return:
LightningModule with batchnorm layers synchronized between process groups
"""
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)
return model
Do we even need the sync_batchnorm flag on the Trainer with this plugin?
Do we even need the sync_batchnorm flag on the Trainer with this plugin?
This reminds me of the checkpoint_callback + callbacks list situation for ModelCheckpoint. The difference here is that there can be only one plugin of this type.
This reminds me of the checkpoint_callback + callbacks list situation for ModelCheckpoint. The difference here is that there can be only one plugin of this type.
If we have the flag, we can use sync_batchnorm=True and get the default implementation of the plugin without creating and passing it to the trainer.
If we drop the flag, then to use sync_batchnorm you must create and pass a plugin (of any type that extends SyncBatchNormPlugin), and lightning would provide the default implementation for you to create and pass