has the same shape as the tensor, then flatten them both to be 2D, pass them through
:func:`masked_softmax`, then put the tensor back in its original shape.
tensor_shape = tensor.size()
reshaped_tensor = tensor.view(-1, tensor.size()[-1])
if mask is not None:
while mask.dim() < tensor.dim():
mask = mask.unsqueeze(1)
mask = mask.expand_as(tensor).contiguous().float()
mask = mask.view(-1, mask.size()[-1])
reshaped_result = masked_softmax(reshaped_tensor, mask)
return reshaped_result.view(*tensor_shape)
def weighted_sum(matrix: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
After Change
assume the tensor has shape ``(batch_size, ..., sequence_length)`` and that the mask (if given)
has shape ``(batch_size, sequence_length)``.
return _last_dimension_applicator(masked_softmax, tensor, mask)
def last_dim_log_softmax(tensor: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: