91f31ae7eb26298d13f86ec146c7a4ced7af744d,examples/sir_hmc.py,,generate_data,#Any#,98

Before Change


def generate_data(args):
    logging.info("Generating data...")
    params = {"R0": torch.tensor(args.basic_reproduction_number),
              "tau": torch.tensor(args.recovery_time),
              "rho": torch.tensor(args.response_rate)}
    empty_data = [None] * args.duration

After Change


        obs = torch.stack([site["value"]
                           for name, site in tr.trace.nodes.items()
                           if re.match("obs_[0-9]+", name)])
        S2I = torch.stack([site["value"]
                          for name, site in tr.trace.nodes.items()
                          if re.match("S2I_[0-9]+", name)])
        assert len(obs) == len(empty_data)

        obs_sum = int(obs[:args.duration].sum())
        S2I_sum = int(S2I[:args.duration].sum())
        if obs_sum >= args.min_observations:
            logging.info("Observed {:d}/{:d} infections:\n{}".format(
                obs_sum, S2I_sum, " ".join([str(int(x)) for x in obs[:args.duration]])))
            return {"S2I": S2I, "obs": obs}
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 6

Instances


Project Name: uber/pyro
Commit Name: 91f31ae7eb26298d13f86ec146c7a4ced7af744d
Time: 2020-04-21
Author: fritzo@uber.com
File Name: examples/sir_hmc.py
Class Name:
Method Name: generate_data


Project Name: cornellius-gp/gpytorch
Commit Name: 91b0d220c8e816766fd4565e1d2f5115d3afbefe
Time: 2018-10-12
Author: gpleiss@gmail.com
File Name: test/functions/test_inv_quad_log_det.py
Class Name: TestInvQuadLogDetBatch
Method Name: test_log_det_only


Project Name: kengz/SLM-Lab
Commit Name: 4342569bc5e505a9d1a36b3b4a3d1269d3a2bcab
Time: 2018-09-21
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/hydra_dqn.py
Class Name: MultitaskDQN
Method Name: space_act