for i, s in enumerate(self._all_states):
param_groups.extend(s["param_groups"])
end = start + len(s["param_groups"])
partition.append((start, end))
start = end
return {
After Change
// Unify the shard states and the state that pytorch would expect, given the model.
// Indexation needs several redirections, since each shard only knows a limited scope of the model
// - get the pytorch compliant parameter indexing
state_dict = super().state_dict()
// - go through the per-shard states, which are all indexed locally
for rank, s in enumerate(self._all_states):
// -- match the local indexing and the global partition, update the corresponding saved state globally