Hey,
I've been trying to convert the winter2summer checkpoint to an onnx model, so i've tried running test.py with the following additions:
dummy_input = torch.randn(1, 3, 256, 256)
torch.onnx.export(model, dummy_input, "./cycleGAN.onnx")
I got the following error:
Traceback (most recent call last):
File "test.py", line 69, in <module>
torch.onnx.export(model, dummy_input, "./cycle.onnx")
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py", line 25, in export
return utils.export(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 131, in export
strip_doc_string=strip_doc_string)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 363, in _export
_retain_param_name, do_constant_folding)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 266, in _model_to_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 217, in _trace_and_get_graph_from_model
orig_state_dict_keys = _unique_state_dict(model).keys()
File "/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py", line 238, in _unique_state_dict
state_dict = module.state_dict(keep_vars=True)
AttributeError: 'TestModel' object has no attribute 'state_dict'
Any and all help is much appreciated
Thanks!
Looks like I was feeding it the wrong model instance, the issue has been fixed after inputting the generator network
This worked for me:
dummy_input = torch.randn(1, 3, 256, 256)
torch.onnx.export(model.netG, dummy_input, "./cycleGAN.onnx")
I'll be testing the model now.
Most helpful comment
Looks like I was feeding it the wrong model instance, the issue has been fixed after inputting the generator network
This worked for me:
I'll be testing the model now.