model_trainer = ModelTrainer(config.get_player("model-trainer"))
prediction_client = PredictionClient(config.get_player("prediction-client"))
server0 = config.get_player("server0")
server1 = config.get_player("server1")
crypto_producer = config.get_player("crypto-producer")
with tfe.protocol.Pond(server0, server1, crypto_producer) as prot:
// get model parameters as private tensors from model owner
params = prot.define_private_input(model_trainer.player, model_trainer.provide_input, masked=True) // pylint: disable=E0632
// we"ll use the same parameters for each prediction so we cache them to avoid re-training each time
params = prot.cache(params)
// get prediction input from client
x, = prot.define_private_input(prediction_client.player, prediction_client.provide_input, masked=True) // pylint: disable=E0632
// helpers
conv = lambda x, w: prot.conv2d(x, w, 1, "VALID")
pool = lambda x: prot.avgpool2d(x, (2, 2), (2, 2), "VALID")
// compute prediction
Wconv1, bconv1, Wconv2, bconv2, Wfc1, bfc1, Wfc2, bfc2 = params
bconv1 = prot.reshape(bconv1, [-1, 1, 1])
bconv2 = prot.reshape(bconv2, [-1, 1, 1])
layer1 = pool(prot.relu(conv(x, Wconv1) + bconv1))
layer2 = pool(prot.relu(conv(layer1, Wconv2) + bconv2))
layer2 = prot.reshape(layer2, [-1, ModelTrainer.HIDDEN_FC1])
layer3 = prot.matmul(layer2, Wfc1) + bfc1
logits = prot.matmul(layer3, Wfc2) + bfc2
// send prediction output back to client
prediction_op = prot.define_output(prediction_client.player, [logits], prediction_client.receive_output)
with tfe.Session() as sess:
After Change
return op
model_trainer = ModelTrainer()
prediction_client = PredictionClient()
// get model parameters as private tensors from model owner
params = tfe.define_private_input("model-trainer", model_trainer.provide_input, masked=True) // pylint: disable=E0632