Transformers: GPT2 tokenizer is so slow because of sum()

Created on 14 Nov 2019  路  3Comments  路  Source: huggingface/transformers

馃悰 Bug

Hi,
As the discussion started in that #1621 issue, GPT2 tokenization is so slow even with 50MB of the dataset.

I'm using run_lm_finetuning.py and here are the steps to reproduce the problem:

  • Have a dataset not an even bigger one. 20MB of a dataset is enough.
  • Call run_lm_finetuning.py to train (finetune) the dataset. Here are my parameters:
--train_data_file "/train/datafile" \
--eval_data_file "/eval/datafile" \
--output_dir "/train/model" \
--model_type gpt2 \
--model_name_or_path distilgpt2 \
--cache_dir "/train/cache" \
--do_train \
--evaluate_during_training \
--per_gpu_train_batch_size 1 \
--per_gpu_eval_batch_size 1 \
--gradient_accumulation_steps 5 \
--overwrite_output_dir \
--seed 99
  • You'll see it'll spend 20+ mins (depends on your cpu) to tokenize just 50MB of a text file.

I dug into huggingface/transformers 's codebase and profiled the tokenization process. And it is obvious that this summation drains the time:
https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/tokenization_utils.py#L644

I run profiler and here is the result:

73791524 function calls in 1566.379 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    27157    0.083    0.000    0.109    0.000 <frozen importlib._bootstrap>:1009(_handle_fromlist)
    27157    0.065    0.000    0.128    0.000 _bootlocale.py:33(getpreferredencoding)
    81471    0.070    0.000    0.327    0.000 locale.py:589(setlocale)
    27157    0.422    0.000    0.876    0.000 locale.py:647(getpreferredencoding)
    27157    0.363    0.000    5.997    0.000 regex.py:328(findall)
    27157    0.662    0.000    1.682    0.000 regex.py:434(_compile)
  4815114    8.744    0.000   16.641    0.000 tokenization_gpt2.py:139(bpe)
  2030532    1.116    0.000    1.887    0.000 tokenization_gpt2.py:149(<lambda>)
    27157   22.702    0.001  110.038    0.004 tokenization_gpt2.py:180(_tokenize)
 25459527    5.702    0.000    5.702    0.000 tokenization_gpt2.py:194(<genexpr>)
 10242602    1.764    0.000    1.764    0.000 tokenization_gpt2.py:195(<genexpr>)
  1377876    1.678    0.000    1.975    0.000 tokenization_gpt2.py:91(get_pairs)
    95205    0.526    0.000    0.910    0.000 tokenization_utils.py:1043(special_tokens_map)
    95205    0.932    0.000    1.987    0.000 tokenization_utils.py:1055(all_special_tokens)
        1    0.119    0.119 1566.379 1566.379 tokenization_utils.py:615(tokenize)
    40789    0.099    0.000    0.169    0.000 tokenization_utils.py:623(split_on_token)
        1    0.287    0.287 1566.260 1566.260 tokenization_utils.py:641(split_on_tokens)
    54417    0.698    0.000  112.123    0.002 tokenization_utils.py:659(<genexpr>)
    27157    0.063    0.000    0.063    0.000 {built-in method _locale.nl_langinfo}
    81471    0.252    0.000    0.252    0.000 {built-in method _locale.setlocale}
   761640    0.384    0.000    0.384    0.000 {built-in method builtins.getattr}
    54314    0.022    0.000    0.022    0.000 {built-in method builtins.hasattr}
   516605    0.150    0.000    0.150    0.000 {built-in method builtins.isinstance}
  1821447    0.159    0.000    0.159    0.000 {built-in method builtins.len}
   472563    3.469    0.000    5.355    0.000 {built-in method builtins.min}
        1 1453.081 1453.081 1565.204 1565.204 {built-in method builtins.sum}
  2043214    0.297    0.000    0.297    0.000 {method 'add' of 'set' objects}
   456488    0.055    0.000    0.055    0.000 {method 'append' of 'list' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
  4815114    1.169    0.000    1.169    0.000 {method 'encode' of 'str' objects}
  5550977   16.572    0.000   18.336    0.000 {method 'extend' of 'list' objects}
    27157    3.952    0.000    3.952    0.000 {method 'findall' of '_regex.Pattern' objects}
  2057689    0.784    0.000    0.784    0.000 {method 'get' of 'dict' objects}
   735863    0.233    0.000    0.233    0.000 {method 'index' of 'tuple' objects}
  4894984   38.307    0.000   44.010    0.000 {method 'join' of 'str' objects}
        1    0.000    0.000    0.000    0.000 {method 'keys' of 'dict' objects}
  4855903    1.365    0.000    1.365    0.000 {method 'split' of 'str' objects}
    68048    0.009    0.000    0.009    0.000 {method 'strip' of 'str' objects}
    95205    0.024    0.000    0.024    0.000 {method 'values' of 'dict' objects}

I turned it into this by removing sum()

(self._tokenize(token, **kwargs) if token not \
                    in self.added_tokens_encoder and token not in self.all_special_tokens \
                    else [token] for token in tokenized_text)

and here is the profiler result:

   73275678 function calls in 121.030 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    27157    0.058    0.000    0.076    0.000 <frozen importlib._bootstrap>:1009(_handle_fromlist)
    27157    0.041    0.000    0.084    0.000 _bootlocale.py:33(getpreferredencoding)
    81471    0.058    0.000    0.211    0.000 locale.py:589(setlocale)
    27157    0.330    0.000    0.625    0.000 locale.py:647(getpreferredencoding)
    27157    0.267    0.000    4.996    0.000 regex.py:328(findall)
    27157    0.434    0.000    1.160    0.000 regex.py:434(_compile)
  4815114    9.797    0.000   18.875    0.000 tokenization_gpt2.py:139(bpe)
  2030532    1.270    0.000    2.100    0.000 tokenization_gpt2.py:149(<lambda>)
    27157   24.693    0.001  119.272    0.004 tokenization_gpt2.py:180(_tokenize)
 25459527    6.204    0.000    6.204    0.000 tokenization_gpt2.py:194(<genexpr>)
 10242602    1.975    0.000    1.975    0.000 tokenization_gpt2.py:195(<genexpr>)
  1377876    2.002    0.000    2.328    0.000 tokenization_gpt2.py:91(get_pairs)
    68050    0.287    0.000    0.475    0.000 tokenization_utils.py:1043(special_tokens_map)
    68050    0.507    0.000    1.061    0.000 tokenization_utils.py:1055(all_special_tokens)
        1    0.031    0.031  121.030  121.030 tokenization_utils.py:615(tokenize)
    27263    0.077    0.000    0.158    0.000 tokenization_utils.py:623(split_on_token)
        1    0.178    0.178  120.999  120.999 tokenization_utils.py:641(split_on_tokens)
        1    0.330    0.330  120.350  120.350 tokenization_utils.py:659(<listcomp>)
    27157    0.043    0.000    0.043    0.000 {built-in method _locale.nl_langinfo}
    81471    0.148    0.000    0.148    0.000 {built-in method _locale.setlocale}
   544400    0.188    0.000    0.188    0.000 {built-in method builtins.getattr}
    54314    0.014    0.000    0.014    0.000 {built-in method builtins.hasattr}
   407985    0.092    0.000    0.092    0.000 {built-in method builtins.isinstance}
  1807921    0.181    0.000    0.181    0.000 {built-in method builtins.len}
   472563    3.992    0.000    6.092    0.000 {built-in method builtins.min}
  2043214    0.326    0.000    0.326    0.000 {method 'add' of 'set' objects}
   456488    0.064    0.000    0.064    0.000 {method 'append' of 'list' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
  4815114    1.259    0.000    1.259    0.000 {method 'encode' of 'str' objects}
  5550977   18.064    0.000   20.040    0.000 {method 'extend' of 'list' objects}
    27157    3.569    0.000    3.569    0.000 {method 'findall' of '_regex.Pattern' objects}
  2057689    0.839    0.000    0.839    0.000 {method 'get' of 'dict' objects}
   735863    0.273    0.000    0.273    0.000 {method 'index' of 'tuple' objects}
  4894984   41.821    0.000   48.026    0.000 {method 'join' of 'str' objects}
        1    0.000    0.000    0.000    0.000 {method 'keys' of 'dict' objects}
  4842377    1.597    0.000    1.597    0.000 {method 'split' of 'str' objects}
    54522    0.007    0.000    0.007    0.000 {method 'strip' of 'str' objects}
    68050    0.012    0.000    0.012    0.000 {method 'values' of 'dict' objects}

You can see 121 seconds vs 1566 seconds. It is 12x times faster without sum(). Okay lets discuss do we need sum()? Actually, not. Because the sum() just flattens the array with the leanest way and there are far more efficient ways. See that answer on StackOverflow.
Also as written in official python doc , sum() is developed for numbers rather than strings.

So I replaced sum() with list(itertools.chain.from_iterable(text)) as follows and run profiler.

return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) if token not \
                    in self.added_tokens_encoder and token not in self.all_special_tokens \
                    else [token] for token in tokenized_text)))

