Pytorch-lightning: Support sync batch norm as a plugin

Created on 25 Oct 2020  ·  2Comments  ·  Source: PyTorchLightning/pytorch-lightning

🚀 Feature

Support sync batch norm as a plugin so we don't duplicate code across accelerators

Motivation

It'll look very similar to this: https://github.com/PyTorchLightning/pytorch-lightning/pull/4285

Pitch

import torch
from pytorch_lightning.core.lightning import LightningModule


class SyncBatchNormPlugin(object):
    """
    Plugin to link a custom sync batchnorm implementation to any arbitrary accelerator.
    Example::
        class MySyncBN(SyncBatchNormPlugin):

            def configure_sync_batchnorm(self, model):
                model = MySyncBNWrapper(model)
                return model
        my_sync_bn = MySyncBN()
        trainer = Trainer(sync_batchnorm=True, plugins=[my_sync_bn])
    """

    def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
        """
        By default: adds global batchnorm for a model spread across multiple GPUs and nodes.
        Override this to synchronize batchnorm between specific process groups instead
        of the whole world or use a different sync_bn like `apex`'s version.
        Args:
            model: pointer to current :class:`LightningModule`.
        Return:
            LightningModule with batchnorm layers synchronized between process groups
        """
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)

        return model

Do we even need the sync_batchnorm flag on the Trainer with this plugin?

enhancement help wanted

All 2 comments

Do we even need the sync_batchnorm flag on the Trainer with this plugin?

This reminds me of the checkpoint_callback + callbacks list situation for ModelCheckpoint. The difference here is that there can be only one plugin of this type.

This reminds me of the checkpoint_callback + callbacks list situation for ModelCheckpoint. The difference here is that there can be only one plugin of this type.

If we have the flag, we can use sync_batchnorm=True and get the default implementation of the plugin without creating and passing it to the trainer.

If we drop the flag, then to use sync_batchnorm you must create and pass a plugin (of any type that extends SyncBatchNormPlugin), and lightning would provide the default implementation for you to create and pass

Was this page helpful?
0 / 5 - 0 ratings