f7c759ca562303127a9991574d5a985d4dff99e8,sonnet/python/modules/gated_rnn_test.py,ConvLSTMTest,testTraining,#ConvLSTMTest#Any#Any#Any#,996
Before Change
initial_state=initial_state,
dtype=tf.float32)
loss = tf.reduce_mean(tf.square(output))
train_op = tf.train.GradientDescentOptimizer(1).minimize(loss)
init = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init)
sess.run(train_op)
@parameterized.parameters(
(snt.Conv1DLSTM, 1, False, 1, 1),
(snt.Conv1DLSTM, 1, False, 1, 5),
(snt.Conv1DLSTM, 1, False, 6, 1),
After Change
(snt.Conv2DLSTM, 2, False),
(snt.Conv2DLSTM, 2, True),
)
def testTraining(self, lstm_class, dim, trainable_initial_state):
Test that training works, with or without trainable initial state.
time_steps = 1
batch_size = 2
input_shape = (8,) * dim
input_channels = 3
output_channels = 5
input_shape = (batch_size,) + input_shape + (input_channels,)
lstm = lstm_class(
input_shape=input_shape[1:],
output_channels=output_channels,
kernel_shape=1)
inputs = tf.random_normal((time_steps,) + input_shape, dtype=tf.float32)
initial_state = lstm.initial_state(
batch_size, tf.float32, trainable_initial_state)
def loss_fn():
output, _ = tf.nn.dynamic_rnn(lstm,
inputs,
time_major=True,
initial_state=initial_state,
dtype=tf.float32)
return tf.reduce_mean(tf.square(output))
train_op = tf.train.GradientDescentOptimizer(1).minimize(
loss_fn if tf.executing_eagerly() else loss_fn())
init = tf.global_variables_initializer()
self.evaluate(init)
self.evaluate(train_op)
@parameterized.parameters(
(snt.Conv1DLSTM, 1, False, 1, 1),
(snt.Conv1DLSTM, 1, False, 1, 5),
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 12
Instances
Project Name: deepmind/sonnet
Commit Name: f7c759ca562303127a9991574d5a985d4dff99e8
Time: 2018-07-17
Author: tomhennigan@google.com
File Name: sonnet/python/modules/gated_rnn_test.py
Class Name: ConvLSTMTest
Method Name: testTraining
Project Name: deepmind/sonnet
Commit Name: f7c759ca562303127a9991574d5a985d4dff99e8
Time: 2018-07-17
Author: tomhennigan@google.com
File Name: sonnet/python/modules/gated_rnn_test.py
Class Name: ConvLSTMTest
Method Name: testTraining
Project Name: deepmind/sonnet
Commit Name: f7c759ca562303127a9991574d5a985d4dff99e8
Time: 2018-07-17
Author: tomhennigan@google.com
File Name: sonnet/python/modules/gated_rnn_test.py
Class Name: ConvLSTMTest
Method Name: testDilatedConv
Project Name: deepmind/sonnet
Commit Name: f7c759ca562303127a9991574d5a985d4dff99e8
Time: 2018-07-17
Author: tomhennigan@google.com
File Name: sonnet/python/modules/gated_rnn_test.py
Class Name: ConvLSTMTest
Method Name: testLayerNorm