vector_to_grads(all_grads[0],self.nn_state.rbm_am.parameters())
vector_to_grads(all_grads[1],self.nn_state.rbm_ph.parameters())
// assign all available gradients to the corresponding parameter
//for net in self.nn_state.networks:
After Change
// self.nn_state.rbm_am.hidden_bias],lr=lr)
batch_bases = None
callbacks.on_train_start(self)
t0 = time.time()
batch_num = m.ceil(train_samples.shape[0] / pos_batch_size)
for ep in progress_bar(range(1,epochs+1), desc="Epochs ",
disable=disable_progbar):
random_permutation = torch.randperm(train_samples.shape[0])
shuffled_samples = train_samples[random_permutation]
// List of all the batches for positive phase.
pos_batches = [shuffled_samples[batch_start:(batch_start + pos_batch_size)]
for batch_start in range(0, len(train_samples), pos_batch_size)]
if input_bases is not None:
shuffled_bases = input_bases[random_permutation]
pos_batches_bases = [shuffled_bases[batch_start:(batch_start + pos_batch_size)]
for batch_start in range(0, len(train_samples), pos_batch_size)]
// DATALOADER
//pos_batches = DataLoader(train_samples, batch_size=pos_batch_size,
// shuffle=shf)
//multiplier = int((neg_batch_size / pos_batch_size) + 0.5)
//
//neg_batches = [DataLoader(train_samples, batch_size=neg_batch_size,
// shuffle=True)
// for i in range(multiplier)]
//neg_batches = chain(*neg_batches)
callbacks.on_epoch_start(self, ep)
if self.stop_training: // check for stop_training signal
break
//for batch_num, (pos_batch,neg_batch) in enumerate(zip(pos_batches,
// neg_batches)):
for b in range(batch_num):
callbacks.on_batch_start(self, ep, b)
//if input_bases is not None:
// pos_batch_bases = input_bases[batch_num*pos_batch_size:(batch_num+1)*;pos_batch_size]
if input_bases is None:
random_permutation = torch.randperm(train_samples.shape[0])
neg_batch = train_samples[random_permutation][0:neg_batch_size]
//neg_batches = [shuffled_samples[batch_start:(batch_start + neg_batch_size)]
// for batch_start in range(0, len(train_samples), neg_batch_size)]
//if input_bases is not None:
else:
random_permutation = torch.randperm(z_samples.shape[0])
neg_batch = z_samples[random_permutation][0:neg_batch_size]
batch_bases = pos_batches_bases[b]
self.nn_state.set_visible_layer(neg_batch)
all_grads = self.compute_batch_gradients(k, pos_batches[b],batch_bases)
optimizer.zero_grad() // clear any cached gradients
for p,net in enumerate(self.nn_state.networks):
rbm = getattr(self.nn_state,net)
vector_to_grads(all_grads[p],rbm.parameters())
//vector_to_grads(all_grads[0],self.nn_state.rbm_am.parameters())
//vector_to_grads(all_grads[1],self.nn_state.rbm_ph.parameters())
// assign all available gradients to the corresponding parameter
//for net in self.nn_state.networks:
//for net in self.nn_state.networks:
// rbm = getattr(self.nn_state, net)
// for param in all_grads[net].keys():
// getattr(rbm, param).grad = all_grads[net][param]
optimizer.step() // tell the optimizer to apply the gradients
callbacks.on_batch_end(self, ep, b)
if observer is not None:
if ((ep % observer.frequency) == 0):
//if target_psi is not None:
stat = observer.scan(ep,self.nn_state)
callbacks.on_epoch_end(self, ep)
//F = self.fidelity(target_psi,vis)
//print(F.item())
//callbacks.on_train_end(self)
t1 = time.time()
print("\nElapsed time = %.2f" %(t1-t0))
def vector_to_grads(vec, parameters):