def forward(self, x):
x = self.base(x)
x = F.avg_pool2d(x, x.size()[2:])
f = x.view(x.size(0), -1)
if not self.training:
return f
y = self.classifier(f)
if self.loss == {"xent"}:
return y
elif self.loss == {"xent", "htri"}:
return y, f
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
After Change
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residualout = self.relu(out)return out
class Bottleneck(nn.Module):
expansion = 4