8950729ae44ebd96e72ca65a17efd3cc36c7c301,pliers/tests/extractors/test_text_extractors.py,,test_bert_LM_extractor,#,396

Before Change


    res_file = ext.transform(stim_file).to_df()
    res_target = ext_target.transform(stim).to_df()
    res_topn = ext_topn.transform(stim).to_df()
    res_threshold = ext_threshold.transform(stim).to_df()
    res_default = ext_default.transform(stim_masked).to_df()
    res_return_mask = ext_return_mask.transform(stim).to_df()

    assert res.shape[0] == 1

    // test onset/duration
    assert res_file["onset"][0] == 1.0
    assert res_file["duration"][0] == 0.2

    // Check target words
    assert all([w.capitalize() in res_target.columns for w in target_wds])
    assert res_target.shape[1] == 6

    // Check top_n
    assert res_topn.shape[1] == 104
    assert all([res_topn.iloc[:,3][0] > res_topn.iloc[:,i][0] for i in range(4,103)])

    // Check threshold and return_softmax
    tknz = BertTokenizer.from_pretrained("bert-base-uncased")
    vocab = tknz.vocab.keys()
    for v in vocab:
        if v.capitalize() in res_threshold.columns:
            assert res_threshold[v.capitalize()][0] >= .1
            assert res_threshold[v.capitalize()][0] <= 1

    // Test update mask method
    assert ext_target.mask == 1
    ext_target.update_mask(new_mask="sentence")
    assert ext_target.mask == "sentence"
    res_target_new = ext_target.transform(stim).to_df()
    assert all([res_target[c][0] != res_target_new[c][0]
                for c in ["Target", "Word"]])
    with pytest.raises(ValueError) as err:
        ext_target.update_mask(new_mask=["some", "mask"])
    assert "must be a string" in str(err.value)
    
    // Test default mask
    assert res_default.shape[0] == 1

    // Test return mask and input
    assert res_return_mask["true_word"][0] == "is"
    assert "true_word_score" in res_return_mask.columns
    assert res_return_mask["sequence"][0] == "This is not a tokenized sentence ."

    // remove
    del ext, ext_masked, ext_target, ext_topn, ext_threshold, ext_default, \
        ext_return_mask
    del res, res_masked, res_file, res_target, res_topn, res_threshold, \
        res_default, res_return_mask

After Change


    res_masked = BertLMExtractor().transform(stim_masked).to_df()
    res_file =  BertLMExtractor(mask=2).transform(stim_file).to_df()
    res_target = ext_target.transform(stim).to_df()
    res_topn = BertLMExtractor(mask=3, top_n=100).transform(stim).to_df()
    res_threshold = BertLMExtractor(mask=4, threshold=.1, return_softmax=True).transform(stim).to_df()
    res_default = BertLMExtractor().transform(stim_masked).to_df()
    res_return_mask = BertLMExtractor(mask=1, top_n=10, return_masked_word=True, return_input=True).transform(stim).to_df()

    assert res.shape[0] == 1

    // test onset/duration
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 11

Instances


Project Name: tyarkoni/pliers
Commit Name: 8950729ae44ebd96e72ca65a17efd3cc36c7c301
Time: 2020-04-02
Author: rbrrcc@gmail.com
File Name: pliers/tests/extractors/test_text_extractors.py
Class Name:
Method Name: test_bert_LM_extractor


Project Name: tyarkoni/pliers
Commit Name: 8950729ae44ebd96e72ca65a17efd3cc36c7c301
Time: 2020-04-02
Author: rbrrcc@gmail.com
File Name: pliers/tests/extractors/test_text_extractors.py
Class Name:
Method Name: test_bert_sequence_extractor


Project Name: tyarkoni/pliers
Commit Name: 8950729ae44ebd96e72ca65a17efd3cc36c7c301
Time: 2020-04-02
Author: rbrrcc@gmail.com
File Name: pliers/tests/extractors/test_text_extractors.py
Class Name:
Method Name: test_bert_LM_extractor


Project Name: tyarkoni/pliers
Commit Name: 8950729ae44ebd96e72ca65a17efd3cc36c7c301
Time: 2020-04-02
Author: rbrrcc@gmail.com
File Name: pliers/tests/extractors/test_text_extractors.py
Class Name:
Method Name: test_bert_sentiment_extractor