Addons: BeamSearchDecoder with non LSTM cells raises ValueError exception

Created on 31 May 2020  路  3Comments  路  Source: tensorflow/addons

System information

  • OS Platform and Distribution: Tensorflow Docker Container based on image tensorflow/tensorflow:2.2.0-jupyter
  • TensorFlow version and how it was installed (source or binary): 2.2.0
  • TensorFlow-Addons version and how it was installed (source or binary): tfa-nightly 0.11.0.dev20200520032238
  • Python version: 3.6.9
  • Is GPU used? (yes/no): no

Describe the bug

Working with @w4nderlust on this PR https://github.com/uber/ludwig/pull/699 we encountered an issue with seq2seq beam search. If we use SimpleRNNCell or GRUCell with beam search and no Atttention we see this error:

ValueError: The two structures don't have the same nested structure.
First structure: type=list str=[<tf.Tensor: shape=(384, 256), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>]
Second structure: type=int str=256
More specifically: Substructure "type=list str=[<tf.Tensor: shape=(384, 256), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>]" is a sequence, while substructure "type=int str=256" is not

No error occurs when we use LSTMCell with beam search and no Attention.

Refer this posting for more context.

Code to reproduce the issue
Minimal Working example (NOTE: updated minimal working example to simplify code)

import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_addons as tfa

# ValueError exception raised for both SimpleRNNCell and GRUCell

#DECODER_CELL_TYPE = keras.layers.SimpleRNNCell  # Raises ValueError Exception
#DECODER_CELL_TYPE = keras.layers.GRUCell        # Raises ValueError Exception
DECODER_CELL_TYPE = keras.layers.LSTMCell        # Does not raise ValueError Exception

VOCAB_SIZE = 100
EMBED_SIZE = 10
RNN_UNITS = 256
INPUT_SEQUENCE_SIZE = 10
OUTPUT_SEQUENCE_SIZE = 15
BATCH_SIZE = 128
BEAM_WIDTH = 3
SIMULATE_LSTM_ENCODER = False

#============== simulate output from encoder ========================
encoder_outputs = tf.zeros([BATCH_SIZE, RNN_UNITS], dtype=tf.float32)
if SIMULATE_LSTM_ENCODER:
    encoder_state = [tf.zeros([BATCH_SIZE, RNN_UNITS], dtype=tf.float32),
                     tf.zeros([BATCH_SIZE, RNN_UNITS], dtype=tf.float32)]
else:
    encoder_state = [tf.zeros([BATCH_SIZE, RNN_UNITS], dtype=tf.float32)]



# ================ Setup Decoder =====================
embeddings_dec = keras.layers.Embedding(VOCAB_SIZE, EMBED_SIZE)
decoder_cell = DECODER_CELL_TYPE(RNN_UNITS)
output_layer = keras.layers.Dense(VOCAB_SIZE)

GO_SYMBOL = VOCAB_SIZE - 1
END_SYMBOL = 0
batch_size = BATCH_SIZE
encoder_sequence_length = tf.convert_to_tensor(
    np.array([INPUT_SEQUENCE_SIZE] * BATCH_SIZE),
    tf.int32
)

decoder_input = tf.expand_dims(
    [GO_SYMBOL] * batch_size, 1)
start_tokens = tf.fill([batch_size], GO_SYMBOL)
end_token = END_SYMBOL

decoder_emb_inp = embeddings_dec(decoder_input)

if DECODER_CELL_TYPE._keras_api_names[0] != 'keras.layers.LSTMCell':
    # adjust for non LSTMCell decoder
    encoder_state = [encoder_state[0]]


#================= setup for beam search ==================
decoder_initial_state = decoder_cell.get_initial_state(
    batch_size=BATCH_SIZE,
    dtype=tf.float32
)

if not isinstance(decoder_initial_state, list):
    decoder_initial_state = [decoder_initial_state]

tiled_decoder_initial_state = tfa.seq2seq.tile_batch(
    decoder_initial_state,
    multiplier=BEAM_WIDTH
)

decoder = tfa.seq2seq.beam_search_decoder.BeamSearchDecoder(
    cell=decoder_cell,
    beam_width=BEAM_WIDTH,
    output_layer=output_layer
)

#================= perform beam search  =====================
decoder_embedding_matrix = embeddings_dec.variables[0]

(
    first_finished,
    first_inputs,
    first_state
) = decoder.initialize(
    decoder_embedding_matrix,
    start_tokens=start_tokens,
    end_token=end_token,
    initial_state=tiled_decoder_initial_state
)

inputs = first_inputs
state = first_state

for j in range(OUTPUT_SEQUENCE_SIZE):
    outputs, next_state, next_inputs, finished = decoder.step(
        j, inputs, state, training=False)
    inputs = next_inputs
    state = next_state

bug seq2seq

Most helpful comment

Thank you as well. I can confirm the fix addresses the issue we encountered.

All 3 comments

@dynamicwebpaige this is another blocker, similar to the similar problem that was solved before.

Thank you for the quick resolution!

Thank you as well. I can confirm the fix addresses the issue we encountered.

Was this page helpful?
0 / 5 - 0 ratings