Kmn = covariances.Kuf(inducing_variable, kernel, Xnew) // [M, L, N, P]
Knn = kernel(Xnew, full=full_cov, full_output_cov=full_output_cov) // [N, P](x N)x P or [N, P](x P)
M, L, N, K = [Kmn.shape[i] for i in range(Kmn.shape.ndims)]
Kmm = tf.reshape(Kmm, (M * L, M * L))
if full_cov == full_output_cov:
Kmn = tf.reshape(Kmn, (M * L, N * K))
After Change
Kmn = covariances.Kuf(inducing_variable, kernel, Xnew) // [M, L, N, P]
Knn = kernel(Xnew, full=full_cov, full_output_cov=full_output_cov) // [N, P](x N)x P or [N, P](x P)
M, L, N, K = tf.unstack(tf.shape(Kmn), num=Kmn.shape.ndims, axis=0)
Kmm = tf.reshape(Kmm, (M * L, M * L))
if full_cov == full_output_cov:
Kmn = tf.reshape(Kmn, (M * L, N * K))