650f6ee1e0b3c2888a2c6d7db9c3d159cae5a583,tests/pytorch/test_nn.py,,test_set_trans,#,212
Before Change
st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
if F.gpu_ctx():
st_enc_0.cuda()
st_enc_1.cuda()
st_dec.cuda()
print(st_enc_0, st_enc_1, st_dec)
// test/Ǘ: basic
After Change
assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2
def test_set_trans():
ctx = F.ctx()
g = dgl.DGLGraph(nx.path_graph(15))
st_enc_0 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, "sab")
st_enc_1 = nn.SetTransformerEncoder(50, 5, 10, 100, 2, "isab", 3)
st_dec = nn.SetTransformerDecoder(50, 5, 10, 100, 2, 4)
if F.gpu_ctx():
st_enc_0 = st_enc_0.to(ctx)
st_enc_1 = st_enc_1.to(ctx)
st_dec = st_dec.to(ctx)
print(st_enc_0, st_enc_1, st_dec)
// test/Ǘ: basic
h0 = F.randn((g.number_of_nodes(), 50))
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 5
Instances
Project Name: dmlc/dgl
Commit Name: 650f6ee1e0b3c2888a2c6d7db9c3d159cae5a583
Time: 2019-08-27
Author: expye@outlook.com
File Name: tests/pytorch/test_nn.py
Class Name:
Method Name: test_set_trans
Project Name: dmlc/dgl
Commit Name: 650f6ee1e0b3c2888a2c6d7db9c3d159cae5a583
Time: 2019-08-27
Author: expye@outlook.com
File Name: tests/pytorch/test_nn.py
Class Name:
Method Name: test_set2set
Project Name: dmlc/dgl
Commit Name: 650f6ee1e0b3c2888a2c6d7db9c3d159cae5a583
Time: 2019-08-27
Author: expye@outlook.com
File Name: tests/pytorch/test_nn.py
Class Name:
Method Name: test_glob_att_pool