B, H, T, _ = q.shape
q = q.zero_()
mask = subsequent_mask(T)
res, _ = attn(q, k, v, mask=mask)
res = res.numpy()
gold = v.numpy()
for b in range(B):
for h in range(H):
After Change
B, H, T, _ = q.shape
q = q.zero_()
mask = subsequent_mask(T)
res = attn((q, k, v, mask))
res = res.numpy()
gold = v.numpy()
for b in range(B):
for h in range(H):