class MNISTGenerator(BaseGenerator):
def create(self):
gan = self.gan
config = self.config
ops = self.ops
end_features = config.end_features or 10
ops.describe("custom_generator")
net = gan.inputs.x
net = ops.reshape(net, [gan.batch_size(), -1])
net = ops.linear(net, end_features)
net = ops.lookup("tanh")(net)
self.fy = net
self.sample = net
return net
def layer(self, name):
return getattr(self, name)
class MNISTDiscriminator(BaseDiscriminator):