System information
Describe the bug
There seems to be no examples showing how to use AdamW with learning rate scheduler normally, so I try to use AdamW like the code below. The code is correct with adam, but with AdamW with learning rate decay, it doesn't work.
Can anyone give a right example using the AdamW with learning rate decay?
Code to reproduce the issue
import tensorflow as tf
import os
from tensorflow_addons.optimizers import AdamW
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, enable=True)
print(gpus)
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(16, (3, 3), padding='same', activation='relu', input_shape=(32, 32, 3)),
tf.keras.layers.AveragePooling2D(),
tf.keras.layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
tf.keras.layers.AveragePooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10, activation='softmax')
])
step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[10000, 15000], [1e-0, 1e-1, 1e-2])
# lr and wd can be a function or a tensor
lr = 1e-3 * schedule(step)
wd = lambda: 1e-4 * schedule(step)
print(lr)
print(wd)
# PiecewiseConstantDecay also doesn't seem to work properly
optimizer = AdamW(learning_rate=lr, weight_decay=wd)
# optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
model.compile(optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=40, validation_split=0.1)
model.evaluate(x_test, y_test, verbose=2)
Other info / logs
Traceback (most recent call last):
File "tmp2.py", line 42, in <module>
model.fit(x_train, y_train, epochs=40, validation_split=0.1)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/keras/engine/training.py", line 728, in fit
use_multiprocessing=use_multiprocessing)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 324, in fit
total_epochs=epochs)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 123, in run_one_epoch
batch_outs = execution_function(iterator)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 86, in execution_function
distributed_function(input_fn))
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/eager/def_function.py", line 457, in __call__
result = self._call(*args, **kwds)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/eager/def_function.py", line 520, in _call
return self._stateless_fn(*args, **kwds)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/eager/function.py", line 1823, in __call__
return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/eager/function.py", line 1141, in _filtered_call
self.captured_inputs)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/eager/function.py", line 1224, in _call_flat
ctx, args, cancellation_manager=cancellation_manager)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/eager/function.py", line 511, in call
ctx=ctx)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/eager/execute.py", line 67, in quick_execute
six.raise_from(core._status_to_exception(e.code, message), None)
File "<string>", line 2, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot assign a device for operation sequential/conv2d/Conv2D/ReadVariableOp: Could not satisfy explicit device specification '/job:localhost/replica:0/task:0/device:GPU:0' because no supported kernel for GPU devices is available.
Colocation Debug Info:
Colocation group had the following types and supported devices:
Root Member(assigned_device_name_index_=2 requested_device_name_='/job:localhost/replica:0/task:0/device:GPU:0' assigned_device_name_='/job:localhost/replica:0/task:0/device:GPU:0' resource_device_name_='/job:localhost/replica:0/task:0/device:GPU:0' supported_device_types_=[CPU] possible_devices_=[]
RealDiv: GPU CPU XLA_CPU XLA_GPU
LogicalAnd: GPU CPU XLA_CPU XLA_GPU
_Arg: GPU CPU XLA_CPU XLA_GPU
ReadVariableOp: GPU CPU XLA_CPU XLA_GPU
Greater: GPU CPU XLA_CPU XLA_GPU
Sub: GPU CPU XLA_CPU XLA_GPU
Const: GPU CPU XLA_CPU XLA_GPU
Pack: GPU CPU XLA_CPU XLA_GPU
LessEqual: GPU CPU XLA_CPU XLA_GPU
Identity: GPU CPU XLA_CPU XLA_GPU
Cast: GPU CPU XLA_CPU XLA_GPU
Sum: GPU CPU XLA_CPU XLA_GPU
ResourceApplyAdam: GPU CPU XLA_CPU XLA_GPU
Mul: GPU CPU XLA_CPU XLA_GPU
Sqrt: GPU CPU XLA_CPU XLA_GPU
AssignSubVariableOp: GPU CPU XLA_CPU XLA_GPU
AddV2: GPU CPU XLA_CPU XLA_GPU
Pow: GPU CPU XLA_CPU XLA_GPU
Colocation members, user-requested devices, and framework assigned devices, if any:
sequential_conv2d_conv2d_readvariableop_resource (_Arg) framework assigned device=/job:localhost/replica:0/task:0/device:GPU:0
adamw_adamw_update_resourceapplyadam_m (_Arg) framework assigned device=/job:localhost/replica:0/task:0/device:GPU:0
adamw_adamw_update_resourceapplyadam_v (_Arg) framework assigned device=/job:localhost/replica:0/task:0/device:GPU:0
sequential/conv2d/Conv2D/ReadVariableOp (ReadVariableOp) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/Read/ReadVariableOp (ReadVariableOp) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/Const (Const) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/Const_1 (Const) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/Const_2 (Const) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/Const_3 (Const) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/Const_4 (Const) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/LessEqual (LessEqual) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/Greater (Greater) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/Greater_1 (Greater) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/LessEqual_1 (LessEqual) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/and (LogicalAnd) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/case/preds_c (Pack) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/case/Cast (Cast) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/case/Const (Const) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/case/num_true_conds (Sum) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/case/n_true_conds (Const) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/case/LessEqual (LessEqual) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/case/Assert/Const (Const) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/case/Assert/AssertGuard/Identity (Identity) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/PiecewiseConstant/case/cond/Identity (Identity) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/mul/x (Const) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/mul (Mul) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/mul_1/ReadVariableOp (ReadVariableOp) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/mul_1 (Mul) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/AssignSubVariableOp (AssignSubVariableOp) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/Identity_1 (Identity) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/add/y (Const) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/add (AddV2) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/Cast (Cast) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/Identity_2 (Identity) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/Identity_3 (Identity) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/Pow (Pow) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/Pow_1 (Pow) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/sub/x (Const) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/sub (Sub) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/Sqrt (Sqrt) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/sub_1/x (Const) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/sub_1 (Sub) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/truediv (RealDiv) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/mul_2 (Mul) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/Const (Const) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/sub_2/x (Const) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/sub_2 (Sub) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/sub_3/x (Const) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/sub_3 (Sub) /job:localhost/replica:0/task:0/device:GPU:0
AdamW/AdamW/update/ResourceApplyAdam (ResourceApplyAdam) /job:localhost/replica:0/task:0/device:GPU:0
Op: ReadVariableOp
Node attrs: dtype=DT_FLOAT
Registered kernels:
device='XLA_CPU_JIT'; dtype in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT8, ..., DT_BFLOAT16, DT_COMPLEX128, DT_HALF, DT_UINT32, DT_UINT64]
device='XLA_GPU_JIT'; dtype in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT8, ..., DT_BFLOAT16, DT_COMPLEX128, DT_HALF, DT_UINT32, DT_UINT64]
device='GPU'
device='CPU'
device='XLA_CPU'
device='XLA_GPU'
[[{{node sequential/conv2d/Conv2D/ReadVariableOp}}]] [Op:__inference_distributed_function_2137]
cc code owner: @PhilJd
When I run the above code on the CPU only, no error is reported. But another problem arises, learning rate decay and weight decay do not work.
I found that when using model.fit(), tf.optimizers.schedules.PiecewiseConstantDecay should be used as a parameter to learning_rate like below:
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[1407*20, 1407*30], [1e-3 1e-4, 1e-5])
optimizer = tf.keras.optimizers.Adam(learning_rate=schedule)
model.compile(optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=40, validation_split=0.1)
So I tried to use AdamW as well, learning rate decay works, but the weight decay doesn't work:
step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[1407*20, 1407*30], [1e-3 1e-4, 1e-5])
wd = lambda: 1e-1 * schedule(step)
# weight decay cannot be changed with schedule
optimizer = tf.keras.optimizers.AdamW(learning_rate=schedule, weight_decay=wd)
model.compile(optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=40, validation_split=0.1)
When weight decay doesn't change with learning rate schedule, learning curve may be like this:

