'TFA BeamSearchDecoder Clarification Request

If the question seems to dumb, it is because I am new to TensorFlow. I was implementing a toy endocer-decoder problem using TensorFlow 2’s TFA seq2seq implementation. The API was clearly understandable until I wanted to change my BasicDecoder with BeamSearchDecoder. My question is regarding start_tokens and end_token arguments’ initialization of BeamSearchDecoder.

Here is a copy of the implementation, any help is appreciated.


tf.keras.backend.clear_session()
tf.random.set_seed(42)

enc_vocab_size = len(train_vocab) + 1
dec_vocab_size = len(target_vocab) + 1
embed_size = 10


import tensorflow_addons as tfa

encoder_inputs = keras.layers.Input(shape=[None], dtype=np.int32)
decoder_inputs = keras.layers.Input(shape=[None], dtype=np.int32)
sequence_lengths = keras.layers.Input(shape=[], dtype=np.int32)


encoder_embeddings = keras.layers.Embedding(enc_vocab_size, embed_size)(encoder_inputs)
encoder = keras.layers.LSTM(512, return_state = True)
encoder_outputs, state_h, state_c = encoder(encoder_embeddings)
encoder_state = [state_h, state_c]


sampler = tfa.seq2seq.sampler.TrainingSampler()


decoder_embeddings = keras.layers.Embedding(dec_vocab_size, embed_size)(decoder_inputs)
decoder_cell = keras.layers.LSTMCell(512)
output_layer = keras.layers.Dense(dec_vocab_size)



beam_width = 10
start_tokens = tf.zeros([32], tf.dtypes.int32)
end_token = tf.constant(1, tf.dtypes.int32)
decoder = tfa.seq2seq.beam_search_decoder.BeamSearchDecoder(cell = decoder_cell, beam_width = beam_width, output_layer = output_layer)
decoder_initial_state = tfa.seq2seq.beam_search_decoder.tile_batch(encoder_state, multiplier = beam_width)
outputs, _, _ = decoder(decoder_embeddings, start_tokens = start_tokens, end_token = end_token, initial_state = decoder_initial_state)
Y_proba = tf.nn.softmax(outputs.rnn_output)


model = keras.models.Model(inputs = [encoder_inputs, decoder_inputs], outputs = [Y_proba])
model.compile(loss="sparse_categorical_crossentropy", optimizer = 'adam', metrics = ['accuracy'])

Error Trance:


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-101-6cf083735ed0> in <module>()
     34 decoder = tfa.seq2seq.beam_search_decoder.BeamSearchDecoder(cell = decoder_cell, beam_width = beam_width, output_layer = output_layer)
     35 decoder_initial_state = tfa.seq2seq.beam_search_decoder.tile_batch(encoder_state, multiplier = beam_width)
---> 36 outputs, _, _ = decoder(decoder_embeddings, start_tokens = start_tokens, end_token = end_token, initial_state = decoder_initial_state)
     37 Y_proba = tf.nn.softmax(outputs.rnn_output)
     38 

1 frames
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    690       except Exception as e:  # pylint:disable=broad-except
    691         if hasattr(e, 'ag_error_metadata'):
--> 692           raise e.ag_error_metadata.to_exception(e)
    693         else:
    694           raise

ValueError: Exception encountered when calling layer "beam_search_decoder" (type BeamSearchDecoder).

