src_features, tgt_features = _collect_report_features(fields)
for j, feat in enumerate(src_features):
logger.info(" * src feature %d size = %d"
% (j, len(fields[feat].vocab)))
for j, feat in enumerate(tgt_features):
logger.info(" * tgt feature %d size = %d"
% (j, len(fields[feat].vocab)))
After Change
model_opt = default_opt
model_opt.__dict__.update(checkpoint["opt"].__dict__)
logger.info("Loading vocab from checkpoint at %s." % opt.train_from)
vocab = checkpoint["vocab"]
else:
checkpoint = None
model_opt = opt
vocab = torch.load(opt.data + ".vocab.pt")
// Load a shard dataset to determine the data_type.
// (All datasets have the same data_type).
// this should be refactored out of existence reasonably soon
first_dataset = torch.load(glob.glob(opt.data + ".train*.pt")[0])
data_type = first_dataset.data_type
// check for code where vocab is saved instead of fields
// (in the future this will be done in a smarter way
if old_style_vocab(vocab):
fields = load_fields_from_vocab(vocab, data_type)
else:
fields = vocab
// Report src and tgt vocab sizes, including for features
for side in ["src", "tgt"]:
for name, f in fields[side]:
if f.use_vocab:
logger.info(" * %s vocab size = %d" % (name, len(f.vocab)))
// Build model.
model = build_model(model_opt, opt, fields, checkpoint)
n_params, enc, dec = _tally_parameters(model)
logger.info("encoder: %d" % enc)
logger.info("decoder: %d" % dec)
logger.info("* number of parameters: %d" % n_params)
_check_save_model_path(opt)
// Build optimizer.
optim = build_optim(model, opt, checkpoint)
// Build model saver
model_saver = build_model_saver(model_opt, opt, model, fields, optim)
trainer = build_trainer(opt, device_id, model, fields,
optim, data_type, model_saver=model_saver)
// this line is kind of a temporary kludge because different objects expect
// fields to have a different structure
dataset_fields = dict(chain.from_iterable(fields.values()))
train_iter = build_dataset_iter("train", dataset_fields, opt)
valid_iter = build_dataset_iter(
"valid", dataset_fields, opt, is_train=False)