num_dims = self.ndims()
if self.is_tt_matrix():
M, N = 1, 1
for i in range(num_dims):
curr_core_shape = self._tt_cores[i].get_shape()
M *= curr_core_shape[1]
After Change
N = np.prod(raw_shape[1].as_list())
return (M, N)
else:
return self.get_raw_shape()[0]
@property
def tt_cores(self):
A tuple of TT-cores.