if max_batch < 0:
max_batch = input_len
result = self.forward_impl(texts[:max_batch], dense_feat[:max_batch])
if input_len > max_batch:
texts = texts[max_batch:]
dense_feat = dense_feat[max_batch:]
while len(texts) > 0:
result_extension = self.forward_impl(
texts[:max_batch], dense_feat[:max_batch]
)
// the result of forward is either a torch.Tensor or a List[Any]
if isinstance(result, torch.Tensor):
result = torch.cat([result, result_extension], dim=0)
else:
result.extend(result_extension)
texts = texts[max_batch:]
dense_feat = dense_feat[max_batch:]
return result
@torch.jit.script_method
def make_prediction(
self,