Pytorch-lightning: Better/correct test for callback

Created on 10 Oct 2020  路  12Comments  路  Source: PyTorchLightning/pytorch-lightning

馃殌 Feature

Reopen #4009 and let it merge...

Motivation

The actual test with bools is weak, do not check call order

enhancement tests / CI

Most helpful comment

from unittest.mock import MagicMock, call, ANY
from pytorch_lightning import Trainer, LightningModule
from tests.base import EvalModelTemplate
from unittest import mock


@mock.patch("torch.save")  # need to mock torch.save or we get pickle error
def test_callback_system(torch_save):
    model = EvalModelTemplate()
    # pretend to be a callback, record all calls
    callback = MagicMock()
    trainer = Trainer(callbacks=[callback], max_steps=1, num_sanity_val_steps=0)
    trainer.fit(model)

    # check if a method was called exactly once
    callback.on_fit_start.assert_called_once()

    # check how many times a method was called
    assert callback.on_train_batch_end.call_count == 1

    # check that a method was NEVER called
    callback.on_keyboard_interrupt.assert_not_called()

    # check with what a method was called
    callback.on_fit_end.assert_called_with(trainer, model)

    # check exact call order
    callback.assert_has_calls([
        call.on_init_start(trainer),
        call.on_init_end(trainer),
        call.setup(trainer, None, "fit"),
        call.on_fit_start(trainer, model),
        call.on_pretrain_routine_start(trainer, model),
        call.on_pretrain_routine_end(trainer, model),
        call.on_train_start(trainer, model),
        call.on_epoch_start(trainer, model),
        call.on_train_epoch_start(trainer, model),
        # BATCH 0
        call.on_batch_start(trainer, model),
        # here we don't care about exact values in batch, so we say ANY
        call.on_train_batch_start(trainer, model, ANY, 0, 0),
        call.on_batch_end(trainer, model),
        # here we don't care about exact values in batch, so we say ANY
        call.on_train_batch_end(trainer, model, [], ANY, 0, 0),
        call.on_epoch_end(trainer, model),
        call.on_train_epoch_end(trainer, model, ANY),
        call.on_save_checkpoint(trainer, model),
        call.on_save_checkpoint().__bool__(),   # what's this lol?
        call.on_train_end(trainer, model),
        call.on_fit_end(trainer, model),
        call.teardown(trainer, model, "fit"),
    ])

Here is a simple example of how to track calls with unittest.mock.
It is very elegant, easy to understand and allows you to check that methods were called with the expected arguments and in the exact order.

Please consider testing the callbacks this way.
The same could be applied to the model hooks (#4010). It is very straight forward and also easier to read than the old test.

All 12 comments

There are some hooks in both ModelHooks and Callback that are redundant I think.
on_epoch_start, on_epoch_end, on_batch_start, on_batch_end.

These already have their alternatives with on_train_* prefix.

Should we remove them?

from unittest.mock import MagicMock, call, ANY
from pytorch_lightning import Trainer, LightningModule
from tests.base import EvalModelTemplate
from unittest import mock


@mock.patch("torch.save")  # need to mock torch.save or we get pickle error
def test_callback_system(torch_save):
    model = EvalModelTemplate()
    # pretend to be a callback, record all calls
    callback = MagicMock()
    trainer = Trainer(callbacks=[callback], max_steps=1, num_sanity_val_steps=0)
    trainer.fit(model)

    # check if a method was called exactly once
    callback.on_fit_start.assert_called_once()

    # check how many times a method was called
    assert callback.on_train_batch_end.call_count == 1

    # check that a method was NEVER called
    callback.on_keyboard_interrupt.assert_not_called()

    # check with what a method was called
    callback.on_fit_end.assert_called_with(trainer, model)

    # check exact call order
    callback.assert_has_calls([
        call.on_init_start(trainer),
        call.on_init_end(trainer),
        call.setup(trainer, None, "fit"),
        call.on_fit_start(trainer, model),
        call.on_pretrain_routine_start(trainer, model),
        call.on_pretrain_routine_end(trainer, model),
        call.on_train_start(trainer, model),
        call.on_epoch_start(trainer, model),
        call.on_train_epoch_start(trainer, model),
        # BATCH 0
        call.on_batch_start(trainer, model),
        # here we don't care about exact values in batch, so we say ANY
        call.on_train_batch_start(trainer, model, ANY, 0, 0),
        call.on_batch_end(trainer, model),
        # here we don't care about exact values in batch, so we say ANY
        call.on_train_batch_end(trainer, model, [], ANY, 0, 0),
        call.on_epoch_end(trainer, model),
        call.on_train_epoch_end(trainer, model, ANY),
        call.on_save_checkpoint(trainer, model),
        call.on_save_checkpoint().__bool__(),   # what's this lol?
        call.on_train_end(trainer, model),
        call.on_fit_end(trainer, model),
        call.teardown(trainer, model, "fit"),
    ])

Here is a simple example of how to track calls with unittest.mock.
It is very elegant, easy to understand and allows you to check that methods were called with the expected arguments and in the exact order.

Please consider testing the callbacks this way.
The same could be applied to the model hooks (#4010). It is very straight forward and also easier to read than the old test.

@awaelchli you working on this?
Also, should we deprecate/remove https://github.com/PyTorchLightning/pytorch-lightning/issues/4049#issuecomment-706575159 first and then update the test?

what I posted here is all I worked on. It is a fully functional test + demo of other functionalities. it can be extended a little bit and then replace all the custom callback tracking in the old test. Please feel free to take it and use this code. If not, I will find some time.

Also, should we deprecate/remove #4049 (comment) first and then update the test?

I believe hooks like on_epoch_start can be useful if we "redefine" them to be running on epoch start regardless of training, validation, or test. If this is not desired, I'd rather have them removed.

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

It is very elegant, easy to understand and allows you to check that methods were called with the expected arguments and in the exact order.

I like that you can also test the called arguments but it seems that this assert_has_calls is order independent as the test passes when I shuffle calls inside or remove all to have callback.assert_has_calls([]) instead

UPDATE: sem that we can list all methods

assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
    ]

That's strange, because docs say it asserts the exact order in sequence:
https://docs.python.org/3/library/unittest.mock.html#unittest.mock.Mock.assert_has_calls

That's strange, because docs say it asserts the exact order in sequence:
https://docs.python.org/3/library/unittest.mock.html#unittest.mock.Mock.assert_has_calls

maybe some bug in implementation but is you simply test the replacement with an empty list, it passes too

yes it makes sense that it passes with empty list because any number of calls can occur before or after the sequence you pass in, so for example,

before()
a()
b()
c()
after()

assert_has_calls([]) # true
assert_has_calls([a]) # true
assert_has_calls([a, b]) # true
assert_has_calls([a, b, c]) # true
assert_has_calls([b, a, c]) # FALSE!!
assert_has_calls([before, a, b, c, after]) # true

Thanks for taking care of this @Borda !

yes it makes sense that it passes with an empty list because any number of calls can occur before or after the sequence you pass in, so for example,

I see, it is a bit dangerous in case you miss some at the beginning or at the end and all seems to be fine... 8-)

Was this page helpful?
0 / 5 - 0 ratings