Keras: how to do OHEM(online hard example mining) on keras

Created on 10 May 2017  Â·  5Comments  Â·  Source: keras-team/keras

Hi,
Recently, i try to do OHEM by keras. Have anyone done this on keras ? my idea as follows:

step1 : get the losses of 100 samples each batch during the FP(forward propagation) stage.
step2 : sort the losses of 100 samples by desc
step3 : use top-k losses during the BP(Backpropagation) stage. 

however, i can't find the api, so i view the source code and found out some relevant codes as follow.

  • the code location is in keras/engine/training.py at 888 lines
       # Compute total loss.
        total_loss = None
        for i in range(len(self.outputs)):
            if i in skip_indices:
                continue
            y_true = self.targets[i]
            y_pred = self.outputs[i]
            weighted_loss = weighted_losses[i]
            sample_weight = sample_weights[i]
            mask = masks[i]
            loss_weight = loss_weights_list[i]
            output_loss = weighted_loss(y_true, y_pred,
                                        sample_weight, mask)
            if len(self.outputs) > 1:
                self.metrics_tensors.append(output_loss)
                self.metrics_names.append(self.output_names[i] + '_loss')
            if total_loss is None:
                total_loss = loss_weight * output_loss
            else:
                total_loss += loss_weight * output_loss
        if total_loss is None:
            if not self.losses:
                raise RuntimeError('The model cannot be compiled '
                                   'because it has no loss to optimize.')
            else:
                total_loss = 0.

the total_loss is a tensor hold all samples ' loss. And Keras use the train_function with the total_loss to train and update parameters. its' code in keras/engine/training.py at 1003 lines.

    def _make_train_function(self):
        if not hasattr(self, 'train_function'):
            raise RuntimeError('You must compile your model before using it.')
        if self.train_function is None:
            inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
            if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
                inputs += [K.learning_phase()]

            training_updates = self.optimizer.get_updates(
                self._collected_trainable_weights,
                self.constraints,
                self.total_loss)
            updates = self.updates + training_updates
            # Gets loss and metrics. Updates weights at each call.
            self.train_function = K.function(inputs,
                                             [self.total_loss] + self.metrics_tensors,
                                             updates=updates,
                                             **self._function_kwargs)

So i just replace the total_loss with the top-k loss. it may work. so i add two lines code behind the total_loss

import tensorflow as tf
total_loss = tf.nn.top_k(total_loss, k=40)

the whole code as follows.

        # Compute total loss.
        total_loss = None
        for i in range(len(self.outputs)):
            if i in skip_indices:
                continue
            y_true = self.targets[i]
            y_pred = self.outputs[i]
            weighted_loss = weighted_losses[i]
            sample_weight = sample_weights[i]
            mask = masks[i]
            loss_weight = loss_weights_list[i]
            output_loss = weighted_loss(y_true, y_pred,
                                        sample_weight, mask)
            if len(self.outputs) > 1:
                self.metrics_tensors.append(output_loss)
                self.metrics_names.append(self.output_names[i] + '_loss')
            if total_loss is None:
                total_loss = loss_weight * output_loss
            else:
                total_loss += loss_weight * output_loss
        if total_loss is None:
            if not self.losses:
                raise RuntimeError('The model cannot be compiled '
                                   'because it has no loss to optimize.')
            else:
                total_loss = 0.
        #total_loss = 0.
        # Add regularization penalties
        # and other layer-specific losses.
        for loss_tensor in self.losses:
            total_loss += loss_tensor
        # modify by cxt get top-k loss
        import tensorflow as tf
        total_loss = tf.nn.top_k(total_loss, k=40)

however, it show the error as follows. the reason is the code run during the model.compile stage. there is no data. How can i get the top-k loss ?

  File "/home/cxt/softwares/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 921, in compile
    total_loss = tf.nn.top_k(total_loss, k=40)

  File "/home/cxt/softwares/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/nn_ops.py", line 1998, in top_k
    return gen_nn_ops._top_kv2(input, k=k, sorted=sorted, name=name)

  File "/home/cxt/softwares/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_nn_ops.py", line 2502, in _top_kv2
    name=name)

  File "/home/cxt/softwares/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 768, in apply_op
    op_def=op_def)

  File "/home/cxt/softwares/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2338, in create_op
    set_shapes_for_outputs(ret)

  File "/home/cxt/softwares/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1719, in set_shapes_for_outputs
    shapes = shape_func(op)

  File "/home/cxt/softwares/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1669, in call_with_requiring
    return call_cpp_shape_fn(op, require_shape_fn=True)

  File "/home/cxt/softwares/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/common_shapes.py", line 610, in call_cpp_shape_fn
    debug_python_shape_fn, require_shape_fn)

  File "/home/cxt/softwares/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/common_shapes.py", line 676, in _call_cpp_shape_fn_impl
    raise ValueError(err.message)

ValueError: Shape must be at least rank 1 but is rank 0 for 'TopKV2' (op: 'TopKV2') with input shapes: [], [].
stale

Most helpful comment

I'd be interested in this as well! My current workaround is to use a custom Sequence, which calls predict on random samples of my dataset before each epoch and then supplies fit_generator with hard samples. This is not a perfect workaround because of increased computational cost (not truly online).

Another idea would be to use all samples in the forward propagation and adapt your loss function to ignore easy samples and put a focus on hard samples. Facebook AI Research has had some good success with this – they call it Focal Loss.

All 5 comments

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

Try return losses for all samples, but multiply top-40 of them by 1 and other by 0.

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

I'd be interested in this as well! My current workaround is to use a custom Sequence, which calls predict on random samples of my dataset before each epoch and then supplies fit_generator with hard samples. This is not a perfect workaround because of increased computational cost (not truly online).

Another idea would be to use all samples in the forward propagation and adapt your loss function to ignore easy samples and put a focus on hard samples. Facebook AI Research has had some good success with this – they call it Focal Loss.

I am interested in this method. @max-vogler I would be very grateful if you could provide an example of your workaround.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

amityaffliction picture amityaffliction  Â·  3Comments

NancyZxll picture NancyZxll  Â·  3Comments

anjishnu picture anjishnu  Â·  3Comments

yil8 picture yil8  Â·  3Comments

zygmuntz picture zygmuntz  Â·  3Comments