]
// Again, not sure max_workers == number of diffs to avg
diffs = random.sample(diffs, server_config.get("max_workers"))
raw_diffs = [
[diff[model_param] for diff in diffs]
for model_param in range(len(model_params))
]
logging.info("raw diffs lengths: %s" % str([len(row) for row in raw_diffs]))
sums = [reduce(th.add, param) for param in raw_diffs]
logging.info("sums shapes: %s" % str([sum.shape for sum in sums]))
diff_avg = [th.div(param, len(diffs)) for param in sums]
logging.info("diff_avg shapes: %s" % str([d.shape for d in diff_avg]))