g = self.gates(x)
begin_g = 0 // index of first scalar gate capsule
zs = []
for n, mul in enumerate(self.repr_in):
if mul == 0:
continue
dim = 2 * n + 1
// crop out capsules of order n
field_x = x[:, begin_x: begin_x + mul * dim] // [batch, feature * repr, x, y, z]
begin_x += mul * dim
if n == 0:
if self.scalar_act is not None:
field = self.scalar_act(field_x)
else:
field = field_x
else:
if self.gates is not None:
// reshape channels in capsules and capsule entries
field_x = field_x.contiguous()
field_x = field_x.view(nbatch, mul, dim, nx, ny, nz) // [batch, feature, repr, x, y, z]
// crop out corresponding scalar gates
field_g = g[:, begin_g: begin_g + mul] // [batch, feature, x, y, z]
begin_g += mul
// reshape channels for broadcasting
field_g = field_g.contiguous()
field_g = field_g.view(nbatch, mul, 1, nx, ny, nz) // [batch, feature, repr, x, y, z]
// scale non-scalar capsules by gate values
field = field_x * field_g // [batch, feature, repr, x, y, z]
field = field.view(nbatch, mul * dim, nx, ny, nz) // [batch, feature * repr, x, y, z]
else:
field = field_x
zs.append(field)
// TODO change this cat into new_empty and fill
return torch.cat(zs, dim=1) // does not contain gates