forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PAC Bayes Quadratic bound open sourcing.
PiperOrigin-RevId: 302629851
- Loading branch information
Showing
13 changed files
with
1,080 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Unveiling the predictive power of static structure in glassy systems | ||
|
||
This repository contains an open source implementation of the graph neural | ||
network model described in our paper. | ||
The model can be trained using the training binary included in this repository, | ||
and the dataset published with our paper. | ||
|
||
|
||
## Abstract | ||
|
||
Despite decades of theoretical studies, the nature of the glass transition | ||
remains elusive and debated, while the existence of structural predictors of the | ||
dynamics is a major open question. Recent approaches propose inferring | ||
predictors from a variety of human-defined features using machine learning. | ||
We learn the long time evolution of a glassy system solely from the initial | ||
particle positions and without any hand-crafted features, using a powerful | ||
model: graph neural networks. We show that this method strongly outperforms | ||
state-of-the-art methods, generalizing over a wide range of temperatures, | ||
pressures, and densities. In shear experiments, it predicts the location of | ||
rearranging particles. The structural predictors learned by our network unveil a | ||
correlation length which increases with larger timescales to reach the size of | ||
our system. Beyond glasses, our method could apply to many other physical | ||
systems that map to a graph of local interactions. | ||
|
||
|
||
## Dataset | ||
|
||
### System description | ||
|
||
The dataset was generated with the LAMMPS molecular dynamics package. | ||
The simulated system has periodic boundaries and is a binary mixture of 4096 | ||
large (A) and small (B) particles that interact via a 6-12 Lennard-Jones | ||
potential. | ||
The interaction coefficients are set for a typical Kob-Andersen configuration. | ||
|
||
### Data format | ||
|
||
The data is stored in Python's pickle format protocol version 3. | ||
Each file contains the data for one of the equilibrated systems in a Python | ||
dictionary. The dictionary contains the following entries: | ||
|
||
- `positions` the particle positions of the equilibrated system. | ||
- `types` the particle types (0 == type A and 1 == type B) of the equilibrated | ||
system. | ||
- `box` the dimensions of the periodic cubic simulation box. | ||
- `time` the logarithmically sampled time points. | ||
- `time_indices` the indices of the time points for which the sampled | ||
trajectories on average reach a certain value of the intermediate | ||
scattering function. | ||
- `is_values` the values of the intermediate scattering function associated | ||
with each time index. | ||
- `trajectory_start_velocities` the velocities drawn from a Boltzmann | ||
distribution at the start of each trajectory. | ||
- `trajectory_target_positions` the positions of the particles for each of | ||
the trajectories at selected time points (as defined by the `time_indices` | ||
array and the corresponding values of the intermediate scattering function | ||
stored in `is_values`). | ||
- `metadata` a dictionary containing additional metadata: | ||
- `temperature` the temperature at which the system was equilibrated. | ||
- `pressure` the pressure at which the system was equilibrated. | ||
- `fluid` the type of fluid which was simulated (Kob-Andersen). | ||
|
||
All units are in Lennard-Jones units. The positions are stored in the absolute | ||
coordinate system i.e. they are outside of the simulation box if the particle | ||
crossed a periodic boundary during the simulation. | ||
|
||
|
||
## Reference | ||
|
||
If this repository is helpful for your research please cite the following | ||
publication: | ||
|
||
Unveiling the predictive power of static structure in glassysystems | ||
V. Bapst, T. Keck, A. Grabska-Barwinska, C. Donner, E. D. Cubuk, | ||
S. S. Schoenholz, A.Obika, A. W. R. Nelson, T. Back, D. Hassabis and P. Kohli | ||
|
||
|
||
## Disclaimer | ||
This is not an official Google product. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright 2019 Deepmind Technologies Limited. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Applies a graph-based network to predict particle mobilities in glasses.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
|
||
from __future__ import print_function | ||
|
||
import os | ||
|
||
from absl import app | ||
from absl import flags | ||
|
||
from glassy_dynamics import train | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_string( | ||
'data_directory', | ||
'', | ||
'Directory which contains the train or test datasets.') | ||
flags.DEFINE_integer( | ||
'time_index', | ||
9, | ||
'The time index of the target mobilities.') | ||
flags.DEFINE_integer( | ||
'max_files_to_load', | ||
None, | ||
'The maximum number of files to load.') | ||
flags.DEFINE_string( | ||
'checkpoint_path', | ||
'checkpoints/t044_s09.ckpt', | ||
'Path used to load the model.') | ||
|
||
|
||
def main(argv): | ||
if len(argv) > 1: | ||
raise app.UsageError('Too many command-line arguments.') | ||
|
||
file_pattern = os.path.join(FLAGS.data_directory, 'aggregated*') | ||
train.apply_model( | ||
checkpoint_path=FLAGS.checkpoint_path, | ||
file_pattern=file_pattern, | ||
max_files_to_load=FLAGS.max_files_to_load, | ||
time_index=FLAGS.time_index) | ||
|
||
|
||
if __name__ == '__main__': | ||
app.run(main) |
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
# Copyright 2019 Deepmind Technologies Limited. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""A graph neural network based model to predict particle mobilities. | ||
The architecture and performance of this model is described in our publication: | ||
"Unveiling the predictive power of static structure in glassy systems". | ||
""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
|
||
from __future__ import print_function | ||
|
||
import functools | ||
|
||
from graph_nets import graphs | ||
from graph_nets import modules as gn_modules | ||
from graph_nets import utils_tf | ||
|
||
import sonnet as snt | ||
import tensorflow.compat.v1 as tf | ||
from typing import Any, Dict, Text, Tuple, Optional | ||
|
||
|
||
def make_graph_from_static_structure( | ||
positions, | ||
types, | ||
box, | ||
edge_threshold): | ||
"""Returns graph representing the static structure of the glass. | ||
Each particle is represented by a node in the graph. The particle type is | ||
stored as a node feature. | ||
Two particles at a distance less than the threshold are connected by an edge. | ||
The relative distance vector is stored as an edge feature. | ||
Args: | ||
positions: particle positions with shape [n_particles, 3]. | ||
types: particle types with shape [n_particles]. | ||
box: dimensions of the cubic box that contains the particles with shape [3]. | ||
edge_threshold: particles at distance less than threshold are connected by | ||
an edge. | ||
""" | ||
# Calculate pairwise relative distances between particles: shape [n, n, 3]. | ||
cross_positions = positions[tf.newaxis, :, :] - positions[:, tf.newaxis, :] | ||
# Enforces periodic boundary conditions. | ||
box_ = box[tf.newaxis, tf.newaxis, :] | ||
cross_positions += tf.cast(cross_positions < -box_ / 2., tf.float32) * box_ | ||
cross_positions -= tf.cast(cross_positions > box_ / 2., tf.float32) * box_ | ||
# Calculates adjacency matrix in a sparse format (indices), based on the given | ||
# distances and threshold. | ||
distances = tf.norm(cross_positions, axis=-1) | ||
indices = tf.where(distances < edge_threshold) | ||
|
||
# Defines graph. | ||
nodes = types[:, tf.newaxis] | ||
senders = indices[:, 0] | ||
receivers = indices[:, 1] | ||
edges = tf.gather_nd(cross_positions, indices) | ||
|
||
return graphs.GraphsTuple( | ||
nodes=tf.cast(nodes, tf.float32), | ||
n_node=tf.reshape(tf.shape(nodes)[0], [1]), | ||
edges=tf.cast(edges, tf.float32), | ||
n_edge=tf.reshape(tf.shape(edges)[0], [1]), | ||
globals=tf.zeros((1, 1), dtype=tf.float32), | ||
receivers=tf.cast(receivers, tf.int32), | ||
senders=tf.cast(senders, tf.int32) | ||
) | ||
|
||
|
||
def apply_random_rotation(graph): | ||
"""Returns randomly rotated graph representation. | ||
The rotation is an element of O(3) with rotation angles multiple of pi/2. | ||
This function assumes that the relative particle distances are stored in | ||
the edge features. | ||
Args: | ||
graph: The graphs tuple as defined in `graph_nets.graphs`. | ||
""" | ||
# Transposes edge features, so that the axes are in the first dimension. | ||
# Outputs a tensor of shape [3, n_particles]. | ||
xyz = tf.transpose(graph.edges) | ||
# Random pi/2 rotation(s) | ||
permutation = tf.random.shuffle(tf.constant([0, 1, 2], dtype=tf.int32)) | ||
xyz = tf.gather(xyz, permutation) | ||
# Random reflections. | ||
symmetry = tf.random_uniform([3], minval=0, maxval=2, dtype=tf.int32) | ||
symmetry = 1 - 2 * tf.cast(tf.reshape(symmetry, [3, 1]), tf.float32) | ||
xyz = xyz * symmetry | ||
edges = tf.transpose(xyz) | ||
return graph.replace(edges=edges) | ||
|
||
|
||
class GraphBasedModel(snt.AbstractModule): | ||
"""Graph based model which predicts particle mobilities from their positions. | ||
This network encodes the nodes and edges of the input graph independently, and | ||
then performs message-passing on this graph, updating its edges based on their | ||
associated nodes, then updating the nodes based on the input nodes' features | ||
and their associated updated edge features. | ||
This update is repeated several times. | ||
Afterwards the resulting node embeddings are decoded to predict the particle | ||
mobility. | ||
""" | ||
|
||
def __init__(self, | ||
n_recurrences, | ||
mlp_sizes, | ||
mlp_kwargs = None, | ||
name='Graph'): | ||
"""Creates a new GraphBasedModel object. | ||
Args: | ||
n_recurrences: the number of message passing steps in the graph network. | ||
mlp_sizes: the number of neurons in each layer of the MLP. | ||
mlp_kwargs: additional keyword aguments passed to the MLP. | ||
name: the name of the Sonnet module. | ||
""" | ||
super(GraphBasedModel, self).__init__(name=name) | ||
self._n_recurrences = n_recurrences | ||
|
||
if mlp_kwargs is None: | ||
mlp_kwargs = {} | ||
|
||
model_fn = functools.partial( | ||
snt.nets.MLP, | ||
output_sizes=mlp_sizes, | ||
activate_final=True, | ||
**mlp_kwargs) | ||
|
||
final_model_fn = functools.partial( | ||
snt.nets.MLP, | ||
output_sizes=mlp_sizes + (1,), | ||
activate_final=False, | ||
**mlp_kwargs) | ||
|
||
with self._enter_variable_scope(): | ||
self._encoder = gn_modules.GraphIndependent( | ||
node_model_fn=model_fn, | ||
edge_model_fn=model_fn) | ||
|
||
if self._n_recurrences > 0: | ||
self._propagation_network = gn_modules.GraphNetwork( | ||
node_model_fn=model_fn, | ||
edge_model_fn=model_fn, | ||
# We do not use globals, hence we just pass the identity function. | ||
global_model_fn=lambda: lambda x: x, | ||
reducer=tf.unsorted_segment_sum, | ||
edge_block_opt=dict(use_globals=False), | ||
node_block_opt=dict(use_globals=False), | ||
global_block_opt=dict(use_globals=False)) | ||
|
||
self._decoder = gn_modules.GraphIndependent( | ||
node_model_fn=final_model_fn, | ||
edge_model_fn=model_fn) | ||
|
||
def _build(self, graphs_tuple): | ||
"""Connects the model into the tensorflow graph. | ||
Args: | ||
graphs_tuple: input graph tensor as defined in `graphs_tuple.graphs`. | ||
Returns: | ||
tensor with shape [n_particles] containing the predicted particle | ||
mobilities. | ||
""" | ||
encoded = self._encoder(graphs_tuple) | ||
outputs = encoded | ||
|
||
for _ in range(self._n_recurrences): | ||
# Adds skip connections. | ||
inputs = utils_tf.concat([outputs, encoded], axis=-1) | ||
outputs = self._propagation_network(inputs) | ||
|
||
decoded = self._decoder(outputs) | ||
return tf.squeeze(decoded.nodes, axis=-1) |
Oops, something went wrong.