Addons: WeightNormalization doesn't work with @tf.function in an edge case

Created on 27 Jan 2020  路  11Comments  路  Source: tensorflow/addons

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): macOS 10.15.2 (19C57)
  • TensorFlow version and how it was installed (source or binary): 2.1.0 via pip
  • TensorFlow-Addons version and how it was installed (source or binary): 0.8.0-dev via pip
  • Python version: 3.7.4
  • Is GPU used? (yes/no): no

Describe the bug

When the train function is decorated with @tf.function, using multiple WeightNormalization layers in a multi-output model will cause the training process hanging on without any meaningful output logs.

The model would work if:

  • Comment out the @tf.function
  • Using only one WeightNormalization layer
  • Using only one output in the InnerModel

I don't really know what happen, I post the full code to reproduce the issue. The issue also reproducible on Linux (ubuntu 16.04) with GPU (RTX 2080, CUDA 10.1) enabled.

Code to reproduce the issue

import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np

class InnerModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.model1 = tf.keras.Sequential([
            tfa.layers.WeightNormalization(
                tf.keras.layers.Conv2D(3, 1),
            ),
        ])
        self.model2 = tf.keras.Sequential([
            tfa.layers.WeightNormalization(
                tf.keras.layers.Conv2D(3, 1),
            ),
        ])

    def call(self, inputs):
        out1 = self.model1(inputs)
        out2 = self.model2(out1)
        return out1, out2

class OuterModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.model1 = InnerModel()
        self.model2 = InnerModel()
        self.downsample = tf.keras.layers.AveragePooling2D(2, 1, 'same')

    def call(self, inputs):
        init_out = self.model1(inputs)
        inputs = self.downsample(inputs)
        final_out = self.model2(inputs)
        return init_out, final_out

def loss_fn(out):
    loss = 0.
    for i in range(len(out)):
        for j in range(len(out[i])):
            loss += tf.reduce_mean(out[i][j])
    return loss

net = OuterModel()
opt = tf.keras.optimizers.Adam(1e-4)

@tf.function
def train_step(inputs):
    with tf.GradientTape() as tape:
        out = net(inputs)
        loss = loss_fn(out)
    grads = tape.gradient(loss, net.trainable_variables)
    opt.apply_gradients(zip(grads, net.trainable_variables))

def train():
    for i in range(20):
        x = np.random.randn(1, 32, 32, 3)
        train_step(x)
        print(f'step {i}')

train()

Other info / logs

2020-01-27 15:35:40.604091: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-01-27 15:35:40.624043: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fc57e1f3f50 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-01-27 15:35:40.624082: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
WARNING:tensorflow:Layer outer_model is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.
bug help wanted layers

Most helpful comment

https://github.com/tensorflow/addons/blob/6052477edcefad9691357e84312b127d5b57f4c0/tensorflow_addons/layers/wrappers.py#L62 seems to be the root cause. I suspect tf.CriticalSection doesn't honor neither name nor shared_name arguments and every time creates its own unique CS. Not only doesn't it prevent the parallel execution https://github.com/tensorflow/addons/blob/6052477edcefad9691357e84312b127d5b57f4c0/tensorflow_addons/layers/wrappers.py#L134-L135 it creates a total mess with locks (and eventually a deadlock) somewhere deep inside. This tiny workaround https://github.com/failure-to-thrive/addons/commit/ec008b35e89d810bc7d7b41960664f98859336ea works like a charm.

All 11 comments

Close autograph could fix the bug

Isn't that @tf.function means using autograph? If removing the @tf.function annotation, the training speed is 3x slower (in my case).

@ye2020 use it with weightnormal layer.
I could fix the bug later.

https://github.com/tensorflow/addons/blob/6052477edcefad9691357e84312b127d5b57f4c0/tensorflow_addons/layers/wrappers.py#L62 seems to be the root cause. I suspect tf.CriticalSection doesn't honor neither name nor shared_name arguments and every time creates its own unique CS. Not only doesn't it prevent the parallel execution https://github.com/tensorflow/addons/blob/6052477edcefad9691357e84312b127d5b57f4c0/tensorflow_addons/layers/wrappers.py#L134-L135 it creates a total mess with locks (and eventually a deadlock) somewhere deep inside. This tiny workaround https://github.com/failure-to-thrive/addons/commit/ec008b35e89d810bc7d7b41960664f98859336ea works like a charm.

Another possible explanation - multiple tf.CriticalSection calls are not allowed in @tf.function as tf.Variable are. Solution is the same - call them outside and pass the value via argument or global variable. Unfortunately the documentation on tf.CriticalSection is very scarce. Even the shared_name argument is not explained at all.

Autograph cause the mess and dead lock.Global CS change the OP behavior and fail the OP tests.

According to document:A CriticalSection object is a resource in the graph which executes subgraphs in serial order.
Each layer has its variables and WeightNormalization.call only execute _update_weights only one time. So I believe there is no need to use CriticalSection.

Just for reference here was the issue for why it was added:
https://github.com/tensorflow/addons/issues/722

Just for reference here was the issue for why it was added:

722

Mount CS to the layer which is wrapped may fix the issue.

Fixed by #1115

This should be fixed on nightly. Please let us know if it is not:
https://github.com/tensorflow/addons/pull/1190

Was this page helpful?
0 / 5 - 0 ratings