if args.load_weights:
// load pretrained weights but ignore layers that don"t match in size
print("Loading pretrained weights from "{}"".format(args.load_weights))
checkpoint = torch.load(args.load_weights)
pretrain_dict = checkpoint["state_dict"]
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
if args.resume:
if osp.isfile(args.resume):
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint["state_dict"])
args.start_epoch = checkpoint["epoch"]
rank1 = checkpoint["rank1"]
print("Loaded checkpoint from "{}"".format(args.resume))
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, rank1))
else:
print("=> No checkpoint found at "{}"".format(args.resume))
if use_gpu:
model = nn.DataParallel(model).cuda()