c141e570011e7adf3634bd65a3e7de30d8fbdca2,opennmt/utils/checkpoint.py,,average_checkpoints,#Any#Any#Any#Any#,190

Before Change



  for checkpoint_path in checkpoints_path:
    tf.logging.info("Loading checkpoint %s" % checkpoint_path)
    reader = tf.train.load_checkpoint(checkpoint_path)
    for name in avg_values:
      avg_values[name] += reader.get_tensor(name) / num_checkpoints

  latest_step = int(checkpoints_path[-1].split("-")[-1])

After Change


  for i, checkpoint_path in enumerate(checkpoints_path):
    tf.logging.info("Loading checkpoint %s" % checkpoint_path)
    variables = get_checkpoint_variables(checkpoint_path)
    for name, value in six.iteritems(variables):
      if _variable_is_trainable(name, value):
        scaled_value = value / num_checkpoints
        if name in new_variables:
          new_variables[name] += scaled_value
        else:
          new_variables[name] = scaled_value
      elif i + 1 == num_checkpoints:  // Take non trainable variables from the last checkpoint.
        new_variables[name] = value

  return _create_checkpoint_from_variables(
      new_variables,
      output_dir,
      session_config=session_config)
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 6

Instances


Project Name: OpenNMT/OpenNMT-tf
Commit Name: c141e570011e7adf3634bd65a3e7de30d8fbdca2
Time: 2018-10-18
Author: guillaumekln@users.noreply.github.com
File Name: opennmt/utils/checkpoint.py
Class Name:
Method Name: average_checkpoints


Project Name: elbayadm/attn2d
Commit Name: eea50f3869d720a0b4ae64960da11bc3bc59881c
Time: 2017-10-19
Author: myleott@fb.com
File Name: train.py
Class Name:
Method Name: main