Just wondering what the best way to implement mixup in lightning is, possibly in the dataset?
Let's use the facebookresearch/mixup-cifar10 as an example. They implement it in the training loop.
for batch_idx, (inputs, targets) in enumerate(trainloader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
inputs, targets_a, targets_b, lam = mixup_data(inputs, targets,
args.alpha, use_cuda)
# rest of code
In PyTorch lightning, you would do this in your training step.
def training_step(self, batch, batch_idx):
# we don't need to worry about CUDA, PyTorch Lightning does this for us
inputs, targets = batch
inputs, targets_a, targets_b, lam = mixup_data(inputs, targets,
args.alpha, use_cuda)
# rest of code
I'll give it a shot, thanks.