Skip to content

Commit

Permalink
PAC Bayes Quadratic bound open sourcing.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 302629851
  • Loading branch information
Vikram Tankasali authored and derpson committed Apr 6, 2020
1 parent afcdc77 commit f6395d7
Show file tree
Hide file tree
Showing 13 changed files with 1,080 additions and 0 deletions.
80 changes: 80 additions & 0 deletions glassy_dynamics/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Unveiling the predictive power of static structure in glassy systems

This repository contains an open source implementation of the graph neural
network model described in our paper.
The model can be trained using the training binary included in this repository,
and the dataset published with our paper.


## Abstract

Despite decades of theoretical studies, the nature of the glass transition
remains elusive and debated, while the existence of structural predictors of the
dynamics is a major open question. Recent approaches propose inferring
predictors from a variety of human-defined features using machine learning.
We learn the long time evolution of a glassy system solely from the initial
particle positions and without any hand-crafted features, using a powerful
model: graph neural networks. We show that this method strongly outperforms
state-of-the-art methods, generalizing over a wide range of temperatures,
pressures, and densities. In shear experiments, it predicts the location of
rearranging particles. The structural predictors learned by our network unveil a
correlation length which increases with larger timescales to reach the size of
our system. Beyond glasses, our method could apply to many other physical
systems that map to a graph of local interactions.


## Dataset

### System description

The dataset was generated with the LAMMPS molecular dynamics package.
The simulated system has periodic boundaries and is a binary mixture of 4096
large (A) and small (B) particles that interact via a 6-12 Lennard-Jones
potential.
The interaction coefficients are set for a typical Kob-Andersen configuration.

### Data format

The data is stored in Python's pickle format protocol version 3.
Each file contains the data for one of the equilibrated systems in a Python
dictionary. The dictionary contains the following entries:

- `positions` the particle positions of the equilibrated system.
- `types` the particle types (0 == type A and 1 == type B) of the equilibrated
system.
- `box` the dimensions of the periodic cubic simulation box.
- `time` the logarithmically sampled time points.
- `time_indices` the indices of the time points for which the sampled
trajectories on average reach a certain value of the intermediate
scattering function.
- `is_values` the values of the intermediate scattering function associated
with each time index.
- `trajectory_start_velocities` the velocities drawn from a Boltzmann
distribution at the start of each trajectory.
- `trajectory_target_positions` the positions of the particles for each of
the trajectories at selected time points (as defined by the `time_indices`
array and the corresponding values of the intermediate scattering function
stored in `is_values`).
- `metadata` a dictionary containing additional metadata:
- `temperature` the temperature at which the system was equilibrated.
- `pressure` the pressure at which the system was equilibrated.
- `fluid` the type of fluid which was simulated (Kob-Andersen).

All units are in Lennard-Jones units. The positions are stored in the absolute
coordinate system i.e. they are outside of the simulation box if the particle
crossed a periodic boundary during the simulation.


## Reference

If this repository is helpful for your research please cite the following
publication:

Unveiling the predictive power of static structure in glassysystems
V. Bapst, T. Keck, A. Grabska-Barwinska, C. Donner, E. D. Cubuk,
S. S. Schoenholz, A.Obika, A. W. R. Nelson, T. Back, D. Hassabis and P. Kohli


## Disclaimer
This is not an official Google product.

62 changes: 62 additions & 0 deletions glassy_dynamics/apply_binary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Applies a graph-based network to predict particle mobilities in glasses."""

from __future__ import absolute_import
from __future__ import division

from __future__ import print_function

import os

from absl import app
from absl import flags

from glassy_dynamics import train

FLAGS = flags.FLAGS

flags.DEFINE_string(
'data_directory',
'',
'Directory which contains the train or test datasets.')
flags.DEFINE_integer(
'time_index',
9,
'The time index of the target mobilities.')
flags.DEFINE_integer(
'max_files_to_load',
None,
'The maximum number of files to load.')
flags.DEFINE_string(
'checkpoint_path',
'checkpoints/t044_s09.ckpt',
'Path used to load the model.')


def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')

file_pattern = os.path.join(FLAGS.data_directory, 'aggregated*')
train.apply_model(
checkpoint_path=FLAGS.checkpoint_path,
file_pattern=file_pattern,
max_files_to_load=FLAGS.max_files_to_load,
time_index=FLAGS.time_index)


if __name__ == '__main__':
app.run(main)
Binary file not shown.
Binary file added glassy_dynamics/checkpoints/t044_s09.ckpt.index
Binary file not shown.
Binary file added glassy_dynamics/checkpoints/t044_s09.ckpt.meta
Binary file not shown.
190 changes: 190 additions & 0 deletions glassy_dynamics/graph_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A graph neural network based model to predict particle mobilities.
The architecture and performance of this model is described in our publication:
"Unveiling the predictive power of static structure in glassy systems".
"""

from __future__ import absolute_import
from __future__ import division

from __future__ import print_function

import functools

from graph_nets import graphs
from graph_nets import modules as gn_modules
from graph_nets import utils_tf

import sonnet as snt
import tensorflow.compat.v1 as tf
from typing import Any, Dict, Text, Tuple, Optional


