// training
if args.pretrain == None:
model = fit(args, train_sequence, train_cluster_id)
torch.save(model.rnn_model.state_dict(), "rnn_model {}".format(args.dataset))
else: // use pretrained model
model = init_model(args, train_sequence, train_cluster_id)
model.rnn_model.load_state_dict(torch.load("rnn_model {}".format(args.dataset)))