Vision: How to modify the loss function of models in torchvison?

Created on 13 Feb 2020  路  7Comments  路  Source: pytorch/vision

Excuse me if this question is a little stupid, for I just recently got access to this extraordinary field and cannot find the answer after some researching.
I invoked the pretrained mrcnn model in torchvison however its output wasn't so ideal. So I wonder if I can modify the loss function to improve its performance without rewriting the whole framework?
Thanks a lot for any advice.

models question object detection

All 7 comments

Hi,

The loss for Mask R-CNN can be found in https://github.com/pytorch/vision/blob/master/torchvision/models/detection/roi_heads.py

torchvision contains reference implementations, and it should be easy to modify some parts of it to adapt it to your need.

For the losses in particular, you could hack around it in your code by doing something like (warning: it's a bit hacky and untested)

from torchvision.models.detection import roi_heads

def my_new_loss(...):
    pass

roi_heads.mask_rcnn_loss = my_new_loss

The mask_rcnn_loss is implemented in https://github.com/pytorch/vision/blob/2d1bf7cb78a9f0505bd1b0899822869d056a9a09/torchvision/models/detection/roi_heads.py#L108-L138

Hi @fmassa,
I tried to apply what you shared above, but it didn't work.

As a test:

  • I took maskrcnn_loss, changed the name, and added a printto make sure that everything was ok.
  • I tried to use roi_heads.mask_rcnn_loss = My_Loss
  • And I alsoI tried to use mymodel.roi_heads.mask_rcnn_loss = My_Loss

Unfortunately, in both case, MyLoss was never called (print never executed).

I use TorchVision 0.6.1.

What am I doing wrong?

Hi,

For what I proposed, you would need to do instead

torchvision.models.detection.roi_heads.mask_rcnn_loss = My_Loss

i.e., directly modify what torchvision is using.
Is that what you tried doing?

Yes, it's exactly what I did.
I forgot to precise that I used first from torchvision.models.detection import roi_heads, and then I did roi_heads.mask_rcnn_loss = My_Loss, which I think is comparable to what you proposed.

Of course, I do that before creating my model.

Hi,
I found the solution, You should change
roi_heads.mask_rcnn_loss = my_new_loss

to

roi_heads.maskrcnn_loss = my_new_loss

@FiReTiTi

@ErfanMN Thanks a lot, it works. Good catch!!!
Why didn't I get any error when using roi_heads.mask_rcnn_loss = My_Loss?

@FiReTiTi because you set a variable that never used.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

dssa56 picture dssa56  路  60Comments

varunagrawal picture varunagrawal  路  45Comments

fmassa picture fmassa  路  45Comments

fmassa picture fmassa  路  30Comments

fmassa picture fmassa  路  34Comments