Transformers: Cannot load optimizer and lr_scheduler states with TPU training

Created on 12 Jun 2020  路  8Comments  路  Source: huggingface/transformers

馃悰 Bug

When restarting training and loading the optimizer.pt and scheduler.pt, the training crashes as the existing code does not know how to load it with TPU.

Information

The stacktrace -

Exception in device=TPU:5: don't know how to restore data location of torch.FloatStorage (tagged with xla:0)
Traceback (most recent call last):
  File "/home/saurabh/venv/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 231, in _start_fn
    fn(gindex, *args)
  File "/home/saurabh/<retracted>", line 334, in _mp_fn
    main()
  File "/home/saurabh/<retracted>", line 303, in main
    trainer.train(model_path=model_path)
  File "/home/saurabh/venv/lib/python3.6/site-packages/transformers/trainer.py", line 386, in train
    torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
  File "/home/saurabh/venv/lib/python3.6/site-packages/torch/serialization.py", line 584, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "/home/saurabh/venv/lib/python3.6/site-packages/torch/serialization.py", line 764, in _legacy_load
    result = unpickler.load()
  File "/home/saurabh/venv/lib/python3.6/site-packages/torch/serialization.py", line 720, in persistent_load
    deserialized_objects[root_key] = restore_location(obj, location)
  File "/home/saurabh/venv/lib/python3.6/site-packages/torch/serialization.py", line 802, in restore_location
    return default_restore_location(storage, str(map_location))
  File "/home/saurabh/venv/lib/python3.6/site-packages/torch/serialization.py", line 179, in default_restore_location
    + location + ")")
RuntimeError: don't know how to restore data location of torch.FloatStorage (tagged with xla:0)

This happens when loading a partially trained model.

A reference implementation is this
https://github.com/pytorch-tpu/fairseq/blob/tpu/fairseq/trainer.py#L195
With a discussion here https://github.com/pytorch/xla/issues/1343
Model I am using (Bert, XLNet ...): any model

Language I am using the model on (English, Chinese ...):

The problem arises when using:

  • [x] the official example scripts: (give details below)
  • [x] my own modified scripts: (give details below)

The tasks I am working on is:

  • [x] an official GLUE/SQUaD task: (give the name)
  • [x] my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

  1. Train any model on TPU, wait for a checkpoint to happen
  2. move the tokenizer files to the checkpoint dir (another bug, where the trainer expects the tokenizer configs to be present at the same directory as checkpoint dir, that only happens at the very end of training, not at one of the earlier checkpoints)
  3. Restart training again from the checkpoint on TPU

Expected behavior

Trainer loads the optimizer and scheduler to TPU and starts training.

Environment info

  • transformers version: 2.11.0 (master)
  • Platform: Linux-5.3.0-1026-gcp-x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.6.9
  • PyTorch version (GPU?): 1.6.0a0+6bdfd6a (False)
  • Tensorflow version (GPU?): 2.2.0 (False)
  • Using GPU in script?: False
  • Using distributed or parallel set-up in script?: yes, 8 way with xla_spawn.py

Most helpful comment

@LysandreJik Any updates on this bug? this prevents resuming training from a checkpoint on TPUs

All 8 comments

Thanks @misrasaurabh1, I'll look into it.
We'll have to add such a test into the TPU CI once we have it (sooner rather than later).

Any updates on getting rid of this error?
Makes it hard to use TPUs because preemptible machines cannot be used in Google Cloud if there is no way to resume from checkpoints.

Thanks!

I encountered the same issue which I found to be due to the fact that the script cannot map the optimizer to the
proper tpu device, here's the line in question:
https://github.com/huggingface/transformers/blob/d088d744adb4e5aa45262a34acab3ae9e81de169/src/transformers/trainer.py#L403

My solution was to replace

optimizer.load_state_dict(
        torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
        )

by:

if is_torch_tpu_available():

    # load state_dict on CPU and then transfer object to xla device
    optimizer.load_state_dict(torch.load(os.path.join(model_path, "optimizer.pt")))
    xm.send_cpu_data_to_device(optimizer,xm.xla_device())
else:
    optimizer.load_state_dict(
        torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
        )

that seemed to have done the trick with torch-xla-nightly. hope this helps

I encountered the same issue which I found to be due to the fact that the script cannot map the optimizer to the
proper tpu device, here's the line in question:

https://github.com/huggingface/transformers/blob/d088d744adb4e5aa45262a34acab3ae9e81de169/src/transformers/trainer.py#L403

My solution was to replace

optimizer.load_state_dict(
        torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
        )

by:

