Mask_rcnn: model.py's datagenerator() may not be threadsafe

Created on 3 Nov 2017  路  26Comments  路  Source: matterport/Mask_RCNN

I haven't been able to run train_shapes.ipynb to completion perhaps because of threading issues. Training the head branches fails with the following output:

Starting at epoch 0. LR=0.002

Checkpoint Path: E:\repos\Mask_RCNN.wip\logs\shapes20171102T1726\mask_rcnn_shapes_{epoch:04d}.h5
Selecting layers to train
fpn_c5p5               (Conv2D)
fpn_c4p4               (Conv2D)
fpn_c3p3               (Conv2D)
fpn_c2p2               (Conv2D)
fpn_p5                 (Conv2D)
fpn_p2                 (Conv2D)
fpn_p3                 (Conv2D)
fpn_p4                 (Conv2D)
In model:  rpn_model
    rpn_conv_shared        (Conv2D)
    rpn_class_raw          (Conv2D)
    rpn_bbox_pred          (Conv2D)
mrcnn_mask_conv1       (TimeDistributed)
mrcnn_mask_bn1         (TimeDistributed)
mrcnn_mask_conv2       (TimeDistributed)
mrcnn_class_conv1      (TimeDistributed)
mrcnn_mask_bn2         (TimeDistributed)
mrcnn_class_bn1        (TimeDistributed)
mrcnn_mask_conv3       (TimeDistributed)
mrcnn_mask_bn3         (TimeDistributed)
mrcnn_class_conv2      (TimeDistributed)
mrcnn_class_bn2        (TimeDistributed)
mrcnn_mask_conv4       (TimeDistributed)
mrcnn_mask_bn4         (TimeDistributed)
mrcnn_bbox_fc          (TimeDistributed)
mrcnn_mask_deconv      (TimeDistributed)
mrcnn_class_logits     (TimeDistributed)
mrcnn_mask             (TimeDistributed)
e:\toolkits.win\anaconda3-4.4.0\envs\dlwin36coco\lib\site-packages\tensorflow\python\ops\gradients_impl.py:95: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
e:\toolkits.win\anaconda3-4.4.0\envs\dlwin36coco\lib\site-packages\keras\engine\training.py:1987: UserWarning: Using a generator with `use_multiprocessing=True` and multiple workers may duplicate your data. Please consider using the`keras.utils.Sequence class.
  UserWarning('Using a generator with `use_multiprocessing=True`'
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-11-2101606c7d8e> in <module>()
      6             learning_rate=config.LEARNING_RATE,
      7             epochs=1,
----> 8             layers='heads')

E:\repos\Mask_RCNN.wip\model.py in train(self, train_dataset, val_dataset, learning_rate, epochs, layers)
   2089             initial_epoch=self.epoch,
   2090             epochs=epochs,
-> 2091             **fit_kwargs
   2092             )
   2093         self.epoch = max(self.epoch, epochs)

e:\toolkits.win\anaconda3-4.4.0\envs\dlwin36coco\lib\site-packages\keras\legacy\interfaces.py in wrapper(*args, **kwargs)
     85                 warnings.warn('Update your `' + object_name +
     86                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 87             return func(*args, **kwargs)
     88         wrapper._original_function = func
     89         return wrapper

e:\toolkits.win\anaconda3-4.4.0\envs\dlwin36coco\lib\site-packages\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   2000                                              use_multiprocessing=use_multiprocessing,
   2001                                              wait_time=wait_time)
-> 2002             enqueuer.start(workers=workers, max_queue_size=max_queue_size)
   2003             output_generator = enqueuer.get()
   2004 

e:\toolkits.win\anaconda3-4.4.0\envs\dlwin36coco\lib\site-packages\keras\utils\data_utils.py in start(self, workers, max_queue_size)
    594                     thread = threading.Thread(target=data_generator_task)
    595                 self._threads.append(thread)
--> 596                 thread.start()
    597         except:
    598             self.stop()

e:\toolkits.win\anaconda3-4.4.0\envs\dlwin36coco\lib\multiprocessing\process.py in start(self)
    103                'daemonic processes are not allowed to have children'
    104         _cleanup()
--> 105         self._popen = self._Popen(self)
    106         self._sentinel = self._popen.sentinel
    107         _children.add(self)

e:\toolkits.win\anaconda3-4.4.0\envs\dlwin36coco\lib\multiprocessing\context.py in _Popen(process_obj)
    221     @staticmethod
    222     def _Popen(process_obj):
