def init_state(self, src, memory_bank, enc_hidden):
See :obj:`RNNDecoderBase.init_state()`
for i, model_decoder in enumerate(self.model_decoders):
model_decoder.init_state(src, memory_bank[i], enc_hidden[i])
def map_state(self, fn):
for model_decoder in self.model_decoders:
model_decoder.map_state(fn)
class EnsembleGenerator(nn.Module):
Dummy Generator that delegates to individual real Generators,
and then averages the resulting target distributions.
def __init__(self, model_generators, raw_probs=False):
super(EnsembleGenerator, self).__init__()
self.model_generators = nn.ModuleList(model_generators)
self._raw_probs = raw_probs
def forward(self, hidden, attn=None, src_map=None):
Compute a distribution over the target dictionary
by averaging distributions from models in the ensemble.
All models in the ensemble must share a target vocabulary.
distributions = torch.stack(