if is_torch_tpu_available():

    # load state_dict on CPU and then transfer object to xla device
    optimizer.load_state_dict(torch.load(os.path.join(model_path, "optimizer.pt")))
    xm.send_cpu_data_to_device(optimizer,xm.xla_device())
else:
    optimizer.load_state_dict(
        torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
        )

that seemed to have done the trick with torch-xla-nightly. hope this helps

I tried:

print(device)
# BUG: can't simply map to the XLA device at the moment
if TPU_ACCELERATOR:
    # load state_dict on CPU and then transfer object to XLA device
    net.load_state_dict(torch.load(model_load_file, map_location="cpu"))
    xm.send_cpu_data_to_device(net, device)
else:
    net.load_state_dict(torch.load(model_load_file, map_location=device))

and it failed with:

xla:4
<ipython-input-4-bbc8442c330c> in <module>
     96             # load state_dict on CPU and then transfer object to XLA device
     97             net.load_state_dict(torch.load(model_load_file, map_location="cpu"))
---> 98             xm.send_cpu_data_to_device(net, device)
     99         else:
    100             net.load_state_dict(torch.load(model_load_file, map_location=device))

/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py in send_cpu_data_to_device(data, device)
    629     return type(v) == torch.Tensor and v.device.type == 'cpu'
    630 
--> 631   return ToXlaTensorArena(convert_fn, select_fn).transform(data)
    632 
    633 

/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py in transform(self, inputs)
    312     self._collect_tensors(inputs)
    313     self._convert()
--> 314     return self._replace_tensors(inputs)
    315 
    316 

/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py in _replace_tensors(self, inputs)
    306 
    307     return xu.for_each_instance_rewrite(inputs, lambda x: self._select_fn(x),
--> 308                                         convert_fn)
    309 
    310   def transform(self, inputs):

/opt/conda/lib/python3.7/site-packages/torch_xla/utils/utils.py in for_each_instance_rewrite(value, select_fn, fn)
    197 def for_each_instance_rewrite(value, select_fn, fn):
    198   rwmap = dict()
--> 199   return _for_each_instance_rewrite(value, select_fn, fn, rwmap)
    200 
    201 

/opt/conda/lib/python3.7/site-packages/torch_xla/utils/utils.py in _for_each_instance_rewrite(value, select_fn, fn, rwmap)
    188     rwmap[id(value)] = result
    189     for k in result.__dict__.keys():
--> 190       v = _for_each_instance_rewrite(result.__dict__[k], select_fn, fn, rwmap)
    191       result.__dict__[k] = v
    192   else:

/opt/conda/lib/python3.7/site-packages/torch_xla/utils/utils.py in _for_each_instance_rewrite(value, select_fn, fn, rwmap)
    189     for k in result.__dict__.keys():
    190       v = _for_each_instance_rewrite(result.__dict__[k], select_fn, fn, rwmap)
--> 191       result.__dict__[k] = v
    192   else:
    193     rwmap[id(value)] = result

TypeError: 'mappingproxy' object does not support item assignment

How about something simpler like

# load state_dict on CPU and then transfer object to XLA device
net.load_state_dict(torch.load(model_load_file, map_location="cpu"))
net.to(device)

?

I think it does the job just fine.

I encountered the same issue which I found to be due to the fact that the script cannot map the optimizer to the
proper tpu device, here's the line in question:
https://github.com/huggingface/transformers/blob/d088d744adb4e5aa45262a34acab3ae9e81de169/src/transformers/trainer.py#L403

My solution was to replace

optimizer.load_state_dict(
        torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
        )

by:

if is_torch_tpu_available():

    # load state_dict on CPU and then transfer object to xla device
    optimizer.load_state_dict(torch.load(os.path.join(model_path, "optimizer.pt")))
    xm.send_cpu_data_to_device(optimizer,xm.xla_device())
else:
    optimizer.load_state_dict(
        torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
        )

that seemed to have done the trick with torch-xla-nightly. hope this helps

This works, however, the progress bar starts from 0, and then, just takes a load of time to come to the step where the checkpoint is present! How to tackle that? I am training on the cloud (tpu v 3.8) and using xla_spawn script to distribute training among cores

@LysandreJik Any updates on this bug? this prevents resuming training from a checkpoint on TPUs

I am also having the same problem on loading a model from TPU, and resume the training. Any solutions?

Was this page helpful?
0 / 5 - 0 ratings

Related issues

siddsach picture siddsach  路  3Comments

lemonhu picture lemonhu  路  3Comments

lcswillems picture lcswillems  路  3Comments

chuanmingliu picture chuanmingliu  路  3Comments

0x01h picture 0x01h  路  3Comments