Pytorch-lightning: How to keep some LightningModule's parameters on cpu when using CUDA devices for training

Created on 28 Sep 2020  ·  17Comments  ·  Source: PyTorchLightning/pytorch-lightning

❓ Questions and Help

What is your question?

I tried to transform my code into Lightning yesterday, but the CUDA OOM error occurred. My model has a very large parameter nn.Embedding(24000000, 128) (more than 22GB), which obviously exceeds the memory of my CUDA device. I implemented two classes to sovle this problem in my torch_version code, the pseudo code is as follows:

PyTorch Code

class Emb(nn.Module):
        def __init__(self):
                xxxxxx # some init operations
                self.emb = nn.Emebdding(24000000, 128)
        def forward(self, idx):
                return self.emb(idx)

class MyModule(nn.Module):
        def __init__(self):
                xxxxxx # some init operations
                self.calculation = some_calculation()
        def forward(self, input):
                out = self.calculation(input)
                return out

# train part:
get_emb = Emb()
model = MyModule()
model = model.cuda()
optimizer = some_optimizer([{"params": e.parameters}, {"params": model.parametersba}], lr=1e-3)
loss_metric = some_loss()
for epo in epoch:
        for x, y in dataloader:
                embs = get_emb(x.cpu()).cuda()
                out = model(embs)
                loss = loss_metric(out, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

The torch_version code above keeps the nn.Embedding on cpu and ensures that the optimization of training is completed on CUDA devices. But I don't know how to achieve this via pytorch_lightning, because the entire 'training' part is encapsulated in training_step. The PL code is as follows:

PL Code

class MyModule(pl.LightningModule):
        def __init__(self):
                xxxxxx # some init operations
                self.calculation = some_calculation()
                self.emb = nn.Embedding(24000000, 128)
                self.loss_metric = some_loss()
        def training_step(self, batch, batch_idx):
                x, y = batch
                embs = self.emb(x)
                out = self.calculation(embs)
                return {"loss": self.loss_metric(out, y)}

# train part
model = MyModule()
trainer = pl.Trainer(gpus=-1)
trainer.fit(model, dataloader)

So, is there any recommended way to keep a part of the LightningModule's parameters on cpu when using CUDA devices for training?

What's your environment?

  • OS: Ubuntu 16.04.6 LTS
  • CUDA: version 10.2, 2080Ti
  • Version 0.9.0
question

Most helpful comment

I modified my code like this:

class MyModule(pl.LightningModule):
        def __init__(self):
                xxxxxx # some init operations
                self.calculation = some_calculation()
                self.emb = [nn.Embedding(24000000, 128)]
                self.loss_metric = some_loss()

        def training_step(self, batch, batch_idx):
                x, y = batch
                embs = self.emb[0](x.cpu()).to(self.device)
                out = self.calculation(embs)
                return {"loss": self.loss_metric(out, y)}

        def on_save_checkpoint(self, checkpoint):
                checkpoint["emb"] = self.emb

        def on_load_checkpoint(self, checkpoint):
                self.emb = checkpoint["emb"]

It works! Thank you @rohitgr7 and @awaelchli!

All 17 comments

Hi! thanks for your contribution!, great first issue!

if you do this:

class MyModule(pl.LightningModule):
        def __init__(self):
                xxxxxx # some init operations
                self.calculation = some_calculation()
                self.emb = [nn.Embedding(24000000, 128)]
                self.loss_metric = some_loss()
        def forward(self, input):
                x, y = input
                embs = self.emb[0](x.cpu()).to(self.device)
                out = self.calculation(embs)
                return {"loss": self.loss_metric(out, y)}

it should work I guess. Can't think of a better solution than this :sweat_smile:

if you do this:

class MyModule(pl.LightningModule):
        def __init__(self):
                xxxxxx # some init operations
                self.calculation = some_calculation()
                self.emb = [nn.Embedding(24000000, 128)]
                self.loss_metric = some_loss()
        def forward(self, input):
                x, y = input
                embs = self.emb[0](x.cpu()).to(self.device)
                out = self.calculation(embs)
                return {"loss": self.loss_metric(out, y)}

it should work I guess. Can't think of a better solution than this 😅

@rohitgr7 Really?! In this case, will the self.emb be saved in ckpt along with other parameters of MyModule? Sorry, I just noticed that I had a typo in the PL Code: forward(self, input) -> training_step(self, batch, batch_idx)

Yeah, won't save. Didnt think of that. If any module in the lightning has a .to method then it will be moved to device. Somehow need to think of a way to override this .to method for embeddings.

@rohitgr7 I tried to use self.emb = nn.Embedding(24000000, 128).cpu() in lightning code, but it failed. Actually, it is very common in recommendation system to use this kind of large-scale embedding as the trainable weight of the model. For example, the sparse features of User Id (more than 24000000) can be represented by a dense embedding matrix. So is there any possible to implement this operation in Pytorch-Lightning?

Looked at PyTorch source code. Found something. Can you try this? @David-AJ


class SpecialEmbedding(nn.Module):
        def __init__(self, fin, fout):
                self.emb = nn.Embedding(fin, fout)

        def _apply(self, fn):
                return self

        def forward(self, x):
                return self.emb(x)

class MyModule(pl.LightningModule):
        def __init__(self):
                xxxxxx # some init operations
                self.calculation = some_calculation()
                self.emb = SpecialEmbedding(24000000, 128)
                self.loss_metric = some_loss()

        def training_step(self, batch, batch_idx):
                x, y = batch
                embs = self.emb(x.cpu()).to(self.device)
                out = self.calculation(embs)
                return {"loss": self.loss_metric(out, y)}

# train part
model = MyModule()
trainer = pl.Trainer(gpus=-1)
trainer.fit(model, dataloader)

@rohitgr7 Thanks for your kindly help! But this SpecialEmbeddingcode failed again 😅
the error message is as follows:

Traceback (most recent call last):
  File "debug.py", line 742, in <module>
    trainer.fit(model)
  File "/opt/conda/envs/rapids/lib/python3.6/site-packages/pytorch_lightning/trainer/states.py", line 48, in wrapped_fn
    result = fn(self, *args, **kwargs)
  File "/opt/conda/envs/rapids/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1064, in fit
    results = self.accelerator_backend.train()
  File "/opt/conda/envs/rapids/lib/python3.6/site-packages/pytorch_lightning/accelerators/dp_backend.py", line 97, in train
    results = self.trainer.run_pretrain_routine(model)
  File "/opt/conda/envs/rapids/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1239, in run_pretrain_routine
    self.train()
  File "/opt/conda/envs/rapids/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 394, in train
    self.run_training_epoch()
  File "/opt/conda/envs/rapids/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 491, in run_training_epoch
    batch_output = self.run_training_batch(batch, batch_idx)
  File "/opt/conda/envs/rapids/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 844, in run_training_batch
    self.hiddens
  File "/opt/conda/envs/rapids/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 1015, in optimizer_closure
    hiddens)
  File "/opt/conda/envs/rapids/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 1197, in training_forward
    output = self.model(*args)
  File "/opt/conda/envs/rapids/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/conda/envs/rapids/lib/python3.6/site-packages/pytorch_lightning/overrides/data_parallel.py", line 70, in forward
    "them on device: {}".format(self.src_device_obj, t.device))
RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu

It seems that pytorch_lightning forces the parameters of a module to be set on the same device?

not 100% sure why is this RuntimeError should be raised. @awaelchli any suggestions on how to make this work/?

Actually now I also want to know if this is the right way or not or there is another way around since it seems super useful.

@David-AJ is it working on a single GPU device with no distributed backend?

@rohitgr7 did you mean to tag @David-AJ ?

oops, my bad, was looking at your issue too #3998 side by side :sweat_smile:

not 100% sure why is this RuntimeError should be raised. @awaelchli any suggestions on how to make this work/?

Actually now I also want to know if this is the right way or not or there is another way around since it seems super useful.

@David-AJ is it working on a single GPU device with no distributed backend?

Hi @rohitgr7, I tried to run this code on a single GPU and the RuntimeError was raised again.

do you want to train the embedding layer or is it pretrained?
If you want to train it, I'm afraid you can't have it on cpu while also using DP. The error you got above is because DataParallel detects that.

@awaelchli Yes, I want to train it, how about using DDP or any other distributed backend? Actually @rohitgr7‘s first solution self.emb = [nn.Embedding(24000000, 128)] could make self.emb on CPU while training with DP, but in that case the self.emb won't be saved in the ckpt, nor can be loaded using load_from_checkpoint. Could this problem be solved by overriding the on_save_checkpoint and on_load_checkpoint?

self.emb = [nn.Embedding(24000000, 128)]

🤣 this is a funny trick. Very creative. Yeah, this makes torch unaware of this module, and keeps it on the cpu.

Could this problem be solved by overriding the on_save_checkpoint and on_load_checkpoint?

Yes, I think that would do the trick!
But will this this embedding layer not be a huge bottleneck? You will need to transfer all outputs to the GPU and this blocks execution.

@awaelchli

But will this this embedding layer not be a huge bottleneck? You will need to transfer all outputs to the GPU and this blocks execution.

Looks like I have no choice🤣 otherwise I have to train all the module on the cpu, don't know which one could be faster. The application scenario is in the recommendation system, and in fact the number of users and items far exceeds 24 million, all these ID sparse features should be represented by the embedding layer and trained in the module. Do you have any other suggestions on how to make it faster?

Could this problem be solved by overriding the on_save_checkpoint and on_load_checkpoint?

yes. I would do it this way. grab the state dict of the embedding layer and add it to the checkpoint dict. when loading, you do the opposite and read the state dict.

Do you have any other suggestions on how to make it faster?

sorry, nothing comes to my mind :(

I modified my code like this:

class MyModule(pl.LightningModule):
        def __init__(self):
                xxxxxx # some init operations
                self.calculation = some_calculation()
                self.emb = [nn.Embedding(24000000, 128)]
                self.loss_metric = some_loss()

        def training_step(self, batch, batch_idx):
                x, y = batch
                embs = self.emb[0](x.cpu()).to(self.device)
                out = self.calculation(embs)
                return {"loss": self.loss_metric(out, y)}

        def on_save_checkpoint(self, checkpoint):
                checkpoint["emb"] = self.emb

        def on_load_checkpoint(self, checkpoint):
                self.emb = checkpoint["emb"]

It works! Thank you @rohitgr7 and @awaelchli!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

williamFalcon picture williamFalcon  ·  3Comments

versatran01 picture versatran01  ·  3Comments

awaelchli picture awaelchli  ·  3Comments

anthonytec2 picture anthonytec2  ·  3Comments

mmsamiei picture mmsamiei  ·  3Comments