When restarting training and loading the optimizer.pt and scheduler.pt, the training crashes as the existing code does not know how to load it with TPU.
The stacktrace -
Exception in device=TPU:5: don't know how to restore data location of torch.FloatStorage (tagged with xla:0)
Traceback (most recent call last):
File "/home/saurabh/venv/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 231, in _start_fn
fn(gindex, *args)
File "/home/saurabh/<retracted>", line 334, in _mp_fn
main()
File "/home/saurabh/<retracted>", line 303, in main
trainer.train(model_path=model_path)
File "/home/saurabh/venv/lib/python3.6/site-packages/transformers/trainer.py", line 386, in train
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
File "/home/saurabh/venv/lib/python3.6/site-packages/torch/serialization.py", line 584, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "/home/saurabh/venv/lib/python3.6/site-packages/torch/serialization.py", line 764, in _legacy_load
result = unpickler.load()
File "/home/saurabh/venv/lib/python3.6/site-packages/torch/serialization.py", line 720, in persistent_load
deserialized_objects[root_key] = restore_location(obj, location)
File "/home/saurabh/venv/lib/python3.6/site-packages/torch/serialization.py", line 802, in restore_location
return default_restore_location(storage, str(map_location))
File "/home/saurabh/venv/lib/python3.6/site-packages/torch/serialization.py", line 179, in default_restore_location
+ location + ")")
RuntimeError: don't know how to restore data location of torch.FloatStorage (tagged with xla:0)
This happens when loading a partially trained model.
A reference implementation is this
https://github.com/pytorch-tpu/fairseq/blob/tpu/fairseq/trainer.py#L195
With a discussion here https://github.com/pytorch/xla/issues/1343
Model I am using (Bert, XLNet ...): any model
Language I am using the model on (English, Chinese ...):
The problem arises when using:
The tasks I am working on is:
Steps to reproduce the behavior:
Trainer loads the optimizer and scheduler to TPU and starts training.
transformers version: 2.11.0 (master)Thanks @misrasaurabh1, I'll look into it.
We'll have to add such a test into the TPU CI once we have it (sooner rather than later).
Any updates on getting rid of this error?
Makes it hard to use TPUs because preemptible machines cannot be used in Google Cloud if there is no way to resume from checkpoints.
Thanks!
I encountered the same issue which I found to be due to the fact that the script cannot map the optimizer to the
proper tpu device, here's the line in question:
https://github.com/huggingface/transformers/blob/d088d744adb4e5aa45262a34acab3ae9e81de169/src/transformers/trainer.py#L403
My solution was to replace
optimizer.load_state_dict(
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
)
by:
if is_torch_tpu_available():
# load state_dict on CPU and then transfer object to xla device
optimizer.load_state_dict(torch.load(os.path.join(model_path, "optimizer.pt")))
xm.send_cpu_data_to_device(optimizer,xm.xla_device())
else:
optimizer.load_state_dict(
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
)
that seemed to have done the trick with torch-xla-nightly. hope this helps
I encountered the same issue which I found to be due to the fact that the script cannot map the optimizer to the
proper tpu device, here's the line in question:My solution was to replace
optimizer.load_state_dict( torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device) )by:
if is_torch_tpu_available(): # load state_dict on CPU and then transfer object to xla device optimizer.load_state_dict(torch.load(os.path.join(model_path, "optimizer.pt"))) xm.send_cpu_data_to_device(optimizer,xm.xla_device()) else: optimizer.load_state_dict( torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device) )that seemed to have done the trick with torch-xla-nightly. hope this helps
I tried:
print(device)
# BUG: can't simply map to the XLA device at the moment
if TPU_ACCELERATOR:
# load state_dict on CPU and then transfer object to XLA device
net.load_state_dict(torch.load(model_load_file, map_location="cpu"))
xm.send_cpu_data_to_device(net, device)
else:
net.load_state_dict(torch.load(model_load_file, map_location=device))
and it failed with:
xla:4
<ipython-input-4-bbc8442c330c> in <module>
96 # load state_dict on CPU and then transfer object to XLA device
97 net.load_state_dict(torch.load(model_load_file, map_location="cpu"))
---> 98 xm.send_cpu_data_to_device(net, device)
99 else:
100 net.load_state_dict(torch.load(model_load_file, map_location=device))
/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py in send_cpu_data_to_device(data, device)
629 return type(v) == torch.Tensor and v.device.type == 'cpu'
630
--> 631 return ToXlaTensorArena(convert_fn, select_fn).transform(data)
632
633
/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py in transform(self, inputs)
312 self._collect_tensors(inputs)
313 self._convert()
--> 314 return self._replace_tensors(inputs)
315
316
/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py in _replace_tensors(self, inputs)
306
307 return xu.for_each_instance_rewrite(inputs, lambda x: self._select_fn(x),
--> 308 convert_fn)
309
310 def transform(self, inputs):
/opt/conda/lib/python3.7/site-packages/torch_xla/utils/utils.py in for_each_instance_rewrite(value, select_fn, fn)
197 def for_each_instance_rewrite(value, select_fn, fn):
198 rwmap = dict()
--> 199 return _for_each_instance_rewrite(value, select_fn, fn, rwmap)
200
201
/opt/conda/lib/python3.7/site-packages/torch_xla/utils/utils.py in _for_each_instance_rewrite(value, select_fn, fn, rwmap)
188 rwmap[id(value)] = result
189 for k in result.__dict__.keys():
--> 190 v = _for_each_instance_rewrite(result.__dict__[k], select_fn, fn, rwmap)
191 result.__dict__[k] = v
192 else:
/opt/conda/lib/python3.7/site-packages/torch_xla/utils/utils.py in _for_each_instance_rewrite(value, select_fn, fn, rwmap)
189 for k in result.__dict__.keys():
190 v = _for_each_instance_rewrite(result.__dict__[k], select_fn, fn, rwmap)
--> 191 result.__dict__[k] = v
192 else:
193 rwmap[id(value)] = result
TypeError: 'mappingproxy' object does not support item assignment
How about something simpler like
# load state_dict on CPU and then transfer object to XLA device
net.load_state_dict(torch.load(model_load_file, map_location="cpu"))
net.to(device)
?
I think it does the job just fine.
I encountered the same issue which I found to be due to the fact that the script cannot map the optimizer to the
proper tpu device, here's the line in question:
https://github.com/huggingface/transformers/blob/d088d744adb4e5aa45262a34acab3ae9e81de169/src/transformers/trainer.py#L403My solution was to replace
optimizer.load_state_dict( torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device) )by:
if is_torch_tpu_available(): # load state_dict on CPU and then transfer object to xla device optimizer.load_state_dict(torch.load(os.path.join(model_path, "optimizer.pt"))) xm.send_cpu_data_to_device(optimizer,xm.xla_device()) else: optimizer.load_state_dict( torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device) )that seemed to have done the trick with torch-xla-nightly. hope this helps
This works, however, the progress bar starts from 0, and then, just takes a load of time to come to the step where the checkpoint is present! How to tackle that? I am training on the cloud (tpu v 3.8) and using xla_spawn script to distribute training among cores
@LysandreJik Any updates on this bug? this prevents resuming training from a checkpoint on TPUs
I am also having the same problem on loading a model from TPU, and resume the training. Any solutions?
Most helpful comment
@LysandreJik Any updates on this bug? this prevents resuming training from a checkpoint on TPUs