If user passes a traced model into create_supervised* methods and specifies a device, there will be the following error:
traced_model = torch.jit.trace(model, z)
evaluator = create_supervised_evaluator(traced_model, metrics, device='cuda')
> RuntimeError: to is not supported on TracedModules
We need to check if model have the method .to...
Probably, we should wait until this issue is done.
As pytorch/pytorch#15340 has been merged, create_supervised* methods will work on ScriptModule and traced models.
We need to checked with tests on CI ...
@TheCodez if you would like to contribute, you can add a test with traced and scripted models passing to create_supervised* method with device cpu and the test should pass for pytorch-nightly build (can be check with torch.__version__)
the test should pass for pytorch-nightly build (can be check with torch.__version__)
Do you mean only executing the test if on the nightly, or check if it throws a RuntimeError if not on the nightly?
@TheCodez yes something like this.
@vfdev-5 which torch.__version__ to check for nightly?
@TheCodez you can do it like this:
# Check if torch-nightly
if "dev" in torch.__version__:
...
else:
...