Pytorch-lightning: Results gathering with varying tensor shapes (e.g. last batch)

Created on 17 Aug 2020  路  1Comment  路  Source: PyTorchLightning/pytorch-lightning

馃悰 Bug

Results object reduction when batch sizes are different won't work because torch.stack get's different input shapes. This can happen if your dataloader returns a smaller batch for the last iteration, for example.

def recursive_stack(result: MutableMapping):
    for k, v in result.items():
        if isinstance(v, dict):
            recursive_stack(v)
        if isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
            v = torch.stack(v)
            result[k] = v

Context
From slack discussion by @artgor
https://pytorch-lightning.slack.com/archives/CRBLFHY79/p1597604494424600

bug / fix help wanted

Most helpful comment

will submit a PR within the next hours

>All comments

will submit a PR within the next hours

Was this page helpful?
0 / 5 - 0 ratings

Related issues

edenlightning picture edenlightning  路  3Comments

baeseongsu picture baeseongsu  路  3Comments

Vichoko picture Vichoko  路  3Comments

monney picture monney  路  3Comments

DavidRuhe picture DavidRuhe  路  3Comments