Pytorch-lightning: incorrect batch_sizes when Dataloader returns a dict with multiple tensors.

Created on 26 Sep 2020  路  21Comments  路  Source: PyTorchLightning/pytorch-lightning

馃悰 Bug

Tracked batch sizes in result object are incorrect when a Dataloader returns a dict with multiple tensors.

To Reproduce

Create data loader that returns a dict, e.g. batch = {'batchA': tensor_A, 'batchB': tensor_B}.
Both entires have batch size N with N != 2.
For this example a batch size of 2 will be logged since len(batch) == 2.

https://github.com/PyTorchLightning/pytorch-lightning/blob/05e5f03fd7c851b06ca5e34b39eb660857b8f00c/pytorch_lightning/trainer/evaluation_loop.py#L147-L150
https://github.com/PyTorchLightning/pytorch-lightning/blob/05e5f03fd7c851b06ca5e34b39eb660857b8f00c/pytorch_lightning/trainer/training_loop.py#L304-L306

Expected behavior


Log correct batch size.
I'm not sure what can be defined as the 'correct' batch size when there are multiple tensors, but I expect that each tensor in the dict has the same batch_size. So, maybe something like:

if is_result_obj:
    if isinstance(batch, dict):
        batch = batch[list(batch.keys())[0]]
    result_obj.track_batch_size(len(batch))
Priority P0 bug / fix

Most helpful comment

So I guess 2 things should be fixed:

  • Track correct batch_size
  • Allow non-tensor numeric values in .log(...)

All 21 comments

I think doing just len(batch) is still wrong since here if the batch is a tuple or some kind of custom batch datatype then len(batch) will be wrong. Considering the basic mnist example too it will give 2 only which is wrong.

This should probably catch most things. Might be a bit much though.
It returns 1 if it fails to determine the batch size to prevent issues with weighted averaging in reduce_on_epoch_end.

if is_result_obj:
    result_obj.track_batch_size(unpack_batchsize(batch))

# maybe add as staticmethod to ResultObj?
def unpack_batchsize(sample):
    """ 
    Recursively unpack sample to find a torch.Tensor.
    returns len(tensor) when found, or 1 when it hits an empty or non iterable.
    """
    if isinstance(sample, torch.Tensor):
        sample = len(sample)
    elif isinstance(sample, dict):
        sample = next(iter(sample.values()), 1)
    elif isinstance(sample, Iterable):
        sample = next(iter(sample), 1)
    else:
        sample = 1  

    if isinstance(sample, int):
        return sample
    return unpack_batchsize(sample)

I suggest adding a function to the LightningModule batch_len_fx which defaults to len if it is not overriden. Anything could be a batch and lightning shouldn't have the responsability of supporting any batch type.

Exactly what I had in mind @carmocca. Or maybe simple ask to put batch_size in .log itself if on_epoch=True??

.log('some_metric', metric_value, on_epoch=True, batch_size=batch_size)
.log('some_metric', metric_value, on_epoch=False)

Lightning currently defaults to weighted_mean for reduction on epoch end by substituting the reduction method if it is torch.mean:

https://github.com/PyTorchLightning/pytorch-lightning/blob/ebc1b23fa38d54e9805aa4356867369f064c7031/pytorch_lightning/core/step_result.py#L389-L390

If this is the desired behaviour, I think Lightning should at least attempt getting a reasonable estimate for the batch size. In most use cases the dataloader will return multiple tensors, resulting in an incorrect batch estimate if len is the default. (e.g. any supervised method has at least (X, y) in its batch, producing len(batch)=2 as mentioned by @rohitgr7)

This could still be done using batch_len_fx though. On first call, if the method is not overriden, replace the batch_len_fx with a reasonable estimate based on the type of batch. (e.g. len of [tensor, first value in Iterable])

Exactly what I had in mind @carmocca. Or maybe simple ask to put batch_size in .log itself if on_epoch=True??

.log('some_metric', metric_value, on_epoch=True, batch_size=batch_size)
.log('some_metric', metric_value, on_epoch=False)

this should work too. Probably default to 1 if not provided since len is likely to be wrong.

@gerardsn I have a problem exactly with this weighted_mean function.
I'm working with the latest Lightning version from master.

https://github.com/PyTorchLightning/pytorch-lightning/blob/ebc1b23fa38d54e9805aa4356867369f064c7031/pytorch_lightning/core/step_result.py#L369

It gets outputs = [{'checkpoint_on': tensor(28.3303, device='cuda:0'), 'val_loss': tensor(28.3303, device='cuda:0'), 'val_precision1': 0.12652068126520682}].
Because I have only one batch per epoch in validation.

