loss = criterion(output[0], target)
for j in range(1, len(output)):
loss += criterion(output[j], target)
output = output[0]
else: // single output
loss = criterion(output, target)
acc = accuracy(output, target, idx)
After Change
// compute output
output = model(input)
if type(output) == list: // multiple output
loss = 0for o in output:
loss += criterion(o, target, target_weight)
output = output[-1]
else: // single output
loss = criterion(output, target, target_weight)
acc = accuracy(output, target, idx)