// want a sequence of dictionaries of tensors.
// First, unzip the dictionary into a sequence of keys and a
// sequence of tensor-like sequences.
keys, values = zip(*((k, [v_chunk for v_chunk in v_split])
for k, (_, v_split) in non_none.items()))
// Now, yield a dictionary for each shard. The keys are always
// the same. values is a sequence of length //keys where each
// element is a sequence of length //shards. We want to iterate
// over the shards, not over the keys: therefore, the values need
// to be re-zipped by shard and then each shard can be paired
// with the keys.
for shard_tensors in zip(*values):
yield dict(zip(keys, shard_tensors))
// Assumed backprop"d
variables = []
for k, (v, v_split) in non_none.items():
if isinstance(v, torch.Tensor) and state[k].requires_grad:
variables.extend(zip(torch.split(state[k], shard_size),
[v_chunk.grad for v_chunk in v_split]))
inputs, grads = zip(*variables)
torch.autograd.backward(inputs, grads)