Excuse me if this question is a little stupid, for I just recently got access to this extraordinary field and cannot find the answer after some researching.
I invoked the pretrained mrcnn model in torchvison however its output wasn't so ideal. So I wonder if I can modify the loss function to improve its performance without rewriting the whole framework?
Thanks a lot for any advice.
Hi,
The loss for Mask R-CNN can be found in https://github.com/pytorch/vision/blob/master/torchvision/models/detection/roi_heads.py
torchvision contains reference implementations, and it should be easy to modify some parts of it to adapt it to your need.
For the losses in particular, you could hack around it in your code by doing something like (warning: it's a bit hacky and untested)
from torchvision.models.detection import roi_heads
def my_new_loss(...):
pass
roi_heads.mask_rcnn_loss = my_new_loss
The mask_rcnn_loss is implemented in https://github.com/pytorch/vision/blob/2d1bf7cb78a9f0505bd1b0899822869d056a9a09/torchvision/models/detection/roi_heads.py#L108-L138
Hi @fmassa,
I tried to apply what you shared above, but it didn't work.
As a test:
maskrcnn_loss, changed the name, and added a printto make sure that everything was ok.roi_heads.mask_rcnn_loss = My_Lossmymodel.roi_heads.mask_rcnn_loss = My_LossUnfortunately, in both case, MyLoss was never called (print never executed).
I use TorchVision 0.6.1.
What am I doing wrong?
Hi,
For what I proposed, you would need to do instead
torchvision.models.detection.roi_heads.mask_rcnn_loss = My_Loss
i.e., directly modify what torchvision is using.
Is that what you tried doing?
Yes, it's exactly what I did.
I forgot to precise that I used first from torchvision.models.detection import roi_heads, and then I did roi_heads.mask_rcnn_loss = My_Loss, which I think is comparable to what you proposed.
Of course, I do that before creating my model.
Hi,
I found the solution, You should change
roi_heads.mask_rcnn_loss = my_new_loss
to
roi_heads.maskrcnn_loss = my_new_loss
@FiReTiTi
@ErfanMN Thanks a lot, it works. Good catch!!!
Why didn't I get any error when using roi_heads.mask_rcnn_loss = My_Loss?
@FiReTiTi because you set a variable that never used.