dccb5015ca3443c490aa4f1100892b0bfb5f957b,geomstats/riemannian_metric.py,RiemannianMetric,mean,#RiemannianMetric#Any#Any#Any#Any#Any#,241

Before Change


        // i.e. what to do with sq_dists_between_iterates.

        if isinstance(points, list):
            points = gs.vstack(points)

        if point_type == "vector":
            points = gs.to_ndarray(points, to_ndim=2)
        if point_type == "matrix":
            points = gs.to_ndarray(points, to_ndim=3)
        n_points = gs.shape(points)[0]

        if isinstance(weights, list):
            weights = gs.vstack(weights)
        if weights is None:
            weights = gs.ones((n_points, 1))

        weights = gs.array(weights)
        weights = gs.to_ndarray(weights, to_ndim=2, axis=1)

        sum_weights = gs.sum(weights)

        mean = points[0]
        if n_points == 1:
            return mean

        sq_dists_between_iterates = []
        iteration = 0
        while iteration < n_max_iterations:
            a_tangent_vector = self.log(mean, mean)
            tangent_mean = gs.zeros_like(a_tangent_vector)

            logs = self.log(point=points, base_point=mean)
            tangent_mean += gs.einsum("nk,nj->j", weights, logs)

After Change



        mean = points[0]
        if point_type == "vector":
            mean = gs.to_ndarray(mean, to_ndim=2)
        if point_type == "matrix":
            mean = gs.to_ndarray(mean, to_ndim=3)

        if n_points == 1:
            return mean

        sq_dists_between_iterates = []
        iteration = 0
        sq_dist = gs.array([[0.]])
        variance = gs.array([[0.]])

        //iteration = gs.constant(0)

        def while_loop_body(iteration, mean, variance, sq_dist):
            tangent_mean = gs.zeros_like(mean)

            logs = self.log(point=points, base_point=mean)
            tangent_mean += gs.einsum("nk,nj->j", weights, logs)

            tangent_mean /= sum_weights

            mean_next = self.exp(
                tangent_vec=tangent_mean,
                base_point=mean)

            sq_dist = self.squared_dist(mean_next, mean)
            sq_dists_between_iterates.append(sq_dist)

            variance = self.variance(points=points,
                                     weights=weights,
                                     base_point=mean_next)

            mean = mean_next
            iteration += 1
            return [iteration, mean, variance, sq_dist]

        def while_loop_cond(iteration, mean, variance, sq_dist):
            result = gs.logical_or(
                gs.isclose(variance, 0.),
                gs.less_equal(sq_dist, epsilon * variance))
            return result[0, 0]

        last_iteration, mean, variance, sq_dist = gs.while_loop(
            lambda i, m, v, sq: while_loop_cond(i, m, v, sq),
            lambda i, m, v, sq: while_loop_body(i, m, v, sq),
            loop_vars=[iteration, mean, variance, sq_dist],
            maximum_iterations=n_max_iterations)
        //while iteration < n_max_iterations:

        //    if gs.isclose(variance, 0.)[0, 0]:
        //        break
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 7

Instances


Project Name: geomstats/geomstats
Commit Name: dccb5015ca3443c490aa4f1100892b0bfb5f957b
Time: 2018-12-31
Author: ninamio78@gmail.com
File Name: geomstats/riemannian_metric.py
Class Name: RiemannianMetric
Method Name: mean


Project Name: geomstats/geomstats
Commit Name: 01673d1a6dcb41a20e19f951ee450c44c07aeafd
Time: 2019-06-16
Author: ninamio78@gmail.com
File Name: geomstats/riemannian_metric.py
Class Name: RiemannianMetric
Method Name: mean


Project Name: geomstats/geomstats
Commit Name: 31d8076c8dd31c28054e820571ef38234950e101
Time: 2018-05-08
Author: ninamio78@gmail.com
File Name: geomstats/spd_matrices_space.py
Class Name:
Method Name: group_log