Here is the result:

         73791524 function calls in 114.720 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    27157    0.045    0.000    0.060    0.000 <frozen importlib._bootstrap>:1009(_handle_fromlist)
    27157    0.035    0.000    0.067    0.000 _bootlocale.py:33(getpreferredencoding)
    81471    0.045    0.000    0.159    0.000 locale.py:589(setlocale)
    27157    0.277    0.000    0.502    0.000 locale.py:647(getpreferredencoding)
    27157    0.237    0.000    4.258    0.000 regex.py:328(findall)
    27157    0.346    0.000    0.929    0.000 regex.py:434(_compile)
  4815114    8.703    0.000   16.973    0.000 tokenization_gpt2.py:139(bpe)
  2030532    1.171    0.000    1.923    0.000 tokenization_gpt2.py:149(<lambda>)
    27157   22.988    0.001  112.449    0.004 tokenization_gpt2.py:180(_tokenize)
 25459527    5.708    0.000    5.708    0.000 tokenization_gpt2.py:194(<genexpr>)
 10242602    1.755    0.000    1.755    0.000 tokenization_gpt2.py:195(<genexpr>)
  1377876    1.595    0.000    1.900    0.000 tokenization_gpt2.py:91(get_pairs)
    95205    0.345    0.000    0.565    0.000 tokenization_utils.py:1043(special_tokens_map)
    95205    0.581    0.000    1.236    0.000 tokenization_utils.py:1055(all_special_tokens)
        1    0.022    0.022  114.720  114.720 tokenization_utils.py:615(tokenize)
    40789    0.103    0.000    0.182    0.000 tokenization_utils.py:623(split_on_token)
        1    0.583    0.583  114.698  114.698 tokenization_utils.py:641(split_on_tokens)
    54417    0.248    0.000  113.314    0.002 tokenization_utils.py:659(<genexpr>)
    27157    0.032    0.000    0.032    0.000 {built-in method _locale.nl_langinfo}
    81471    0.111    0.000    0.111    0.000 {built-in method _locale.setlocale}
   761640    0.219    0.000    0.219    0.000 {built-in method builtins.getattr}
    54314    0.012    0.000    0.012    0.000 {built-in method builtins.hasattr}
   516605    0.097    0.000    0.097    0.000 {built-in method builtins.isinstance}
  1821447    0.166    0.000    0.166    0.000 {built-in method builtins.len}
   472563    3.855    0.000    5.777    0.000 {built-in method builtins.min}
        1    0.000    0.000    0.000    0.000 {built-in method from_iterable}
  2043214    0.305    0.000    0.305    0.000 {method 'add' of 'set' objects}
   456488    0.058    0.000    0.058    0.000 {method 'append' of 'list' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
  4815114    1.104    0.000    1.104    0.000 {method 'encode' of 'str' objects}
  5550977   17.434    0.000   19.189    0.000 {method 'extend' of 'list' objects}
    27157    3.092    0.000    3.092    0.000 {method 'findall' of '_regex.Pattern' objects}
  2057689    0.759    0.000    0.759    0.000 {method 'get' of 'dict' objects}
   735863    0.243    0.000    0.243    0.000 {method 'index' of 'tuple' objects}
  4894984   41.030    0.000   46.738    0.000 {method 'join' of 'str' objects}
        1    0.000    0.000    0.000    0.000 {method 'keys' of 'dict' objects}
  4855903    1.396    0.000    1.396    0.000 {method 'split' of 'str' objects}
    68048    0.009    0.000    0.009    0.000 {method 'strip' of 'str' objects}
    95205    0.013    0.000    0.013    0.000 {method 'values' of 'dict' objects}

It significantly improves the speed as seen in the difference between 114 seconds and 1566 seconds.

I'm going to create a pull request if everything is clear?
Thank you for your effort.

Most helpful comment

Your PR was merged, you can now use it from the master branch :)
Feel free to open other issues if you find other sub-optimal processes.

All 3 comments

Thank you for taking the time to look into this and opening a pull request!

You are welcome,
This thing has big effort and I always support as a community member :) .
If you merge PR I would be appreciated because I want to use it originally as provided in the master branch, for my ongoing project.
Thanks!

Your PR was merged, you can now use it from the master branch :)
Feel free to open other issues if you find other sub-optimal processes.

Was this page helpful?
0 / 5 - 0 ratings