Mmdetection: how does the function forward_train work?

Created on 25 Jun 2019  ·  5Comments  ·  Source: open-mmlab/mmdetection

I found the detector directory, base.py file ,forward_train is an abstracted function and other detectors fufill this function.But I can not found how does it work..

Most helpful comment

two_stage.py里定义了这个函数

class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
                       MaskTestMixin):
     def forward_train():
      ........

TwoStageDetector类又继承了BaseDetector类,再来看BaseDetector类:

class BaseDetector(nn.Module):
     .....
    @abstractmethod
    def forward_train(self, imgs, img_metas, **kwargs):
        pass
    ......

这这里定义了一个抽象方法,因为继承,所以这个forward_train就是TwoStageDetector里的forward_train,再接下来被调用:`

@auto_fp16(apply_to=('img', ))

def forward(self, img, img_meta, return_loss=True, **kwargs):

        if return_loss:
            return self.forward_train(img, img_meta, **kwargs)
        else:
            return self.forward_test(img, img_meta, **kwargs)

这里的auto_fp16函数,将forward_train作为参数,并且在内部执行了:

def auto_fp16(apply_to=None, out_fp32=False):
    def auto_fp16_wrapper(old_func):
    @functools.wraps(old_func)
    def new_func(*args, **kwargs):
            # check if the module has set the attribute `fp16_enabled`, if not,
            # just fallback to the original method.
            if not isinstance(args[0], torch.nn.Module):
                raise TypeError('@auto_fp16 can only be used to decorate the '
                                'method of nn.Module')
            if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
                return old_func(*args, **kwargs)
            # get the arg spec of the decorated method
            args_info = getfullargspec(old_func)
            # get the argument names to be casted
            args_to_cast = args_info.args if apply_to is None else apply_to
            # convert the args that need to be processed
            new_args = []
            # NOTE: default args are not taken into consideration
            if args:
                arg_names = args_info.args[:len(args)]
                for i, arg_name in enumerate(arg_names):
                    if arg_name in args_to_cast:
                        new_args.append(
                            cast_tensor_type(args[i], torch.float, torch.half))
                    else:
                        new_args.append(args[i])
            # convert the kwargs that need to be processed
            new_kwargs = {}
            if kwargs:
                for arg_name, arg_value in kwargs.items():
                    if arg_name in args_to_cast:
                        new_kwargs[arg_name] = cast_tensor_type(
                            arg_value, torch.float, torch.half)
                    else:
                        new_kwargs[arg_name] = arg_value
            # apply converted arguments to the decorated method
            output = old_func(*new_args, **new_kwargs)
            # cast the results back to fp32 if necessary
            if out_fp32:
                output = cast_tensor_type(output, torch.half, torch.float)
            return output

        return new_func

    return auto_fp16_wrapper

好,不必每行看懂,只要看到output = old_func(*new_args, **new_kwargs)return output就知道,这个forward_train被执行了。

================分割线,追加=================

上面有错误,原因在于output = old_func(*new_args, **new_kwargs)return output并不是执行forward_train,由于这位仁兄发了意见,我也只好再写写啦,咱继续看:
上面看到

class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
                       MaskTestMixin):

这个类,继承了BaseDetector类,问题就在BaseDetector类里面,看代码:

class BaseDetector(nn.Module):
    """Base class for detectors"""
    ....
    @abstractmethod
    def forward_train(self, imgs, img_metas, **kwargs):
        pass
    ....
    @auto_fp16(apply_to=('img', ))
    def forward(self, img, img_meta, return_loss=True, **kwargs):
        if return_loss:
            return self.forward_train(img, img_meta, **kwargs)
        else:
            return self.forward_test(img, img_meta, **kwargs) 

到这里能看明白了吧,BaseDetector里的forward调用了forward_train,然而BaseDetector类下的forward_train是个抽象静态函数,具体定义要在下面网络具体实现里,所以就转到了TwoStageDetector下面的forward_train,至此完成。
当然,如果你设置成SingleStageDetector也行,照样又会调用SingleStageDetector下面的forward_train了,这个代码写的很帅,值得借鉴。

All 5 comments

and also the simple_test

Not clear about your question, but you can read the related code.

two_stage.py里定义了这个函数

class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
                       MaskTestMixin):
     def forward_train():
      ........

TwoStageDetector类又继承了BaseDetector类,再来看BaseDetector类:

class BaseDetector(nn.Module):
     .....
    @abstractmethod
    def forward_train(self, imgs, img_metas, **kwargs):
        pass
    ......

这这里定义了一个抽象方法,因为继承,所以这个forward_train就是TwoStageDetector里的forward_train,再接下来被调用:`

