if dp in amp_params:
amp_params[dp] = distributed_params[dp]
amp_result = apex.amp.initialize(model, optimizer, **amp_params)
if optimizer is not None:
model, optimizer = amp_result
else:
model = amp_result
After Change
elif isinstance(model, dict):
model = {k: torch.nn.DataParallel(v) for k, v in model.items()}
elif use_apex:
model, optimizer = initialize_apex(
model, optimizer, **distributed_params
)