// Creats the cell function
// cell_instance_fn=lambda: cell_fn(num_units=n_hidden, **cell_init_args) // HanSheng
self.cell = cell_fn(num_units=n_hidden, **cell_init_args)
// Apply dropout
if dropout:
if type(dropout) in [tuple, list]:
in_keep_prob = dropout[0]
out_keep_prob = dropout[1]
elif isinstance(dropout, float):
in_keep_prob, out_keep_prob = dropout, dropout
else:
raise Exception("Invalid dropout type (must be a 2-D tuple of "
"float)")
try: // TF1.0
DropoutWrapper_fn = tf.contrib.rnn.DropoutWrapper
except:
DropoutWrapper_fn = tf.nn.rnn_cell.DropoutWrapper
// cell_instance_fn1=cell_instance_fn // HanSheng
// cell_instance_fn=DropoutWrapper_fn(
// cell_instance_fn1(),
// input_keep_prob=in_keep_prob,
// output_keep_prob=out_keep_prob)
self.cell = DropoutWrapper_fn(self.cell,
input_keep_prob=in_keep_prob, output_keep_prob=1.0)//out_keep_prob)
// Apply multiple layers
if n_layer > 1:
try:
MultiRNNCell_fn = tf.contrib.rnn.MultiRNNCell
except:
MultiRNNCell_fn = tf.nn.rnn_cell.MultiRNNCell
// cell_instance_fn2=cell_instance_fn // HanSheng
try:
// cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)], state_is_tuple=True) // HanSheng
self.cell = MultiRNNCell_fn([self.cell] * n_layer, state_is_tuple=True)
except: // when GRU
// cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) // HanSheng
self.cell = MultiRNNCell_fn([self.cell] * n_layer)
After Change
self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
except: // when GRU
// cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) // HanSheng
self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
if dropout:
self.cell = DropoutWrapper_fn(self.cell,