@pytest.mark.parametrize("enable_pl_optimizer", [False, True])
def test_cpu_slurm_save_load(enable_pl_optimizer, tmpdir):
Verify model save/load/checkpoint on CPU.
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
// logger file to get meta
logger = tutils.get_default_logger(tmpdir)
version = logger.version
// fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=logger,
limit_train_batches=0.2,
limit_val_batches=0.2,
callbacks=[ModelCheckpoint(dirpath=tmpdir)],
enable_pl_optimizer=enable_pl_optimizer,
)
result = trainer.fit(model)
real_global_step = trainer.global_step
// traning complete
assert result == 1, "cpu model failed to complete"
// predict with trained model before saving
// make a prediction
dataloaders = model.test_dataloader()
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
for dataloader in dataloaders:
for batch in dataloader:
break
x, y = batch
x = x.view(x.size(0), -1)
model.eval()
pred_before_saving = model(x)
// test HPC saving
// simulate snapshot on slurm
saved_filepath = trainer.checkpoint_connector.hpc_save(trainer.weights_save_path, logger)
assert os.path.exists(saved_filepath)
// new logger file to get meta
logger = tutils.get_default_logger(tmpdir, version=version)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=logger,
callbacks=[ModelCheckpoint(dirpath=tmpdir)],
enable_pl_optimizer=enable_pl_optimizer,
)
model = EvalModelTemplate(**hparams)
// set the epoch start hook so we can predict before the model does the full training
def assert_pred_same():
assert trainer.global_step == real_global_step and trainer.global_step > 0
After Change
// new logger file to get meta
logger = tutils.get_default_logger(tmpdir, version=version)
model = BoringModel()
class _StartCallback(Callback):
// set the epoch start hook so we can predict before the model does the full training
def on_train_epoch_start(self, trainer, model):