8a015062717a7eee6d6325e06ed9fc4d369fedc6,deepchem/models/tests/test_reload.py,,test_smiles2vec_reload,#,973

Before Change



  n_tasks = 1
  data_points = 10
  mols = ["CCCCCCCC"] * data_points
  X = featurizer(mols)

  y = np.random.randint(0, 2, size=(data_points, n_tasks))
  w = np.ones(shape=(data_points, n_tasks))
  dataset = dc.data.NumpyDataset(X, y, w, mols)

After Change



  loader = dc.data.CSVLoader(
      tasks=chembl25_tasks, smiles_field="smiles", featurizer=feat)
  dataset = loader.create_dataset(
      inputs=[dataset_file], shard_size=10000, data_dir=tempfile.mkdtemp())
  y = np.random.randint(0, 2, size=(data_points, n_tasks))
  w = np.ones(shape=(data_points, n_tasks))
  dataset = dc.data.NumpyDataset(dataset.X[:data_points, :max_seq_len], y, w,
                                 dataset.ids[:data_points])

  classsification_metric = dc.metrics.Metric(
      dc.metrics.roc_auc_score, np.mean, mode="classification")

  model_dir = tempfile.mkdtemp()
  model = dc.models.Smiles2Vec(
      char_to_idx=char_to_idx,
      max_seq_len=max_seq_len,
      use_conv=True,
      n_tasks=n_tasks,
      model_dir=model_dir,
      mode="classification")
  model.fit(dataset, nb_epoch=3)

  // Reload Trained Model
  reloaded_model = dc.models.Smiles2Vec(
      char_to_idx=char_to_idx,
      max_seq_len=max_seq_len,
      use_conv=True,
      n_tasks=n_tasks,
      model_dir=model_dir,
      mode="classification")
  reloaded_model.restore()

  // Check predictions match on original dataset
  origpred = model.predict(dataset)
  reloadpred = reloaded_model.predict(dataset)
  assert np.all(origpred == reloadpred)

Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 5

Instances


Project Name: deepchem/deepchem
Commit Name: 8a015062717a7eee6d6325e06ed9fc4d369fedc6
Time: 2020-10-20
Author: bharath@Bharaths-MBP.zyxel.com
File Name: deepchem/models/tests/test_reload.py
Class Name:
Method Name: test_smiles2vec_reload


Project Name: RasaHQ/rasa
Commit Name: 06999395c7898abd4bc9ac3bae57d3516f7bca92
Time: 2019-03-14
Author: ric.wkr@gmail.com
File Name: tests/test_server.py
Class Name:
Method Name: test_stack_training


Project Name: pantsbuild/pants
Commit Name: 6ffda41d1538a5f8e6eab953346a95505c84c40c
Time: 2014-01-17
Author: travis@twitter.com
File Name: src/python/twitter/pants/python/resolver.py
Class Name: MultiResolver
Method Name: __init__