Keras: Flow from generator

Created on 17 Mar 2017  路  5Comments  路  Source: keras-team/keras

Currently the ImageDataGenerator flow support either X, y or directory as input.
Is it possible to flow from my custom generator?

Such as:
```
model.fit_generator(datagen.flow(my_generator, batch_size=32),
samples_per_epoch=len(X_train), epochs=epochs)

stale

Most helpful comment

Is there any update for this problem. I think it is common issue. How we can handle it when you have your image generator and need to use ImageDataGenerator for augmentation? I have the same problem and need for solution.

All 5 comments

You don't have to include ImageDataGenerator here. ImageDataGenerator.flow simply returns an iterator derived from preprocessing.image.Iterator that returns (X, y) tuples each containing batch_size rows.

If my_generator is a similar iterator, you can use it directly:

model.fit_generator(my_generator, samples_per_epoch=len(X_train), epochs=epochs)

I think the need of OP was to use a custom generator that yields images that could then be augmented by ImageDataGenerator

Edit: You can create your custom MyImageDataGenerator class that inherits from the original and replace its default flow function by your custom generator.

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

Is there any update for this problem. I think it is common issue. How we can handle it when you have your image generator and need to use ImageDataGenerator for augmentation? I have the same problem and need for solution.

Hey there, i just came up with a solution for this!

The trick is to use the DataImageGenerator.flow on a batch of your images in your own generator before returning them.

Keras Generator

img_gen = image.ImageDataGenerator(shear_range=0.1, rotation_range=50, 
                                         width_shift_range=0.2, height_shift_range=0.2, 
                                         fill_mode='reflect',
                                         horizontal_flip = True, vertical_flip = False)

Generator

# end function with this 
yield pre_process(batch_x, batch_y, img_gen)

Pre-processing function

def pre_process(x_batch, y_batch, img_gen):
  for xt, yt in image_datagen.flow(x_batch.reshape(-1, x_batch.shape[1], x_batch.shape[2], 1), y_batch, batch_size = len(x_batch)):
    return xt, yt
Was this page helpful?
0 / 5 - 0 ratings