1f0750670cf8ea24ad264debf9603002ab0fb565,test/nn/pool/test_mem_pool.py,,test_mem_pool,#,5
Before Change
mpool = MemPool(heads=3, num_keys=2, in_channels=2, out_channels=3)
assert mpool.__repr__() == "MemPool(2, 3, heads=3, num_keys=2)"
x = torch.rand((5, 4, 2))
mask = torch.Tensor([[1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 1, 1],
[1, 1, 1, 1], [1, 1, 1, 0]]).bool()
x, S = mpool(x, mask)
loss, P = MemPool.kl_loss(S, mask)
assert x.shape == torch.Size([5, 2, 3])
assert (S[~mask] == 0).all()
assert P.shape == S.shape
After Change
assert out.size() == (5, 2, 8)
assert S[~mask].sum() == 0
assert S[mask].sum() == x.size(0)
assert float(loss) > 0
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 3
Instances
Project Name: rusty1s/pytorch_geometric
Commit Name: 1f0750670cf8ea24ad264debf9603002ab0fb565
Time: 2021-03-15
Author: matthias.fey@tu-dortmund.de
File Name: test/nn/pool/test_mem_pool.py
Class Name:
Method Name: test_mem_pool
Project Name: markovmodel/PyEMMA
Commit Name: da0655a02813a76215ee1aaaff446b6f3f4a96da
Time: 2017-04-12
Author: m.scherer@fu-berlin.de
File Name: pyemma/_base/serialization/serialization.py
Class Name: _SerializableBase
Method Name: _save_data_producer
Project Name: NifTK/NiftyNet
Commit Name: 928662beef975cbd2e1d200ade2cd558a3c3d650
Time: 2017-08-27
Author: wenqi.li@ucl.ac.uk
File Name: niftynet/engine/sampler_resize.py
Class Name: ResizeSampler
Method Name: __init__