40c5039380d45c96ddbd3fd951e5c9adf167647d,experimental/unet/train_unet_demo.py,,cli_main,#Any#,17
Before Change
// 1 INIT LIGHTNING MODEL
// ------------------------
pl.seed_everything(args.seed)
model = UnetModule(**vars(args))
// ------------------------
// 2 INIT TRAINER
// ------------------------
trainer = pl.Trainer.from_argparse_args(args)
// ------------------------
// 3 START TRAINING OR TEST
// ------------------------
if args.mode == "train":
trainer.fit(model)
elif args.mode == "test":
assert args.resume_from_checkpoint is not None
outputs = trainer.test(model)
fastmri.save_reconstructions(outputs, args.default_root_dir / "reconstructions")
else:
raise ValueError(f"unrecognized mode {args.mode}")
After Change
// data
// ------------
// this creates a k-space mask for transforming input data
mask = create_mask_for_mask_type(
args.mask_type, args.center_fractions, args.accelerations
)
// use random masks for train transform, fixed masks for val transform
train_transform = UnetDataTransform(args.challenge, mask_func=mask, use_seed=False)
val_transform = UnetDataTransform(args.challenge, mask_func=mask)
test_transform = UnetDataTransform(args.challenge, mask_func=mask)
// ptl data module - this handles data loaders
data_module = FastMriDataModule(
data_path=args.data_path,
challenge=args.challenge,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
test_split=args.test_split,
test_path=args.test_path,
sample_rate=args.sample_rate,
batch_size=args.batch_size,
num_workers=args.num_workers,
distributed_sampler=(args.accelerator == "ddp"),
)
// ------------
// model
// ------------
model = UnetModule(
in_chans=args.in_chans,
out_chans=args.out_chans,
chans=args.chans,
num_pool_layers=args.num_pool_layers,
drop_prob=args.drop_prob,
lr=args.lr,
lr_step_size=args.lr_step_size,
lr_gamma=args.lr_gamma,
weight_decay=args.weight_decay,
)
In pattern: SUPERPATTERN
Frequency: 4
Non-data size: 29
Instances
Project Name: facebookresearch/fastMRI
Commit Name: 40c5039380d45c96ddbd3fd951e5c9adf167647d
Time: 2020-10-21
Author: matt.muckley@gmail.com
File Name: experimental/unet/train_unet_demo.py
Class Name:
Method Name: cli_main
Project Name: facebookresearch/fastMRI
Commit Name: 40c5039380d45c96ddbd3fd951e5c9adf167647d
Time: 2020-10-21
Author: matt.muckley@gmail.com
File Name: tests/test_modules.py
Class Name:
Method Name: test_unet_trainer
Project Name: facebookresearch/fastMRI
Commit Name: 40c5039380d45c96ddbd3fd951e5c9adf167647d
Time: 2020-10-21
Author: matt.muckley@gmail.com
File Name: experimental/unet/train_unet_demo.py
Class Name:
Method Name: cli_main
Project Name: facebookresearch/fastMRI
Commit Name: 40c5039380d45c96ddbd3fd951e5c9adf167647d
Time: 2020-10-21
Author: matt.muckley@gmail.com
File Name: experimental/varnet/train_varnet_demo.py
Class Name:
Method Name: cli_main
Project Name: facebookresearch/fastMRI
Commit Name: 40c5039380d45c96ddbd3fd951e5c9adf167647d
Time: 2020-10-21
Author: matt.muckley@gmail.com
File Name: tests/test_modules.py
Class Name:
Method Name: test_varnet_trainer