Skip to content

Commit 662c29b

Browse files
committed
Fix pytorch gcn example.
1 parent 8a4609d commit 662c29b

File tree

2 files changed

+11
-17
lines changed

2 files changed

+11
-17
lines changed

examples/pytorch/gcn/README.md

+5-9
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,6 @@ to a list of pyG `Data` objects. Then, we use the `PyGDataLoader` to merge the l
99
of `Data` to pyG `Batch` object. Finally, we implement the sampling-based
1010
`GCN` based on pyG `GCNConv`.
1111

12-
13-
## Script arguments (partial)
14-
15-
| Args | Description | Default |
16-
| ---------- | ------------------------------------------------------------------ | ------------ |
17-
| client_num | Set to value bigger than zero to enable multi-processing dataload. | 0 (int) |
18-
| local_mode | Graph learn init use $PWD/tracker_path | False (bool) |
19-
2012
## How to run
2113
### Supervised node classification.
2214
1. Prepare data
@@ -31,4 +23,8 @@ of `Data` to pyG `Batch` object. Finally, we implement the sampling-based
3123
```shell script
3224
cd ../pytorch/gcn/
3325
python train.py
34-
```
26+
```
27+
3. Training with pytorch DDP
28+
```
29+
python -m torch.distributed.launch --use_env train.py --ddp
30+
```

examples/pytorch/gcn/train.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def load_graph(args):
5757
dataset_folder = args.dataset_folder
5858
node_type = 'item'
5959
edge_type = 'relation'
60+
# shoud be split when distributed training.
6061
node_path = dataset_folder + "node_table"
6162
edge_path = dataset_folder + "edge_table"
6263
train_path = dataset_folder + "train_table"
@@ -146,11 +147,7 @@ def run(args):
146147
thg.set_client_num(args.client_num)
147148
thg.launch_server(g)
148149
else:
149-
if args.local_mode:
150-
tracker_path = './tracker_path/'
151-
g.init(task_index=args.rank, task_count=args.world_size, tracker=tracker_path)
152-
else:
153-
g.init(task_index=args.rank, task_count=args.world_size)
150+
g.init(task_index=args.rank, task_count=args.world_size)
154151

155152
# TODO(baole): This is an estimate and an accurate value will be needed from graphlearn.
156153
length_per_worker = args.train_length // args.train_batch_size // args.world_size
@@ -210,9 +207,10 @@ def run(args):
210207
argparser.add_argument('--drop_rate', type=float, default=0.0)
211208
argparser.add_argument('--learning_rate', type=float, default=0.01)
212209
argparser.add_argument('--epoch', type=int, default=60)
213-
argparser.add_argument('--client_num', type=int, default=0)
214-
argparser.add_argument('--ddp', type=bool, default=True)
215-
argparser.add_argument('--local_mode', type=bool, default=False)
210+
argparser.add_argument('--client_num', type=int, default=0,
211+
help="Set to value bigger than zero to enable multi-processing dataload.")
212+
argparser.add_argument('--ddp', action='store_true',
213+
help="whether to use ddp")
216214
args = argparser.parse_args()
217215

218216
init_env(args)

0 commit comments

Comments
 (0)