Mmdetection: Export model architecture

Created on 25 Oct 2020  路  4Comments  路  Source: open-mmlab/mmdetection

Hi,

Let say I want to modify or change a lot of things in the model, for example, get the tensor value from some specific layer of the model or cut some residual connection. Is there any way to extract the model as a Class which is similar to import a model from torchvision.

I try it with torch.onnx by modifying the file "mmdetection/tools/test.py" at line 161 to be:

model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
# DetectoRS
dummy_input = torch.randn(3,1333,800, device = 'cuda')
torch.onnx.export(
    model,
    [dummy_input],
    '/home/detectors.onnx',
    export_params= True,
    keep_initializers_as_inputs=True,
    verbose = False)
    # opset_version=11)

And it raises the error:

Traceback (most recent call last):
  File "./tools/test.py", line 225, in <module>
    main()
  File "./tools/test.py", line 176, in main
    verbose = False)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/onnx/__init__.py", line 208, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/onnx/utils.py", line 92, in export
    use_external_data_format=use_external_data_format)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/onnx/utils.py", line 530, in _export
    fixed_batch_size=fixed_batch_size)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/onnx/utils.py", line 366, in _model_to_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/onnx/utils.py", line 319, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/jit/__init__.py", line 338, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/jit/__init__.py", line 426, in forward
    self._force_outplace,
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/jit/__init__.py", line 412, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/nn/modules/module.py", line 720, in _call_impl
    result = self._slow_forward(*input, **kwargs)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/nn/modules/module.py", line 704, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 84, in new_func
    return old_func(*args, **kwargs)
TypeError: forward() missing 1 required positional argument: 'img_metas'

I guess because the forward function in model was re-written with an additional parameter img_metas.

Most helpful comment

Hi,

Let say I want to modify or change a lot of things in the model, for example, get the tensor value from some specific layer of the model or cut some residual connection. Is there any way to extract the model as a Class which is similar to import a model from torchvision.

I try it with torch.onnx by modifying the file "mmdetection/tools/test.py" at line 161 to be:

model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
# DetectoRS
dummy_input = torch.randn(3,1333,800, device = 'cuda')
torch.onnx.export(
    model,
    [dummy_input],
    '/home/detectors.onnx',
    export_params= True,
    keep_initializers_as_inputs=True,
    verbose = False)
    # opset_version=11)

And it raises the error:

Traceback (most recent call last):
  File "./tools/test.py", line 225, in <module>
    main()
  File "./tools/test.py", line 176, in main
    verbose = False)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/onnx/__init__.py", line 208, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/onnx/utils.py", line 92, in export
    use_external_data_format=use_external_data_format)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/onnx/utils.py", line 530, in _export
    fixed_batch_size=fixed_batch_size)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/onnx/utils.py", line 366, in _model_to_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/onnx/utils.py", line 319, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/jit/__init__.py", line 338, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/jit/__init__.py", line 426, in forward
    self._force_outplace,
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/jit/__init__.py", line 412, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/nn/modules/module.py", line 720, in _call_impl
    result = self._slow_forward(*input, **kwargs)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/nn/modules/module.py", line 704, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 84, in new_func
    return old_func(*args, **kwargs)
TypeError: forward() missing 1 required positional argument: 'img_metas'

I guess because the forward function in model was re-written with an additional parameter img_metas.

Hi,
As the torch.onnx.export only accept tensor/list of tensor/tuple of tensor as input to call the model, however, our MMDet model is called with tensor (img) and dict (img_meta) in default. So before using torch.onnx.export to convert model, we have to first wrap the forward function. Please see https://github.com/open-mmlab/mmdetection/blob/eb7bfbc62658b60e5be8bb00a40d5e3018971f78/mmdet/core/export/pytorch2onnx.py#L48 for details.
Thanks and good luck

All 4 comments

Hi, @drcut , please help resolve this issue.

Hi,

Let say I want to modify or change a lot of things in the model, for example, get the tensor value from some specific layer of the model or cut some residual connection. Is there any way to extract the model as a Class which is similar to import a model from torchvision.

I try it with torch.onnx by modifying the file "mmdetection/tools/test.py" at line 161 to be:

model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
# DetectoRS
dummy_input = torch.randn(3,1333,800, device = 'cuda')
torch.onnx.export(
    model,
    [dummy_input],
    '/home/detectors.onnx',
    export_params= True,
    keep_initializers_as_inputs=True,
    verbose = False)
    # opset_version=11)

And it raises the error:

Traceback (most recent call last):
  File "./tools/test.py", line 225, in <module>
    main()
  File "./tools/test.py", line 176, in main
    verbose = False)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/onnx/__init__.py", line 208, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/onnx/utils.py", line 92, in export
    use_external_data_format=use_external_data_format)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/onnx/utils.py", line 530, in _export
    fixed_batch_size=fixed_batch_size)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/onnx/utils.py", line 366, in _model_to_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/onnx/utils.py", line 319, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/jit/__init__.py", line 338, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/jit/__init__.py", line 426, in forward
    self._force_outplace,
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/jit/__init__.py", line 412, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/nn/modules/module.py", line 720, in _call_impl
    result = self._slow_forward(*input, **kwargs)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/torch/nn/modules/module.py", line 704, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/vcl/anaconda3/envs/giang/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 84, in new_func
    return old_func(*args, **kwargs)
TypeError: forward() missing 1 required positional argument: 'img_metas'

I guess because the forward function in model was re-written with an additional parameter img_metas.

Hi,
As the torch.onnx.export only accept tensor/list of tensor/tuple of tensor as input to call the model, however, our MMDet model is called with tensor (img) and dict (img_meta) in default. So before using torch.onnx.export to convert model, we have to first wrap the forward function. Please see https://github.com/open-mmlab/mmdetection/blob/eb7bfbc62658b60e5be8bb00a40d5e3018971f78/mmdet/core/export/pytorch2onnx.py#L48 for details.
Thanks and good luck

Hi, @drcut , please help resolve this issue.

Sure, done

Thank you for your answer

Was this page helpful?
0 / 5 - 0 ratings