--> 223         return _default_context.get_context().Process._Popen(process_obj)
    224 
    225 class DefaultContext(BaseContext):

e:\toolkits.win\anaconda3-4.4.0\envs\dlwin36coco\lib\multiprocessing\context.py in _Popen(process_obj)
    320         def _Popen(process_obj):
    321             from .popen_spawn_win32 import Popen
--> 322             return Popen(process_obj)
    323 
    324     class SpawnContext(BaseContext):

e:\toolkits.win\anaconda3-4.4.0\envs\dlwin36coco\lib\multiprocessing\popen_spawn_win32.py in __init__(self, process_obj)
     63             try:
     64                 reduction.dump(prep_data, to_child)
---> 65                 reduction.dump(process_obj, to_child)
     66             finally:
     67                 set_spawning_popen(None)

e:\toolkits.win\anaconda3-4.4.0\envs\dlwin36coco\lib\multiprocessing\reduction.py in dump(obj, file, protocol)
     58 def dump(obj, file, protocol=None):
     59     '''Replacement for pickle.dump() using ForkingPickler.'''
---> 60     ForkingPickler(file, protocol).dump(obj)
     61 
     62 #

AttributeError: Can't pickle local object 'GeneratorEnqueuer.start.<locals>.data_generator_task'

I have also tried to modify fit_generator()'s param as follows:

        # Common parameters to pass to fit_generator()
        fit_kwargs = {
            "steps_per_epoch": self.config.STEPS_PER_EPOCH,
            "callbacks": callbacks,
            "validation_data": next(val_generator),
            "validation_steps": self.config.VALIDATION_STPES,
            "max_queue_size": 100,
            "workers": 1,  # Phil: was max(self.config.BATCH_SIZE // 2, 2),
            "use_multiprocessing": False # Phil: was "use_multiprocessing": True,
        }

Unfortunatly, that also results in a crash with the jupyter notebook crashing without any output...

It could well be that the generator is threadsafe. After a quick perusal, however, I haven't found any serializing code anywhere. Threadsafe data generators usually implement some kind of locking mechanism. Here are examples that are threadsafe: https://github.com/fchollet/keras/issues/1638#issuecomment-182139908 and http://anandology.com/blog/using-iterators-and-generators/

Here's a bit of information about my own GPU config:

(from notebook):

os: nt
sys: 3.6.1 |Continuum Analytics, Inc.| (default, May 11 2017, 13:25:24) [MSC v.1900 64 bit (AMD64)]
numpy: 1.13.3, e:\toolkits.win\anaconda3-4.4.0\envs\dlwin36coco\lib\site-packages\numpy\__init__.py
matplotlib: 2.0.2, e:\toolkits.win\anaconda3-4.4.0\envs\dlwin36coco\lib\site-packages\matplotlib\__init__.py
cv2: 3.3.0, e:\toolkits.win\anaconda3-4.4.0\envs\dlwin36coco\lib\site-packages\cv2.cp36-win_amd64.pyd
tensorflow: 1.3.0, e:\toolkits.win\anaconda3-4.4.0\envs\dlwin36coco\lib\site-packages\tensorflow\__init__.py
keras: 2.0.8, e:\toolkits.win\anaconda3-4.4.0\envs\dlwin36coco\lib\site-packages\keras\__init__.py

(from jupyter notebook log):

2017-11-02 16:59:50.162182: I C:\tf_jenkins\home\workspace\rel-win\M\windows-gpu\PY\36\tensorflow\core\common_runtime\gpu\gpu_device.cc:955] Found device 0 with properties:
name: GeForce GTX TITAN X
major: 5 minor: 2 memoryClockRate (GHz) 1.076
pciBusID 0000:03:00.0
Total memory: 12.00GiB
Free memory: 10.06GiB
2

Has anyone else observed similar issues?

Most helpful comment

Hello Folks,

my specs:
Python 3.6
Windows 10
everything else as in requirements for this project.

For me also setting "workers" to 1 and "use_multiprocessing" to False helped to go with the training.

All 26 comments

I think Keras handles thread safety when you pass use_multiprocessing=True to fit_generator(). And the fact that it's crashing even when you set workers=1 tells me that it's not likely to be a thread safety issue.

What version of Keras are you using?

I think I found the problem, but I can't verify it because I can't reproduce the issue on Linux. It's likely a bug in Keras that seems to manifest itself on Windows only according to this thread.

