g = self.gate_act(g)
begin_g = 0// index of first scalar gate capsule
zs = []for n, mul in enumerate(self.repr_out):
if mul == 0:
dim = 2 * n + 1// crop out capsules of order n
field_y = y[:, begin_y: begin_y + mul * dim] // [batch, feature * repr, x, y, z]
begin_y += mul * dim
if n == 0:
// Scalar activationif self.scalar_act is not None:
field = self.scalar_act(field_y)
field = field_y
if self.gate_act is not None:
// reshape channels in capsules and capsule entries
field_y = field_y.contiguous()
field_y = field_y.view(nbatch, mul, dim, nx, ny, nz) // [batch, feature, repr, x, y, z]// crop out corresponding scalar gates
field_g = g[:, begin_g: begin_g + mul] // [batch, feature, x, y, z]
begin_g += mul
// reshape channels for broadcasting
field_g = field_g.contiguous()
field_g = field_g.view(nbatch, mul, 1, nx, ny, nz) // [batch, feature, repr, x, y, z]// scale non-scalar capsules by gate values
field = field_y * field_g // [batch, feature, repr, x, y, z]
field = field.view(nbatch, mul * dim, nx, ny, nz) // [batch, feature * repr, x, y, z]else:
field = field_y
z = torch.cat(zs, dim=1)// dropoutif self.dropout is not None:
z = self.dropout(z)