R = self.distance_matrix(D)
sym = []
rsf_zeros = tf.zeros((B, N, M))
variables = []
for param in self.radial_params:
// We apply the radial pooling filter before atom type conv
// to reduce computation
param_variables, rsf = self.radial_symmetry_function(R, *param)
variables += param_variables
if not self.atom_types:
cond = tf.not_equal(Nbrs_Z, 0.0)
sym.append(tf.reduce_sum(tf.where(cond, rsf, rsf_zeros), 2))
else:
for j in range(len(self.atom_types)):
cond = tf.equal(Nbrs_Z, self.atom_types[j])
sym.append(tf.reduce_sum(tf.where(cond, rsf, rsf_zeros), 2))
layer = tf.stack(sym)
layer = tf.transpose(layer, [1, 2, 0]) // (l, B, N) -> (B, N, l)
m, v = tf.nn.moments(layer, axes=[0])
out_tensor = tf.nn.batch_normalization(layer, m, v, None, None, 1e-3)
if set_tensors:
self.variables = variables
self.out_tensor = out_tensor
return out_tensor
After Change
self.atom_types = atom_types
super(AtomicConvolution, self).__init__(**kwargs)
def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
Parameters
----------
X: tf.Tensor of shape (B, N, d)
Coordinates/features.
Nbrs: tf.Tensor of shape (B, N, M)
Neighbor list.
Nbrs_Z: tf.Tensor of shape (B, N, M)
Atomic numbers of neighbor atoms.
Returns
-------
layer: tf.Tensor of shape (B, N, l)
A new tensor representing the output of the atomic conv layer
inputs = self._get_input_tensors(in_layers)
X = inputs[0]
Nbrs = tf.to_int32(inputs[1])
Nbrs_Z = inputs[2]
// N: Maximum number of atoms
// M: Maximum number of neighbors
// d: Number of coordinates/features/filters
// B: Batch Size
N = X.get_shape()[-2].value
d = X.get_shape()[-1].value
M = Nbrs.get_shape()[-1].value
B = X.get_shape()[0].value
D = self.distance_tensor(X, Nbrs, self.boxsize, B, N, M, d)
R = self.distance_matrix(D)
sym = []
rsf_zeros = tf.zeros((B, N, M))
for param in self.radial_params:
// We apply the radial pooling filter before atom type conv
// to reduce computation
param_variables, rsf = self.radial_symmetry_function(R, *param)
if not self.atom_types:
cond = tf.not_equal(Nbrs_Z, 0.0)
sym.append(tf.reduce_sum(tf.where(cond, rsf, rsf_zeros), 2))
else:
for j in range(len(self.atom_types)):
cond = tf.equal(Nbrs_Z, self.atom_types[j])
sym.append(tf.reduce_sum(tf.where(cond, rsf, rsf_zeros), 2))
layer = tf.stack(sym)
layer = tf.transpose(layer, [1, 2, 0]) // (l, B, N) -> (B, N, l)
m, v = tf.nn.moments(layer, axes=[0])
out_tensor = tf.nn.batch_normalization(layer, m, v, None, None, 1e-3)
if set_tensors:
self._record_variable_scope(self.name)
self.out_tensor = out_tensor
return out_tensor
def radial_symmetry_function(self, R, rc, rs, e):