query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor)
// Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output)
// Generate data for the input (non-mask) tensors.
from_data = 10 * np.random.random_sample((batch_size, 4, 8))
to_data = 10 * np.random.random_sample((batch_size, 2, 8))
// Invoke the data with a random set of mask data. This should mask at least
// one element.
mask_data = np.random.randint(2, size=(batch_size, 4, 2))
masked_output_data = model.predict([from_data, to_data, mask_data])
// Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones((batch_size, 4, 2))
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
// Because one data is masked and one is not, the outputs should not be the
// same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)
// Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
unmasked_output_data = model.predict(
After Change
// Create a 3-dimensional input (the first dimension is implicit).
batch_size = 3
query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
// Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output)
// Generate data for the input (non-mask) tensors.
from_data = 10 * np.random.random_sample((batch_size, 4, 8))
to_data = 10 * np.random.random_sample((batch_size, 2, 8))
// Invoke the data with a random set of mask data. This should mask at least
// one element.
mask_data = np.random.randint(2, size=(batch_size, 4, 2))
masked_output_data = model.predict([from_data, to_data, mask_data])
// Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones((batch_size, 4, 2))
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
// Because one data is masked and one is not, the outputs should not be the
// same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)
// Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8))
output = test_layer(query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
unmasked_output_data = model.predict(