7da9bf6e3f5ed3dcd6f3adc6cf674e11c307741e,baseline/pytorch/tagger/model.py,TaggerModelBase,create,#Any#Any#Any#,51
Before Change
model = cls()
model.lengths_key = kwargs.get("lengths_key")
model.activation_type = kwargs.get("activation", "tanh")
model.embeddings = embeddings
model.pdrop = float(kwargs.get("dropout", 0.5))
model.dropin_values = kwargs.get("dropin", {})
model.labels = labels
embed_model = model.init_embed(**kwargs)
transducer_model = model.init_encoder(embed_model.output_dim, **kwargs)
use_crf = bool(kwargs.get("crf", False))
constraint_mask = kwargs.get("constraint_mask")
if constraint_mask is not None:
constraint_mask = constraint_mask.unsqueeze(0)
if use_crf:
decoder_model = CRF(len(labels), constraint_mask=constraint_mask, batch_first=True)
else:
decoder_model = TaggerGreedyDecoder(
len(labels),
constraint_mask=constraint_mask,
batch_first=True,
reduction=kwargs.get("reduction", "batch")
)
model.layers = TagSequenceModel(len(labels), embed_model, transducer_model, decoder_model)
logger.info(model.layers)
return model
def drop_inputs(self, key, x):
After Change
model.pdrop = float(kwargs.get("dropout", 0.5))
model.dropin_values = kwargs.get("dropin", {})
model.labels = labels
model.create_layers(embeddings, **kwargs)
return model
def create_layers(self, embeddings: Dict[str, TensorDef], **kwargs):
This method defines the model itself, and must be overloaded by derived classes
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 8
Instances Project Name: dpressel/mead-baseline
Commit Name: 7da9bf6e3f5ed3dcd6f3adc6cf674e11c307741e
Time: 2020-04-22
Author: dpressel@gmail.com
File Name: baseline/pytorch/tagger/model.py
Class Name: TaggerModelBase
Method Name: create
Project Name: dpressel/mead-baseline
Commit Name: 7da9bf6e3f5ed3dcd6f3adc6cf674e11c307741e
Time: 2020-04-22
Author: dpressel@gmail.com
File Name: baseline/pytorch/classify/model.py
Class Name: ClassifierModelBase
Method Name: create
Project Name: dpressel/mead-baseline
Commit Name: 7da9bf6e3f5ed3dcd6f3adc6cf674e11c307741e
Time: 2020-04-22
Author: dpressel@gmail.com
File Name: baseline/pytorch/tagger/model.py
Class Name: TaggerModelBase
Method Name: create
Project Name: dpressel/mead-baseline
Commit Name: 7da9bf6e3f5ed3dcd6f3adc6cf674e11c307741e
Time: 2020-04-22
Author: dpressel@gmail.com
File Name: baseline/tf/tagger/model.py
Class Name: TaggerModelBase
Method Name: create