Hi,
As the discussion started in that #1621 issue, GPT2 tokenization is so slow even with 50MB of the dataset.
I'm using run_lm_finetuning.py
and here are the steps to reproduce the problem:
run_lm_finetuning.py
to train (finetune) the dataset. Here are my parameters:--train_data_file "/train/datafile" \
--eval_data_file "/eval/datafile" \
--output_dir "/train/model" \
--model_type gpt2 \
--model_name_or_path distilgpt2 \
--cache_dir "/train/cache" \
--do_train \
--evaluate_during_training \
--per_gpu_train_batch_size 1 \
--per_gpu_eval_batch_size 1 \
--gradient_accumulation_steps 5 \
--overwrite_output_dir \
--seed 99
I dug into huggingface/transformers
's codebase and profiled the tokenization process. And it is obvious that this summation drains the time:
https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/tokenization_utils.py#L644
I run profiler and here is the result:
73791524 function calls in 1566.379 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
27157 0.083 0.000 0.109 0.000 <frozen importlib._bootstrap>:1009(_handle_fromlist)
27157 0.065 0.000 0.128 0.000 _bootlocale.py:33(getpreferredencoding)
81471 0.070 0.000 0.327 0.000 locale.py:589(setlocale)
27157 0.422 0.000 0.876 0.000 locale.py:647(getpreferredencoding)
27157 0.363 0.000 5.997 0.000 regex.py:328(findall)
27157 0.662 0.000 1.682 0.000 regex.py:434(_compile)
4815114 8.744 0.000 16.641 0.000 tokenization_gpt2.py:139(bpe)
2030532 1.116 0.000 1.887 0.000 tokenization_gpt2.py:149(<lambda>)
27157 22.702 0.001 110.038 0.004 tokenization_gpt2.py:180(_tokenize)
25459527 5.702 0.000 5.702 0.000 tokenization_gpt2.py:194(<genexpr>)
10242602 1.764 0.000 1.764 0.000 tokenization_gpt2.py:195(<genexpr>)
1377876 1.678 0.000 1.975 0.000 tokenization_gpt2.py:91(get_pairs)
95205 0.526 0.000 0.910 0.000 tokenization_utils.py:1043(special_tokens_map)
95205 0.932 0.000 1.987 0.000 tokenization_utils.py:1055(all_special_tokens)
1 0.119 0.119 1566.379 1566.379 tokenization_utils.py:615(tokenize)
40789 0.099 0.000 0.169 0.000 tokenization_utils.py:623(split_on_token)
1 0.287 0.287 1566.260 1566.260 tokenization_utils.py:641(split_on_tokens)
54417 0.698 0.000 112.123 0.002 tokenization_utils.py:659(<genexpr>)
27157 0.063 0.000 0.063 0.000 {built-in method _locale.nl_langinfo}
81471 0.252 0.000 0.252 0.000 {built-in method _locale.setlocale}
761640 0.384 0.000 0.384 0.000 {built-in method builtins.getattr}
54314 0.022 0.000 0.022 0.000 {built-in method builtins.hasattr}
516605 0.150 0.000 0.150 0.000 {built-in method builtins.isinstance}
1821447 0.159 0.000 0.159 0.000 {built-in method builtins.len}
472563 3.469 0.000 5.355 0.000 {built-in method builtins.min}
1 1453.081 1453.081 1565.204 1565.204 {built-in method builtins.sum}
2043214 0.297 0.000 0.297 0.000 {method 'add' of 'set' objects}
456488 0.055 0.000 0.055 0.000 {method 'append' of 'list' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
4815114 1.169 0.000 1.169 0.000 {method 'encode' of 'str' objects}
5550977 16.572 0.000 18.336 0.000 {method 'extend' of 'list' objects}
27157 3.952 0.000 3.952 0.000 {method 'findall' of '_regex.Pattern' objects}
2057689 0.784 0.000 0.784 0.000 {method 'get' of 'dict' objects}
735863 0.233 0.000 0.233 0.000 {method 'index' of 'tuple' objects}
4894984 38.307 0.000 44.010 0.000 {method 'join' of 'str' objects}
1 0.000 0.000 0.000 0.000 {method 'keys' of 'dict' objects}
4855903 1.365 0.000 1.365 0.000 {method 'split' of 'str' objects}
68048 0.009 0.000 0.009 0.000 {method 'strip' of 'str' objects}
95205 0.024 0.000 0.024 0.000 {method 'values' of 'dict' objects}
I turned it into this by removing sum()
(self._tokenize(token, **kwargs) if token not \
in self.added_tokens_encoder and token not in self.all_special_tokens \
else [token] for token in tokenized_text)
and here is the profiler result:
73275678 function calls in 121.030 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
27157 0.058 0.000 0.076 0.000 <frozen importlib._bootstrap>:1009(_handle_fromlist)
27157 0.041 0.000 0.084 0.000 _bootlocale.py:33(getpreferredencoding)
81471 0.058 0.000 0.211 0.000 locale.py:589(setlocale)
27157 0.330 0.000 0.625 0.000 locale.py:647(getpreferredencoding)
27157 0.267 0.000 4.996 0.000 regex.py:328(findall)
27157 0.434 0.000 1.160 0.000 regex.py:434(_compile)
4815114 9.797 0.000 18.875 0.000 tokenization_gpt2.py:139(bpe)
2030532 1.270 0.000 2.100 0.000 tokenization_gpt2.py:149(<lambda>)
27157 24.693 0.001 119.272 0.004 tokenization_gpt2.py:180(_tokenize)
25459527 6.204 0.000 6.204 0.000 tokenization_gpt2.py:194(<genexpr>)
10242602 1.975 0.000 1.975 0.000 tokenization_gpt2.py:195(<genexpr>)
1377876 2.002 0.000 2.328 0.000 tokenization_gpt2.py:91(get_pairs)
68050 0.287 0.000 0.475 0.000 tokenization_utils.py:1043(special_tokens_map)
68050 0.507 0.000 1.061 0.000 tokenization_utils.py:1055(all_special_tokens)
1 0.031 0.031 121.030 121.030 tokenization_utils.py:615(tokenize)
27263 0.077 0.000 0.158 0.000 tokenization_utils.py:623(split_on_token)
1 0.178 0.178 120.999 120.999 tokenization_utils.py:641(split_on_tokens)
1 0.330 0.330 120.350 120.350 tokenization_utils.py:659(<listcomp>)
27157 0.043 0.000 0.043 0.000 {built-in method _locale.nl_langinfo}
81471 0.148 0.000 0.148 0.000 {built-in method _locale.setlocale}
544400 0.188 0.000 0.188 0.000 {built-in method builtins.getattr}
54314 0.014 0.000 0.014 0.000 {built-in method builtins.hasattr}
407985 0.092 0.000 0.092 0.000 {built-in method builtins.isinstance}
1807921 0.181 0.000 0.181 0.000 {built-in method builtins.len}
472563 3.992 0.000 6.092 0.000 {built-in method builtins.min}
2043214 0.326 0.000 0.326 0.000 {method 'add' of 'set' objects}
456488 0.064 0.000 0.064 0.000 {method 'append' of 'list' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
4815114 1.259 0.000 1.259 0.000 {method 'encode' of 'str' objects}
5550977 18.064 0.000 20.040 0.000 {method 'extend' of 'list' objects}
27157 3.569 0.000 3.569 0.000 {method 'findall' of '_regex.Pattern' objects}
2057689 0.839 0.000 0.839 0.000 {method 'get' of 'dict' objects}
735863 0.273 0.000 0.273 0.000 {method 'index' of 'tuple' objects}
4894984 41.821 0.000 48.026 0.000 {method 'join' of 'str' objects}
1 0.000 0.000 0.000 0.000 {method 'keys' of 'dict' objects}
4842377 1.597 0.000 1.597 0.000 {method 'split' of 'str' objects}
54522 0.007 0.000 0.007 0.000 {method 'strip' of 'str' objects}
68050 0.012 0.000 0.012 0.000 {method 'values' of 'dict' objects}
You can see 121 seconds vs 1566 seconds. It is 12x times faster without sum()
. Okay lets discuss do we need sum()
? Actually, not. Because the sum()
just flattens the array with the leanest way and there are far more efficient ways. See that answer on StackOverflow.
Also as written in official python doc , sum()
is developed for numbers rather than strings.
So I replaced sum()
with list(itertools.chain.from_iterable(text))
as follows and run profiler.
return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) if token not \
in self.added_tokens_encoder and token not in self.all_special_tokens \
else [token] for token in tokenized_text)))
Here is the result:
73791524 function calls in 114.720 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
27157 0.045 0.000 0.060 0.000 <frozen importlib._bootstrap>:1009(_handle_fromlist)
27157 0.035 0.000 0.067 0.000 _bootlocale.py:33(getpreferredencoding)
81471 0.045 0.000 0.159 0.000 locale.py:589(setlocale)
27157 0.277 0.000 0.502 0.000 locale.py:647(getpreferredencoding)
27157 0.237 0.000 4.258 0.000 regex.py:328(findall)
27157 0.346 0.000 0.929 0.000 regex.py:434(_compile)
4815114 8.703 0.000 16.973 0.000 tokenization_gpt2.py:139(bpe)
2030532 1.171 0.000 1.923 0.000 tokenization_gpt2.py:149(<lambda>)
27157 22.988 0.001 112.449 0.004 tokenization_gpt2.py:180(_tokenize)
25459527 5.708 0.000 5.708 0.000 tokenization_gpt2.py:194(<genexpr>)
10242602 1.755 0.000 1.755 0.000 tokenization_gpt2.py:195(<genexpr>)
1377876 1.595 0.000 1.900 0.000 tokenization_gpt2.py:91(get_pairs)
95205 0.345 0.000 0.565 0.000 tokenization_utils.py:1043(special_tokens_map)
95205 0.581 0.000 1.236 0.000 tokenization_utils.py:1055(all_special_tokens)
1 0.022 0.022 114.720 114.720 tokenization_utils.py:615(tokenize)
40789 0.103 0.000 0.182 0.000 tokenization_utils.py:623(split_on_token)
1 0.583 0.583 114.698 114.698 tokenization_utils.py:641(split_on_tokens)
54417 0.248 0.000 113.314 0.002 tokenization_utils.py:659(<genexpr>)
27157 0.032 0.000 0.032 0.000 {built-in method _locale.nl_langinfo}
81471 0.111 0.000 0.111 0.000 {built-in method _locale.setlocale}
761640 0.219 0.000 0.219 0.000 {built-in method builtins.getattr}
54314 0.012 0.000 0.012 0.000 {built-in method builtins.hasattr}
516605 0.097 0.000 0.097 0.000 {built-in method builtins.isinstance}
1821447 0.166 0.000 0.166 0.000 {built-in method builtins.len}
472563 3.855 0.000 5.777 0.000 {built-in method builtins.min}
1 0.000 0.000 0.000 0.000 {built-in method from_iterable}
2043214 0.305 0.000 0.305 0.000 {method 'add' of 'set' objects}
456488 0.058 0.000 0.058 0.000 {method 'append' of 'list' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
4815114 1.104 0.000 1.104 0.000 {method 'encode' of 'str' objects}
5550977 17.434 0.000 19.189 0.000 {method 'extend' of 'list' objects}
27157 3.092 0.000 3.092 0.000 {method 'findall' of '_regex.Pattern' objects}
2057689 0.759 0.000 0.759 0.000 {method 'get' of 'dict' objects}
735863 0.243 0.000 0.243 0.000 {method 'index' of 'tuple' objects}
4894984 41.030 0.000 46.738 0.000 {method 'join' of 'str' objects}
1 0.000 0.000 0.000 0.000 {method 'keys' of 'dict' objects}
4855903 1.396 0.000 1.396 0.000 {method 'split' of 'str' objects}
68048 0.009 0.000 0.009 0.000 {method 'strip' of 'str' objects}
95205 0.013 0.000 0.013 0.000 {method 'values' of 'dict' objects}
It significantly improves the speed as seen in the difference between 114 seconds and 1566 seconds.
I'm going to create a pull request if everything is clear?
Thank you for your effort.
Thank you for taking the time to look into this and opening a pull request!
You are welcome,
This thing has big effort and I always support as a community member :) .
If you merge PR I would be appreciated because I want to use it originally as provided in the master branch, for my ongoing project.
Thanks!
Your PR was merged, you can now use it from the master branch :)
Feel free to open other issues if you find other sub-optimal processes.
Most helpful comment
Your PR was merged, you can now use it from the master branch :)
Feel free to open other issues if you find other sub-optimal processes.