TPU Trainer does not seem to support --evaluate_during_training. When the training loop goes into logging part, the whole process just hangs up stalling training. The same code/dataset with a multi-gpu setup works well.
I am trying to move my company to Huggingface so want to train models on TPUs on our dataset which hung during the logging step. I was able to replicate the behavior with run_langugage_modelling.py, and the steps to replicate this are shown below.
Other observations are -
I felt that multiprocessing way of doing TPU training wastes a lot of CPU memory because with large datasets one has to use a machine with 100s of GBs of RAM because the features are being replicated 8 times in memory.
Another bug is that with TPU training there are 8 WandB runs generated and it creates a lot of clutter. Suggestions to fix this would be to only do wandb logging from a single process. If its unavoidable to generate 8 wandb runs, tag all the runs to belong to a single 'group' that leads to better organization of the runs. (https://docs.wandb.com/library/advanced/grouping)
Model I am using (Bert, XLNet ...): Roberta with run_language_modelling.py to replicate, T5 with our internal data.
Language I am using the model on (English, Chinese ...): English
The problem arises when using:
The tasks I am working on is:
Steps to reproduce the behavior:
conda activate torch-xla-nightly and start a v2-8 TPU in us-central1-c zone. Set the TPU env varsexport TRAIN_FILE=/path/to/dataset/wiki.train.raw
export TEST_FILE=/path/to/dataset/wiki.test.raw
python xla_spawn.py --num_cores 8 language_modelling/run_language_modeling.py \
--output_dir=output \
--model_type=roberta \
--model_name_or_path=roberta-base \
--do_train \
--train_data_file=$TRAIN_FILE \
--do_eval \
--eval_data_file=$TEST_FILE \
--mlm
--evaluate_during_training
--per_device_train_batch_size=4
--per_device_eval_batch_size=4
When it hangs, the tqdm counter is stuck at step 499 (with 500 as the logging interval) and nothing happens. When I do a Keyboard Interrupt, I get this stack trace.
main()
File "../../../vendor/transformers/examples/xla_spawn.py", line 68, in main
xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 296, in spawn
start_method=start_method)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 158, in start_processes
while not context.join():
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 78, in join
timeout=timeout,
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/multiprocessing/connection.py", line 911, in wait
ready = selector.select(timeout)
File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/selectors.py", line 376, in select
fd_event_list = self._poll.poll(timeout)
KeyboardInterrupt
Being able to log validation set loss during training
transformers version: 2.11.0
Had the same problem with TPU. --logging_step seems to freeze everything.
I have removed logging and then evaluate it after training.
Hi, I fail to reproduce this on master following your steps. Can you try pulling from master and letting me know if the issue is resolved? If it's not I'll take a deeper look.
You can set --logging_steps=10 so that to reduce the time it takes to get to the hang.
I can, however, reproduce the issue with wandb. I'm looking into it now.
Interesting, I retried the same instruction from master with --logging_steps as 50 and it did evaluate the first time but then it again got stuck at the second evaluation attempt at step 99. Something is flaky and not right...
Also now that I got at least one step of evaluation working, I notice that it prints 8 different eval_loss values, one for each process. Not sure how to interpret this. I haven't looked into the logic but looks like the evaluator also splits the eval_data into 8 parts and calculates the eval_loss on them individually without aggregating them into a single final eval_loss for the whole eval dataset. This defeats the purpose of evaluating during training.
Indeed, something's not right. I'm taking a look.
This was working well on 26-27th May. I tried going back to that commit but same error. Maybe something with XLA?
I don't really know, now for some reason it decides to not hang, while it did hang the first time this morning. Even with a clean environment, it doesn't hang anymore on my side.
I'm still investigating
Another really weird bug is that setting --logging_steps to 0 leads to the training hanging up at step 99. I reproduced this same behavior in two different setups. I was using this option to stop logging which would hopefully bypass this above bug with this line of trainer:493
if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
self.global_step == 1 and self.args.logging_first_step
):
I believe this is causing that bug
if os.getenv("WANDB_WATCH") != "false":
wandb.watch(
self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
)
Setting WANDB_WATCH = false fixed the bug, it also evaluates during training now. Starting a PR...
Great.
But Maybe there can be something with XLA? that WandB gradients are not logged and the training freezes?
I am not sure if wandb supports logging of gradients with Pytorch/XLA. I reached out to Wandb to ask about this, should get a reply by tomorrow. It is possible that Pytorch/XLA does not support gradient logging as well. I looked at the XLA github repo and couldn't find a mention of gradients logging with TPUs. I am unfamiliar with XLA interface with wandb and not keen on digging deeper into this. Hopefully wandb offers more clarity soon.
I'm one of the founders of wandb. We're digging into the root cause of this now. We're planning to issue a new release ASAP to ensure users can never get into this hung state. I'll update the thread here. For anyone finding this thread online and hitting the issue, you can add the following code to disable the gradient monitoring in wandb with huggingface.
import os
os.environ["WANDB_WATCH"] = "false"
Or if you're shelling out to a python script:
export WANDB_WATCH=false
python your_script.py
Thank you Chris for looking into this!
@vanpelt The wandb gradient logging has been disabled with PR https://github.com/huggingface/transformers/pull/4926 . Once the Wandb fixes the gradient logging for Pytorch/XLA, we can re-enable this.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
Most helpful comment
I'm one of the founders of wandb. We're digging into the root cause of this now. We're planning to issue a new release ASAP to ensure users can never get into this hung state. I'll update the thread here. For anyone finding this thread online and hitting the issue, you can add the following code to disable the gradient monitoring in wandb with huggingface.
Or if you're shelling out to a python script: