def backward(self, y):
// Return diagonals of matrices
x = y.reshape(-1, self.dim, self.dim).diagonal(0, 1, 2).flatten()
return self._positive_transform.backward(x)
def forward_tensor(self, x):
y = self._positive_transform.forward_tensor(x)
return tf.matrix_diag(tf.reshape(y, (-1, self.dim)))
After Change
// Return diagonals of matrices
if not (y.shape[1] == y.shape[2] == self.dim) and (len(y.shape) == 3):
raise ValueError("shape of input does not match this transform")
return y.diagonal(offset=0, axis1=1, axis2=2).flatten()
def forward_tensor(self, x):
// create diagonal; matrices
return tf.matrix_diag(tf.reshape(x, (-1, self.dim)))