if model_path is None:
model_path = opt.model
checkpoint = torch.load(model_path,
map_location=lambda storage, loc: storage)
fields = onmt.io.load_fields_from_vocab(
checkpoint["vocab"], data_type=opt.data_type)
model_opt = checkpoint["opt"]
for arg in dummy_opt:
if arg not in model_opt:
model_opt.__dict__[arg] = dummy_opt[arg]
model = make_base_model(model_opt, fields,
use_gpu(opt), checkpoint)
model.eval()
model.generator.eval()
return fields, model, model_opt
def make_base_model(model_opt, fields, gpu, checkpoint=None):
Args:
model_opt: the option loaded from checkpoint.
fields: `Field` objects for the model.
gpu(bool): whether to use gpu.
checkpoint: the model gnerated by train phase, or a resumed snapshot
model from a stopped training.
Returns:
the NMTModel.
assert model_opt.model_type in ["text", "img", "audio"], \
("Unsupported model type %s" % (model_opt.model_type))
// Make encoder.