b53e87871ca515b140c6c089d012854ae5b40660,Tars/models/semi_vi.py,SemiVI,__init__,#SemiVI#Any#Any#Any#Any#Any#Any#Any#,11
Before Change
self.other_distributions = tolist(other_distributions)
for distribution in other_distributions:
params += list(distribution.parameters())
self.input_var += distribution.cond_var
self.input_var = list(set(self.input_var))
self.optimizer = optimizer(params, **optimizer_params)
def train(self, train_x, u_train_x, supervised_rate=1, **kwargs):
self.p.train()
self.q.train()
After Change
self.p = p
self.q = approximate_dist
self.d = discriminator
self.other_distributions = nn.ModuleList(tolist(other_distributions))
self.other_losses = other_losses
// set params and optim
q_params = list(self.q.parameters())
p_params = list(self.p.parameters())
d_params = list(self.d.parameters())
other_params = list(self.other_distributions.parameters())
params = q_params + p_params + d_params + other_params
self.optimizer = optimizer(params, **optimizer_params)
// set input_var
self.input_var = copy(self.q.cond_var)
for distribution in self.other_distributions:
self.input_var += copy(distribution.cond_var)
self.input_var = list(set(self.input_var))
def train(self, train_x, u_train_x, supervised_rate=1, **kwargs):
self.p.train()
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 17
Instances
Project Name: masa-su/pixyz
Commit Name: b53e87871ca515b140c6c089d012854ae5b40660
Time: 2018-10-14
Author: masa@weblab.t.u-tokyo.ac.jp
File Name: Tars/models/semi_vi.py
Class Name: SemiVI
Method Name: __init__
Project Name: masa-su/pixyz
Commit Name: b53e87871ca515b140c6c089d012854ae5b40660
Time: 2018-10-14
Author: masa@weblab.t.u-tokyo.ac.jp
File Name: Tars/models/vi.py
Class Name: VI
Method Name: __init__
Project Name: masa-su/pixyz
Commit Name: b53e87871ca515b140c6c089d012854ae5b40660
Time: 2018-10-14
Author: masa@weblab.t.u-tokyo.ac.jp
File Name: Tars/models/vae.py
Class Name: VAE
Method Name: __init__