443a431ac35b19d5948f5dcb7bce1a2e7b1034a7,test/test_foreach.py,TestForeach,test_int_scalarlist,#TestForeach#Any#Any#,379

Before Change


                    self.assertEqual(res, expected)

                // test in-place
                if dtype in torch.testing.floating_types() and self.device_type == "cpu":
                    foreach_bin_op_(tensors, scalars)
                    return
                else:
                    if foreach_bin_op_ == torch._foreach_div_ and \
                       dtype in torch.testing.integral_types() and \
                       self.device_type == "cpu":
                        with self.assertRaisesRegex(RuntimeError, "can"t be cast to the desired output type"):
                            foreach_bin_op_(tensors, scalars)
                    else:
                        foreach_bin_op_(tensors, scalars)
                        self.assertEqual(res, tensors)

    @skipCUDAIfRocm
    @dtypes(*torch.testing.get_all_dtypes())
    def test_float_scalar(self, device, dtype):
        for N in N_values:

After Change


    @dtypes(*torch.testing.get_all_dtypes())
    def test_int_scalarlist(self, device, dtype):
        for N in N_values:
            for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
                                                                     self.foreach_bin_ops_,
                                                                     self.torch_bin_ops):
                tensors = self._get_test_data(device, dtype, N)
                scalars = [1 for _ in range(N)]
                expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]

                // we dont support bool and complex types on CUDA for now
                if (dtype in torch.testing.get_all_complex_dtypes() or dtype == torch.bool) and self.device_type == "cuda":
                    with self.assertRaisesRegex(RuntimeError, "not implemented for"):
                        foreach_bin_op_(tensors, scalars)

                    with self.assertRaisesRegex(RuntimeError, "not implemented for"):
                        foreach_bin_op(tensors, scalars)
                    return

                res = foreach_bin_op(tensors, scalars)

                if dtype == torch.bool:
                    self.assertEqual(res, [torch_bin_op(t.to(torch.float32), s) for t, s in zip(tensors, scalars)])

                    with self.assertRaisesRegex(RuntimeError, "result type Float can"t be cast to the desired output type"):
                        foreach_bin_op_(tensors, scalars)
                    return

                if dtype in torch.testing.integral_types():
                    if self.device_type == "cpu":
                        self.assertEqual(res, [e.to(torch.float32) for e in expected])
                    else:
                        // TODO[type promotion]: Fix once type promotion is enabled.
                        self.assertEqual(res, [e.to(dtype) for e in expected])
                else:
                    self.assertEqual(res, expected)

                if dtype in torch.testing.integral_types() and self.device_type == "cpu":
                    with self.assertRaisesRegex(RuntimeError, "result type Float can"t be cast to the desired output type"):
                        foreach_bin_op_(tensors, scalars)
                    return
                else:
                    foreach_bin_op_(tensors, scalars)
                    self.assertEqual(res, tensors)
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 9

Instances


Project Name: pytorch/pytorch
Commit Name: 443a431ac35b19d5948f5dcb7bce1a2e7b1034a7
Time: 2021-02-03
Author: ngimel@fb.com
File Name: test/test_foreach.py
Class Name: TestForeach
Method Name: test_int_scalarlist


Project Name: pytorch/pytorch
Commit Name: 443a431ac35b19d5948f5dcb7bce1a2e7b1034a7
Time: 2021-02-03
Author: ngimel@fb.com
File Name: test/test_foreach.py
Class Name: TestForeach
Method Name: test_float_scalarlist


Project Name: pytorch/pytorch
Commit Name: 443a431ac35b19d5948f5dcb7bce1a2e7b1034a7
Time: 2021-02-03
Author: ngimel@fb.com
File Name: test/test_foreach.py
Class Name: TestForeach
Method Name: test_complex_scalarlist