nn_probs = nn_state.probability(samples, Z)
NLL = torch.sum(probs_to_logits(nn_probs))
else:
unitary_dict = nn_state.unitary_dict
// print(train_bases)
for i in range(len(samples)):
// Check whether the sample was measured the reference basis
is_reference_basis = True
for j in range(nn_state.num_visible):
if train_bases[i][j] != "Z":
is_reference_basis = False
break
if is_reference_basis is True:
nn_probs = nn_state.probability(samples[i], Z)
NLL += torch.sum(probs_to_logits(nn_probs))
else:
psi_r = rotate_psi(nn_state, train_bases[i], space, unitary_dict)
// Get the index value of the sample state
ind = 0
for j in range(nn_state.num_visible):
if samples[i, nn_state.num_visible - j - 1] == 1:
ind += pow(2, j)
probs_r = cplx.norm_sqr(psi_r[:, ind]) / Z
NLL -= probs_to_logits(probs_r).item()
return (NLL / float(len(samples))).item()
def KL(nn_state, target_psi, space, bases=None, **kwargs):