Pruner does the pruning.
import optuna
study = optuna.study.create_study()
trial = optuna.trial.Trial(study, study._storage.create_new_trial(study._study_id))
pruner = optuna.pruners.ThresholdPruner(upper=2.0, n_warmup_steps=0, interval_steps=1)
trial.report(3.0, 1)
assert pruner.prune(study=study, trial=study._storage.get_trial(trial._trial_id)) # prune!
trial.report(3.0, 0)
assert pruner.prune(study=study, trial=study._storage.get_trial(trial._trial_id)) # doesn't prune!
It's a small bug but some libraries (e.g. AllenNLP) report metrics from epoch==0.
For the input epoch==0, pruner never prunes a trial because it satisfies step <= n_warmup_steps.
I think it doesn't have a strong impact, as almost pruner doesn't prune trial at the beginning of training.
However, for ThresholdPruner, it could be matter.
Thank you for your report.
As @hvy mentioned in https://github.com/optuna/optuna/pull/660#discussion_r343654827, the current implementation of n_warmup_steps assumes 1-based-indexing, and epoch==0 is out of the range.
This implicit specification is sometimes confusing, and I think we can change it to 0-based-indexing, which is more familiar with Python users.
If we employ 0-based-indexing, we can fix this bug, for example, by accepting negative numbers for n_warmup_steps to disable the warmup.
@himkt Does the code in the Reproducible examples above generate any AssertionError?
At least in my local environment, I do not get any errors....
@bigbird555
Ah, thank you for pointing out, you're right.
The second pruning happens since trial.report(3.0, 1) is called before.
My reproducing code should be:
import optuna
study = optuna.study.create_study()
trial = optuna.trial.Trial(study, study._storage.create_new_trial(study._study_id))
pruner = optuna.pruners.ThresholdPruner(upper=2.0, n_warmup_steps=0, interval_steps=1)
trial.report(3.0, 0)
# it should be error
assert pruner.prune(study=study, trial=study._storage.get_trial(trial._trial_id)) # doesn't prune!
trial.report(3.0, 1)
assert pruner.prune(study=study, trial=study._storage.get_trial(trial._trial_id)) # prune!
This problem is resolved by PR #1430, so let me close this issue.
Most helpful comment
Thank you for your report.
As @hvy mentioned in https://github.com/optuna/optuna/pull/660#discussion_r343654827, the current implementation of
n_warmup_stepsassumes 1-based-indexing, andepoch==0is out of the range.This implicit specification is sometimes confusing, and I think we can change it to 0-based-indexing, which is more familiar with Python users.
If we employ 0-based-indexing, we can fix this bug, for example, by accepting negative numbers for
n_warmup_stepsto disable the warmup.