I just wanted to build a model to see how pytorch-lightning works. I am working on jupyter notebook and I stopped the cell in the middle of training. I wanted to free up the CUDA memory and couldn't find a proper way to do that without restarting the kernel. Here I tried these:
del model # model is a pl.LightningModule
del trainer # pl.Trainer
del train_loader # torch DataLoader
torch.cuda.empty_cache()
# this is also stuck
pytorch_lightning.utilities.memory.garbage_collection_cuda()
Deleting model and torch.cuda.empty_cache() works in PyTorch.
Hi! thanks for your contribution!, great first issue!
I think
del model
torch.cuda.empty_cache()
is all you need.
Like I said on the first comment it did not work. Here I got the example from the readme and tried to delete model:
import os
import torch
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
# this is just a plain nn.Module with some structure
class LitClassifier(pl.LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
result = pl.TrainResult(loss)
result.log('train_loss', loss, on_epoch=True)
return result
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss)
result.log('val_loss', loss)
return result
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
# train!
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])
model = LitClassifier()
trainer = pl.Trainer(max_epochs=1, gpus=1)
trainer.fit(model, DataLoader(train), DataLoader(val))
# still doesn't work
del model
torch.cuda.empty_cache()
yeah, didn't work. I thought it should work. Now I also want to know how to do this :sweat_smile:.
@awaelchli can you suggest something here? Thanks.
yep, I think that is because our subprocess does not get killed properly for these signals.
been working on this in #2165, I'll check it also on jupyter/colab once the refactors are done and I can finish this PR. I am fairly confident that this is related and #2165 fixes it, but not 100% sure.
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!