obs = torch.stack([site["value"]
for name, site in tr.trace.nodes.items()
if re.match("obs_[0-9]+", name)])
S2I = torch.stack([site["value"]
for name, site in tr.trace.nodes.items()
if re.match("S2I_[0-9]+", name)])
assert len(obs) == len(empty_data)
obs_sum = int(obs[:args.duration].sum())
S2I_sum = int(S2I[:args.duration].sum())
if obs_sum >= args.min_observations:
logging.info("Observed {:d}/{:d} infections:\n{}".format(
obs_sum, S2I_sum, " ".join([str(int(x)) for x in obs[:args.duration]])))
return {"S2I": S2I, "obs": obs}