// data is (K, j0, ..., jd-2) x jd-1 x 1
data = tf.transpose(matrix_b)
data = tf.reshape(data, (-1, a_raw_shape[1][-1], 1))
for core_idx in range(ndims - 1, -1, -1):
curr_core = tt_matrix_a.tt_cores[core_idx]
// On the k = core_idx iteration, after applying einsum the shape of data
// becomes ik x (ik-1..., id-1, K, j0, ..., jk-1) x rank_k
After Change
// data is (K, j0, ..., jd-2) x jd-1 x 1
data = tf.transpose(matrix_b)
data = tf.reshape(data, (-1, a_raw_shape[1][-1], 1))
for core_idx in reversed(range(ndims)):
curr_core = tt_matrix_a.tt_cores[core_idx]
// On the k = core_idx iteration, after applying einsum the shape of data
// becomes ik x (ik-1..., id-1, K, j0, ..., jk-1) x rank_k