Pytorch-lightning: Feature to automatically choose batch size

Created on 26 Apr 2020  路  2Comments  路  Source: PyTorchLightning/pytorch-lightning

Let's add a flag:

# default False
Trainer(auto_scale_batch_size=True)

This should do binary search on batch size:

  1. Run a few train steps using the current batch size.
  2. if OOM batch_size / 2.
  3. If no OOM, batch_size = batch_size * 1.5.

And so on until we find the optimal batch size. At this point log it so the user knows (including tensorboard), and continue training with the new batch size.

Ideally the user fixes the batch size in future runs to tune the learning rate.

enhancement help wanted

Most helpful comment

I'd recommend a proper binsearch instead, so more like:

1) start with batch_size = 1
2) double batch_size until OOM
3) binsearch between batch_size//2 and batch_size

Alternative mode: just use the highest power of 2 batch size that fits.

All 2 comments

I'd recommend a proper binsearch instead, so more like:

1) start with batch_size = 1
2) double batch_size until OOM
3) binsearch between batch_size//2 and batch_size

Alternative mode: just use the highest power of 2 batch size that fits.

I am doing this right now

def max_gpu_batch_size(
    dataset: data.TextDataset,
    finetuner: pl.LightningModule,
    n_samples: int = 128,
    device: Union[torch.device, int] = 0,
) -> int:
    """
    Tries to find a maximal batch size for a device, assuming only that the memory usage of the
    model and the total available memory are both stable.

    Should be reliable, but slow, you probably only want to run it once.
    """
    device = torch.device(device)  # type: ignore
    device_max_mem = torch.cuda.get_device_properties(device.index).total_memory

    def test_run(batch_size):
        logger.debug(f"Trying a run with batch size {batch_size}")
        with tempfile.TemporaryDirectory() as temp_dir:
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats(device)
            loader = data.TextLoader(dataset, batch_size=batch_size)
            trainer = pl.Trainer(
                default_save_path=temp_dir,
                overfit_pct=n_samples / len(loader),
                gpus=[device.index],
                max_epochs=2,
            )
            try:
                trainer.fit(finetuner, train_dataloader=loader)
            except RuntimeError as e:
                if "CUDA out of memory" in str(e):
                    logger.debug("Exceeded memory capacity")
                    return None
                else:
                    raise e
        usage = torch.cuda.max_memory_allocated(device)
        logger.debug(f"Registered usage: {usage} / {device_max_mem} B")
        return usage

    # Find a majoration of max batch size as a power of two
    usage_with_min_size = 0
    for exponent in range(math.floor(math.log2(n_samples)) + 1):
        max_size = 2 ** exponent
        usage_with_max_size = test_run(max_size)
        if usage_with_max_size is None:
            break
        # This will only change as long as we don't break out, at which point it will
        # equal the usage for the previous test run
        usage_with_min_size = usage_with_max_size
    if usage_with_max_size is not None:
        logger.warning(
            f"Ran out of examples without finding a match batch size (max tried: {max_size})"
            ", you probably want to try with more examples"
        )

    # Bissect to find the max batch size
    min_size = max_size // 2
    while max_size - min_size > 1:
        try_size = (max_size + min_size) // 2
        usage_with_try_size = test_run(try_size)
        if usage_with_try_size is None:
            max_size = try_size
        else:
            min_size = try_size
            usage_with_min_size = usage_with_try_size
    logger.debug(
        f"Mem usage with inferred batch size: {usage_with_min_size} / {device_max_mem}鈥疊"
    )
    return min_size

However, I usually have to downsize it, since in distributed mode, you have the additional requirement that the device batch size should be a multiple of the number of devices used.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

iakremnev picture iakremnev  路  3Comments

Vichoko picture Vichoko  路  3Comments

edenlightning picture edenlightning  路  3Comments

williamFalcon picture williamFalcon  路  3Comments

awaelchli picture awaelchli  路  3Comments