raise NotImplementedError
if args.model_type == "mobilenet":
net = MobileNet(n_class=n_class).cuda()
elif args.model_type == "mobilenetv2":
net = MobileNetV2(n_class=n_class).cuda()
else:
raise NotImplementedError
After Change
net.load_state_dict(torch.load(args.ckpt_path))
if args.mask_path is not None:
SZ = 224 if args.dataset == "imagenet" else 32
data = torch.randn(2, 3, SZ, SZ)ms = ModelSpeedup(net, data, args.mask_path)
ms.speedup_model()
net.to(args.device)
if torch.cuda.is_available() and args.n_gpu > 1: