Pytorch-lightning: Tensorboard logging by epoch instead of by step

Created on 8 Jun 2020  ·  12Comments  ·  Source: PyTorchLightning/pytorch-lightning

Short question concerning the tensorboard logging:

I am using it like this:

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        tensorboard_logs = {'train/loss': avg_loss}
        for name in self.metrics:
            tensorboard_logs['train/{}'.format(name)] = torch.stack([x['metr'][name] for x in outputs]).mean()

        return {'loss': avg_loss, 'log': tensorboard_logs}

It works very well, but in the plots (the x-axis) is the step, so each batch is a step. Is it possible to have the x-axis be the epoch as I want to plot the metrics only per epoch?

Logger question

Most helpful comment

I solved this using this code. Note that, in my case, I may run my code with only one or multiple validation datasets:

def validation_epoch_end(self, outputs):
    if not self.trainer.running_sanity_check:
        for dataset_result in outputs:
            # in case of using only one validation dataset
            if type(dataset_result) is dict:
                for key, val in dataset_result.items():
                    if is_tensor(val):
                        dataset_result[key] = val.cpu().detach()

                self.logger.agg_and_log_metrics(
                    dataset_result, step=self.current_epoch)

            # in case of using multiple validation datasets
            else:  # if type(dataset_result) is list:
                for metrics in dataset_result:
                    for key, val in metrics.items():
                        if is_tensor(val):
                            metrics[key] = val.cpu().detach()

                    self.logger.agg_and_log_metrics(
                        metrics, step=self.current_epoch)

All 12 comments

use a different logger.
TensorBoard can't do this stuff.
For example, WandB

Thank you for your answer, but I don't think that TensorBoard can't do it.

I am currently migrating from Torchbearer and there I am also using Tensorboard and it shows the epoch on the x-axis instead of a step, so it is possible.

To clarify: I am referring just to changing the x-axis numbers from steps to epochs, so I don't have to calculate in which epochs something happened.
Auswahl_178

Yes, because they are using the global step as the epoch, right?

I guess you could try this:

  1. return log: {'global_step': self.current_epoch} and nothing else for logging in your training_step
  2. return your metrics in training_epoch_end for logging

(have not tried)

Thank you, that was a good hint. I debugged it now.

It's possible to pass in the 'step' which will be used as the current_epoch like this:

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        tensorboard_logs = {'train/loss': avg_loss}
        for name in self.metrics:
            tensorboard_logs['train/{}'.format(name)] = torch.stack([x['metr'][name] for x in outputs]).mean()
        tensorboard_logs['step'] = self.current_epoch

        return {'loss': avg_loss, 'log': tensorboard_logs}

ok, great. did you also try to log additionally in the training_step method? Because I think that would not work. Just saying in case someone else finds this issue later on.

No, I am only logging in the end of epoch method and didn't try otherwise.

Hi,

It is possible to track both the steps and epochs using tensorboard. Here is an example. It is quite straightforwd.

    def training_step(self, batch, batch_idx):
        batch, y = batch
        y_hat = self(batch)

        labels_hat = torch.argmax(y_hat, dim=1)
        n_correct_pred = torch.sum(y == labels_hat).item()

        loss = F.cross_entropy(y_hat, y.long())
        tensorboard_logs = {'train_acc_step': n_correct_pred, 'train_loss_step': loss}

        return {'loss': loss, "n_correct_pred": n_correct_pred, "n_pred": len(y), 'log': tensorboard_logs}

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()

        train_acc = sum([x['n_correct_pred'] for x in outputs]) / sum(x['n_pred'] for x in outputs)
        tensorboard_logs = {'train_acc': train_acc, 'train_loss': avg_loss, 'step': self.current_epoch}

        return {'loss': avg_loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        batch, y = batch
        y_hat = self(batch)

        loss = F.cross_entropy(y_hat, y.long())
        labels_hat = torch.argmax(y_hat, dim=1)
        n_correct_pred = torch.sum(y == labels_hat).item()

        return {'val_loss': loss, "n_correct_pred": n_correct_pred, "n_pred": len(y)}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        val_acc = sum([x['n_correct_pred'] for x in outputs]) / sum(x['n_pred'] for x in outputs)
        tensorboard_logs = {'val_loss': avg_loss, 'val_acc': val_acc, 'step': self.current_epoch}

        return {'log': tensorboard_logs}

    def test_step(self, batch, batch_idx):
        batch, y = batch
        y_hat = self(batch)

        loss = F.cross_entropy(y_hat, y.long())
        labels_hat = torch.argmax(y_hat, dim=1)
        n_correct_pred = torch.sum(y == labels_hat).item()

        return {'test_loss': loss, "n_correct_pred": n_correct_pred, "n_pred": len(y)}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        test_acc = sum([x['n_correct_pred'] for x in outputs]) / sum(x['n_pred'] for x in outputs)
        tensorboard_logs = {'test_loss': avg_loss, 'test_acc': test_acc, 'step': self.current_epoch}

        return {'log': tensorboard_logs}

Here is what it looks like.

image

Cheers!

I did not know! Thanks @adeboissiere. So in summary we can provide the step with
{"step": self.current_epoch} or whatever we want.

Does the solution still work with the latest version 0.10.0?

I solved this using this code. Note that, in my case, I may run my code with only one or multiple validation datasets:

def validation_epoch_end(self, outputs):
    if not self.trainer.running_sanity_check:
        for dataset_result in outputs:
            # in case of using only one validation dataset
            if type(dataset_result) is dict:
                for key, val in dataset_result.items():
                    if is_tensor(val):
                        dataset_result[key] = val.cpu().detach()

                self.logger.agg_and_log_metrics(
                    dataset_result, step=self.current_epoch)

            # in case of using multiple validation datasets
            else:  # if type(dataset_result) is list:
                for metrics in dataset_result:
                    for key, val in metrics.items():
                        if is_tensor(val):
                            metrics[key] = val.cpu().detach()

                    self.logger.agg_and_log_metrics(
                        metrics, step=self.current_epoch)
Was this page helpful?
0 / 5 - 0 ratings

Related issues

williamFalcon picture williamFalcon  ·  3Comments

Vichoko picture Vichoko  ·  3Comments

chuong98 picture chuong98  ·  3Comments

as754770178 picture as754770178  ·  3Comments

awaelchli picture awaelchli  ·  3Comments