8b8c38c3f88a6d9ef8c0ef3ff4dc3643632f4d18,tests/distributions/test_zero_inflated.py,,test_zinb_0_gate,#Any#Any#,65

Before Change


def test_zinb_0_gate(total_count, probs):
    // if gate is 0 ZINB is NegativeBinomial
    zinb_ = ZeroInflatedNegativeBinomial(
        torch.zeros(1), total_count=torch.tensor(total_count), probs=torch.tensor(probs)
    )
    neg_bin = NegativeBinomial(torch.tensor(total_count), probs=torch.tensor(probs))
    s = neg_bin.sample((20,))
    zinb_prob = zinb_.log_prob(s)
    neg_bin_prob = neg_bin.log_prob(s)
    assert_close(zinb_prob, neg_bin_prob)

After Change


def test_zinb_0_gate(total_count, probs):
    // if gate is 0 ZINB is NegativeBinomial
    zinb1 = ZeroInflatedNegativeBinomial(
        total_count=torch.tensor(total_count), gate=torch.zeros(1), probs=torch.tensor(probs)
    )
    zinb2 = ZeroInflatedNegativeBinomial(
        total_count=torch.tensor(total_count), gate_logits=torch.tensor(-99.9), probs=torch.tensor(probs)
    )
    neg_bin = NegativeBinomial(torch.tensor(total_count), probs=torch.tensor(probs))
    s = neg_bin.sample((20,))
    zinb1_prob = zinb1.log_prob(s)
    zinb2_prob = zinb2.log_prob(s)
    neg_bin_prob = neg_bin.log_prob(s)
    assert_close(zinb1_prob, neg_bin_prob)
    assert_close(zinb2_prob, neg_bin_prob)
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 4

Non-data size: 9

Instances


Project Name: uber/pyro
Commit Name: 8b8c38c3f88a6d9ef8c0ef3ff4dc3643632f4d18
Time: 2020-10-01
Author: martinjankowiak@users.noreply.github.com
File Name: tests/distributions/test_zero_inflated.py
Class Name:
Method Name: test_zinb_0_gate


Project Name: uber/pyro
Commit Name: 8b8c38c3f88a6d9ef8c0ef3ff4dc3643632f4d18
Time: 2020-10-01
Author: martinjankowiak@users.noreply.github.com
File Name: tests/distributions/test_zero_inflated.py
Class Name:
Method Name: test_zinb_0_gate


Project Name: uber/pyro
Commit Name: 8b8c38c3f88a6d9ef8c0ef3ff4dc3643632f4d18
Time: 2020-10-01
Author: martinjankowiak@users.noreply.github.com
File Name: tests/distributions/test_zero_inflated.py
Class Name:
Method Name: test_zip_1_gate


Project Name: uber/pyro
Commit Name: 8b8c38c3f88a6d9ef8c0ef3ff4dc3643632f4d18
Time: 2020-10-01
Author: martinjankowiak@users.noreply.github.com
File Name: tests/distributions/test_zero_inflated.py
Class Name:
Method Name: test_zinb_1_gate


Project Name: uber/pyro
Commit Name: 8b8c38c3f88a6d9ef8c0ef3ff4dc3643632f4d18
Time: 2020-10-01
Author: martinjankowiak@users.noreply.github.com
File Name: tests/distributions/test_zero_inflated.py
Class Name:
Method Name: test_zip_0_gate