story_embedding = tf.nn.embedding_lookup(self.Embedding,self.story) // [batch_size,story_length,sequence_length,embed_size]
query_embedding=tf.nn.embedding_lookup(self.Embedding,self.query) // [batch_size,sequence_length,embed_size]
// 1.2 mask for story and query
story_mask=tf.get_variable("story_mask",[self.sequence_length,1],initializer=tf.constant_initializer(1.0))
query_mask=tf.get_variable("query_mask",[self.sequence_length,1],initializer=tf.constant_initializer(1.0))
// 1.3 multiply of embedding and mask for story and query
self.story_embedding=tf.multiply(story_embedding,story_mask) // [batch_size,story_length,sequence_length,embed_size]
self.query_embedding=tf.multiply(query_embedding,query_mask) // [batch_size,sequence_length,embed_size]
// 1.4 use bag of words to encoder story and query
self.story_embedding=tf.reduce_sum(self.story_embedding,axis=2) //[batch_size,story_length,embed_size]
After Change
main computation graph here: 1.input encoder 2.dynamic emeory 3.output layer
// 1.input encoder
self.embedding_with_mask()
if self.use_bi_lstm:
self.input_encoder_bi_lstm()
else:
self.input_encoder_bow()
// 2. dynamic emeory
self.hidden_state=self.rnn_story() //[batch_size,block_size,hidden_size]. get hidden state after process the story
// 3.output layer
logits=self.output_module() //[batch_size,vocab_size]