Hope someone can tell me how to do it right, thanks!
It seems like keras treats instances of learning_rate_schedule.LearningRateSchedule separately (in _get_hyper)
Could you try to create a second schedule and see if that works? I.e., something along the lines:
schedule_lr = tf.optimizers.schedules.PiecewiseConstantDecay(
[1407*20, 1407*30], [1e-3, 1e-4, 1e-5])
schedule_wd = tf.optimizers.schedules.PiecewiseConstantDecay(
[1407*20, 1407*30], [1e-4, 1e-5, 1e-6])
optimizer = tf.keras.optimizers.AdamW(learning_rate=schedule_lr, weight_decay=schedule_wd)
Thanks :)
It seems like keras treats instances of
learning_rate_schedule.LearningRateScheduleseparately (in _get_hyper)Could you try to create a second schedule and see if that works? I.e., something along the lines:
schedule_lr = tf.optimizers.schedules.PiecewiseConstantDecay( [1407*20, 1407*30], [1e-3, 1e-4, 1e-5]) schedule_wd = tf.optimizers.schedules.PiecewiseConstantDecay( [1407*20, 1407*30], [1e-4, 1e-5, 1e-6]) optimizer = tf.keras.optimizers.AdamW(learning_rate=schedule_lr, weight_decay=schedule_wd)Thanks :)
It doesn't work:
Traceback (most recent call last):
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/framework/tensor_util.py", line 324, in _AssertCompatible
fn(values)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/framework/tensor_util.py", line 263, in inner
_ = [_check_failed(v) for v in nest.flatten(values)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/framework/tensor_util.py", line 264, in <listcomp>
if not isinstance(v, expected_types)]
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/framework/tensor_util.py", line 248, in _check_failed
raise ValueError(v)
ValueError: <tensorflow.python.keras.optimizer_v2.learning_rate_schedule.PiecewiseConstantDecay object at 0x7f57d72eef60>
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "tmp2.py", line 48, in <module>
model.fit(x_train, y_train, epochs=40, validation_split=0.1)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/keras/engine/training.py", line 728, in fit
use_multiprocessing=use_multiprocessing)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 324, in fit
total_epochs=epochs)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 123, in run_one_epoch
batch_outs = execution_function(iterator)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 86, in execution_function
distributed_function(input_fn))
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/eager/def_function.py", line 457, in __call__
result = self._call(*args, **kwds)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/eager/def_function.py", line 503, in _call
self._initialize(args, kwds, add_initializers_to=initializer_map)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/eager/def_function.py", line 408, in _initialize
*args, **kwds))
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/eager/function.py", line 1848, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/eager/function.py", line 2150, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/eager/function.py", line 2041, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/framework/func_graph.py", line 915, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/eager/def_function.py", line 358, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 73, in distributed_function
per_replica_function, args=(model, x, y, sample_weights))
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 760, in experimental_run_v2
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 1787, in call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 2132, in _call_for_each_replica
return fn(*args, **kwargs)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/autograph/impl/api.py", line 292, in wrapper
return func(*args, **kwargs)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 264, in train_on_batch
output_loss_metrics=model._output_loss_metrics)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/keras/engine/training_eager.py", line 311, in train_on_batch
output_loss_metrics=output_loss_metrics))
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/keras/engine/training_eager.py", line 272, in _process_single_batch
model.optimizer.apply_gradients(zip(grads, trainable_weights))
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_addons/optimizers/weight_decay_optimizers.py", line 153, in apply_gradients
grads_and_vars, name=name)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py", line 441, in apply_gradients
kwargs={"name": name})
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 1917, in merge_call
return self._merge_call(merge_fn, args, kwargs)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 1924, in _merge_call
return merge_fn(self._strategy, *args, **kwargs)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py", line 485, in _distributed_apply
var, apply_grad_to_update_var, args=(grad,), group=False))
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 1530, in update
return self._update(var, fn, args, kwargs, group)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 2142, in _update
return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 2148, in _update_non_slot
result = fn(*args, **kwargs)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py", line 467, in apply_grad_to_update_var
update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_addons/optimizers/weight_decay_optimizers.py", line 173, in _resource_apply_dense
with tf.control_dependencies([self._decay_weights_op(var)]):
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_addons/optimizers/weight_decay_optimizers.py", line 158, in _decay_weights_op
self._get_hyper('weight_decay', var.dtype) * var,
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/ops/variables.py", line 1079, in _run_op
return tensor_oper(a.value(), *args, **kwargs)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/ops/math_ops.py", line 924, in r_binary_op_wrapper
x = ops.convert_to_tensor(x, dtype=y.dtype.base_dtype, name="x")
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/framework/ops.py", line 1184, in convert_to_tensor
return convert_to_tensor_v2(value, dtype, preferred_dtype, name)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/framework/ops.py", line 1242, in convert_to_tensor_v2
as_ref=False)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/framework/ops.py", line 1296, in internal_convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/framework/constant_op.py", line 286, in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/framework/constant_op.py", line 227, in constant
allow_broadcast=True)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/framework/constant_op.py", line 265, in _constant_impl
allow_broadcast=allow_broadcast))
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/framework/tensor_util.py", line 449, in make_tensor_proto
_AssertCompatible(values, dtype)
File "/home/yetao/.local/lib/python3.5/site-packages/tensorflow_core/python/framework/tensor_util.py", line 331, in _AssertCompatible
(dtype.name, repr(mismatch), type(mismatch).__name__))
TypeError: Expected float32, got <tensorflow.python.keras.optimizer_v2.learning_rate_schedule.PiecewiseConstantDecay object at 0x7f57d72eef60> of type 'PiecewiseConstantDecay' instead.
It says weight decay needs to be float32, rather than PiecewiseConstantDecay object, but why learning rate could be?
And I saw someplace implementing the weight decay with learning rate schedule by $wd_t = wd* lr_t / lr$, this seems like a good way to implement it, but I'm not familiar with the implementation of TF2.0.
Thanks for trying! I hope to find some time on the weekend to take a closer look.
I've avoided the model.fit function so far as I feel it does too much under the hood but I guess now's the time to dive in ;)
Inspire by https://github.com/sajadn/AdamW/blob/master/DecoupleWeightDecay.py, I find a way using callback to monitor the weight decay along with the learning rate schedule on the begin of each epoch, and the code below can implement the AdamW with learning rate schelude on epochs (not each update):
import tensorflow as tf
import os
from tensorflow_addons.optimizers import AdamW
import numpy as np
from tensorflow.python.keras import backend as K
from tensorflow.python.util.tf_export import keras_export
from tensorflow.keras.callbacks import Callback
def lr_schedule(epoch):
"""Learning Rate Schedule
Learning rate is scheduled to be reduced after 20, 30 epochs.
Called automatically every epoch as part of callbacks during training.
# Arguments
epoch (int): The number of epochs
# Returns
lr (float32): learning rate
"""
lr = 1e-3
if epoch >= 30:
lr *= 1e-2
elif epoch >= 20:
lr *= 1e-1
print('Learning rate: ', lr)
return lr
def wd_schedule(epoch):
"""Weight Decay Schedule
Weight decay is scheduled to be reduced after 20, 30 epochs.
Called automatically every epoch as part of callbacks during training.
# Arguments
epoch (int): The number of epochs
# Returns
wd (float32): weight decay
"""
wd = 1e-4
if epoch >= 30:
wd *= 1e-2
elif epoch >= 20:
wd *= 1e-1
print('Weight decay: ', wd)
return wd
# just copy the implement of LearningRateScheduler, and then change the lr with weight_decay
@keras_export('keras.callbacks.WeightDecayScheduler')
class WeightDecayScheduler(Callback):
"""Weight Decay Scheduler.
Arguments:
schedule: a function that takes an epoch index as input
(integer, indexed from 0) and returns a new
weight decay as output (float).
verbose: int. 0: quiet, 1: update messages.
```python
# This function keeps the weight decay at 0.001 for the first ten epochs
# and decreases it exponentially after that.
def scheduler(epoch):
if epoch < 10:
return 0.001
else:
return 0.001 * tf.math.exp(0.1 * (10 - epoch))
callback = WeightDecayScheduler(scheduler)
model.fit(data, labels, epochs=100, callbacks=[callback],
validation_data=(val_data, val_labels))
```
"""
def __init__(self, schedule, verbose=0):
super(WeightDecayScheduler, self).__init__()
self.schedule = schedule
self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'weight_decay'):
raise ValueError('Optimizer must have a "weight_decay" attribute.')
try: # new API
weight_decay = float(K.get_value(self.model.optimizer.weight_decay))
weight_decay = self.schedule(epoch, weight_decay)
except TypeError: # Support for old API for backward compatibility
weight_decay = self.schedule(epoch)
if not isinstance(weight_decay, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')
K.set_value(self.model.optimizer.weight_decay, weight_decay)
if self.verbose > 0:
print('\nEpoch %05d: WeightDecayScheduler reducing weight '
'decay to %s.' % (epoch + 1, weight_decay))
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs['weight_decay'] = K.get_value(self.model.optimizer.weight_decay)
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, enable=True)
print(gpus)
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(16, (3, 3), padding='same', activation='relu', input_shape=(32, 32, 3)),
tf.keras.layers.AveragePooling2D(),
tf.keras.layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
tf.keras.layers.AveragePooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10, activation='softmax')
])
optimizer = AdamW(learning_rate=lr_schedule(0), weight_decay=wd_schedule(0))
# optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
tb_callback = tf.keras.callbacks.TensorBoard(os.path.join('logs', 'adamw'),
profile_batch=0)
lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_schedule)
wd_callback = WeightDecayScheduler(wd_schedule)
model.compile(optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=40, validation_split=0.1,
callbacks=[tb_callback, lr_callback, wd_callback])
model.evaluate(x_test, y_test, verbose=2)
This can be a simple example of using AdamW with tf.keras.
But if someone want to use learning rate decay every update of weights, like tf.optimizers.schedules.PiecewiseConstantDecay, it cannot be achieved with the code above.
@PhilJd Thanks!
It seems like keras treats instances of
learning_rate_schedule.LearningRateScheduleseparately (in _get_hyper)Could you try to create a second schedule and see if that works? I.e., something along the lines:
schedule_lr = tf.optimizers.schedules.PiecewiseConstantDecay( [1407*20, 1407*30], [1e-3, 1e-4, 1e-5]) schedule_wd = tf.optimizers.schedules.PiecewiseConstantDecay( [1407*20, 1407*30], [1e-4, 1e-5, 1e-6]) optimizer = tf.keras.optimizers.AdamW(learning_rate=schedule_lr, weight_decay=schedule_wd)Thanks :)
I would think we have to do something like this to weight_decay if we want to pass an instance of LearningRateSchedule into it.
@WindQAQ I agree with you. If we do not establish a connection between weight_decay and learning_rate through initial values such as $wd_t = wd_{init} \cdot lr_t / lr_{init}$, then we must do the same schedule for both, but the current code does not support the case where weight_decay is learning_rate_schedule.LearningRateSchedule. I also think that weight_decay needs to support the typelearning_rate_schedule.LearningRateSchedule.
BTW, the callback method mentioned above can be used normally.
@WindQAQ I agree with you. If we do not establish a connection between
weight_decayandlearning_ratethrough initial values such as $wd_t = wd_{init} \cdot lr_t / lr_{init}$, then we must do the same schedule for both, but the current code does not support the case whereweight_decayislearning_rate_schedule.LearningRateSchedule. I also think thatweight_decayneeds to support the typelearning_rate_schedule.LearningRateSchedule.BTW, the callback method mentioned above can be used normally.
Agree +1. As there is another request in #865, @PhilJd do you think we should support decaying weight_decay param? Thank you.
@wuliytTaotao @WindQAQ
Hi, when mixprecision training ,the code above with WD scheduler doesn't work
in on_epoch_begin(self, epoch, logs)
13 def on_epoch_begin(self, epoch, logs=None):
14 if not hasattr(self.model.optimizer, 'weight_decay'):
---> 15 raise ValueError('Optimizer must have a "weight_decay" attribute.')
16 try: # new API
17 weight_decay = float(K.get_value(self.model.optimizer.weight_decay))
ValueError: Optimizer must have a "weight_decay" attribute.
and because of :
optimizer.weight_decay
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
model.optimizer
and by the way :
per step update WD and lr for ADAM is unnessasary ,because ADAM can adjust lr automatically inside an epoch. and WD is aimed to "Decouple Weight Decay Regularization" (original paper)with loss function and lr.
above all ,with epoch level update is more than sufficient.
@AlexWang1900 Please list your code completely and give your TF version.
@wuliytTaotao
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
optimizer = tfa.optimizers.AdamW(learning_rate=lr_schedule(0), weight_decay=wd_schedule(0),amsgrad=False)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
model.optimizer
Anything moved regarding the bug with GPU ?
Hey guys,
While facing the similar issue, until there is a PR, here is a workaround I found that works for either .fit() (classic keras usage) or in a custom use of the AdamW object.
I first create the AdamW object as opt then assign a lambda function returning the value of wd_schedule(opt.iterations) as weight_decay attribute. This allows to update the weight decay value commonly with the optimizer's number of iterations.
Here is a snippet of code for the case of training scheme using .fit() :
lr_schedule = tf.optimizers.schedules.ExponentialDecay(1e-4, 100, 0.9)
wd_schedule = tf.optimizers.schedules.ExponentialDecay(5e-5, 100, 0.9)
opt = AdamW(learning_rate=lr_schedule, weight_decay=lambda : None)
opt.weight_decay = lambda : wd_schedule(opt.iterations)
mlp.compile(
optimizer=opt,
loss=tf.keras.losses.BinaryCrossentropy())
If I create a tf.keras.callback.CallBack to ensure that the value of weight decay do change:
class DecayHistory(tf.keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.lr = []
self.wd = []
def on_batch_end(self, batch, logs={}):
self.lr.append(self.model.optimizer.lr(self.model.optimizer.iterations))
self.wd.append(self.model.optimizer.weight_decay)
I obtain the expected behavior as shown in the following plot :

PS : @wuliytTaotao 's solution can be updated at each step by using on_batch_end() instead of on_epoch_end()
lr_schedule = tf.optimizers.schedules.ExponentialDecay(1e-4, 100, 0.9)
I obtain the expected behavior as shown in the following plot :
You specify 100 decay steps in the code, but in the plot decay continues for the entire plot range (more than 2000 steps). Can you clarify this discrepancy or update the code, please?
lr_schedule = tf.optimizers.schedules.ExponentialDecay(1e-4, 100, 0.9)
I obtain the expected behavior as shown in the following plot :
You specify 100 decay steps in the code, but in the plot decay continues for the entire plot range (more than 2000 steps). Can you clarify this discrepancy or update the code, please?
Hi,
This is inherent to the way tf.optimizers.schedules.ExponentialDecay is built .
Indeed if you look at the documentation (https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/ExponentialDecay), how decay_steps work is not very clear.
Though what it does is that the decay_steps correspond to how many step it takes to to get from a learning rate lr to a learning rate of value decay_rate * lr.
To have a concrete example, lets take the parameters of the learning rate scheduler above with initial_learning_rate = 1e-4, decay_steps = 100, decay_rate = 0.9 :
learning_rate_step_100 = 0.9 * initial_learning_ratelearning_rate_step_200 = 0.9 * learning_rate_step_100Contrary to some other schedulers (such as Cosine scheduler) , Exponential Decay is infinite. The general formula if staircase=False is :
lr(step) = (decay_rate ** (step / decay_steps) )* initial_learning_rate
Many thanks!
One can also follow this https://github.com/tensorflow/addons/pull/1974 to make AdamW support scheduler. Feel free to open an PR and request my review if anyone is interested in it. Thanks.
@hugoych While, your selected solution works for the schedule, it doesn't allow the optimizier to be serialized anymore, due to this line:
With your solution, the parameter is a callable, but it returns a tensor - following the function self. _serialize_hyperparameter:
A callable is resolved differently, than a tensor or a function - the fix is to revert the order of operations (first resolve the callable, then check if it's a tensor or a custom object (e.g.: in this case a learning rate scheduler)
Until the proper schedules are implemented, this solution can be used in conjunction with yours (this is for SGDW, but the same can be done for AdamW)
`
class SerializableSGDW(tfa.optimizers.SGDW):
def get_config(self):
config = tf.keras.optimizers.SGD.get_config(self)
config.update(
{"weight_decay": self._fixed_serialize_hyperparameter("weight_decay"),}
)
return config
def _fixed_serialize_hyperparameter(self, hyperparameter_name):
"""Serialize a hyperparameter that can be a float, callable, or Tensor."""
value = self._hyper[hyperparameter_name]
# First resolve the callable
if callable(value):
value = value()
if isinstance(value, tf.keras.optimizers.schedules.LearningRateSchedule):
return tf.keras.optimizers.schedules.serialize(value)
if tensor_util.is_tensor(value):
return backend.get_value(value)
return value
`
Note however, after loading the model weight_decay will be variable, and no longer the scheduler
Most helpful comment
Hey guys,
While facing the similar issue, until there is a PR, here is a workaround I found that works for either
.fit()(classic keras usage) or in a custom use of theAdamWobject.I first create the
AdamWobject asoptthen assign alambdafunction returning the value ofwd_schedule(opt.iterations)asweight_decayattribute. This allows to update the weight decay value commonly with the optimizer's number of iterations.Here is a snippet of code for the case of training scheme using
.fit():If I create a
tf.keras.callback.CallBackto ensure that the value of weight decay do change:I obtain the expected behavior as shown in the following plot :
PS : @wuliytTaotao 's solution can be updated at each step by using
on_batch_end()instead ofon_epoch_end()