Skip to content

Commit d78eee8

Browse files
thomaskeckderpson
authored andcommitted
Adds JAX version of glassy dynamics training pipeline.
PiperOrigin-RevId: 348791013
1 parent a6aeb26 commit d78eee8

File tree

2 files changed

+294
-1
lines changed

2 files changed

+294
-1
lines changed

glassy_dynamics/train_binary.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from absl import app
2020
from absl import flags
2121

22-
from glassy_dynamics import train
22+
from glassy_dynamics import train as train_using_tf
23+
from glassy_dynamics import train_using_jax
2324

2425
FLAGS = flags.FLAGS
2526

@@ -39,6 +40,10 @@
3940
'checkpoint_path',
4041
None,
4142
'Path used to store a checkpoint of the best model.')
43+
flags.DEFINE_boolean(
44+
'use_jax',
45+
False,
46+
'Uses jax to train model.')
4247

4348

4449
def main(argv):
@@ -47,6 +52,7 @@ def main(argv):
4752

4853
train_file_pattern = os.path.join(FLAGS.data_directory, 'train/aggregated*')
4954
test_file_pattern = os.path.join(FLAGS.data_directory, 'test/aggregated*')
55+
train = train_using_jax if FLAGS.use_jax else train_using_tf
5056
train.train_model(
5157
train_file_pattern=train_file_pattern,
5258
test_file_pattern=test_file_pattern,

glassy_dynamics/train_using_jax.py

+287
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
# Lint as: python3
2+
# Copyright 2019 Deepmind Technologies Limited.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Training pipeline for the prediction of particle mobilities in glasses."""
17+
18+
import enum
19+
import functools
20+
import logging
21+
import pickle
22+
import random
23+
import haiku as hk
24+
import jax
25+
import jax.numpy as jnp
26+
import jraph
27+
import numpy as np
28+
import optax
29+
30+
# Only used for file operations.
31+
# You can use glob.glob and python's open function to replace the tf usage below
32+
# on most platforms.
33+
import tensorflow.compat.v1 as tf
34+
35+
36+
class ParticleType(enum.IntEnum):
37+
"""The simulation contains two particle types, identified as type A and B.
38+
39+
The dataset encodes the particle type in an integer.
40+
- 0 corresponds to particle type A.
41+
- 1 corresponds to particle type B.
42+
"""
43+
A = 0
44+
B = 1
45+
46+
47+
def make_graph_from_static_structure(positions, types, box, edge_threshold):
48+
"""Returns graph representing the static structure of the glass.
49+
50+
Each particle is represented by a node in the graph. The particle type is
51+
stored as a node feature.
52+
Two particles at a distance less than the threshold are connected by an edge.
53+
The relative distance vector is stored as an edge feature.
54+
55+
Args:
56+
positions: particle positions with shape [n_particles, 3].
57+
types: particle types with shape [n_particles].
58+
box: dimensions of the cubic box that contains the particles with shape [3].
59+
edge_threshold: particles at distance less than threshold are connected by
60+
an edge.
61+
"""
62+
# Calculate pairwise relative distances between particles: shape [n, n, 3].
63+
cross_positions = positions[None, :, :] - positions[:, None, :]
64+
# Enforces periodic boundary conditions.
65+
box_ = box[None, None, :]
66+
cross_positions += (cross_positions < -box_ / 2.).astype(np.float32) * box_
67+
cross_positions -= (cross_positions > box_ / 2.).astype(np.float32) * box_
68+
# Calculates adjacency matrix in a sparse format (indices), based on the given
69+
# distances and threshold.
70+
distances = np.linalg.norm(cross_positions, axis=-1)
71+
indices = np.where(distances < edge_threshold)
72+
# Defines graph.
73+
nodes = types[:, None]
74+
senders = indices[0]
75+
receivers = indices[1]
76+
edges = cross_positions[indices]
77+
78+
return jraph.pad_with_graphs(jraph.GraphsTuple(
79+
nodes=nodes.astype(np.float32),
80+
n_node=np.reshape(nodes.shape[0], [1]),
81+
edges=edges.astype(np.float32),
82+
n_edge=np.reshape(edges.shape[0], [1]),
83+
globals=np.zeros((1, 1), dtype=np.float32),
84+
receivers=receivers.astype(np.int32),
85+
senders=senders.astype(np.int32)
86+
), n_node=4097, n_edge=200000)
87+
88+
89+
def get_targets(initial_positions, trajectory_target_positions):
90+
"""Returns the averaged particle mobilities from the sampled trajectories.
91+
92+
Args:
93+
initial_positions: the initial positions of the particles with shape
94+
[n_particles, 3].
95+
trajectory_target_positions: the absolute positions of the particles at the
96+
target time for all sampled trajectories, each with shape
97+
[n_particles, 3].
98+
"""
99+
targets = np.mean([np.linalg.norm(t - initial_positions, axis=-1)
100+
for t in trajectory_target_positions], axis=0)
101+
return targets.astype(np.float32)
102+
103+
104+
def load_data(file_pattern, time_index, max_files_to_load=None):
105+
"""Returns a graphs and targets of the training or test dataset.
106+
107+
Args:
108+
file_pattern: pattern matching the files with the simulation data.
109+
time_index: the time index of the targets.
110+
max_files_to_load: the maximum number of files to load.
111+
"""
112+
filenames = tf.io.gfile.glob(file_pattern)
113+
if max_files_to_load:
114+
filenames = filenames[:max_files_to_load]
115+
116+
graphs_and_targets = []
117+
for filename in filenames:
118+
with tf.io.gfile.GFile(filename, 'rb') as f:
119+
data = pickle.load(f)
120+
mask = (data['types'] == ParticleType.A).astype(np.int32)
121+
# Mask dummy node due to padding
122+
mask = np.concatenate([mask, np.zeros((1,), dtype=np.int32)], axis=-1)
123+
targets = get_targets(
124+
data['positions'], data['trajectory_target_positions'][time_index])
125+
targets = np.concatenate(
126+
[targets, np.zeros((1,), dtype=np.float32)], axis=-1)
127+
graphs_and_targets.append(
128+
(make_graph_from_static_structure(
129+
data['positions'].astype(np.float32),
130+
data['types'].astype(np.int32),
131+
data['box'].astype(np.float32),
132+
edge_threshold=2.0),
133+
targets,
134+
mask))
135+
return graphs_and_targets
136+
137+
138+
def apply_random_rotation(graph):
139+
"""Returns randomly rotated graph representation.
140+
141+
The rotation is an element of O(3) with rotation angles multiple of pi/2.
142+
This function assumes that the relative particle distances are stored in
143+
the edge features.
144+
145+
Args:
146+
graph: The graphs tuple as defined in `graph_nets.graphs`.
147+
"""
148+
# Transposes edge features, so that the axes are in the first dimension.
149+
# Outputs a tensor of shape [3, n_particles].
150+
xyz = np.transpose(graph.edges)
151+
# Random pi/2 rotation(s)
152+
permutation = np.array([0, 1, 2], dtype=np.int32)
153+
np.random.shuffle(permutation)
154+
xyz = xyz[permutation]
155+
# Random reflections.
156+
symmetry = np.random.randint(0, 2, [3])
157+
symmetry = 1 - 2 * np.reshape(symmetry, [3, 1]).astype(np.float32)
158+
xyz = xyz * symmetry
159+
edges = np.transpose(xyz)
160+
return graph._replace(edges=edges)
161+
162+
163+
def network_definition(graph):
164+
"""Defines a graph neural network.
165+
166+
Args:
167+
graph: Graphstuple the network processes.
168+
169+
Returns:
170+
Decoded nodes.
171+
"""
172+
model_fn = functools.partial(
173+
hk.nets.MLP,
174+
w_init=hk.initializers.VarianceScaling(1.0),
175+
b_init=hk.initializers.VarianceScaling(1.0))
176+
mlp_sizes = (64, 64)
177+
num_message_passing_steps = 7
178+
179+
node_encoder = model_fn(output_sizes=mlp_sizes, activate_final=True)
180+
edge_encoder = model_fn(output_sizes=mlp_sizes, activate_final=True)
181+
node_decoder = model_fn(output_sizes=mlp_sizes + (1,), activate_final=False)
182+
183+
node_encoding = node_encoder(graph.nodes)
184+
edge_encoding = edge_encoder(graph.edges)
185+
graph = graph._replace(nodes=node_encoding, edges=edge_encoding)
186+
187+
update_edge_fn = jraph.concatenated_args(
188+
model_fn(output_sizes=mlp_sizes, activate_final=True))
189+
update_node_fn = jraph.concatenated_args(
190+
model_fn(output_sizes=mlp_sizes, activate_final=True))
191+
gn = jraph.InteractionNetwork(
192+
update_edge_fn=update_edge_fn,
193+
update_node_fn=update_node_fn,
194+
include_sent_messages_in_node_update=True)
195+
196+
for _ in range(num_message_passing_steps):
197+
graph = graph._replace(
198+
nodes=jnp.concatenate([graph.nodes, node_encoding], axis=-1),
199+
edges=jnp.concatenate([graph.edges, edge_encoding], axis=-1))
200+
graph = gn(graph)
201+
202+
return jnp.squeeze(node_decoder(graph.nodes), axis=-1)
203+
204+
205+
def train_model(train_file_pattern,
206+
test_file_pattern,
207+
max_files_to_load=None,
208+
n_epochs=1000,
209+
time_index=9,
210+
learning_rate=1e-4,
211+
grad_clip=1.0,
212+
measurement_store_interval=1000,
213+
checkpoint_path=None):
214+
"""Trains GraphModel using tensorflow.
215+
216+
Args:
217+
train_file_pattern: pattern matching the files with the training data.
218+
test_file_pattern: pattern matching the files with the test data.
219+
max_files_to_load: the maximum number of train and test files to load.
220+
If None, all files will be loaded.
221+
n_epochs: the number of passes through the training dataset (epochs).
222+
time_index: the time index (0-9) of the target mobilities.
223+
learning_rate: the learning rate used by the optimizer.
224+
grad_clip: all gradients are clipped to the given value.
225+
measurement_store_interval: number of steps between storing objective values
226+
(loss and correlation).
227+
checkpoint_path: ignored by this implementation.
228+
"""
229+
if checkpoint_path:
230+
logging.warning('The checkpoint_path argument is ignored.')
231+
random.seed(42)
232+
np.random.seed(42)
233+
# Loads train and test dataset.
234+
dataset_kwargs = dict(
235+
time_index=time_index,
236+
max_files_to_load=max_files_to_load)
237+
logging.info('Load training data')
238+
training_data = load_data(train_file_pattern, **dataset_kwargs)
239+
logging.info('Load test data')
240+
test_data = load_data(test_file_pattern, **dataset_kwargs)
241+
logging.info('Finished loading data')
242+
243+
network = hk.without_apply_rng(hk.transform(network_definition))
244+
params = network.init(jax.random.PRNGKey(42), training_data[0][0])
245+
246+
opt_init, opt_update = optax.chain(
247+
optax.clip_by_global_norm(grad_clip),
248+
optax.scale_by_adam(0.9, 0.999, 1e-8),
249+
optax.scale(-learning_rate))
250+
opt_state = opt_init(params)
251+
252+
network_apply = jax.jit(network.apply)
253+
254+
@jax.jit
255+
def loss_fn(params, graph, targets, mask):
256+
decoded_nodes = network_apply(params, graph) * mask
257+
return (jnp.sum((decoded_nodes - targets)**2 * mask) /
258+
jnp.sum(mask))
259+
260+
@jax.jit
261+
def update(params, opt_state, graph, targets, mask):
262+
loss, grads = jax.value_and_grad(loss_fn)(params, graph, targets, mask)
263+
updates, opt_state = opt_update(grads, opt_state)
264+
return optax.apply_updates(params, updates), opt_state, loss
265+
266+
train_stats = []
267+
i = 0
268+
logging.info('Start training')
269+
for epoch in range(n_epochs):
270+
logging.info('Start epoch %r', epoch)
271+
random.shuffle(training_data)
272+
for graph, targets, mask in training_data:
273+
graph = apply_random_rotation(graph)
274+
params, opt_state, loss = update(params, opt_state, graph, targets, mask)
275+
train_stats.append(loss)
276+
277+
if (i+1) % measurement_store_interval == 0:
278+
logging.info('Start evaluation run')
279+
test_stats = []
280+
for test_graph, test_targets, test_mask in test_data:
281+
predictions = network_apply(params, test_graph)
282+
test_stats.append(np.corrcoef(
283+
predictions[test_mask == 1], test_targets[test_mask == 1])[0, 1])
284+
logging.info('Train loss %r', np.mean(train_stats))
285+
logging.info('Test correlation %r', np.mean(test_stats))
286+
train_stats = []
287+
i += 1

0 commit comments

Comments
 (0)