q, k, v = qkv
with tf.device("/cpu:0"):
q = tf.zeros_like(q)
res = dot_product_attention(q, k, v)
with tf.Session() as sess:
res, gold = sess.run([res, v])
B, H, T, _ = q.get_shape().as_list()
for b in range(B):
After Change
with tf.device("/cpu:0"):
q = tf.zeros_like(q)
dot_product_attention = SeqDotProductAttention(0.0)
res = dot_product_attention((q, k, v, None))
if get_version(tf) < 2:
with tf.Session() as sess:
res, gold = sess.run([res, v])
else:
res, gold = res.numpy(), v.numpy()
B, H, T, _ = q.get_shape().as_list()
for b in range(B):
for h in range(H):
for t in range(T):