self.Rs_out = [(m, l) for m, l in Rs_out if m >= 1]
self.Rs_in = [(m, l) for m, l in Rs_in if m >= 1]
self.multiplicities_out = [m for m, _ in self.Rs_out]
self.multiplicities_in = [m for m, _ in self.Rs_in]
self.dims_out = [2 * l + 1 for _, l in self.Rs_out]
self.dims_in = [2 * l + 1 for _, l in self.Rs_in]
self.radial_function = radial_function
self.register_buffer("radii", radii)
self.J_filter_max = J_filter_max
self.n_out = sum(m * d for m, d in zip(self.multiplicities_out, self.dims_out))
self.n_in = sum(m * d for m, d in zip(self.multiplicities_in, self.dims_in))
self.sh_backwardable = sh_backwardable
self.nweights = 0
set_of_irreps = set()
After Change
self.register_buffer("radii", radii)
self.J_filter_max = J_filter_max
self.n_out = sum(m * (2 * l + 1) for m, l in self.Rs_out)
self.n_in = sum(m * (2 * l + 1) for m, l in self.Rs_in)
self.sh_backwardable = sh_backwardable
self.nweights = 0
set_of_irreps = set()