Let's add a flag:
# default False
Trainer(auto_scale_batch_size=True)
This should do binary search on batch size:
And so on until we find the optimal batch size. At this point log it so the user knows (including tensorboard), and continue training with the new batch size.
Ideally the user fixes the batch size in future runs to tune the learning rate.
I'd recommend a proper binsearch instead, so more like:
1) start with batch_size = 1
2) double batch_size until OOM
3) binsearch between batch_size//2 and batch_size
Alternative mode: just use the highest power of 2 batch size that fits.
I am doing this right now
def max_gpu_batch_size(
dataset: data.TextDataset,
finetuner: pl.LightningModule,
n_samples: int = 128,
device: Union[torch.device, int] = 0,
) -> int:
"""
Tries to find a maximal batch size for a device, assuming only that the memory usage of the
model and the total available memory are both stable.
Should be reliable, but slow, you probably only want to run it once.
"""
device = torch.device(device) # type: ignore
device_max_mem = torch.cuda.get_device_properties(device.index).total_memory
def test_run(batch_size):
logger.debug(f"Trying a run with batch size {batch_size}")
with tempfile.TemporaryDirectory() as temp_dir:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
loader = data.TextLoader(dataset, batch_size=batch_size)
trainer = pl.Trainer(
default_save_path=temp_dir,
overfit_pct=n_samples / len(loader),
gpus=[device.index],
max_epochs=2,
)
try:
trainer.fit(finetuner, train_dataloader=loader)
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logger.debug("Exceeded memory capacity")
return None
else:
raise e
usage = torch.cuda.max_memory_allocated(device)
logger.debug(f"Registered usage: {usage} / {device_max_mem} B")
return usage
# Find a majoration of max batch size as a power of two
usage_with_min_size = 0
for exponent in range(math.floor(math.log2(n_samples)) + 1):
max_size = 2 ** exponent
usage_with_max_size = test_run(max_size)
if usage_with_max_size is None:
break
# This will only change as long as we don't break out, at which point it will
# equal the usage for the previous test run
usage_with_min_size = usage_with_max_size
if usage_with_max_size is not None:
logger.warning(
f"Ran out of examples without finding a match batch size (max tried: {max_size})"
", you probably want to try with more examples"
)
# Bissect to find the max batch size
min_size = max_size // 2
while max_size - min_size > 1:
try_size = (max_size + min_size) // 2
usage_with_try_size = test_run(try_size)
if usage_with_try_size is None:
max_size = try_size
else:
min_size = try_size
usage_with_min_size = usage_with_try_size
logger.debug(
f"Mem usage with inferred batch size: {usage_with_min_size} / {device_max_mem}鈥疊"
)
return min_size
However, I usually have to downsize it, since in distributed mode, you have the additional requirement that the device batch size should be a multiple of the number of devices used.
Most helpful comment
I'd recommend a proper binsearch instead, so more like:
1) start with
batch_size = 12) double
batch_sizeuntil OOM3) binsearch between
batch_size//2andbatch_sizeAlternative mode: just use the highest power of 2 batch size that fits.