When using TrainResult as the return of LightningModule.training_step(),TrainRsult.log() can not add metrics to Trainer.logged_metrics and this make Checkpoint.format_checkpoint_name() works not well.
class Resnet18(pl.LightningModule):
def __init__(self, input_dim=40, numclass=1211, learning_rate=0.1, batch_size=128, num_workers=3,
**kwargs):
super(Resnet18, self).__init__()
self.save_hyperparameters()
self.example_input_array = torch.rand((1, 200, input_dim))
self.net = resnet18(num_classes=numclass)
def forward(self, x):
"""
input: size (batch, seq_len, input_features)
outpu: size (batch, new_seq_len, output_features)
"""
x = torch.unsqueeze(x, 1)
x = torch.cat([x, x, x], 1)
x = self.net(x)
return x
def train_dataloader(self):
transform = transforms.Compose([transforms.ToTensor()])
trainloader = Dataset(path,
transform=transform)
trainloader = DataLoader(trainloader,
batch_size=self.hparams.batch_size,
shuffle=True,
num_workers=self.hparams.num_workers,
pin_memory=True)
return trainloader
def loss(self, input, target):
return F.cross_entropy(input, target)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
result = pl.TrainResult(minimize=loss, checkpoint_on=loss)
result.log(train_loss,loss) # dose not work well
print(result.get_epoch_log_metrics()) # will print None and checkpoint file can't get the metric
result.log_dict({'train_loss': loss}) # work
print(result.get_epoch_log_metrics()) # will print train_loss and checkpoint file correctly get the metric
return result
if __name__ == '__main__':
ckpt = ModelCheckpoint(filepath=osp.join("save/ckpt", "{epoch:03d}-{train_loss:.2f}"),
monitor='checkpoint_on',
mode='min',
save_top_k=-1,
verbose=True,
save_weights_only=True)
callbacks = [ckpt]
# model
model = Resnet18()
# training
trainer = pl.Trainer(gpus=args.gpus,
max_epochs=2,
profiler=True,
checkpoint_callback=ckpt,
early_stop_callback=False,
# callbacks=callbacks,
limit_train_batches=0.01, )
trainer.fit(model)
conda):Hi! thanks for your contribution!, great first issue!
Here
result.log(train_loss, loss)
should be
result.log('train_loss', loss)
@2531919747 can you confirm that @rohitgr7's suggestion works?
Here
result.log(train_loss, loss)should be
result.log('train_loss', loss)
Thank you , I found the default value of "on_epoch" in result.log() is False but "on_step" is True, and I used get_epoch_log_metrics() get the 'train_loss' correctly.
Most helpful comment
Here
should be