// Flatten the param_groups, save the partition which logs the rank <> shard correspondence
partition: List[Tuple[int, int]] = []
param_groups: List[Dict[Any, Any]] = []
start = 0
for i, s in enumerate(self._all_states):
param_groups.extend(s["param_groups"])
After Change
// Unify the shard states and the state that pytorch would expect, given the model.
// Indexation needs several redirections, since each shard only knows a limited scope of the model
// - get the pytorch compliant parameter indexing
state_dict = super().state_dict()
// - go through the per-shard states, which are all indexed locally
for rank, s in enumerate(self._all_states):
// -- match the local indexing and the global partition, update the corresponding saved state globally