@@ -11,12 +11,12 @@ conditioned on graph and generate graphs given text.
11
11
12
12
[ Jax] ( https://github.com/google/jax#installation ) ,
13
13
[ Haiku] ( https://github.com/deepmind/dm-haiku#installation ) ,
14
- [ Optax] ( https://github.com/deepmind/dm-haiku #installation ) , and
14
+ [ Optax] ( https://optax.readthedocs.io/en/latest/ #installation ) , and
15
15
[ Jraph] ( https://github.com/deepmind/jraph ) are needed for this package. It has
16
16
been developed and tested on python 3 with the following packages:
17
17
18
18
* Jax==0.2.13
19
- * Haiku==0.0.5
19
+ * Haiku==0.0.5.dev
20
20
* Optax==0.0.6
21
21
* Jraph==0.0.1.dev
22
22
@@ -167,38 +167,76 @@ it elsewhere.
167
167
168
168
## Run baseline models
169
169
170
- Note: our code supports training with multiple GPUs.
171
-
172
- To run the default baseline GNN-based TransformerXL on Wikigraphs with 8
173
- GPUs:
170
+ To quickly test-run a small model with 1 GPU:
174
171
175
172
``` base
176
173
python main.py --model_type=graph2text \
177
174
--dataset=freebase2wikitext \
178
175
--checkpoint_dir=/tmp/graph2text \
179
176
--job_mode=train \
177
+ --train_batch_size=2 \
178
+ --gnn_num_layers=1 \
179
+ --num_gpus=1
180
+ ```
181
+
182
+ To run the default baseline unconditional TransformerXL on Wikigraphs with 8
183
+ GPUs:
184
+
185
+ ``` base
186
+ python main.py --model_type=text \
187
+ --dataset=freebase2wikitext \
188
+ --checkpoint_dir=/tmp/text \
189
+ --job_mode=train \
190
+ --train_batch_size=64 \
191
+ --gnn_num_layers=1 \
192
+ --num_gpus=8
193
+ ```
194
+
195
+ To run the default baseline BoW-based TransformerXL on Wikigraphs with 8
196
+ GPUs:
197
+
198
+ ``` base
199
+ python main.py --model_type=bow2text \
200
+ --dataset=freebase2wikitext \
201
+ --checkpoint_dir=/tmp/bow2text \
202
+ --job_mode=train \
180
203
--train_batch_size=64 \
181
204
--gnn_num_layers=1 \
182
205
--num_gpus=8
183
206
```
184
207
185
- We ran our experiments in the paper using 8 Nvidia V100 GPUs. To allow for
186
- batch parallization for the GNN-based (graph2text) model, we pad graphs to
187
- the largest graph in the batch. The full run takes almost 4 days. BoW- and
188
- nodes-based models can be trained within 14 hours because there is no
189
- additional padding.
208
+ To run the default baseline Nodes-only GNN-based TransformerXL on Wikigraphs
209
+ with 8 GPUs:
210
+
211
+ ``` base
212
+ python main.py --model_type=bow2text \
213
+ --dataset=freebase2wikitext \
214
+ --checkpoint_dir=/tmp/bow2text \
215
+ --job_mode=train \
216
+ --train_batch_size=64 \
217
+ --gnn_num_layers=0 \
218
+ --num_gpus=8
219
+ ```
190
220
191
- Or to quickly test-run a small model:
221
+ To run the default baseline GNN-based TransformerXL on Wikigraphs with 8
222
+ GPUs:
192
223
193
224
``` base
194
225
python main.py --model_type=graph2text \
195
226
--dataset=freebase2wikitext \
196
227
--checkpoint_dir=/tmp/graph2text \
197
228
--job_mode=train \
198
- --train_batch_size=2 \
199
- --gnn_num_layers=1
229
+ --train_batch_size=64 \
230
+ --gnn_num_layers=1 \
231
+ --num_gpus=8
200
232
```
201
233
234
+ We ran our experiments in the paper using 8 Nvidia V100 GPUs. Reduce the batch
235
+ size if the model does not fit into memory. To allow for batch parallization for
236
+ the GNN-based (graph2text) model, we pad graphs to the largest graph in the
237
+ batch. The full run takes almost 4 days. BoW- and nodes-based models can be
238
+ trained within 14 hours because there is no additional padding.
239
+
202
240
To evaluate the model on the validation set (this only uses 1 GPU):
203
241
204
242
``` base
0 commit comments