System information
Describe the bug
Working with @w4nderlust on this PR https://github.com/uber/ludwig/pull/699 we encountered an issue with seq2seq beam search. If we use SimpleRNNCell or GRUCell with beam search and no Atttention we see this error:
ValueError: The two structures don't have the same nested structure.
First structure: type=list str=[<tf.Tensor: shape=(384, 256), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>]
Second structure: type=int str=256
More specifically: Substructure "type=list str=[<tf.Tensor: shape=(384, 256), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>]" is a sequence, while substructure "type=int str=256" is not
No error occurs when we use LSTMCell with beam search and no Attention.
Refer this posting for more context.
Code to reproduce the issue
Minimal Working example (NOTE: updated minimal working example to simplify code)
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_addons as tfa
# ValueError exception raised for both SimpleRNNCell and GRUCell
#DECODER_CELL_TYPE = keras.layers.SimpleRNNCell # Raises ValueError Exception
#DECODER_CELL_TYPE = keras.layers.GRUCell # Raises ValueError Exception
DECODER_CELL_TYPE = keras.layers.LSTMCell # Does not raise ValueError Exception
VOCAB_SIZE = 100
EMBED_SIZE = 10
RNN_UNITS = 256
INPUT_SEQUENCE_SIZE = 10
OUTPUT_SEQUENCE_SIZE = 15
BATCH_SIZE = 128
BEAM_WIDTH = 3
SIMULATE_LSTM_ENCODER = False
#============== simulate output from encoder ========================
encoder_outputs = tf.zeros([BATCH_SIZE, RNN_UNITS], dtype=tf.float32)
if SIMULATE_LSTM_ENCODER:
encoder_state = [tf.zeros([BATCH_SIZE, RNN_UNITS], dtype=tf.float32),
tf.zeros([BATCH_SIZE, RNN_UNITS], dtype=tf.float32)]
else:
encoder_state = [tf.zeros([BATCH_SIZE, RNN_UNITS], dtype=tf.float32)]
# ================ Setup Decoder =====================
embeddings_dec = keras.layers.Embedding(VOCAB_SIZE, EMBED_SIZE)
decoder_cell = DECODER_CELL_TYPE(RNN_UNITS)
output_layer = keras.layers.Dense(VOCAB_SIZE)
GO_SYMBOL = VOCAB_SIZE - 1
END_SYMBOL = 0
batch_size = BATCH_SIZE
encoder_sequence_length = tf.convert_to_tensor(
np.array([INPUT_SEQUENCE_SIZE] * BATCH_SIZE),
tf.int32
)
decoder_input = tf.expand_dims(
[GO_SYMBOL] * batch_size, 1)
start_tokens = tf.fill([batch_size], GO_SYMBOL)
end_token = END_SYMBOL
decoder_emb_inp = embeddings_dec(decoder_input)
if DECODER_CELL_TYPE._keras_api_names[0] != 'keras.layers.LSTMCell':
# adjust for non LSTMCell decoder
encoder_state = [encoder_state[0]]
#================= setup for beam search ==================
decoder_initial_state = decoder_cell.get_initial_state(
batch_size=BATCH_SIZE,
dtype=tf.float32
)
if not isinstance(decoder_initial_state, list):
decoder_initial_state = [decoder_initial_state]
tiled_decoder_initial_state = tfa.seq2seq.tile_batch(
decoder_initial_state,
multiplier=BEAM_WIDTH
)
decoder = tfa.seq2seq.beam_search_decoder.BeamSearchDecoder(
cell=decoder_cell,
beam_width=BEAM_WIDTH,
output_layer=output_layer
)
#================= perform beam search =====================
decoder_embedding_matrix = embeddings_dec.variables[0]
(
first_finished,
first_inputs,
first_state
) = decoder.initialize(
decoder_embedding_matrix,
start_tokens=start_tokens,
end_token=end_token,
initial_state=tiled_decoder_initial_state
)
inputs = first_inputs
state = first_state
for j in range(OUTPUT_SEQUENCE_SIZE):
outputs, next_state, next_inputs, finished = decoder.step(
j, inputs, state, training=False)
inputs = next_inputs
state = next_state
@dynamicwebpaige this is another blocker, similar to the similar problem that was solved before.
Thank you for the quick resolution!
Thank you as well. I can confirm the fix addresses the issue we encountered.
Most helpful comment
Thank you as well. I can confirm the fix addresses the issue we encountered.