// remember to save the graph to backward cache before making it
// a local variable
ctx.backward_cache = g
g = g.local_var()
g.edata["s"] = score
g.update_all(fn.copy_e("s", "m"), fn.max("m", "smax"))
g.apply_edges(fn.e_sub_v("s", "smax", "out"))
g.edata["out"] = th.exp(g.edata["out"])
g.update_all(fn.copy_e("out", "m"), fn.sum("m", "out_sum"))
g.apply_edges(fn.e_div_v("out", "out_sum", "out"))
out = g.edata["out"]
ctx.save_for_backward(out)
return out
@staticmethod