def get_output(self, train):
if not train and not self.output_reconstruction:
return self._get_hidden(train)
if self.tie_weights:
for e,d in zip(self.encoders, self.decoders):
map(self._tranpose_weights, e.get_weights(), d.get_weights())
After Change
return self.encoder.get_output(train)
def get_output(self, train):
decoded = self.decoder.get_output(train)
if self.tie_weights:
encoder_params = self.encoder.get_weights()
decoder_params = self.decoder.get_weights()
for dec_param, enc_param in zip(decoder_params, encoder_params):
if len(dec_param.shape) > 1:
enc_param = dec_param.T
return decoded
def get_config(self):
return {"name":self.__class__.__name__,
"encoder_config":self.encoder.get_config(),