Vision: Simplify dimension checks on functional_tensor.py

Created on 11 Dec 2020  路  11Comments  路  Source: pytorch/vision

馃殌 Feature

The functional_tensor.py file makes use of the private method _is_tensor_a_torch_image in every public operator to check its dimensions:
https://github.com/pytorch/vision/blob/dab475720f116c42fc80e437b25496dfc94d2a8a/torchvision/transforms/functional_tensor.py#L10-L11

Examples:
https://github.com/pytorch/vision/blob/dab475720f116c42fc80e437b25496dfc94d2a8a/torchvision/transforms/functional_tensor.py#L146-L147
https://github.com/pytorch/vision/blob/dab475720f116c42fc80e437b25496dfc94d2a8a/torchvision/transforms/functional_tensor.py#L166-L167

This check is repetitive and reduces the code readability. We should fix this by using ~decorators. See https://github.com/pytorch/vision/pull/3123#discussion_r540840674 for details.~ assertions.

cc @vfdev-5

good first issue transforms

Most helpful comment

@datumbox yes, seems like it is a blocker for decorator usage. On the other hand, the goal of this issue is to refactor a bit the current code base by removing explicit type checking like

 if not _is_tensor_a_torch_image(img): 
     raise TypeError('tensor is not a torch image.') 

Using decorator we remove 2 lines and add 1. I propose to just create a simple helper methods and reuse them with the similar number of lines gain:

def _assert_image_tensor(img):
    if not _is_tensor_a_torch_image(img):
        raise TypeError("tensor is not a torch image.")    

def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
    """PRIVATE METHOD. Crop the given Image Tensor.
    """
    _assert_image_tensor(img)

    return img[..., top:top + height, left:left + width]

What do you think ?

All 11 comments

FYI I'm not sure decorators are supported in torchscript

FYI I'm not sure decorators are supported in torchscript

The following works

import torch
print(torch.__version__)
import typing


def check_tensor(func: typing.Callable) -> typing.Callable:

    def wrapper(t: torch.Tensor, p: int = 5):
        if not isinstance(t, torch.Tensor):
            raise TypeError()
        print("Inside decorator")
        return func(t, p)

    return wrapper


@check_tensor
def foo(t: torch.Tensor, p: int = 5) -> torch.Tensor:
    for i in range(10):
        if i < p:
            continue
        t += i
    return t


print(foo(torch.tensor([0.0])))
sfoo = torch.jit.script(foo)
print(sfoo(torch.tensor([0.0])))

> 1.8.0.dev20201211
> Inside decorator
> tensor([35.])
> Inside decorator
> tensor([35.])

Great!

Somebody working on this feature?

@avijit9 Not at the moment. If you are interested in sending a PR that would be awesome!

How to deal with multiple arguments? For example, the crop function takes multiple arguments. But if we use a decorator with *args, **kwargs I'm getting the following error -

FAILED test/test_decorator.py::Tester::test_crop - torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:

The decorator I implemented is shown below -

def check_tensor_torch_image(f: typing.Callable) -> typing.Callable:
    def wrapper(*args, **kwargs):
        if args[0].ndim < 2:
            raise TypeError("Tensor is not a torch image.")
        return f(*args, **kwargs)
    return wrapper

The unit test is as follows -

    def test_crop(self):
        scripted_fn = torch.jit.script(F_t.crop)
        shape = (10,)
        tensor = torch.rand(*shape, dtype=torch.float, device=self.device)
        with self.assertRaises(Exception) as context:
            scripted_fn(tensor)
        self.assertTrue('Tensor is not a torch image.' in str(context.exception))

This is a known issue - https://github.com/pytorch/pytorch/issues/29637

cc - @datumbox

Minimal code to reproduce this -

import torch
print(torch.__version__)
import typing


def check_tensor(func: typing.Callable) -> typing.Callable:

    def wrapper(*args, **kwargs):
        if not isinstance(args[0], torch.Tensor):
            raise TypeError()
        print("Inside decorator")
        return func(*args, **kwargs)

    return wrapper


@check_tensor
def foo(t: torch.Tensor, p: int = 5) -> torch.Tensor:
    for i in range(10):
        if i < p:
            continue
        t += i
    return t


> print(foo(torch.tensor([0.0])))
> sfoo = torch.jit.script(foo)
> print(sfoo(torch.tensor([0.0])))

@avijit9 Thank you for the detailed analysis!

@vfdev-5 Do you have any suggestion to work around the limitation? I suppose it's possible to implement a couple of versions of the decorator since most of the methods in functional_tensor.py have limited number of arguments but I think that would lead to more messy code which defeats the purpose of your original recommendation. Thoughts?

@datumbox yes, seems like it is a blocker for decorator usage. On the other hand, the goal of this issue is to refactor a bit the current code base by removing explicit type checking like

 if not _is_tensor_a_torch_image(img): 
     raise TypeError('tensor is not a torch image.') 

Using decorator we remove 2 lines and add 1. I propose to just create a simple helper methods and reuse them with the similar number of lines gain:

def _assert_image_tensor(img):
    if not _is_tensor_a_torch_image(img):
        raise TypeError("tensor is not a torch image.")    

def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
    """PRIVATE METHOD. Crop the given Image Tensor.
    """
    _assert_image_tensor(img)

    return img[..., top:top + height, left:left + width]

What do you think ?

@vfdev-5 agreed, that's still an improvement.

@avijit9 would you be still interested to send a PR based on the above proposal?

@datumbox Sure! I'll send a PR soon.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

iacolippo picture iacolippo  路  4Comments

300LiterPropofol picture 300LiterPropofol  路  3Comments

feiyangsuo picture feiyangsuo  路  3Comments

IssamLaradji picture IssamLaradji  路  3Comments

zhang-zhenyu picture zhang-zhenyu  路  3Comments