Pytorch-lightning: Custom training (without back propagation)

Created on 28 Nov 2019  Â·  10Comments  Â·  Source: PyTorchLightning/pytorch-lightning

What is your question?

I have a custom nn.module of Gaussian Mixture Models Classifier based on _sklearn's_, which is trained by E-M algorithm instead of backpropagation.

As I'm trying to use this framework to implement training, evaluations and checkpointing. I've come to the problem that I can't understand how to override the optimizer and back-propagation behavior.

As this custom classifier trains with it .fit method. I want to call this method in each train step, instead of doing a forward pass and getting the loss value and passing it to the optimizer.

Code

class GMMClassifier(nn.Module):
    def __init__(self, num_classes):
        super(GMMClassifier, self).__init__()
        n_features = 128
        self.gmm_list = []
        for _ in num_classes:
            # one gmm per singer as stated in Tsai; Fujihara; Mesaros et. al works on SID
            self.gmm_list.append(GaussianMixture(n_components=64, n_features=n_features))

    def forward_score(self, x):
        """
        :param x: MFCC of a track with shape (samples, frames, coefficients, )
        :return: The Log Likelihood for each track and frame tested on every GMM (one per singer / class) as:
            log likelihood = log p(X_t / P_i)]
            where t is time frame and
                i is the singer GMM
            with shape: (sample, frame, gmm_prediction)
        """
        n_samples = x.size(0)
        # asume that all the samples has equal frame number
        n_frames = x.size(1)
        n_features = x.size(2)
        score_tensors_per_gmm = []
        # print('info: feeding gmms...')
        for gmm in self.gmm_list:
            # predict each frame for each sampple
            # optimization: flatten (samples, frames, features) to (samples*frames, features)
            x = x.reshape(-1, n_features)
            log_prob = gmm.score(x)  # output shape is (samples*frames, )
            log_prob = torch.unsqueeze(log_prob, dim=1)  # output shape is (samples*frames, 1) for stacking
            score_tensors_per_gmm.append(
                log_prob
            )
        score_tensors_per_gmm = torch.stack(score_tensors_per_gmm,
                                            dim=1)  # output tensor shape is ((samples*frames, gmm)
        y = score_tensors_per_gmm.reshape(
            n_samples,
            n_frames,
            len(self.gmm_list)
        )  # reshape to (sample, frame, gmm)
        return y

    def forward(self, x):
        """
        :param x: MFCC of a track with shape (samples, frames, coefficients, )
        :return: The prediction for each track calculated as:
            Singer_id = arg max_i [1/T sum^T_t=1 [log p(X_t / P_i)]]
            where t is time frame and
                i is the singer GMM
            with shape: (n_samples, )
        """
        x = self.forward_score(x)  # shape is (sample, frame, gmm)
        x = x.sum(dim=1)  # shape is (sample, gmm)
        x = x.argmax(dim=1)  # shape is (sample, )
        # sum the scores over the
        return x
class L_GMMClassifier(ptl.LightningModule):
    """
    Sample model to show how to define a template
    """

    def __init__(self, hparams, num_classes, train_dataset, eval_dataset, test_dataset):
        super(L_GMMClassifier, self).__init__()
        self.hparams = hparams
        # self.loss = torch.nn.CrossEntropyLoss()
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.test_dataset = test_dataset
        # build model
        self.model = GMMClassifier(num_classes)
        # self.optimizer = torch.optim.Adam(self.model.parameters(), lr=hparams.learning_rate)

    # ---------------------
    # TRAINING
    # ---------------------
    def forward(self, x):
        """
        No special modification required for lightning, define as you normally would
        :param x:
        :return:
        """
        return self.model(x)

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop
        :param batch:
        :return:
        """
        # forward pass
        x, y = batch['x'], batch['y']
        for gmm_idx, gmm in enumerate(self.model.gmm_list):
            gmm.fit()

question tutorial / example won't fix

All 10 comments

What happens if you return a dummy optimizer and override all the hooks that do gradient manipulation (basically make them do nothing)?

I've done exactly that with a dummy optimizer. It works. It's quite a bit of shimming but it will achieve what you're going for.

Unable to share our actual code, but it will be something like

```
def training_step(...):
tqdm = {
"batch_nb": batch_nb,
"mode": "train",
"optimizer_idx": optimizer_idx,
"global_step": self.global_step,
}'

    gmm.fit()

    # This takes place of your real loss 
    shim = torch.FloatTensor([0.0])
    shim.requires_grad = True
    return {"loss": shim, "progress_bar": tqdm}

````

@Vichoko did you manage to find a workaround? I'm interested.

Thanks for the help.
I'll update this issue as soon as I manage to get the model working, training and evaluating on my data.

I managed to train (fit) and evaluate the sklearn's GMM instances by using a dummy optimizer (i just used a torch.optim.Adam with an empty tensor) and overriding the backward method.

def backward(self, use_amp, loss, optimizer):
    return

The thing is now that I want to save (and load) the model after training one epoch. It has been impossible to save the model neither restore it.

I've tried to make a checkpoint with the following setup:

     checkpoint_callback = ModelCheckpoint(
            filepath=save_dir,
            save_best_only=True,
            verbose=True,
            monitor='val_acc',
            mode='max',
            prefix=''
        )
        logger = ptl.logging.TestTubeLogger(
            save_dir=save_dir,
            version=1  # An existing version with a saved checkpoint
        )
        self.trainer = ptl.Trainer(
            gpus=hyperparams.gpus,
            distributed_backend=hyperparams.distributed_backend,
            logger=logger,
            checkpoint_callback=checkpoint_callback,
            default_save_path=save_dir,
            early_stop_callback=None,
            max_nb_epochs=1  # 2
        )

The logger do not save any checkpoint, just metrics, and media. But the checkpoint made a _ckpt_epoch_1.ckpt file that I'm not sure how to restore.

Edit: Just opened this .ckpt file and it contains a dictionary without GMM important parameters.
Is it possible to save the checkpoints as pickles?
Any advise overriding the checkpointing behaviour could be useful too.

@Vichoko subclass checkpoint callback and modify the behavior you want?

would love a tutorial on this for people wanting to do this with lightning

Still working on it but I will happily share the process when I get it working nice and clean ☺.

I've checked the pytorch_lightning.callbacks.ModelCheckpoint callback and overriding it can be the solution to the saving of the model.

@williamFalcon Regarding the model loading behavior, an I get some guidance on this issue?
As ModelCheckpoint only involves saving/dumping the model after an epoch. I couldn't find where is the _load_ logic in the framework.

I mean, a way to automatically load the best checkpoint when initializing the trainer. Useful to continue training after stopping or run tests overt the best model without explicitly load the model with some boilerplate code.

@Vichoko still working in this issue?

we do not support loading the best automatically...

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Was this page helpful?
0 / 5 - 0 ratings