Keras: Improving the speed of fit_generator

Created on 16 May 2016  路  15Comments  路  Source: keras-team/keras

I have found fit_generator to be quite slow.
This makes the disk read speed a major bottleneck when the network is relatively shallow.

Using the multiprocessing module instead of the threading one, I was able to get significant improvements in speed (half the time) as shown in this gist example.

I could look into making a pull request with this new implementation if you think the speed gain is worth it.

Most helpful comment

for people complaining about keras 2 being way slower. make sure that you made the following change when fitting your model:

- model.fit_generator(train_generator, train_gen.n, 1)  (old way)
+ model.fit_generator(train_generator, train_gen.n/train_gen.batch_size, 1) (new way)

similar issue: https://github.com/fchollet/keras/issues/6406#issuecomment-308248241

All 15 comments

Sounds awesome.

Sounds interesting. How safe is it?

A PR would be welcome.

I'll write a PR asap then.

I have not made any safety test but have never encountered any issue with this snippet.

A few handwaving arguments for safety (over threading):

  • multiprocessing launches several independent processes. When one process crashes, it should not affect the other processes (whereas a crashing thread can)
  • multiprocessing effectively bypasses GIL limitations which supposedly is a plus for stability

One catch I had with my approach was the need to reset the random seed for each process. That's because processes forked from the same parent process share the parent's seed. That may make reproducibility a bit harder...

@fchollet :

I have written a new generator_queue function to implement the multiprocessing approach.

From unit tests, the memory usage increases quite significantly when you add more processes (with 4 processes, memory usage = x2 wrt threading approach)

I can write the PR with multiprocessing as an option so that the user chooses what best fits his system. Let me know what you think.

I replaced the import part in training.py and now it works fine:

import multiprocessing as threading
import multiprocessing as queue

also replace threading.thread with threading.Process.

Otherwise the original generator thread is very stupid and it can be witnessed that there are no samples in the queue at some period due to the existence of GIL.

I have also implemented my own caching generator using multiprocessing in order to address this bottleneck. Any idea on when this can be incorporated on head?

I can write the PR with multiprocessing as an option so that the user chooses what best fits his system. Let me know what you think.

You can just submit a PR using multiprocessing, with no further options.

@fchollet : OK, will PR asap

@bobchennan : Did you verify you obtained a speedup ? I'd be surprised if you got any without actually specifying to use multiple processes somewhere in the code.

@tdeboissiere : Regarding reproducible random numbers in threads: I didn't look at your code to see how you're re-seeding, nor to see if it's been resolved, but why can't it be re-seeded in a consistent way? The simplest form [not suitable for security/cryptography applications] might be something like: seed(random() + incremental_thread_id). Due to random being shared among threads, each thread would have to also have the random state saved (random.getstate() or numpy.random.get_state()) or, probably preferable, its own instance of the random class (http://stackoverflow.com/questions/5836335/consistenly-create-same-random-numpy-array/5837352#5837352). If keeping various states (the most irritating way), one would probably best keep a pool of random numbers to minimize the overhead of state changing, but I think this is unnecessary with the separate class instances.

Keras 1.2.2 does have this multiprocessing included, but my fit_generator() is still about 4x slower than fit(). Can this be further sped up?

I installed Keras 2.0.2 and the same code with fit_generator is slower than using Keras version 1.1. I think the changes were not incorporated to increase the speed.

Same question here!!!
After updated to keras 2.0.2, fit_generator seems slower. Any update about this issue?

Same here :( !!

Can confirm it is 10x to 100x slower. I downgraded to 1.2.2

for people complaining about keras 2 being way slower. make sure that you made the following change when fitting your model:

- model.fit_generator(train_generator, train_gen.n, 1)  (old way)
+ model.fit_generator(train_generator, train_gen.n/train_gen.batch_size, 1) (new way)

similar issue: https://github.com/fchollet/keras/issues/6406#issuecomment-308248241

Was this page helpful?
0 / 5 - 0 ratings