Pytorch-lightning: Trainer.from_argparse_args with additional kwargs causes model to not be saved

Created on 3 May 2020  ·  7Comments  ·  Source: PyTorchLightning/pytorch-lightning

🐛 Bug

When using Trainer.from_argparse_args() to initialize the trainer, there will be some specific arguments that we would like to keep constant and not send as part of hparams. If the extra arguments turn out to be an object, such as a TensorBoardLogger or a ModelCheckpoint object, the model will not be saved because these objects get added to hparams

To Reproduce

Code sample

import os
from argparse import ArgumentParser

import pytorch_lightning as pl
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST


class LitModel(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {"train_loss": loss}
        return {"loss": loss, "log": tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def train_dataloader(self):
        dataset = MNIST(
            os.getcwd(), train=True, download=True, transform=transforms.ToTensor()
        )
        loader = DataLoader(dataset, batch_size=32, num_workers=4, shuffle=True)
        return loader


def main(hparams):
    logger = TensorBoardLogger(save_dir=os.getenv("HOME"), name="logs")
    net = LitModel(hparams)
    trainer = Trainer.from_argparse_args(
        hparams, logger=logger, checkpoint_callback=True, overfit_pct=0.01,
    )
    trainer.fit(net)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--sample", type=int, default=42, help="Sample Argument")
    hparams = parser.parse_args()
    main(hparams)

Error

TypeError: cannot pickle '_thread.lock' object

Expected behavior

The model to be saved without considering the extra arguments sent as hparams

Environment

  • CUDA:

    • GPU:



      • GeForce GTX 1050 Ti



    • available: True

    • version: 10.1

  • Packages:

    • numpy: 1.18.1

    • pyTorch_debug: False

    • pyTorch_version: 1.4.0

    • pytorch-lightning: 0.7.5

    • tensorboard: 2.2.1

    • tqdm: 4.45.0

  • System:

    • OS: Linux

    • architecture:



      • 64bit


      • ELF



    • processor: x86_64

    • python: 3.8.2

    • version: 32-Ubuntu SMP Wed Apr 22 17:40:10 UTC 2020

Additional context

Possible Fix

I believe that a possible fix is to change the from_argparse_args method
from

    @classmethod
    def from_argparse_args(cls, args, **kwargs):

        params = vars(args)
        params.update(**kwargs)

        return cls(**params)

to

    @classmethod
    def from_argparse_args(cls, args, **kwargs):

        params = vars(args)

        return cls(**params, **kwargs)

This ensures that the **kwargs are not added to the hparams of the model and the model gets saved successfully, but I'm not sure of the impact of this change

Priority P0 bug / fix help wanted

All 7 comments

Hi! thanks for your contribution!, great first issue!

I had the same issue. I fixed it using the following "import copy; self.hparams=copy.deepcopy(hparams)". This is a quick fix, but maybe it can help you.

I usually have some lines like

mparams = Namespace(**vars(params))
del mparams.logger

for this after I've created the trainer object with params.

mparams = Namespace(**vars(params))
del mparams.logger

I try your solution, but it makes trainer can't save model. Here is my code:

def main(hparams):

    checkpoint_callback = ModelCheckpoint(
        filepath=os.getcwd(),
        monitor='val_loss',
        mode='min',
    )

    if hparams.ddp == 1:
        hparams.distributed_backend = 'ddp'
    if hparams.early_stop == 1:
        hparams.early_stop_callback=True
    hparams.img_size = 256

    trainer = pl.Trainer.from_argparse_args(hparams, checkpoint_callback=checkpoint_callback)
    mparams = Namespace(**vars(hparams))
    del mparams.checkpoint_callback
    hparams = mparams
    if hparams.model == 'base':
        model = PVM_Baseline(hparams)
    elif hparams.model == 'gcn':
        model = PVM_GCN(hparams)
    # Run learning rate finder

    if hparams.lr == 0:
        # find best lr
        hparams.lr = 0.001
        hparams.auto_lr_find='lr'
    if hparams.test == 0:
        trainer.fit(model)
    else:
        checkpoint = osp.join(hparams.checkpoint, 'best_model.ckpt')
        tags_csv = osp.join(hparams.checkpoint, 'meta_tags.csv')
        model = model.load_from_metrics(
            weights_path=checkpoint,
            tags_csv=tags_csv,
            on_gpu=True,
            map_location=None
        )
        trainer.test(model)

Also I have tried self.hparams=copy.deepcopy(hparams) in dataset.py, but it throws error:

  File "/Users/xd/code/PVM/models/baseline.py", line 207, in val_dataloader
    dataset = PVMDataset(self.hparams, 'val', self.word_dict, transform)
  File "/Users/xd/code/PVM/dataset/pvm_dataset.py", line 21, in __init__
    self.hparams = deepcopy(hparams)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 248, in _deepcopy_method
    return type(x)(x.__func__, deepcopy(x.__self__, memo))
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 281, in _reconstruct
    state = deepcopy(state, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 241, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/xd/miniconda3/lib/python3.7/copy.py", line 169, in deepcopy
    rv = reductor(4)
TypeError: can't pickle _thread.lock objects

Is there any solutions to this problem?

I had the same issue. I fixed it using the following "import copy; self.hparams=copy.deepcopy(hparams)". This is a quick fix, but maybe it can help you.

This solved an issue I had that occurred only when returning val_loss from validation_epoch_end. The error was the same, although different traceback.
TypeError: can't pickle _thread.lock objects

Show traceback

Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/opt/conda/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/lightning_run.py", line 116, in <module>
    model, train_dataloader=train_dataloader, val_dataloaders=val_dataloader
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 793, in fit
    self.run_pretrain_routine(model)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 913, in run_pretrain_routine
    self.train()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 347, in train
    self.run_training_epoch()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 452, in run_training_epoch
    self.call_checkpoint_callback()
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 789, in call_checkpoint_callback
    self.checkpoint_callback.on_validation_end(self, self.get_model())
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py", line 10, in wrapped_fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 231, in on_validation_end
    self._do_check_save(filepath, current, epoch)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 265, in _do_check_save
    self._save_model(filepath)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 142, in _save_model
    self.save_function(filepath)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_io.py", line 253, in save_checkpoint
    self._atomic_save(checkpoint, filepath)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/training_io.py", line 244, in _atomic_save
    torch.save(checkpoint, tmp_path)
  File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 328, in save
    _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
  File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 401, in _legacy_save
    pickler.dump(obj)
TypeError: can't pickle _thread.lock objects

Possible Fix

I believe that a possible fix is to change the from_argparse_args method
from

    @classmethod
    def from_argparse_args(cls, args, **kwargs):

    params = vars(args)
    params.update(**kwargs)

    return cls(**params)

to

    @classmethod
    def from_argparse_args(cls, args, **kwargs):

    params = vars(args)

    return cls(**params, **kwargs)

This ensures that the **kwargs are not added to the hparams of the model and the model gets saved successfully, but I'm not sure of the impact of this change

Changing just this small part of the source code worked for me

Another workaround would be to use for eg. Trainer(gpus = hparams.gpus ...), which is not suggested but will definitely work

Just confirmed that #2029 fixes the issue for me.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

as754770178 picture as754770178  ·  3Comments

mmsamiei picture mmsamiei  ·  3Comments

williamFalcon picture williamFalcon  ·  3Comments

jcreinhold picture jcreinhold  ·  3Comments

williamFalcon picture williamFalcon  ·  3Comments