a6667bf98c59a7447a6bc55869459e83f5bdb603,dnc/sparse_memory.py,SparseMemory,read_from_sparse_memory,#SparseMemory#Any#Any#Any#Any#Any#,185

Before Change


  def read_from_sparse_memory(self, memory, indexes, keys, last_used_mem, usage):
    b = keys.size(0)
    read_positions = []
    read_weights = []

    // print(keys.squeeze())
    // non-differentiable operations
    for batch in range(b):
      distances, positions = indexes[batch].search(keys[batch])
      read_weights.append(distances)
      read_positions.append(T.clamp(positions, 0, self.mem_size - 1))

    // add least used mem to read positions
    read_positions = T.stack(read_positions, 0)

    // TODO: explore possibility of reading co-locations and such
    // if read_collocations:
      // read the previous and the next memory locations
      // read_positions = T.cat([read_positions, read_positions-1, read_positions+1], -1)

    read_positions = var(read_positions)
    read_positions = T.cat([read_positions, last_used_mem.unsqueeze(1)], 2)
    // print(read_positions.squeeze())

    // add weight of 0 for least used mem block
    read_weights = T.stack(read_weights, 0)
    new_block = read_weights.new(b, 1, 1)
    new_block.fill_(δ)
    read_weights = T.cat([read_weights, new_block], 2)
    read_weights = var(read_weights)
    // condition read weights by their usages
    relevant_usages = usage.gather(1, read_positions.squeeze())
    read_weights = (read_weights.squeeze(1) * relevant_usages).unsqueeze(1)

    (b, m, w) = memory.size()
    read_vectors = memory.gather(1, read_positions.squeeze().unsqueeze(2).expand(b, self.K+1, w))

After Change


    // TODO: explore possibility of reading co-locations or ranges and such
    (b, r, k) = read_positions.size()
    read_positions = var(read_positions)
    read_positions = T.cat([read_positions.view(b, -1), last_used_mem], 1)

    (b, m, w) = memory.size()
    visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, r*k+1, w))

    read_weights = F.softmax(θ(visible_memory, keys), dim=2)
    read_vectors = T.bmm(read_weights, visible_memory)

    return read_vectors, read_positions, read_weights, visible_memory
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 4

Non-data size: 6

Instances


Project Name: ixaxaar/pytorch-dnc
Commit Name: a6667bf98c59a7447a6bc55869459e83f5bdb603
Time: 2017-12-07
Author: root@ixaxaar.in
File Name: dnc/sparse_memory.py
Class Name: SparseMemory
Method Name: read_from_sparse_memory


Project Name: ruotianluo/ImageCaptioning.pytorch
Commit Name: c8fadd2d970f1c62ae8a842464056263f8d1232f
Time: 2017-02-13
Author: rluo@ttic.edu
File Name: resnet.py
Class Name: myResnet
Method Name: forward


Project Name: rusty1s/pytorch_geometric
Commit Name: 51b53dcbab8ec7ab0b6e8a64284a919db2d2254a
Time: 2018-05-08
Author: matthias.fey@tu-dortmund.de
File Name: torch_geometric/transform/local_cartesian.py
Class Name: LocalCartesian
Method Name: __call__