net = net.module
// if you are using PyTorch newer than 0.4 (e.g., built from
// GitHub source), you can remove str() on self.device
state_dict = torch.load(save_path, map_location=str(self.device))
// patch InstanceNorm checkpoints prior to 0.4
for key in state_dict:
self.__patch_instance_norm_state_dict(state_dict, net, key.split("."))
net.load_state_dict(state_dict)
// print network information
def print_networks(self, verbose):
print("---------- Networks initialized -------------")