In Keras, in data_utils.py, in class GeneratorEnqueuer: the function data_generator_task() is created as a local function. And later in the process Python multitasking tries to pickle that function but it can't because functions can be pickled only if they're at the module level.

The right way to fix this is to patch Keras by moving the function data_generator_task() to the module level. I would submit a pull request to Keras, but since I don't get this error on Linux I couldn't do proper testing. Maybe you could submit a PR. Alternatively, patch your local copy of Keras and see if that fixes the problem.

In utils/data_utils.py, in class GeneratorEnqueuer, the function data_generator_task() is indeed created as a local function. Later in the process multiprocessing tries to pickle that function when marshalling the generator to additional processes on Windows. This results in the following error:

AttributeError: Can't pickle local object 'GeneratorEnqueuer.start.<locals>.data_generator_task'

That functions can be pickled only at the python module level was perhaps true with older python/pickle versions. It isn't an absolute requirement on Python 3.x. Making data_generator_task() an instance method of GeneratorEnqueuer fixes the particular problem above. It also opens a big can of worms:

TypeError: can't pickle generator objects

No easy fix for a problem like this one, even though several Keras issues opened over 2016 and 2017 repeatedly pointed to this issue (#3962, #4142, #5071, #5510). These reports are regularly flushed out by Keras' stale issue bot, perhaps preventing any serious attempt at fixing the underlying problem.

I'm currently experimenting with an infinite iterator, modifying utils/data_utils.py and engine/training.py. Meanwhile, I have a question for you, @waleedka: could your data generator be implemented using an OrderedEnqueuer instead of a GeneratorEnqueuer object? There's a chance this object could be marshalled over to multiple windows processes because none of its members are a generator:

class OrderedEnqueuer(SequenceEnqueuer):
    """Builds a Enqueuer from a Sequence.

    Used in `fit_generator`, `evaluate_generator`, `predict_generator`.

    # Arguments
        sequence: A `keras.utils.data_utils.Sequence` object.
        use_multiprocessing: use multiprocessing if True, otherwise threading
        shuffle: whether to shuffle the data at the beginning of each epoch
    """

    def __init__(self, sequence,
                 use_multiprocessing=False,
                 shuffle=False):
        self.sequence = sequence
        self.use_multiprocessing = use_multiprocessing
        self.shuffle = shuffle
        self.workers = 0
        self.executor = None
        self.queue = None
        self.run_thread = None
        self.stop_signal = None
...

as opposed to

class GeneratorEnqueuer(SequenceEnqueuer):
    """Builds a queue out of a data generator.

    Used in `fit_generator`, `evaluate_generator`, `predict_generator`.

    # Arguments
        generator: a generator function which endlessly yields data
        use_multiprocessing: use multiprocessing if True, otherwise threading
        wait_time: time to sleep in-between calls to `put()`
        random_seed: Initial seed for workers,
            will be incremented by one for each workers.
    """

    def __init__(self, generator,
                 use_multiprocessing=False,
                 wait_time=0.05,
                 random_seed=None):
        self.wait_time = wait_time
        self._generator = generator
        self._use_multiprocessing = use_multiprocessing
        self._threads = []
        self._stop_event = None
        self.queue = None
        self.random_seed = random_seed
...

I investigated this further and this is my understanding.

The data_generator_task() uses the self variable from the outer score, which works okay if you're doing multithreading because all threads run in the same process and access the same memory. But, of course, due to the global interpreter lock (GIL), multithreading is not very efficient, and that's why we need multiprocessing. This is where the behavior is different between Windows and Linux.

On Windows it fails. You can't pickle data_generator_task() because it's a local function.

On Linux, pickle is not required because the OS simply copies the whole process and all it's memory and creates a new process from that. So the code works. The problem, though, is that the new process has a copy of the variables but they're not shared. Each process has it's own independent copy of the self variable. This causes each copy of the generator to produces it's own copy of the input data, leading to duplication, but that's not really a problem if the generators independently shuffle their data. So it's broken on Linux as well, but the problem doesn't have significant side effects for most use cases.

I think the right solution is to change data_generator_task() to be a module level function and pass variables to it as arguments rather than having it use variable from the parent score. The multiprocessing module provides a mechanism to pass shared variables that get synchronized between processes. This StackOverflow question has relevant answers.

I'll read up more about OrderedEnqueuer and then reply to your second question.

@philferriere Following up on our offline chat. I tested whether the data generator on Linux causes data duplication such that all workers return results in the same order. I didn't find any evidence of that. While each generator runs independently and they do return the same data, the order of the returned data is different in each worker, so we don't get a situation where the same image is return several times in a row. So as far as I can tell, the current Keras implementation, despite it's issues, is still working okay (on Linux).

A better solution, though, is the newly added Sequence class. This is the discussion that led to introducing it. And this blog post by the author of the PR explaining it.

I think changing the data_generator() function (in model.py) to a Sequence class should solve this issue. This is likely a better solution than trying to patch the Keras bug, because it looks like Keras is moving away from using generators anyway. And the change shouldn't be that hard, just refactoring data_generator() into a class and replacing yield with return. I'll see if I can work on it this weekend. If you think you can do it before then that's even better, and I'd be happy to review the changes.

I had the same issue on windows, but your 2 line modification worked for me. I am not using the jupyter notebook, just running in a separate python script though.

  • "workers": max(self.config.BATCH_SIZE // 2, 2),
  • "use_multiprocessing": True,
  • "workers": 1,
  • "use_multiprocessing": False,

Hello Folks,

my specs:
Python 3.6
Windows 10
everything else as in requirements for this project.

For me also setting "workers" to 1 and "use_multiprocessing" to False helped to go with the training.

@Westerby @tinku99 I'm pretty new to the whole programming thing and am trying to slog through it! How do you update the train_shapes.ipynb with those variables?

Hi @alexiskattan - go to model.py script, and there you find it around line 2067 :)

@alexiskattan @Westerby @tinku99, I have submitted this PR to Keras to get the issue fixed. You may want to add your vote/voice to the conversation here to get this long-standing Windows issue finally adressed.

Could Config.USE_MULTIPROCESSING be added? model.py would be changed to use for example:

workers=max(self.config.BATCH_SIZE // 2, 2) if self.config.USE_MULTIPROCESSING  else 1,
use_multiprocessing=self.config.USE_MULTIPROCESSING,

@petsuter I think @philferriere 's PR in Keras should solve the root cause of the problem without requiring additional changes on our side. Right Phile?

@waleedka I thought the PR mainly makes the error message clearer, but it's still a problem.

generator methods (fit_generator [...] with use_multiprocessing=True and workers>0 is not supported on Windows (no marshalling of generators across process boundaries) and will result in a ValueError exception

Maybe I misunderstood?

@waleexka You still need to pick up PR #127

I meant @waleedka, you still need to pick up PR #127

Got it, thanks! I merged PR #127 and added a comment in the code linking to this thread.

Thanks
had the same issue

Python 3.5.2
Ubuntu 16.04.3 LTS
Keras 2.0.8

corrected by setting "workers" to 1

@waleedka did you ever work on moving data_generator() to a Sequence class?

I just made this PR to move data_generator() into a Sequence class #740

@coreywho I tried your codes. They did help me go further than the initial version. (from almost 20+ epochs to 50+ epochs)... but I still encounter the multiprocessing error message and I'm kind of confusing...

The program stuck on midway, no error message but gpu just not working (memory is used, but the percentage in nvidia-smi is 0% for a long time). When I ctrl+c, I see a lot of errors info related to multiprocessing, or synchronize.py. It seems that the multiprocessing fails to synchronize...

I am using keras 2.0.8 and tensorflow 1.6, is this something related to keras version?

@waleedka Hello.
Have you solved it yet?

i have the same error, i've set the workers=1, use_multiprocessing=False, still have the same error, what should i do?

Hi,I can totally run your code in windows 10 but when I run it in my centos virtual python 3.5 enviroment (build with conda),
I got
`Exception in thread Thread-2:
Traceback (most recent call last):
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/threading.py", line 914, in _bootstrap_inner
self.run()
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/threading.py", line 862, in run
self._target(self._args, *self._kwargs)
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/site-packages/keras/utils/data_utils.py", line 666, in _run
with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/site-packages/keras/utils/data_utils.py", line 661, in
initargs=(seqs, self.random_seed))
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/multiprocessing/context.py", line 118, in Pool
context=self.get_context())
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/multiprocessing/pool.py", line 182, in __init__
self._worker_handler.start()
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/threading.py", line 844, in start
_start_new_thread(self._bootstrap, ())
RuntimeError: can't start new thread

Exception in thread Thread-3:
Traceback (most recent call last):
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/threading.py", line 914, in _bootstrap_inner
self.run()
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/threading.py", line 862, in run
self._target(self._args, *self._kwargs)
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/site-packages/keras/utils/data_utils.py", line 666, in _run
with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/site-packages/keras/utils/data_utils.py", line 661, in
initargs=(seqs, self.random_seed))
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/multiprocessing/context.py", line 118, in Pool
context=self.get_context())
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/multiprocessing/pool.py", line 174, in __init__
self._repopulate_pool()
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/multiprocessing/pool.py", line 239, in _repopulate_pool
w.start()
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/multiprocessing/process.py", line 105, in start
self._popen = self._Popen(self)
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/multiprocessing/context.py", line 267, in _Popen
return Popen(process_obj)
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/multiprocessing/popen_fork.py", line 20, in __init__
self._launch(process_obj)
File "/home/caibojun/anaconda2/envs/card/lib/python3.5/multiprocessing/popen_fork.py", line 67, in _launch
self.pid = os.fork()
BlockingIOError: [Errno 11] Resource temporarily unavailable`
So what's the problem?

Hello Folks,

my specs:
Python 3.6
Windows 10
everything else as in requirements for this project.

For me also setting "workers" to 1 and "use_multiprocessing" to False helped to go with the training.

yes, I solved this problem by this way锛宧elp anyone else has the same problem can see the method

I am still having the issue on ppc64 IBM Power machine running Ubuntu. Can someone advice how I should fix

I had this problem, but I changed the data_generator function to a standard keras Sequence data generator and the problem is solved:

`class DataGenerator(keras.utils.Sequence):
def __init__(self, dataset, config, shuffle=True, augment=False, augmentation=None, random_rois=0, batch_size=1,
detection_targets=False, no_augmentation_sources=None, mode='training'):
# Initializations
self.dataset = dataset
self.config = config
self.shuffle = shuffle
self.augment = augment
self.augmentation = augmentation
self.random_rois = random_rois
self.batch_size = batch_size
self.detection_targets = detection_targets
self.no_augmentation_sources = no_augmentation_sources or []
self.mode = mode
self.on_epoch_end()

def on_epoch_end(self):
    # Updates indexes after each epoch
    image_ids = np.copy(self.dataset.images_ids)
    self.indexes = image_ids
    if self.shuffle == True:
        np.random.shuffle(self.indexes)

def __data_generation(self, batch_images_ids):
    # Generates data containing batch_size samples
    # X : (n_samples, *dim, n_channels)
    # Anchors
    # [anchor_count, (y1, x1, y2, x2)]
    backbone_shapes = compute_backbone_shapes(self.config, self.config.image_shape)
    anchors = Utils.generate_pyramid_anchors(scales=self.config.rpn_anchor_scales,
                                             ratios=self.config.rpn_anchor_ratios,
                                             feature_shapes=backbone_shapes,
                                             feature_strides=self.config.backbone_strides,
                                             anchor_stride=self.config.rpn_anchor_stride)

    # Generate A Batch Data
    for i in range(0, len(batch_images_ids)):
        # Get Ground-Truth bounding boxes and masks for the current image.
        image_id = self.dataset.images_ids[batch_images_ids[i]]

        # If the image source is not to be augmented pass None as augmentation
        if self.dataset.images_info[image_id]['source'] in self.no_augmentation_sources:
            image, image_meta, gt_class_ids, gt_boxes, gt_masks = load_image_gt(dataset=self.dataset,
                                                                                config=self.config,
                                                                                image_id=image_id,
                                                                                augment=self.augment,
                                                                                augmentation=None,
                                                                                use_mini_mask=
                                                                                self.config.use_mini_mask,
                                                                                mode=self.mode)
        else:
            image, image_meta, gt_class_ids, gt_boxes, gt_masks = load_image_gt(dataset=self.dataset,
                                                                                config=self.config,
                                                                                image_id=image_id,
                                                                                augment=self.augment,
                                                                                augmentation=self.augmentation,
                                                                                use_mini_mask=
                                                                                self.config.use_mini_mask,
                                                                                mode=self.mode)

        # Skip images that have no instances. This can happen in cases where we train on a subset of classes and the
        # image doesn't have any of the classes we care about.
        if not np.any(gt_class_ids > 0):
            continue

        # RPN Targets
        rpn_match, rpn_bbox = build_rpn_targets(image.shape, anchors, gt_class_ids, gt_boxes, self.config)

        # Mask R-CNN Targets
        if self.random_rois:
            rpn_rois = generate_random_rois(image.shape, self.random_rois, gt_class_ids, gt_boxes)
            if self.detection_targets:
                rois, mrcnn_class_ids, mrcnn_bbox, mrcnn_mask = build_detection_targets(rpn_rois, gt_class_ids,
                                                                                        gt_boxes, gt_masks,
                                                                                        self.config)

        # Init batch arrays
        if i == 0:
            batch_image_meta = np.zeros((self.batch_size,) + image_meta.shape, dtype=image_meta.dtype)
            batch_rpn_match = np.zeros([self.batch_size, anchors.shape[0], 1], dtype=rpn_match.dtype)
            batch_rpn_bbox = np.zeros([self.batch_size, self.config.rpn_train_anchor_per_image, 4],
                                      dtype=rpn_bbox.dtype)
            batch_images = np.zeros((self.batch_size,) + image.shape, dtype=np.float32)
            batch_gt_class_ids = np.zeros((self.batch_size, self.config.max_gt_instances), dtype=np.int32)
            batch_gt_boxes = np.zeros((self.batch_size, self.config.max_gt_instances, 4), dtype=np.int32)
            batch_gt_masks = np.zeros(
                (self.batch_size, gt_masks.shape[0], gt_masks.shape[1], self.config.max_gt_instances),
                dtype=gt_masks.dtype)
            if self.random_rois:
                batch_rpn_rois = np.zeros((self.batch_size, rpn_rois.shape[0], 4), dtype=rpn_rois.dtype)
                if self.detection_targets:
                    batch_rois = np.zeros((self.batch_size,) + rois.shape, dtype=rois.dtype)
                    batch_mrcnn_class_ids = np.zeros((self.batch_size,) + mrcnn_class_ids.shape,
                                                     dtype=mrcnn_class_ids.dtype)
                    batch_mrcnn_bbox = np.zeros((self.batch_size,) + mrcnn_bbox.shape, dtype=mrcnn_bbox.dtype)
                    batch_mrcnn_mask = np.zeros((self.batch_size,) + mrcnn_mask.shape, dtype=mrcnn_mask.dtype)

        # If more instances than fits in the array, sub-sample from them.
        if gt_boxes.shape[0] > self.config.max_gt_instances:
            ids = np.random.choice(np.arange(gt_boxes.shape[0]), self.config.max_gt_instances, replace=False)
            gt_class_ids = gt_class_ids[ids]
            gt_boxes = gt_boxes[ids]
            gt_masks = gt_masks[:, :, ids]

        # Add to batch
        batch_image_meta[i] = image_meta
        batch_rpn_match[i] = rpn_match[:, np.newaxis]
        batch_rpn_bbox[i] = rpn_bbox
        batch_images[i] = mold_image(image.astype(np.float32), self.config)
        batch_gt_class_ids[i, :gt_class_ids.shape[0]] = gt_class_ids
        batch_gt_boxes[i, :gt_boxes.shape[0]] = gt_boxes
        batch_gt_masks[i, :, :, :gt_masks.shape[-1]] = gt_masks
        if self.random_rois:
            batch_rpn_rois[i] = rpn_rois
            if self.detection_targets:
                batch_rois[i] = rois
                batch_mrcnn_class_ids[i] = mrcnn_class_ids
                batch_mrcnn_bbox[i] = mrcnn_bbox
                batch_mrcnn_mask[i] = mrcnn_mask

    # Output results
    inputs = [batch_images, batch_image_meta, batch_rpn_match, batch_rpn_bbox, batch_gt_class_ids,
              batch_gt_boxes, batch_gt_masks]
    outputs = []

    if self.random_rois:
        inputs.extend([batch_rpn_rois])
        if self.detection_targets:
            inputs.extend([batch_rois])
            # Keras requires that output and targets have the same number of dimensions
            batch_mrcnn_class_ids = np.expand_dims(batch_mrcnn_class_ids, -1)
            outputs.extend([batch_mrcnn_class_ids, batch_mrcnn_bbox, batch_mrcnn_mask])

    return inputs, outputs


def __len__(self):
    # Denotes the number of batches per epoch
    return int(np.ceil(len(self.dataset.images_ids) / self.batch_size))

def __getitem__(self, index):
    # Generate one batch of data
    # index is the batch counter

    # Generate item indexes of the batch
    item_indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]

    # Generate the batch data
    inputs, outputs = self.__data_generation(item_indexes)

    return inputs, outputs`
Was this page helpful?
0 / 5 - 0 ratings