def _kinetic_energy(self, r):
r_flat = torch.cat([r[site_name].reshape(-1) for site_name in sorted(r)])
if self.inverse_mass_matrix.dim() == 2:
return 0.5 * self.inverse_mass_matrix.matmul(r_flat).dot(r_flat)
else:
return 0.5 * self.inverse_mass_matrix.dot(r_flat ** 2)
def _energy(self, z, r):
After Change
for site_names, inv_mass_matrix in self.inverse_mass_matrix.items():
r_flat = torch.cat([r[site_name].reshape(-1) for site_name in site_names])
if inv_mass_matrix.dim() == 1:
energy = energy + 0.5 * inv_mass_matrix.dot(r_flat ** 2)
else:
energy = energy + 0.5 * inv_mass_matrix.matmul(r_flat).dot(r_flat)
return energy