b63bd3184d5708a11c522e6e2a11ad79834422d5,pixyz/losses/adversarial_loss.py,AdversarialKullbackLeibler,_get_estimated_value,#AdversarialKullbackLeibler#Any#Any#,168

Before Change



        if discriminator:
            // sample x from q
            x_dict = get_dict_values(x, self._q.input_var, True)
            x_q_dict = self._q.sample(x_dict, batch_size=batch_size)
            x_q_dict = get_dict_values(x_q_dict, self.d.input_var, True)

            // sample y_p from d
            y_p_dict = self.d.sample(detach_dict(x_p_dict))
            y_p = get_dict_values(y_p_dict, self.d.var)[0]

            // sample y_q from d
            y_q_dict = self.d.sample(detach_dict(x_q_dict))
            y_q = get_dict_values(y_q_dict, self.d.var)[0]

            return self.d_loss(y_p, y_q, batch_size), x

After Change



        if discriminator:
            // sample x_q from q
            x_q_dict = get_dict_values(self._q.sample(x, batch_size=batch_size), self.d.input_var, True)

            // sample y_p from d
            y_p = get_dict_values(self.d.sample(detach_dict(x_p_dict)), self.d.var)[0]
            // sample y_q from d
            y_q = get_dict_values(self.d.sample(detach_dict(x_q_dict)), self.d.var)[0]

            return self.d_loss(y_p, y_q, batch_size), x

        // sample y from d
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 7

Instances


Project Name: masa-su/pixyz
Commit Name: b63bd3184d5708a11c522e6e2a11ad79834422d5
Time: 2019-03-14
Author: masa@weblab.t.u-tokyo.ac.jp
File Name: pixyz/losses/adversarial_loss.py
Class Name: AdversarialKullbackLeibler
Method Name: _get_estimated_value


Project Name: masa-su/pixyz
Commit Name: 0c81011805c9ab4d6f7f314f674d39e51f5ba8eb
Time: 2018-10-30
Author: masa@weblab.t.u-tokyo.ac.jp
File Name: Tars/losses/gan_loss.py
Class Name: GANLoss
Method Name: estimate


Project Name: masa-su/pixyz
Commit Name: b63bd3184d5708a11c522e6e2a11ad79834422d5
Time: 2019-03-14
Author: masa@weblab.t.u-tokyo.ac.jp
File Name: pixyz/losses/adversarial_loss.py
Class Name: AdversarialJensenShannon
Method Name: _get_estimated_value