model_out = self.lw_model(fc_feats, att_feats, labels, masks, att_masks, data["gts"], torch.arange(0, len(data["gts"])), sc_flag, struc_flag)
loss = model_out["loss"]
logger_logs = {k:v.data for k,v in model_out.items() if k != "loss"}
logger_logs["scheduled_sampling_prob"] = self.model.ss_prob
logger_logs["training_loss"] = loss
output = {k:v if k == "loss" else v.data for k,v in model_out.items()}
output["log"] = logger_logs
return output