g = self.gate_act(g)
begin_g = 0 // index of first scalar gate capsule
zs = []
for n, mul in enumerate(self.repr_out):
if mul == 0:
continue
dim = 2 * n + 1
// crop out capsules of order n
field_y = y[:, begin_y: begin_y + mul * dim] // [batch, feature * repr, x, y, z]
begin_y += mul * dim
if n == 0:
// Scalar activation
if self.scalar_act is not None:
field = self.scalar_act(field_y)
else:
field = field_y
else:
if self.gate_act is not None:
// reshape channels in capsules and capsule entries
field_y = field_y.contiguous()
field_y = field_y.view(nbatch, mul, dim, nx, ny, nz) // [batch, feature, repr, x, y, z]
// crop out corresponding scalar gates
field_g = g[:, begin_g: begin_g + mul] // [batch, feature, x, y, z]
begin_g += mul
// reshape channels for broadcasting
field_g = field_g.contiguous()
field_g = field_g.view(nbatch, mul, 1, nx, ny, nz) // [batch, feature, repr, x, y, z]
// scale non-scalar capsules by gate values
field = field_y * field_g // [batch, feature, repr, x, y, z]
field = field.view(nbatch, mul * dim, nx, ny, nz) // [batch, feature * repr, x, y, z]
else:
field = field_y
zs.append(field)
z = torch.cat(zs, dim=1)
// dropout
if self.dropout is not None:
z = self.dropout(z)