# Copyright 2019 Deepmind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A graph neural network based model to predict particle mobilities. The architecture and performance of this model is described in our publication: "Unveiling the predictive power of static structure in glassy systems". """ import functools from typing import Any, Dict, Text, Tuple, Optional from graph_nets import graphs from graph_nets import modules as gn_modules from graph_nets import utils_tf import sonnet as snt import tensorflow.compat.v1 as tf def make_graph_from_static_structure( positions: tf.Tensor, types: tf.Tensor, box: tf.Tensor, edge_threshold: float) -> graphs.GraphsTuple: """Returns graph representing the static structure of the glass. Each particle is represented by a node in the graph. The particle type is stored as a node feature. Two particles at a distance less than the threshold are connected by an edge. The relative distance vector is stored as an edge feature. Args: positions: particle positions with shape [n_particles, 3]. types: particle types with shape [n_particles]. box: dimensions of the cubic box that contains the particles with shape [3]. edge_threshold: particles at distance less than threshold are connected by an edge. """ # Calculate pairwise relative distances between particles: shape [n, n, 3]. cross_positions = positions[tf.newaxis, :, :] - positions[:, tf.newaxis, :] # Enforces periodic boundary conditions. box_ = box[tf.newaxis, tf.newaxis, :] cross_positions += tf.cast(cross_positions < -box_ / 2., tf.float32) * box_ cross_positions -= tf.cast(cross_positions > box_ / 2., tf.float32) * box_ # Calculates adjacency matrix in a sparse format (indices), based on the given # distances and threshold. distances = tf.norm(cross_positions, axis=-1) indices = tf.where(distances < edge_threshold) # Defines graph. nodes = types[:, tf.newaxis] senders = indices[:, 0] receivers = indices[:, 1] edges = tf.gather_nd(cross_positions, indices) return graphs.GraphsTuple( nodes=tf.cast(nodes, tf.float32), n_node=tf.reshape(tf.shape(nodes)[0], [1]), edges=tf.cast(edges, tf.float32), n_edge=tf.reshape(tf.shape(edges)[0], [1]), globals=tf.zeros((1, 1), dtype=tf.float32), receivers=tf.cast(receivers, tf.int32), senders=tf.cast(senders, tf.int32) ) def apply_random_rotation(graph: graphs.GraphsTuple) -> graphs.GraphsTuple: """Returns randomly rotated graph representation. The rotation is an element of O(3) with rotation angles multiple of pi/2. This function assumes that the relative particle distances are stored in the edge features. Args: graph: The graphs tuple as defined in `graph_nets.graphs`. """ # Transposes edge features, so that the axes are in the first dimension. # Outputs a tensor of shape [3, n_particles]. xyz = tf.transpose(graph.edges) # Random pi/2 rotation(s) permutation = tf.random.shuffle(tf.constant([0, 1, 2], dtype=tf.int32)) xyz = tf.gather(xyz, permutation) # Random reflections. symmetry = tf.random_uniform([3], minval=0, maxval=2, dtype=tf.int32) symmetry = 1 - 2 * tf.cast(tf.reshape(symmetry, [3, 1]), tf.float32) xyz = xyz * symmetry edges = tf.transpose(xyz) return graph.replace(edges=edges) class GraphBasedModel(snt.AbstractModule): """Graph based model which predicts particle mobilities from their positions. This network encodes the nodes and edges of the input graph independently, and then performs message-passing on this graph, updating its edges based on their associated nodes, then updating the nodes based on the input nodes' features and their associated updated edge features. This update is repeated several times. Afterwards the resulting node embeddings are decoded to predict the particle mobility. """ def __init__(self, n_recurrences: int, mlp_sizes: Tuple[int], mlp_kwargs: Optional[Dict[Text, Any]] = None, name='Graph'): """Creates a new GraphBasedModel object. Args: n_recurrences: the number of message passing steps in the graph network. mlp_sizes: the number of neurons in each layer of the MLP. mlp_kwargs: additional keyword aguments passed to the MLP. name: the name of the Sonnet module. """ super(GraphBasedModel, self).__init__(name=name) self._n_recurrences = n_recurrences if mlp_kwargs is None: mlp_kwargs = {} model_fn = functools.partial( snt.nets.MLP, output_sizes=mlp_sizes, activate_final=True, **mlp_kwargs) final_model_fn = functools.partial( snt.nets.MLP, output_sizes=mlp_sizes + (1,), activate_final=False, **mlp_kwargs) with self._enter_variable_scope(): self._encoder = gn_modules.GraphIndependent( node_model_fn=model_fn, edge_model_fn=model_fn) if self._n_recurrences > 0: self._propagation_network = gn_modules.GraphNetwork( node_model_fn=model_fn, edge_model_fn=model_fn, # We do not use globals, hence we just pass the identity function. global_model_fn=lambda: lambda x: x, reducer=tf.unsorted_segment_sum, edge_block_opt=dict(use_globals=False), node_block_opt=dict(use_globals=False), global_block_opt=dict(use_globals=False)) self._decoder = gn_modules.GraphIndependent( node_model_fn=final_model_fn, edge_model_fn=model_fn) def _build(self, graphs_tuple: graphs.GraphsTuple) -> tf.Tensor: """Connects the model into the tensorflow graph. Args: graphs_tuple: input graph tensor as defined in `graphs_tuple.graphs`. Returns: tensor with shape [n_particles] containing the predicted particle mobilities. """ encoded = self._encoder(graphs_tuple) outputs = encoded for _ in range(self._n_recurrences): # Adds skip connections. inputs = utils_tf.concat([outputs, encoded], axis=-1) outputs = self._propagation_network(inputs) decoded = self._decoder(outputs) return tf.squeeze(decoded.nodes, axis=-1)