Lightning tries to reduce on epoch end.
It feeds
result = tensor([27.8364], device='cuda:0'), weights=tensor([2]), into the weighted_mean function and I get an error here:
https://github.com/PyTorchLightning/pytorch-lightning/blob/ebc1b23fa38d54e9805aa4356867369f064c7031/pytorch_lightning/core/step_result.py#L897
AttributeError: 'list' object has no attribute 'device'
I think it's related to this issue. It would be nice to not reduce anything if it's just one batch per epoch.

@fogside in your example result is a tensor so result.device should not though an error.

@fogside in your example result is a tensor so result.device should not though an error.

But it's a list with a tensor inside.

result = tensor([27.8364], device='cuda:0'), weights=tensor([2])

you refering to this right?

result = tensor([27.8364], device='cuda:0'), weights=tensor([2])

you refering to this right?

Sorry, I just realized, that I was mistaken.
Actually it calls this method twice for some reason.
I added prints at the beginning of weighted_mean function and at the reduce_on_epoch_end (I also changed the number of batches in this example)

Result[k] tensor([23.6331, 26.0617, 24.0941, 25.3255], device='cuda:0')
result:  tensor([23.6331, 26.0617, 24.0941, 25.3255], device='cuda:0')
weights:  tensor([2, 2, 2, 2])
Result[k] [0.14285714285714285, 0.06451612903225806, 0.056179775280898875, 0.13793103448275862]
result:  [0.14285714285714285, 0.06451612903225806, 0.056179775280898875, 0.13793103448275862]
weights:  tensor([2, 2, 2, 2])

And on the second time it gives me the error.

are you logging non-tensor values? maybe doing .item() somewhere in the logs? if not, can you put .log statements here??

are you logging non-tensor values?

Yes, I was calculating precision in numpy.. Isn't it possible to log non-tensor values?

no.. also to calculate precision or anyother metric you can try pl.metrics package which computes these metrics on the current device itself.

or you can just do torch.tensor(numpy_value) in .log

no.. also to calculate precision or another metric you can try pl.metrics package which does all of these on the current device itself.

I see. Thank you!
Actually I was trying to work with pytorch-metric-learning and used some function for topk precision estimation from there. But it looks quite tough to merge these 2 frameworks. I see now that topk precision should be calculated in pytorch. Another thing is that I need to have as big batch as possible to get a good topk estimation (it's even better to have the whole val set), that's why I found it hard to make this estimations in the validation_step. Maybe I should look into some Callbacks?
But it's not related to this issue.

already working on topk accuracy. Maybe will add topk precision and recall in pl.metrics as well. Can you point me to the implementation of topk precision in pytorch-metric-learning package. It would be helpful. Thanks :)

already working on topk accuracy. Maybe will add topk precision and recall in pl.metrics as well. Can you point me to the implementation of topk precision in pytorch-metric-learning package. It would be helpful. Thanks :)

It's great!
Sure. I used the class AccuracyCalculator like this

accuracy_calculator = AccuracyCalculator(include=("mean_average_precision_at_r"),  k=5)
accuracies = self.accuracy_calculator.get_accuracy(embeddings,
                                                           embeddings,
                                                           labels,
                                                           labels,
                                                           True)

Implementation:
https://github.com/KevinMusgrave/pytorch-metric-learning/blob/10bed5ee8719a543827aa32ea658603c2fcb0130/src/pytorch_metric_learning/utils/accuracy_calculator.py#L45

So I guess 2 things should be fixed:

  • Track correct batch_size
  • Allow non-tensor numeric values in .log(...)

@gerardsn I have a problem exactly with this weighted_mean function.
I'm working with the latest Lightning version from master.

https://github.com/PyTorchLightning/pytorch-lightning/blob/ebc1b23fa38d54e9805aa4356867369f064c7031/pytorch_lightning/core/step_result.py#L369

It gets outputs = [{'checkpoint_on': tensor(28.3303, device='cuda:0'), 'val_loss': tensor(28.3303, device='cuda:0'), 'val_precision1': 0.12652068126520682}].
Because I have only one batch per epoch in validation.

Lightning tries to reduce on epoch end.
It feeds
result = tensor([27.8364], device='cuda:0'), weights=tensor([2]), into the weighted_mean function and I get an error here:
https://github.com/PyTorchLightning/pytorch-lightning/blob/ebc1b23fa38d54e9805aa4356867369f064c7031/pytorch_lightning/core/step_result.py#L897

AttributeError: 'list' object has no attribute 'device'
I think it's related to this issue. It would be nice to not reduce anything if it's just one batch per epoch.

this is fixed on master

ok, making changes to this today.

What do we want as the default behavior? doesn't the custom reduce function solve the problem of custom batches etc?

The batches are not tracked correctly.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

srush picture srush  路  3Comments

DavidRuhe picture DavidRuhe  路  3Comments

justusschock picture justusschock  路  3Comments

jcreinhold picture jcreinhold  路  3Comments

williamFalcon picture williamFalcon  路  3Comments