65f5a9006cf5e7e8daef6187e7987d7792f0feb6,se3_cnn/non_linearities/gated_activation.py,GatedActivation,forward,#GatedActivation#Any#,61

Before Change


            g = self.gates(x)
            begin_g = 0  // index of first scalar gate capsule

        zs = []

        for n, mul in enumerate(self.repr_in):
            if mul == 0:
                continue
            dim = 2 * n + 1

            // crop out capsules of order n
            field_x = x[:, begin_x: begin_x + mul * dim]  // [batch, feature * repr, x, y, z]
            begin_x += mul * dim

            if n == 0:
                if self.scalar_act is not None:
                    field = self.scalar_act(field_x)
                else:
                    field = field_x
            else:
                if self.gates is not None:
                    // reshape channels in capsules and capsule entries
                    field_x = field_x.contiguous()
                    field_x = field_x.view(nbatch, mul, dim, nx, ny, nz)  // [batch, feature, repr, x, y, z]

                    // crop out corresponding scalar gates
                    field_g = g[:, begin_g: begin_g + mul]  // [batch, feature, x, y, z]
                    begin_g += mul
                    // reshape channels for broadcasting
                    field_g = field_g.contiguous()
                    field_g = field_g.view(nbatch, mul, 1, nx, ny, nz)  // [batch, feature, repr, x, y, z]

                    // scale non-scalar capsules by gate values
                    field = field_x * field_g  // [batch, feature, repr, x, y, z]
                    field = field.view(nbatch, mul * dim, nx, ny, nz)  // [batch, feature * repr, x, y, z]
                else:
                    field = field_x

            zs.append(field)

        // TODO change this cat into new_empty and fill
        return torch.cat(zs, dim=1)  // does not contain gates

After Change


            g = self.gates(x)
            begin_g = 0  // index of first scalar gate capsule

        size_out = sum(mul * (2 * n + 1) for n, mul in enumerate(self.repr_in))
        z = x.new_empty((x.size(0), size_out, x.size(2), x.size(3), x.size(4)))

        for n, mul in enumerate(self.repr_in):
            if mul == 0:
                continue
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 8

Instances


Project Name: mariogeiger/se3cnn
Commit Name: 65f5a9006cf5e7e8daef6187e7987d7792f0feb6
Time: 2018-06-12
Author: geiger.mario@gmail.com
File Name: se3_cnn/non_linearities/gated_activation.py
Class Name: GatedActivation
Method Name: forward


Project Name: OpenNMT/OpenNMT-py
Commit Name: 685126644ae540be72eb662527269a0395e2c9eb
Time: 2017-09-05
Author: bpeters@coli.uni-saarland.de
File Name: onmt/IO.py
Class Name:
Method Name: make_features


Project Name: mariogeiger/se3cnn
Commit Name: aa7c004df5c781fc3b5b8131d7a9e64fd196203e
Time: 2018-05-26
Author: geiger.mario@gmail.com
File Name: se3_cnn/blocks/gated_block.py
Class Name: GatedBlock
Method Name: forward