aux_loss_fn = None
if args.thresh_test_preds:
thresholds = pd.read_csv(args.thresh_test_preds, header=None).values.squeeze()
elif len(last_thresholds) > 0:
// Re-use previous thresholds, if provided.
// Why? More accurate reporting, and not that slow. Don"t compute thresholds on training, for example -- but can recycle val threshold
thresholds = last_thresholds
else:
// Default thresholds -- faster, but less accurate
thresholds = np.array([default_threshold for _ in range(int(model.out_dim/heads_per_class))])
total_loss = 0
total_classifier_loss = 0
total_lm_loss = 0
total_multihead_variance_loss = 0
class_accuracies = torch.zeros(model.out_dim).cuda()
if model.out_dim/heads_per_class > 1 and not args.use_softmax:
keys = list(args.non_binary_cols)
elif args.use_softmax:
keys = [str(m) for m in range(model.out_dim)]
else:
keys = [""]
info_dicts = [{"fp" : 0, "tp" : 0, "fn" : 0, "tn" : 0, "std" : 0,
"metric" : args.metric, "micro" : args.micro} for k in keys]
// Sanity check -- should do this sooner. Does //classes match expected output?
assert model.out_dim == len(keys) * heads_per_class, "model.out_dim does not match keys (%s) x heads_per_class (%d)" % (keys, heads_per_class)
batch_adjustment = 1. / len(text)
// Save all outputs *IF* small enough, and requested for thresholding -- basically, on validation
//if threshold_validation and LR is not None:
all_batches = []
all_stds = []
all_labels = []
print("Running %d batches" % len(text))
for i, data in tqdm(enumerate(text), total=len(text), unit="batch", desc=tqdm_desc, position=1, ncols=100):
text_batch, labels_batch, length_batch = get_supervised_batch(data, args.cuda, model, args.ids, args, heads_per_class=args.heads_per_class)
class_out, (lm_out, _) = transform(model, text_batch, labels_batch, length_batch, args, LR)
class_std = None
if heads_per_class > 1:
all_heads, class_out, class_std = class_out
classifier_loss = clf_loss_fn(all_heads, labels_batch)
else:
classifier_loss = clf_loss_fn(class_out, labels_batch)
loss = classifier_loss
classifier_loss = classifier_loss.clone() // save for reporting
// Also compute multihead variance loss -- from classifier [divide by output size since it scales linearly]
if args.aux_head_variance_loss_weight > 0.:
multihead_variance_loss = model.classifier.get_last_layer_variance() / model.out_dim
loss = loss + multihead_variance_loss * args.aux_head_variance_loss_weight
// Divide by // batches? Since we"re looking at the parameters here, and should be batch independent.
// multihead_variance_loss *= batch_adjustment
if args.aux_lm_loss:
lm_labels = text_batch[1:]
lm_losses = aux_loss_fn(lm_out[:-1].view(-1, lm_out.size(2)).contiguous().float(),
lm_labels.contiguous().view(-1))
padding_mask = (torch.arange(lm_labels.size(0)).unsqueeze(1).cuda() > length_batch).float()
portion_unpadded = padding_mask.sum() / padding_mask.size(0)
lm_loss = portion_unpadded * torch.mean(lm_losses * (padding_mask.view(-1).float()))
// Scale LM loss -- since it"s so big
if args.aux_lm_loss_weight > 0.:
loss = loss + lm_loss * args.aux_lm_loss_weight
// Training
if LR is not None:
LR.optimizer.zero_grad()
loss.backward()
LR.optimizer.step()
LR.step()
// Remove loss from CUDA -- kill gradients and save memory.
total_loss += loss.detach().cpu().numpy()
if args.use_softmax:
labels_batch = onehot(labels_batch.squeeze(), model.out_dim)
class_out = onehot(torch.max(class_out, -1)[1].squeeze(), int(model.out_dim/heads_per_class))
total_classifier_loss += classifier_loss.detach().cpu().numpy()
if args.aux_lm_loss:
total_lm_loss += lm_loss.detach().cpu().numpy()
if args.aux_head_variance_loss_weight > 0:
total_multihead_variance_loss += multihead_variance_loss.detach().cpu().numpy()
for j in range(int(model.out_dim/heads_per_class)):
std = None
if class_std is not None:
std = class_std[:,j]
info_dicts[j] = update_info_dict(info_dicts[j], labels_batch[:, j], class_out[:, j], thresholds[j], std=std)
// Save, for overall thresholding (not on training)
if threshold_validation and LR is None:
all_labels.append(labels_batch.detach().cpu().numpy())
all_batches.append(class_out.detach().cpu().numpy())
if class_std is not None:
all_stds.append(class_std.detach().cpu().numpy())
if threshold_validation and LR is None:
all_batches = np.concatenate(all_batches)
all_labels = np.concatenate(all_labels)
if heads_per_class > 1:
all_stds = np.concatenate(all_stds)
// Compute new thresholds -- per class
_, thresholds, _, _ = _binary_threshold(all_batches, all_labels, args.threshold_metric, True, global_tweaks=0, heads_per_class=heads_per_class, class_single_threshold=False)
info_dicts = [{"fp" : 0, "tp" : 0, "fn" : 0, "tn" : 0, "std" : 0.,
"metric" : args.metric, "micro" : args.micro} for k in keys]
// In multihead case, look at class averages? Why? More predictive. Works especially well when we force single per-class threshold.
for j in range(int(model.out_dim/heads_per_class)):
std = None
if heads_per_class > 1:
std = all_stds[:, j]
info_dicts[j] = update_info_dict(info_dicts[j], all_labels[:, j], all_batches[:, j], thresholds[j], std=std)
// Metrics for all items -- with current best thresholds
total_metrics, class_metric_strs = get_metric_report(info_dicts, args, keys, LR)
After Change
else:
class_out, clf_out = class_out
if args.dual_thresh:
class_out = class_out[:, :-1]
classifier_loss = clf_loss_fn(class_out, labels_batch)
if args.use_softmax:
class_out = F.softmax(class_out, -1)
loss = classifier_loss
classifier_loss = classifier_loss.clone() // save for reporting
// Also compute multihead variance loss -- from classifier [divide by output size since it scales linearly]
if args.aux_head_variance_loss_weight > 0.:
multihead_variance_loss = model.classifier.get_last_layer_variance() / model.out_dim
loss = loss + multihead_variance_loss * args.aux_head_variance_loss_weight
// Divide by // batches? Since we"re looking at the parameters here, and should be batch independent.
// multihead_variance_loss *= batch_adjustment
if args.aux_lm_loss:
lm_labels = text_batch[1:]
lm_losses = aux_loss_fn(lm_out[:-1].view(-1, lm_out.size(2)).contiguous().float(),
lm_labels.contiguous().view(-1))
padding_mask = (torch.arange(lm_labels.size(0)).unsqueeze(1).cuda() > length_batch).float()
portion_unpadded = padding_mask.sum() / padding_mask.size(0)
lm_loss = portion_unpadded * torch.mean(lm_losses * (padding_mask.view(-1).float()))
// Scale LM loss -- since it"s so big
if args.aux_lm_loss_weight > 0.:
loss = loss + lm_loss * args.aux_lm_loss_weight
// Training
if LR is not None:
LR.optimizer.zero_grad()
loss.backward()
LR.optimizer.step()
LR.step()
// Remove loss from CUDA -- kill gradients and save memory.
total_loss += loss.detach().cpu().numpy()
if args.use_softmax:
labels_batch = onehot(labels_batch.squeeze(), model.out_dim)
class_out = onehot(clf_out.view(-1), int(model.out_dim/heads_per_class))
total_classifier_loss += classifier_loss.detach().cpu().numpy()
if args.aux_lm_loss:
total_lm_loss += lm_loss.detach().cpu().numpy()
if args.aux_head_variance_loss_weight > 0:
total_multihead_variance_loss += multihead_variance_loss.detach().cpu().numpy()
for j in range(int(model.out_dim/heads_per_class)):
std = None
if class_std is not None:
std = class_std[:,j]
info_dicts[j] = update_info_dict(info_dicts[j], labels_batch[:, j], class_out[:, j], thresholds[j], std=std)
// Save, for overall thresholding (not on training)
if threshold_validation and LR is None:
all_labels.append(labels_batch.detach().cpu().numpy())
all_batches.append(class_out.detach().cpu().numpy())
if class_std is not None:
all_stds.append(class_std.detach().cpu().numpy())
if threshold_validation and LR is None:
all_batches = np.concatenate(all_batches)
all_labels = np.concatenate(all_labels)
if heads_per_class > 1:
all_stds = np.concatenate(all_stds)
// Compute new thresholds -- per class
_, thresholds, _, _ = _binary_threshold(all_batches, all_labels, args.threshold_metric, args.micro, global_tweaks=args.global_tweaks)
info_dicts = [{"fp" : 0, "tp" : 0, "fn" : 0, "tn" : 0, "std" : 0.,
"metric" : args.report_metric, "micro" : args.micro} for k in keys]
// In multihead case, look at class averages? Why? More predictive. Works especially well when we force single per-class threshold.
for j in range(int(model.out_dim/heads_per_class)):
std = None
if heads_per_class > 1:
std = all_stds[:, j]
info_dicts[j] = update_info_dict(info_dicts[j], all_labels[:, j], all_batches[:, j], thresholds[j], std=std)
// Metrics for all items -- with current best thresholds
total_metrics, class_metric_strs = get_metric_report(info_dicts, args, keys, LR)