Hey, i am transforming my keras model to a tf estimator. With tf.keras.estimator.model_to_estimator( However, I would like to use training Callbacks (mostly for Learning Rate Decay). How can I do this?
I would also like to use this.
I would like to set a callback because there is no keras optimizer which allows you to set a piecewise learning rate, despite this being part of an optimization strategy.
I also could not find documentation on how to use a tensorflow optimizer, though I saw documentation suggesting it was possible to wrap a tensorflow optimizer.
This seems to me to be urgent, because of the work of Wilson et al. suggesting that adaptive methods may be limited in their applicability.
still no way for this?
+1. What's the recommended approach?
I feel like this is more appropriately an issue for the TensorFlow team, so I filed an issue there.
+1. Any solutions right now?
If the reason for transforming a keras model to a tf.estimator is training on multiple GPUs then you should definitely try the (tf.)keras.utils.multi_gpu_model function:
your_keras_modelyour_keras_model_multi_gpu = multi_gpu_model(your_keras_model, gpus=<number of available GPUs>)your_keras_model_multi_gpu just as a standard keras model including all the callback functionalityyour_keras_model.save_weights(). The weights of your_keras_model are updated while training your_keras_model_multi_gpu.As @ptiwald mentioned using multi_gpu_model is an option. Unfortunately, this doesn't really fully utilize all the gpu's capacity. In other words, it's slow, effectively defeating the purpose of using multi gpu setup. It's there a workaround to use tf estimator with callbacks in a distributed environment? I don't want to use horovod or other libraries.
I am trying to use TFRecords to improve performance was thinking about distributed training down the line. I have adaptive learning and early stop callbacks to train the model in keras. Not sure if I can do the same with tf.estimators. Any idea would help.
Most helpful comment
I feel like this is more appropriately an issue for the TensorFlow team, so I filed an issue there.