in user code:

    File "/usr/local/lib/python3.7/dist-packages/tensorflow_addons/seq2seq/beam_search_decoder.py", line 941, in call  *
        self,
    File "/usr/local/lib/python3.7/dist-packages/typeguard/__init__.py", line 262, in wrapper  *
        retval = func(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/tensorflow_addons/seq2seq/decoder.py", line 430, in body  *
        (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(
    File "/usr/local/lib/python3.7/dist-packages/tensorflow_addons/seq2seq/beam_search_decoder.py", line 705, in step  *
        cell_outputs, next_cell_state = self._cell(
    File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler  **
        raise e.with_traceback(filtered_tb) from None

    ValueError: Exception encountered when calling layer "lstm_cell_1" (type LSTMCell).
    
    Dimensions must be equal, but are 80 and 320 for '{{node beam_search_decoder/decoder/while/BeamSearchDecoderStep/lstm_cell_1/mul}} = Mul[T=DT_FLOAT](beam_search_decoder/decoder/while/BeamSearchDecoderStep/lstm_cell_1/Sigmoid_1, beam_search_decoder/decoder/while/BeamSearchDecoderStep/Reshape_2)' with input shapes: [320,80,2048], [320,512].
    
    Call arguments received:
      • inputs=tf.Tensor(shape=(320, None, 10), dtype=float32)
      • states=ListWrapper(['tf.Tensor(shape=(320, 512), dtype=float32)', 'tf.Tensor(shape=(320, 512), dtype=float32)'])
      • training=None


Call arguments received:
  • embedding=tf.Tensor(shape=(None, None, 10), dtype=float32)
  • start_tokens=tf.Tensor(shape=(32,), dtype=int32)
  • end_token=tf.Tensor(shape=(), dtype=int32)
  • initial_state=['tf.Tensor(shape=(None, 512), dtype=float32)', 'tf.Tensor(shape=(None, 512), dtype=float32)']
  • training=None
  • kwargs=<class 'inspect._empty'>



Solution 1:[1]

I answered this on this GitHub repository: https://github.com/ageron/handson-ml2/issues/541

Here is a minimalistic implementation; without attention, of what you want. Beamserach is used during inference once we are done with training.

For implementation of the encoder-decoder part see the provided github.

Implementing Beamserach

def beam_search_inferance_model(beam_width):
  batch_size = tf.shape(encoder_input)[:1]
  max_output_length = Y_train.shape[1]
  start_tokens = tf.fill(dims = batch_size, value = sos_id)
  decoder_initial_state = tfa.seq2seq.tile_batch(encoder_state_HC, multiplier = beam_width)
  beam_search_inference = tfa.seq2seq.BeamSearchDecoder(cell = LSTMCell, beam_width = beam_width, output_layer = output_layer, maximum_iterations = max_output_length)
  outputs, _, _ = beam_search_inference(decoder_embd_layer.variables, start_tokens = start_tokens, end_token = 0, initial_state = decoder_initial_state)
  final_outputs = tf.transpose(outputs.predicted_ids, perm = (0,2,1))
  beam_scores = tf.transpose(outputs.beam_search_decoder_output.scores, perm = (0,2,1))
  return keras.Model(inputs = [encoder_input], outputs = [final_outputs, beam_scores])

beam_search_inferance_model = beam_search_inferance_model(3)

Utility function I copied this function from TFA's API tutorial and adapted it!!!

def beam_translate(sentence):
  X = prepare_date_strs_padded(sentence)
  result, beam_scores = beam_search_inferance_model.predict(X)
  for beam, score in zip(result, beam_scores):
    output = ids_to_date_strs(beam)
    beam_score = [a.sum() for a in score]
    print('Input: %s' % sentence)
    print('-----' * 12)
    for i in range(len(output)):
      print('{} Predicted translation: {}  {}'.format(i + 1, output[i], beam_score[i]))
    print('\n')

Output

beam_translate(["July 14, 1789", "September 01, 2020"])

Input: ['July 14, 1789', 'September 01, 2020']
------------------------------------------------------------
1 Predicted translation: 2288-01-11  -83.7786865234375
2 Predicted translation: 2288-01-10  -83.90345764160156
3 Predicted translation: 2288-01-21  -84.30797576904297


Input: ['July 14, 1789', 'September 01, 2020']
------------------------------------------------------------
1 Predicted translation: 2221-02-26  -79.02340698242188
2 Predicted translation: 2222-02-26  -79.29275512695312
3 Predicted translation: 2221-02-21  -80.06587982177734

I hope this helps!

Cheers, Kasra

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1