40c5039380d45c96ddbd3fd951e5c9adf167647d,experimental/unet/train_unet_demo.py,,cli_main,#Any#,17

Before Change


    // 1 INIT LIGHTNING MODEL
    // ------------------------
    pl.seed_everything(args.seed)
    model = UnetModule(**vars(args))

    // ------------------------
    // 2 INIT TRAINER
    // ------------------------
    trainer = pl.Trainer.from_argparse_args(args)

    // ------------------------
    // 3 START TRAINING OR TEST
    // ------------------------
    if args.mode == "train":
        trainer.fit(model)
    elif args.mode == "test":
        assert args.resume_from_checkpoint is not None
        outputs = trainer.test(model)
        fastmri.save_reconstructions(outputs, args.default_root_dir / "reconstructions")
    else:
        raise ValueError(f"unrecognized mode {args.mode}")

After Change


    // data
    // ------------
    // this creates a k-space mask for transforming input data
    mask = create_mask_for_mask_type(
        args.mask_type, args.center_fractions, args.accelerations
    )
    // use random masks for train transform, fixed masks for val transform
    train_transform = UnetDataTransform(args.challenge, mask_func=mask, use_seed=False)
    val_transform = UnetDataTransform(args.challenge, mask_func=mask)
    test_transform = UnetDataTransform(args.challenge, mask_func=mask)
    // ptl data module - this handles data loaders
    data_module = FastMriDataModule(
        data_path=args.data_path,
        challenge=args.challenge,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        test_split=args.test_split,
        test_path=args.test_path,
        sample_rate=args.sample_rate,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        distributed_sampler=(args.accelerator == "ddp"),
    )

    // ------------
    // model
    // ------------
    model = UnetModule(
        in_chans=args.in_chans,
        out_chans=args.out_chans,
        chans=args.chans,
        num_pool_layers=args.num_pool_layers,
        drop_prob=args.drop_prob,
        lr=args.lr,
        lr_step_size=args.lr_step_size,
        lr_gamma=args.lr_gamma,
        weight_decay=args.weight_decay,
    )
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 4

Non-data size: 29

Instances


Project Name: facebookresearch/fastMRI
Commit Name: 40c5039380d45c96ddbd3fd951e5c9adf167647d
Time: 2020-10-21
Author: matt.muckley@gmail.com
File Name: experimental/unet/train_unet_demo.py
Class Name:
Method Name: cli_main


Project Name: facebookresearch/fastMRI
Commit Name: 40c5039380d45c96ddbd3fd951e5c9adf167647d
Time: 2020-10-21
Author: matt.muckley@gmail.com
File Name: tests/test_modules.py
Class Name:
Method Name: test_unet_trainer


Project Name: facebookresearch/fastMRI
Commit Name: 40c5039380d45c96ddbd3fd951e5c9adf167647d
Time: 2020-10-21
Author: matt.muckley@gmail.com
File Name: experimental/unet/train_unet_demo.py
Class Name:
Method Name: cli_main


Project Name: facebookresearch/fastMRI
Commit Name: 40c5039380d45c96ddbd3fd951e5c9adf167647d
Time: 2020-10-21
Author: matt.muckley@gmail.com
File Name: experimental/varnet/train_varnet_demo.py
Class Name:
Method Name: cli_main


Project Name: facebookresearch/fastMRI
Commit Name: 40c5039380d45c96ddbd3fd951e5c9adf167647d
Time: 2020-10-21
Author: matt.muckley@gmail.com
File Name: tests/test_modules.py
Class Name:
Method Name: test_varnet_trainer