>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = (
process_group.size() if process_group is not None else dist.get_world_size()
)
tensor = bucket.get_tensors()[0]
fut = dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future()
After Change
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = group_to_use.size()
tensor = bucket.get_tensors()[0]
fut = dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future()