@dtypes(*torch.testing.get_all_dtypes())
def test_int_scalarlist(self, device, dtype):
for N in N_values:
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
self.foreach_bin_ops_,
self.torch_bin_ops):
tensors = self._get_test_data(device, dtype, N)
scalars = [1 for _ in range(N)]
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
// we dont support bool and complex types on CUDA for now
if (dtype in torch.testing.get_all_complex_dtypes() or dtype == torch.bool) and self.device_type == "cuda":
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
foreach_bin_op_(tensors, scalars)
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
foreach_bin_op(tensors, scalars)
return
res = foreach_bin_op(tensors, scalars)
if dtype == torch.bool:
self.assertEqual(res, [torch_bin_op(t.to(torch.float32), s) for t, s in zip(tensors, scalars)])
with self.assertRaisesRegex(RuntimeError, "result type Float can"t be cast to the desired output type"):
foreach_bin_op_(tensors, scalars)
return
if dtype in torch.testing.integral_types():
if self.device_type == "cpu":
self.assertEqual(res, [e.to(torch.float32) for e in expected])
else:
// TODO[type promotion]: Fix once type promotion is enabled.
self.assertEqual(res, [e.to(dtype) for e in expected])
else:
self.assertEqual(res, expected)
if dtype in torch.testing.integral_types() and self.device_type == "cpu":
with self.assertRaisesRegex(RuntimeError, "result type Float can"t be cast to the desired output type"):
foreach_bin_op_(tensors, scalars)
return
else:
foreach_bin_op_(tensors, scalars)
self.assertEqual(res, tensors)