def encode(self, x=None):
x = super(PositionalLookupTableEmbeddings, self).encode(x) * math.sqrt(self.dsz)
B, T, C = get_shape_as_list(x)
signal = get_timing_signal_1d(T, C, min_timescale=1.0, max_timescale=self.max_timescale, start_index=0)
return x + signal
After Change
def encode(self, x):
x = super().encode(x) * tf.constant(self.scale)
T = tf.shape(x)[1]
pos = self.positional(T)
return self.dropout(x + pos, training=TRAIN_FLAG())
class LearnedPositionalLookupTableEmbeddings(LearnedPositionalMixin, LookupTableEmbeddings):