Pytorch-lightning: online hard-example mining/examining under Multi-GPU ='dp'

Created on 17 Mar 2020  ·  4Comments  ·  Source: PyTorchLightning/pytorch-lightning

Background

Hi, I try to track the prediction of each individual sample during training/validation-step. The main purpose is to do online hard-example mining/examining.

I found out a way of doing this is to make the input variable of the functions training/validation_step carrying the sample-id information, for example, the file-name. So I made the input to be a dictionary.

Example Code

class LightningModule():
    def validation_step(self, batch, batch_idx):
        y = batch['target'].float()
        y_hat = self.forward(batch) 
        loss = self.get_loss(y_hat, y) 

        # append the individual result 
        for i in range(len(batch['sample_id'])):
            self.validation_result['prediction_result'].append(y_hat[i])
            self.validation_result['sample_id'].append(batch['sample_id'][i])
            self.validation_result['target'].append(batch['target'][i])
        return {'val_loss': loss}

    def forward(self, batch):
        x = batch['x']
        y_hat = self.model( x)
        return y_hat

Input-Dict works in Single GPU but fail under multi-GPUs-dp

input_batch = {  
    'x' : Tensor (1st dimension as batch), 
    'target':  Tensor (1st dimension as batch), 
    'sample-id': [a, b, c] (list-object) 
}

AND It takes me some time to realize that all value-objects inside the input-dictionary should be torch.Tensor, not list contains strings, otherwise while training under Multi-GPU ='dp' mode, the list-obj won't be separated properly.

Input-Dict works in both Single/multi-GPUs-dp

input_batch = {  
    'x' : Tensor (1st dimension as batch), 
    'target':  Tensor (1st dimension as batch), 
    'sample-id': 1D-Tensor for sample-id  ex Tensor([1 , 3, 5]) 
}

Currently, I still have some doubts on this approach...
Does anyone try to implement similar functions, online hard-example mining, with different approaches?
Tks : )

question

Most helpful comment

have you considered using a library such as pytorch-metric-learning?

in general, it would look something like

class MinerNetwork(pl.LightningModule):
  def __init__(...):
    self.network = # define network here
    self.miner_function = miners.DistanceWeightedMiner()
    self.objective = losses.TripletMarginLoss()

  def forward(self, data, labels):
    embeddings = self.network(data)
    return embeddings

  def training_step(self, batch, batch_idx):
    data, labels = batch
    embeddings = self(data)
    pairs = self.miner_function(embeddings, labels)
    loss = self.objective(embeddings, labels, pairs)
    return loss

this does mining within each batch that you pass in. i'm not sure where you're doing the mining currently but it seems suspicious to be appending data to a class attribute (self.validation_result). this will likely break if you try running on ddp because you send a copy of the model to each worker.

All 4 comments

Hi! thanks for your contribution!, great first issue!

@neggert @jeffling @jeremyjordan pls ^^

have you considered using a library such as pytorch-metric-learning?

in general, it would look something like

class MinerNetwork(pl.LightningModule):
  def __init__(...):
    self.network = # define network here
    self.miner_function = miners.DistanceWeightedMiner()
    self.objective = losses.TripletMarginLoss()

  def forward(self, data, labels):
    embeddings = self.network(data)
    return embeddings

  def training_step(self, batch, batch_idx):
    data, labels = batch
    embeddings = self(data)
    pairs = self.miner_function(embeddings, labels)
    loss = self.objective(embeddings, labels, pairs)
    return loss

this does mining within each batch that you pass in. i'm not sure where you're doing the mining currently but it seems suspicious to be appending data to a class attribute (self.validation_result). this will likely break if you try running on ddp because you send a copy of the model to each worker.

Thanks for reply!

My original purpose is to pick-out and record the hard-samples during the training/validation after every epoch. Therefore I append the result into the lightning-model-instance. Thanks for pointing out that it would be a failure design on multi-gpus with ddp mode.

I didn't know pytorch-metric-learning before. It seems to be one of right libraries that I should look at. Really appreciate!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

edenlightning picture edenlightning  ·  3Comments

williamFalcon picture williamFalcon  ·  3Comments

versatran01 picture versatran01  ·  3Comments

iakremnev picture iakremnev  ·  3Comments

as754770178 picture as754770178  ·  3Comments