def make_graph_from_static_structure(
positions,
types,
box,
edge_threshold):
"""Returns graph representing the static structure of the glass.
Each particle is represented by a node in the graph. The particle type is
stored as a node feature.
Two particles at a distance less than the threshold are connected by an edge.
The relative distance vector is stored as an edge feature.
Args:
positions: particle positions with shape [n_particles, 3].
types: particle types with shape [n_particles].
box: dimensions of the cubic box that contains the particles with shape [3].
edge_threshold: particles at distance less than threshold are connected by
an edge.
"""
# Calculate pairwise relative distances between particles: shape [n, n, 3].
cross_positions = positions[tf.newaxis, :, :] - positions[:, tf.newaxis, :]
# Enforces periodic boundary conditions.
box_ = box[tf.newaxis, tf.newaxis, :]
cross_positions += tf.cast(cross_positions < -box_ / 2., tf.float32) * box_
cross_positions -= tf.cast(cross_positions > box_ / 2., tf.float32) * box_
# Calculates adjacency matrix in a sparse format (indices), based on the given
# distances and threshold.
distances = tf.norm(cross_positions, axis=-1)
indices = tf.where(distances < edge_threshold)

# Defines graph.
nodes = types[:, tf.newaxis]
senders = indices[:, 0]
receivers = indices[:, 1]
edges = tf.gather_nd(cross_positions, indices)

return graphs.GraphsTuple(
nodes=tf.cast(nodes, tf.float32),
n_node=tf.reshape(tf.shape(nodes)[0], [1]),
edges=tf.cast(edges, tf.float32),
n_edge=tf.reshape(tf.shape(edges)[0], [1]),
globals=tf.zeros((1, 1), dtype=tf.float32),
receivers=tf.cast(receivers, tf.int32),
senders=tf.cast(senders, tf.int32)
)


def apply_random_rotation(graph):
"""Returns randomly rotated graph representation.
The rotation is an element of O(3) with rotation angles multiple of pi/2.
This function assumes that the relative particle distances are stored in
the edge features.
Args:
graph: The graphs tuple as defined in `graph_nets.graphs`.
"""
# Transposes edge features, so that the axes are in the first dimension.
# Outputs a tensor of shape [3, n_particles].
xyz = tf.transpose(graph.edges)
# Random pi/2 rotation(s)
permutation = tf.random.shuffle(tf.constant([0, 1, 2], dtype=tf.int32))
xyz = tf.gather(xyz, permutation)
# Random reflections.
symmetry = tf.random_uniform([3], minval=0, maxval=2, dtype=tf.int32)
symmetry = 1 - 2 * tf.cast(tf.reshape(symmetry, [3, 1]), tf.float32)
xyz = xyz * symmetry
edges = tf.transpose(xyz)
return graph.replace(edges=edges)


class GraphBasedModel(snt.AbstractModule):
"""Graph based model which predicts particle mobilities from their positions.
This network encodes the nodes and edges of the input graph independently, and
then performs message-passing on this graph, updating its edges based on their
associated nodes, then updating the nodes based on the input nodes' features
and their associated updated edge features.
This update is repeated several times.
Afterwards the resulting node embeddings are decoded to predict the particle
mobility.
"""

def __init__(self,
n_recurrences,
mlp_sizes,
mlp_kwargs = None,
name='Graph'):
"""Creates a new GraphBasedModel object.
Args:
n_recurrences: the number of message passing steps in the graph network.
mlp_sizes: the number of neurons in each layer of the MLP.
mlp_kwargs: additional keyword aguments passed to the MLP.
name: the name of the Sonnet module.
"""
super(GraphBasedModel, self).__init__(name=name)
self._n_recurrences = n_recurrences

if mlp_kwargs is None:
mlp_kwargs = {}

model_fn = functools.partial(
snt.nets.MLP,
output_sizes=mlp_sizes,
activate_final=True,
**mlp_kwargs)

final_model_fn = functools.partial(
snt.nets.MLP,
output_sizes=mlp_sizes + (1,),
activate_final=False,
**mlp_kwargs)

with self._enter_variable_scope():
self._encoder = gn_modules.GraphIndependent(
node_model_fn=model_fn,
edge_model_fn=model_fn)

if self._n_recurrences > 0:
self._propagation_network = gn_modules.GraphNetwork(
node_model_fn=model_fn,
edge_model_fn=model_fn,
# We do not use globals, hence we just pass the identity function.
global_model_fn=lambda: lambda x: x,
reducer=tf.unsorted_segment_sum,
edge_block_opt=dict(use_globals=False),
node_block_opt=dict(use_globals=False),
global_block_opt=dict(use_globals=False))

self._decoder = gn_modules.GraphIndependent(
node_model_fn=final_model_fn,
edge_model_fn=model_fn)

def _build(self, graphs_tuple):
"""Connects the model into the tensorflow graph.
Args:
graphs_tuple: input graph tensor as defined in `graphs_tuple.graphs`.
Returns:
tensor with shape [n_particles] containing the predicted particle
mobilities.
"""
encoded = self._encoder(graphs_tuple)
outputs = encoded

for _ in range(self._n_recurrences):
# Adds skip connections.
inputs = utils_tf.concat([outputs, encoded], axis=-1)
outputs = self._propagation_network(inputs)

decoded = self._decoder(outputs)
return tf.squeeze(decoded.nodes, axis=-1)
Loading

0 comments on commit f6395d7

Please sign in to comment.