cce84b5ca5fb02bda814138bd361eea6cafc16d5,test/test_foreach.py,TestForeach,test_complex_scalarlist,#TestForeach#Any#Any#,584

Before Change


    @dtypes(*torch.testing.get_all_dtypes())
    def test_complex_scalarlist(self, device, dtype):
        for N in N_values:
            for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
                                                                     self.foreach_bin_ops_,
                                                                     self.torch_bin_ops):
                tensors = self._get_test_data(device, dtype, N)
                scalars = [3 + 5j for _ in range(N)]
                expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]

After Change


                    res = foreach_bin_op(tensors, scalars)
                    self.assertEqual(res, expected)

                if dtype not in [torch.complex64, torch.complex128]:
                    with self.assertRaisesRegex(RuntimeError, "can"t be cast to the desired output type"):
                        foreach_bin_op_(tensors, scalars)
                else:
                    foreach_bin_op_(tensors, scalars)
                    self.assertEqual(res, tensors)

    @skipCUDAIfRocm
    @dtypes(*torch.testing.get_all_dtypes())
    def test_bool_scalar(self, device, dtype):
        for N in N_values:
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 5

Instances


Project Name: pytorch/pytorch
Commit Name: cce84b5ca5fb02bda814138bd361eea6cafc16d5
Time: 2021-02-02
Author: iuriiz@devfair004.maas
File Name: test/test_foreach.py
Class Name: TestForeach
Method Name: test_complex_scalarlist


Project Name: dnouri/skorch
Commit Name: 53ef00376510ee9ba4506918db06b25dad4a7ea4
Time: 2017-07-31
Author: benjamin.bossan@ottogroup.com
File Name: inferno/callbacks.py
Class Name: BestLoss
Method Name: initialize


Project Name: pytorch/pytorch
Commit Name: 110a17a4d96d21cecf449073c9b66e1c888d2573
Time: 2021-03-04
Author: iuriiz@fb.com
File Name: test/test_foreach.py
Class Name: TestForeach
Method Name: test_complex_scalarlist