// Test with the same number of base points and points
points = self.group.random_uniform(n_samples=n_samples)
base_points = self.group.random_uniform(n_samples=n_samples)
results = metric.log(points, base_points)
self.assertTrue(np.allclose(results.shape,
(n_samples, self.group.dimension)))
After Change
// TODO(nina): this test fails
n_samples = self.n_samples
for metric_type in self.metrics:
metric = self.metrics[metric_type]
one_point = self.group.random_uniform(n_samples=1)
one_base_point = self.group.random_uniform(n_samples=1)
n_point = self.group.random_uniform(n_samples=n_samples)
n_base_point = self.group.random_uniform(n_samples=n_samples)
// Test with the 1 base point, and several different points
result = metric.log(n_point, one_base_point)
self.assertTrue(np.allclose(result.shape,
(n_samples, self.group.dimension)))
expected = np.vstack([metric.log(point, one_base_point)
for point in n_point])
self.assertTrue(np.allclose(expected.shape,
(n_samples, self.group.dimension)))
self.assertTrue(np.allclose(result, expected),
"with metric {}".format(metric_type))