@auto_fp16(apply_to=('img', ))

def forward(self, img, img_meta, return_loss=True, **kwargs):

        if return_loss:
            return self.forward_train(img, img_meta, **kwargs)
        else:
            return self.forward_test(img, img_meta, **kwargs)

这里的auto_fp16函数,将forward_train作为参数,并且在内部执行了:

def auto_fp16(apply_to=None, out_fp32=False):
    def auto_fp16_wrapper(old_func):
    @functools.wraps(old_func)
    def new_func(*args, **kwargs):
            # check if the module has set the attribute `fp16_enabled`, if not,
            # just fallback to the original method.
            if not isinstance(args[0], torch.nn.Module):
                raise TypeError('@auto_fp16 can only be used to decorate the '
                                'method of nn.Module')
            if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
                return old_func(*args, **kwargs)
            # get the arg spec of the decorated method
            args_info = getfullargspec(old_func)
            # get the argument names to be casted
            args_to_cast = args_info.args if apply_to is None else apply_to
            # convert the args that need to be processed
            new_args = []
            # NOTE: default args are not taken into consideration
            if args:
                arg_names = args_info.args[:len(args)]
                for i, arg_name in enumerate(arg_names):
                    if arg_name in args_to_cast:
                        new_args.append(
                            cast_tensor_type(args[i], torch.float, torch.half))
                    else:
                        new_args.append(args[i])
            # convert the kwargs that need to be processed
            new_kwargs = {}
            if kwargs:
                for arg_name, arg_value in kwargs.items():
                    if arg_name in args_to_cast:
                        new_kwargs[arg_name] = cast_tensor_type(
                            arg_value, torch.float, torch.half)
                    else:
                        new_kwargs[arg_name] = arg_value
            # apply converted arguments to the decorated method
            output = old_func(*new_args, **new_kwargs)
            # cast the results back to fp32 if necessary
            if out_fp32:
                output = cast_tensor_type(output, torch.half, torch.float)
            return output

        return new_func

    return auto_fp16_wrapper

好,不必每行看懂,只要看到output = old_func(*new_args, **new_kwargs)return output就知道,这个forward_train被执行了。

================分割线,追加=================

上面有错误,原因在于output = old_func(*new_args, **new_kwargs)return output并不是执行forward_train,由于这位仁兄发了意见,我也只好再写写啦,咱继续看:
上面看到

class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
                       MaskTestMixin):

这个类,继承了BaseDetector类,问题就在BaseDetector类里面,看代码:

class BaseDetector(nn.Module):
    """Base class for detectors"""
    ....
    @abstractmethod
    def forward_train(self, imgs, img_metas, **kwargs):
        pass
    ....
    @auto_fp16(apply_to=('img', ))
    def forward(self, img, img_meta, return_loss=True, **kwargs):
        if return_loss:
            return self.forward_train(img, img_meta, **kwargs)
        else:
            return self.forward_test(img, img_meta, **kwargs) 

到这里能看明白了吧,BaseDetector里的forward调用了forward_train,然而BaseDetector类下的forward_train是个抽象静态函数,具体定义要在下面网络具体实现里,所以就转到了TwoStageDetector下面的forward_train,至此完成。
当然,如果你设置成SingleStageDetector也行,照样又会调用SingleStageDetector下面的forward_train了,这个代码写的很帅,值得借鉴。

@dream-in-night thank you for your explanation in detail. I don't understand the operating of forward_train for a week until foud this answer.

hello,there also exit a little error in my answer ,and i will reply it tommrow------------------ 原始邮件 ------------------
发件人: "Wei Xia"notifications@github.com
发送时间: 2019年9月3日(星期二) 晚上11:14
收件人: "open-mmlab/mmdetection"mmdetection@noreply.github.com;
抄送: "dream-in-night"517059224@qq.com;"Mention"mention@noreply.github.com;
主题: Re: [open-mmlab/mmdetection] how does the function forward_trainwork? (#862)

@dream-in-night thank you for your explanation in detail. I don't understand the operating of forward_train for a week until foud this answer.


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub, or mute the thread.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

hust-kevin picture hust-kevin  ·  3Comments

tianxinhang picture tianxinhang  ·  3Comments

fengxiuyaun picture fengxiuyaun  ·  3Comments

liugaolian picture liugaolian  ·  3Comments

letanloc1998 picture letanloc1998  ·  3Comments