@@ -57,6 +57,7 @@ def load_graph(args):
57
57
dataset_folder = args .dataset_folder
58
58
node_type = 'item'
59
59
edge_type = 'relation'
60
+ # shoud be split when distributed training.
60
61
node_path = dataset_folder + "node_table"
61
62
edge_path = dataset_folder + "edge_table"
62
63
train_path = dataset_folder + "train_table"
@@ -146,11 +147,7 @@ def run(args):
146
147
thg .set_client_num (args .client_num )
147
148
thg .launch_server (g )
148
149
else :
149
- if args .local_mode :
150
- tracker_path = './tracker_path/'
151
- g .init (task_index = args .rank , task_count = args .world_size , tracker = tracker_path )
152
- else :
153
- g .init (task_index = args .rank , task_count = args .world_size )
150
+ g .init (task_index = args .rank , task_count = args .world_size )
154
151
155
152
# TODO(baole): This is an estimate and an accurate value will be needed from graphlearn.
156
153
length_per_worker = args .train_length // args .train_batch_size // args .world_size
@@ -210,9 +207,10 @@ def run(args):
210
207
argparser .add_argument ('--drop_rate' , type = float , default = 0.0 )
211
208
argparser .add_argument ('--learning_rate' , type = float , default = 0.01 )
212
209
argparser .add_argument ('--epoch' , type = int , default = 60 )
213
- argparser .add_argument ('--client_num' , type = int , default = 0 )
214
- argparser .add_argument ('--ddp' , type = bool , default = True )
215
- argparser .add_argument ('--local_mode' , type = bool , default = False )
210
+ argparser .add_argument ('--client_num' , type = int , default = 0 ,
211
+ help = "Set to value bigger than zero to enable multi-processing dataload." )
212
+ argparser .add_argument ('--ddp' , action = 'store_true' ,
213
+ help = "whether to use ddp" )
216
214
args = argparser .parse_args ()
217
215
218
216
init_env (args )
0 commit comments