optimizer = init_optimizer(model.parameters(), **optimizer_kwargs(args))
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.stepsize, gamma=args.gamma)
if args.fixbase_epoch > 0:
if hasattr(model, "classifier") and isinstance(model.classifier, nn.Module):
optimizer_tmp = init_optimizer(model.classifier.parameters(), **optimizer_kwargs(args))
else:
print("Warn: model has no attribute "classifier" and fixbase_epoch is reset to 0")
args.fixbase_epoch = 0
raise NotImplementedError
if args.load_weights and check_isfile(args.load_weights):
// load pretrained weights but ignore layers that don"t match in size
checkpoint = torch.load(args.load_weights)
pretrain_dict = checkpoint["state_dict"]
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
print("Loaded pretrained weights from "{}"".format(args.load_weights))
if args.resume and check_isfile(args.resume):
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint["state_dict"])
args.start_epoch = checkpoint["epoch"] + 1
print("Loaded checkpoint from "{}"".format(args.resume))
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, checkpoint["rank1"]))
if use_gpu:
model = nn.DataParallel(model).cuda()
if args.evaluate:
print("Evaluate only")
for name in args.target_names:
print("Evaluating {} ...".format(name))
queryloader = testloader_dict[name]["query"]
galleryloader = testloader_dict[name]["gallery"]
distmat = test(model, queryloader, galleryloader, args.pool_tracklet_features, use_gpu, return_distmat=True)
if args.visualize_ranks:
visualize_ranked_results(
distmat, dm.return_testdataset_by_name(name),
save_dir=osp.join(args.save_dir, "ranked_results", name),
topk=20
)
return
start_time = time.time()
ranklogger = RankLogger(args.source_names, args.target_names)
train_time = 0
print("==> Start training")
if args.fixbase_epoch > 0:
print("Train classifier for {} epochs while keeping base network frozen".format(args.fixbase_epoch))
for epoch in range(args.fixbase_epoch):
start_train_time = time.time()
train(epoch, model, criterion, optimizer_tmp, trainloader, use_gpu, freeze_bn=True)
train_time += round(time.time() - start_train_time)
del optimizer_tmp
print("Now open all layers for training")
raise NotImplementedError
for epoch in range(args.start_epoch, args.max_epoch):
start_train_time = time.time()
train(epoch, model, criterion, optimizer, trainloader, use_gpu)
After Change
for epoch in range(args.fixbase_epoch):
start_train_time = time.time()
train(epoch, model, criterion, optimizer, trainloader, use_gpu, fixbase=True)
train_time += round(time.time() - start_train_time)
print("Done. All layers are open to train for {} epochs".format(args.max_epoch))
optimizer.load_state_dict(initial_optim_state)