Hi,
As the discussion started in that #1621 issue, GPT2 tokenization is so slow even with 50MB of the dataset.
I'm using run_lm_finetuning.py and here are the steps to reproduce the problem:
run_lm_finetuning.py to train (finetune) the dataset. Here are my parameters:--train_data_file "/train/datafile" \
--eval_data_file "/eval/datafile" \
--output_dir "/train/model" \
--model_type gpt2 \
--model_name_or_path distilgpt2 \
--cache_dir "/train/cache" \
--do_train \
--evaluate_during_training \
--per_gpu_train_batch_size 1 \
--per_gpu_eval_batch_size 1 \
--gradient_accumulation_steps 5 \
--overwrite_output_dir \
--seed 99
I dug into huggingface/transformers 's codebase and profiled the tokenization process. And it is obvious that this summation drains the time:
https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/tokenization_utils.py#L644
I run profiler and here is the result:
73791524 function calls in 1566.379 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
27157 0.083 0.000 0.109 0.000 <frozen importlib._bootstrap>:1009(_handle_fromlist)
27157 0.065 0.000 0.128 0.000 _bootlocale.py:33(getpreferredencoding)
81471 0.070 0.000 0.327 0.000 locale.py:589(setlocale)
27157 0.422 0.000 0.876 0.000 locale.py:647(getpreferredencoding)
27157 0.363 0.000 5.997 0.000 regex.py:328(findall)
27157 0.662 0.000 1.682 0.000 regex.py:434(_compile)
4815114 8.744 0.000 16.641 0.000 tokenization_gpt2.py:139(bpe)
2030532 1.116 0.000 1.887 0.000 tokenization_gpt2.py:149(<lambda>)
27157 22.702 0.001 110.038 0.004 tokenization_gpt2.py:180(_tokenize)
25459527 5.702 0.000 5.702 0.000 tokenization_gpt2.py:194(<genexpr>)
10242602 1.764 0.000 1.764 0.000 tokenization_gpt2.py:195(<genexpr>)
1377876 1.678 0.000 1.975 0.000 tokenization_gpt2.py:91(get_pairs)
95205 0.526 0.000 0.910 0.000 tokenization_utils.py:1043(special_tokens_map)
95205 0.932 0.000 1.987 0.000 tokenization_utils.py:1055(all_special_tokens)
1 0.119 0.119 1566.379 1566.379 tokenization_utils.py:615(tokenize)
40789 0.099 0.000 0.169 0.000 tokenization_utils.py:623(split_on_token)
1 0.287 0.287 1566.260 1566.260 tokenization_utils.py:641(split_on_tokens)
54417 0.698 0.000 112.123 0.002 tokenization_utils.py:659(<genexpr>)
27157 0.063 0.000 0.063 0.000 {built-in method _locale.nl_langinfo}
81471 0.252 0.000 0.252 0.000 {built-in method _locale.setlocale}
761640 0.384 0.000 0.384 0.000 {built-in method builtins.getattr}
54314 0.022 0.000 0.022 0.000 {built-in method builtins.hasattr}
516605 0.150 0.000 0.150 0.000 {built-in method builtins.isinstance}
1821447 0.159 0.000 0.159 0.000 {built-in method builtins.len}
472563 3.469 0.000 5.355 0.000 {built-in method builtins.min}
1 1453.081 1453.081 1565.204 1565.204 {built-in method builtins.sum}
2043214 0.297 0.000 0.297 0.000 {method 'add' of 'set' objects}
456488 0.055 0.000 0.055 0.000 {method 'append' of 'list' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
4815114 1.169 0.000 1.169 0.000 {method 'encode' of 'str' objects}
5550977 16.572 0.000 18.336 0.000 {method 'extend' of 'list' objects}
27157 3.952 0.000 3.952 0.000 {method 'findall' of '_regex.Pattern' objects}
2057689 0.784 0.000 0.784 0.000 {method 'get' of 'dict' objects}
735863 0.233 0.000 0.233 0.000 {method 'index' of 'tuple' objects}
4894984 38.307 0.000 44.010 0.000 {method 'join' of 'str' objects}
1 0.000 0.000 0.000 0.000 {method 'keys' of 'dict' objects}
4855903 1.365 0.000 1.365 0.000 {method 'split' of 'str' objects}
68048 0.009 0.000 0.009 0.000 {method 'strip' of 'str' objects}
95205 0.024 0.000 0.024 0.000 {method 'values' of 'dict' objects}
I turned it into this by removing sum()
(self._tokenize(token, **kwargs) if token not \
in self.added_tokens_encoder and token not in self.all_special_tokens \
else [token] for token in tokenized_text)
and here is the profiler result:
73275678 function calls in 121.030 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
27157 0.058 0.000 0.076 0.000 <frozen importlib._bootstrap>:1009(_handle_fromlist)
27157 0.041 0.000 0.084 0.000 _bootlocale.py:33(getpreferredencoding)
81471 0.058 0.000 0.211 0.000 locale.py:589(setlocale)
27157 0.330 0.000 0.625 0.000 locale.py:647(getpreferredencoding)
27157 0.267 0.000 4.996 0.000 regex.py:328(findall)
27157 0.434 0.000 1.160 0.000 regex.py:434(_compile)
4815114 9.797 0.000 18.875 0.000 tokenization_gpt2.py:139(bpe)
2030532 1.270 0.000 2.100 0.000 tokenization_gpt2.py:149(<lambda>)
27157 24.693 0.001 119.272 0.004 tokenization_gpt2.py:180(_tokenize)
25459527 6.204 0.000 6.204 0.000 tokenization_gpt2.py:194(<genexpr>)
10242602 1.975 0.000 1.975 0.000 tokenization_gpt2.py:195(<genexpr>)
1377876 2.002 0.000 2.328 0.000 tokenization_gpt2.py:91(get_pairs)
68050 0.287 0.000 0.475 0.000 tokenization_utils.py:1043(special_tokens_map)
68050 0.507 0.000 1.061 0.000 tokenization_utils.py:1055(all_special_tokens)
1 0.031 0.031 121.030 121.030 tokenization_utils.py:615(tokenize)
27263 0.077 0.000 0.158 0.000 tokenization_utils.py:623(split_on_token)
1 0.178 0.178 120.999 120.999 tokenization_utils.py:641(split_on_tokens)
1 0.330 0.330 120.350 120.350 tokenization_utils.py:659(<listcomp>)
27157 0.043 0.000 0.043 0.000 {built-in method _locale.nl_langinfo}
81471 0.148 0.000 0.148 0.000 {built-in method _locale.setlocale}
544400 0.188 0.000 0.188 0.000 {built-in method builtins.getattr}
54314 0.014 0.000 0.014 0.000 {built-in method builtins.hasattr}
407985 0.092 0.000 0.092 0.000 {built-in method builtins.isinstance}
1807921 0.181 0.000 0.181 0.000 {built-in method builtins.len}
472563 3.992 0.000 6.092 0.000 {built-in method builtins.min}
2043214 0.326 0.000 0.326 0.000 {method 'add' of 'set' objects}
456488 0.064 0.000 0.064 0.000 {method 'append' of 'list' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
4815114 1.259 0.000 1.259 0.000 {method 'encode' of 'str' objects}
5550977 18.064 0.000 20.040 0.000 {method 'extend' of 'list' objects}
27157 3.569 0.000 3.569 0.000 {method 'findall' of '_regex.Pattern' objects}
2057689 0.839 0.000 0.839 0.000 {method 'get' of 'dict' objects}
735863 0.273 0.000 0.273 0.000 {method 'index' of 'tuple' objects}
4894984 41.821 0.000 48.026 0.000 {method 'join' of 'str' objects}
1 0.000 0.000 0.000 0.000 {method 'keys' of 'dict' objects}
4842377 1.597 0.000 1.597 0.000 {method 'split' of 'str' objects}
54522 0.007 0.000 0.007 0.000 {method 'strip' of 'str' objects}
68050 0.012 0.000 0.012 0.000 {method 'values' of 'dict' objects}
You can see 121 seconds vs 1566 seconds. It is 12x times faster without sum(). Okay lets discuss do we need sum()? Actually, not. Because the sum() just flattens the array with the leanest way and there are far more efficient ways. See that answer on StackOverflow.
Also as written in official python doc , sum() is developed for numbers rather than strings.
So I replaced sum() with list(itertools.chain.from_iterable(text)) as follows and run profiler.
return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) if token not \
in self.added_tokens_encoder and token not in self.all_special_tokens \
else [token] for token in tokenized_text)))
Here is the result:
73791524 function calls in 114.720 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
27157 0.045 0.000 0.060 0.000 <frozen importlib._bootstrap>:1009(_handle_fromlist)
27157 0.035 0.000 0.067 0.000 _bootlocale.py:33(getpreferredencoding)
81471 0.045 0.000 0.159 0.000 locale.py:589(setlocale)
27157 0.277 0.000 0.502 0.000 locale.py:647(getpreferredencoding)
27157 0.237 0.000 4.258 0.000 regex.py:328(findall)
27157 0.346 0.000 0.929 0.000 regex.py:434(_compile)
4815114 8.703 0.000 16.973 0.000 tokenization_gpt2.py:139(bpe)
2030532 1.171 0.000 1.923 0.000 tokenization_gpt2.py:149(<lambda>)
27157 22.988 0.001 112.449 0.004 tokenization_gpt2.py:180(_tokenize)
25459527 5.708 0.000 5.708 0.000 tokenization_gpt2.py:194(<genexpr>)
10242602 1.755 0.000 1.755 0.000 tokenization_gpt2.py:195(<genexpr>)
1377876 1.595 0.000 1.900 0.000 tokenization_gpt2.py:91(get_pairs)
95205 0.345 0.000 0.565 0.000 tokenization_utils.py:1043(special_tokens_map)
95205 0.581 0.000 1.236 0.000 tokenization_utils.py:1055(all_special_tokens)
1 0.022 0.022 114.720 114.720 tokenization_utils.py:615(tokenize)
40789 0.103 0.000 0.182 0.000 tokenization_utils.py:623(split_on_token)
1 0.583 0.583 114.698 114.698 tokenization_utils.py:641(split_on_tokens)
54417 0.248 0.000 113.314 0.002 tokenization_utils.py:659(<genexpr>)
27157 0.032 0.000 0.032 0.000 {built-in method _locale.nl_langinfo}
81471 0.111 0.000 0.111 0.000 {built-in method _locale.setlocale}
761640 0.219 0.000 0.219 0.000 {built-in method builtins.getattr}
54314 0.012 0.000 0.012 0.000 {built-in method builtins.hasattr}
516605 0.097 0.000 0.097 0.000 {built-in method builtins.isinstance}
1821447 0.166 0.000 0.166 0.000 {built-in method builtins.len}
472563 3.855 0.000 5.777 0.000 {built-in method builtins.min}
1 0.000 0.000 0.000 0.000 {built-in method from_iterable}
2043214 0.305 0.000 0.305 0.000 {method 'add' of 'set' objects}
456488 0.058 0.000 0.058 0.000 {method 'append' of 'list' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
4815114 1.104 0.000 1.104 0.000 {method 'encode' of 'str' objects}
5550977 17.434 0.000 19.189 0.000 {method 'extend' of 'list' objects}
27157 3.092 0.000 3.092 0.000 {method 'findall' of '_regex.Pattern' objects}
2057689 0.759 0.000 0.759 0.000 {method 'get' of 'dict' objects}
735863 0.243 0.000 0.243 0.000 {method 'index' of 'tuple' objects}
4894984 41.030 0.000 46.738 0.000 {method 'join' of 'str' objects}
1 0.000 0.000 0.000 0.000 {method 'keys' of 'dict' objects}
4855903 1.396 0.000 1.396 0.000 {method 'split' of 'str' objects}
68048 0.009 0.000 0.009 0.000 {method 'strip' of 'str' objects}
95205 0.013 0.000 0.013 0.000 {method 'values' of 'dict' objects}
It significantly improves the speed as seen in the difference between 114 seconds and 1566 seconds.
I'm going to create a pull request if everything is clear?
Thank you for your effort.
Thank you for taking the time to look into this and opening a pull request!
You are welcome,
This thing has big effort and I always support as a community member :) .
If you merge PR I would be appreciated because I want to use it originally as provided in the master branch, for my ongoing project.
Thanks!
Your PR was merged, you can now use it from the master branch :)
Feel free to open other issues if you find other sub-optimal processes.
Most helpful comment
Your PR was merged, you can now use it from the master branch :)
Feel free to open other issues if you find other sub-optimal processes.