assert len(obs_shape) == 3, "assumed right now"
for yi in range(obs_shape[0]):
for xi in range(obs_shape[1]):
new_x_gen_np = sess.run(new_x_gen, {x_sample: x_gen})
x_gen[:,yi,xi,:] = new_x_gen_np[:,yi,xi,:].copy()
return x_gen
// get loss gradients over multiple GPUs
After Change
x_gen = [np.zeros((args.batch_size,) + obs_shape, dtype=np.float32) for i in range(args.nr_gpu)]
for yi in range(obs_shape[0]):
for xi in range(obs_shape[1]):
new_x_gen_np = sess.run(new_x_gen, {xs[i]: x_gen[i] for i in range(args.nr_gpu)})
for i in range(args.nr_gpu):
x_gen[i][:,yi,xi,:] = new_x_gen_np[i][:,yi,xi,:]
return np.concat(x_gen, axis=0)