def forward(self, inputs):
if self.op in self.CONCAT_OPS:
inputs = self.resize(inputs)
return torch.cat(inputs, dim=1)
if self.op in self.SUM_OPS:
inputs = self.resize(inputs)
return torch.stack(inputs, dim=0).sum(dim=0)
if self.op in self.MULTI_OPS:
inputs = self.resize(inputs)
result = 1
for item in inputs:
result = result*item
return result
if self.op in self.SOFTSUM_OPS:
inputs = [conv(tensor) for conv, tensor in zip(self.conv, inputs)]
return torch.stack(inputs, dim=0).sum(dim=0)
raise ValueError("Combine operation must be in {}, instead got {}.".format(self.ALL_OPS, self.op))
def extra_repr(self):
After Change
if callable(self.op):
return self.op(inputs)
if self.op in self.OPS:
return self.OPS[self.op](inputs)
raise ValueError("Combine operation must be a callable or \
one from {}, instead got {}.".format(self.ALL_OPS, self.op))