As described in #358, It would be relatively simple to add invalid value detection to Stable-Baselines.
Adding CheckNaNTensorFlow to the losses, and use tf.add_check_numerics_ops in the policies, allows invalid value detection and prevents check_numerics from wrapping over evey operation in the graph (tf.add_check_numerics_ops does this by default).
Code example
class CheckNaNTensorFlow:
def __init__(self):
self.check_nan_ops = []
self.old_ops = []
def __enter__(self):
# wrap list so we have a copy of the current list, as list are mutable
self.old_ops = list(tf.get_default_graph().get_operations())
def __exit__(self):
for op in tf.get_default_graph().get_operations():
if not op in self.old_ops:
try:
self.check_nan_ops.append(tf.check_numerics(op, ""))
except TypeError: # things like reshape or conditional operations are skipped
pass
with CheckNaNTensorFlow() as tf_nan:
... # define the TF graph you want to check here
sess.run([val_1, val_2] + tf_nan.check_nan_ops, td_map) # add to run
Additional context
This will probably add some overhead, my guess only as much as an if instruction. On CPU branch prediction should not bother too much of the performance, it might however hurt GPU performance significanlty, this needs verifing.
PPO2, CPU, forward and back propagation without invalid value checking:
with invalid value checking:
Most helpful comment
PPO2, CPU, forward and back propagation without invalid value checking:
with invalid value checking: