-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
186 lines (162 loc) · 6.71 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# main.py
import pytorch_lightning as pl
import argparse
import yaml
import os
from pl_modules.citywalk_datamodule import CityWalkDataModule
from pl_modules.urbannav_datamodule import UrbanNavDataModule
from pl_modules.urban_nav_module import UrbanNavModule
from pl_modules.urbannav_jepa_module import UrbanNavJEPAModule
from pl_modules.citywalk_jepa_datamodule import CityWalkJEPADataModule
from pytorch_lightning.strategies import DDPStrategy
import torch
import glob
torch.set_float32_matmul_precision('medium')
pl.seed_everything(42, workers=True)
# Remove the WandbLogger import from the top
# from pytorch_lightning.loggers import WandbLogger
class DictNamespace(argparse.Namespace):
def __init__(self, **kwargs):
for key, value in kwargs.items():
if isinstance(value, dict):
setattr(self, key, DictNamespace(**value))
else:
setattr(self, key, value)
def parse_args():
parser = argparse.ArgumentParser(description='Train UrbanNav model')
parser.add_argument('--config', type=str, default='config/default.yaml', help='Path to config file')
parser.add_argument('--checkpoint', type=str, default=None, help='Path to model checkpoint. If not provided, the latest checkpoint will be used.')
args = parser.parse_args()
return args
def load_config(config_path):
with open(config_path, 'r') as f:
cfg_dict = yaml.safe_load(f)
cfg = DictNamespace(**cfg_dict)
return cfg
def find_latest_checkpoint(checkpoint_dir):
"""
Finds the latest checkpoint in the given directory based on modification time.
Args:
checkpoint_dir (str): Path to the directory containing checkpoints.
Returns:
str: Path to the latest checkpoint file.
Raises:
FileNotFoundError: If no checkpoint files are found in the directory.
"""
print(checkpoint_dir)
checkpoint_pattern = os.path.join(checkpoint_dir, '*.ckpt')
checkpoint_files = glob.glob(checkpoint_pattern)
if not checkpoint_files:
raise FileNotFoundError(f"No checkpoint files found in directory: {checkpoint_dir}")
# Sort checkpoints by modification time (latest first)
checkpoint_files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
latest_checkpoint = checkpoint_files[0]
return latest_checkpoint
def main():
args = parse_args()
cfg = load_config(args.config)
# Create result directory
result_dir = os.path.join(cfg.project.result_dir, cfg.project.run_name)
os.makedirs(result_dir, exist_ok=True)
cfg.project.result_dir = result_dir # Update result_dir in cfg
# Save config file in result directory
with open(os.path.join(result_dir, 'config.yaml'), 'w') as f:
yaml.dump(cfg.__dict__, f)
# Initialize the DataModule
if cfg.data.type == 'citywalk':
datamodule = CityWalkDataModule(cfg)
elif cfg.data.type == 'urbannav':
datamodule = UrbanNavDataModule(cfg)
elif cfg.data.type == 'citywalk_jepa':
datamodule = CityWalkJEPADataModule(cfg)
else:
raise ValueError(f"Invalid dataset: {cfg.data.dataset}")
# Initialize the model
if cfg.model.type == 'urbannav':
model = UrbanNavModule(cfg)
elif cfg.model.type == 'urbannav_jepa':
model = UrbanNavJEPAModule(cfg)
else:
raise ValueError(f"Invalid model: {cfg.model.type}")
print(pl.utilities.model_summary.ModelSummary(model, max_depth=2))
# Initialize logger
logger = None # Default to no logger
# Check if logging with Wandb is enabled in config
use_wandb = cfg.logging.enable_wandb
if use_wandb:
try:
from pytorch_lightning.loggers import WandbLogger # Import here to handle ImportError
wandb_logger = WandbLogger(
project=cfg.project.name,
name=cfg.project.run_name,
save_dir=result_dir
)
logger = wandb_logger
print("WandbLogger initialized.")
except ImportError:
print("Wandb is not installed. Skipping Wandb logging.")
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=os.path.join(result_dir, 'checkpoints'),
save_last=True,
save_top_k=1,
monitor='val/direction_loss',
)
num_gpu = torch.cuda.device_count()
# num_gpu = 1
# Set up Trainer
if num_gpu > 1:
trainer = pl.Trainer(
default_root_dir=result_dir,
max_epochs=cfg.training.max_epochs,
logger=logger, # Pass the logger (WandbLogger or None)
devices=num_gpu,
precision='16-mixed' if cfg.training.amp else 32,
accelerator='gpu',
callbacks=[
checkpoint_callback,
pl.callbacks.TQDMProgressBar(refresh_rate=cfg.logging.pbar_rate),
],
log_every_n_steps=1,
strategy=DDPStrategy(find_unused_parameters=True)
)
else:
trainer = pl.Trainer(
default_root_dir=result_dir,
max_epochs=cfg.training.max_epochs,
logger=logger, # Pass the logger (WandbLogger or None)
devices=num_gpu,
precision='16-mixed' if cfg.training.amp else 32,
accelerator='gpu',
callbacks=[
checkpoint_callback,
pl.callbacks.TQDMProgressBar(refresh_rate=cfg.logging.pbar_rate),
],
log_every_n_steps=1,
)
if cfg.training.resume:
# Determine the checkpoint path
try:
if args.checkpoint:
checkpoint_path = args.checkpoint
if not os.path.isfile(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
else:
# Automatically find the latest checkpoint
checkpoint_dir = os.path.join(cfg.project.result_dir, 'checkpoints')
if not os.path.isdir(checkpoint_dir):
raise FileNotFoundError(f"Checkpoint directory does not exist: {checkpoint_dir}")
checkpoint_path = os.path.join(checkpoint_dir, 'last.ckpt')
if not os.path.isfile(checkpoint_path):
raise FileNotFoundError()
else:
print(f"No checkpoint specified. Using the latest checkpoint: {checkpoint_path}")
print(f"Training resume from checkpoint: {checkpoint_path}")
except FileNotFoundError:
print("No checkpoint found. Training from scratch.")
checkpoint_path = None
trainer.fit(model, datamodule=datamodule, ckpt_path=checkpoint_path)
else:
# Start training
trainer.fit(model, datamodule=datamodule)
if __name__ == '__main__':
main()