-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
649 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
project: | ||
name: 'UrbanNav_Citywalk' | ||
run_name: 'toy_data_pretrain' | ||
result_dir: 'results' | ||
|
||
training: | ||
batch_size: 8 | ||
max_epochs: 10 | ||
gpus: 1 | ||
amp: true | ||
normalize_step_length: true | ||
resume: true | ||
direction_loss_weight: 1.0 | ||
reg_coeff: 0.1 | ||
|
||
scheduler: | ||
name: 'cosine' | ||
step_size: 10 | ||
gamma: 0.1 | ||
|
||
optimizer: | ||
name: 'adamw' | ||
lr: 1e-4 | ||
|
||
model: | ||
do_rgb_normalize: true | ||
do_resize: true | ||
obs_encoder: | ||
type: 'vit_l_16' | ||
context_size: 5 | ||
target_size: 5 | ||
encoder_feat_dim: 1024 | ||
ema_range: [0.998, 1] | ||
|
||
data: | ||
type: citywalk | ||
video_dir: '/home/xinhao/citywalk/test_videos' | ||
num_workers: 6 | ||
pose_fps: 5 | ||
video_fps: 30 | ||
target_fps: 1 | ||
num_train: 40 | ||
num_val: 20 | ||
num_test: 20 | ||
|
||
logging: | ||
enable_wandb: false # Set to false to disable Wandb logging | ||
pbar_rate: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import os | ||
import numpy as np | ||
import torch | ||
from torch.utils.data import Dataset | ||
from decord import VideoReader, cpu | ||
import torch.nn.functional as F | ||
from tqdm import tqdm | ||
|
||
class VideoJEPADataset(Dataset): | ||
def __init__(self, cfg, mode): | ||
super().__init__() | ||
self.cfg = cfg | ||
self.mode = mode | ||
self.video_dir = cfg.data.video_dir | ||
self.context_size = cfg.model.obs_encoder.context_size # Number of input frames | ||
self.target_size = cfg.model.obs_encoder.target_size # Number of target frames | ||
self.video_fps = cfg.data.video_fps | ||
self.target_fps = cfg.data.target_fps | ||
self.frame_multiplier = self.video_fps // self.target_fps | ||
|
||
# Load video paths | ||
self.video_paths = [ | ||
os.path.join(self.video_dir, f) | ||
for f in sorted(os.listdir(self.video_dir)) | ||
if f.endswith('.mp4') | ||
] | ||
print(f"Total videos found: {len(self.video_paths)}") | ||
|
||
# Split videos according to mode | ||
if mode == 'train': | ||
self.video_paths = self.video_paths[:cfg.data.num_train] | ||
elif mode == 'val': | ||
self.video_paths = self.video_paths[cfg.data.num_train: cfg.data.num_train + cfg.data.num_val] | ||
elif mode == 'test': | ||
self.video_paths = self.video_paths[cfg.data.num_train + cfg.data.num_val: | ||
cfg.data.num_train + cfg.data.num_val + cfg.data.num_test] | ||
else: | ||
raise ValueError(f"Invalid mode {mode}") | ||
|
||
print(f"Number of videos for {mode}: {len(self.video_paths)}") | ||
|
||
# Build the look-up table (lut) and video_ranges | ||
self.lut = [] | ||
self.video_ranges = [] | ||
idx_counter = 0 | ||
for video_idx, video_path in enumerate(tqdm(self.video_paths, desc="Building LUT")): | ||
# Initialize VideoReader to get number of frames | ||
vr = VideoReader(video_path, ctx=cpu(0)) | ||
num_frames = len(vr) | ||
usable_frames = num_frames // self.frame_multiplier - (self.context_size + self.target_size) | ||
if usable_frames <= 0: | ||
continue # Skip videos that are too short | ||
start_idx = idx_counter | ||
for frame_start in range(0, usable_frames, self.context_size): | ||
self.lut.append((video_idx, frame_start)) | ||
idx_counter += 1 | ||
end_idx = idx_counter | ||
self.video_ranges.append((start_idx, end_idx)) | ||
assert len(self.lut) > 0, "No usable samples found." | ||
|
||
print(f"Total samples in LUT: {len(self.lut)}") | ||
print(f"Total video ranges: {len(self.video_ranges)}") | ||
|
||
# Initialize the video reader cache per worker | ||
self.video_reader_cache = {'video_idx': None, 'video_reader': None} | ||
|
||
def __len__(self): | ||
return len(self.lut) | ||
|
||
def __getitem__(self, index): | ||
video_idx, frame_start = self.lut[index] | ||
|
||
# Retrieve or create the VideoReader for the current video | ||
if self.video_reader_cache['video_idx'] != video_idx: | ||
# Replace the old VideoReader with the new one | ||
self.video_reader_cache['video_reader'] = VideoReader(self.video_paths[video_idx], ctx=cpu(0)) | ||
self.video_reader_cache['video_idx'] = video_idx | ||
video_reader = self.video_reader_cache['video_reader'] | ||
|
||
# Compute actual frame indices for input and target | ||
actual_frame_start = frame_start * self.frame_multiplier | ||
frame_indices_input = actual_frame_start + np.arange(self.context_size) * self.frame_multiplier | ||
frame_indices_target = actual_frame_start + (self.context_size + np.arange(self.target_size)) * self.frame_multiplier | ||
|
||
# Ensure frame indices are within the video length | ||
num_frames = len(video_reader) | ||
frame_indices_input = [min(idx, num_frames - 1) for idx in frame_indices_input] | ||
frame_indices_target = [min(idx, num_frames - 1) for idx in frame_indices_target] | ||
|
||
# Load the frames | ||
input_frames = video_reader.get_batch(frame_indices_input).asnumpy() | ||
target_frames = video_reader.get_batch(frame_indices_target).asnumpy() | ||
|
||
# Process frames | ||
input_frames = self.process_frames(input_frames) | ||
target_frames = self.process_frames(target_frames) | ||
|
||
sample = { | ||
'input_frames': input_frames, # Shape: (context_size, 3, H, W) | ||
'target_frames': target_frames # Shape: (target_size, 3, H, W) | ||
} | ||
|
||
return sample | ||
|
||
def process_frames(self, frames): | ||
""" | ||
Convert frames to tensor, normalize, and resize if necessary. | ||
Args: | ||
frames (numpy.ndarray): Array of frames with shape (N, H, W, C). | ||
Returns: | ||
torch.Tensor: Processed frames with shape (N, 3, H, W). | ||
""" | ||
# Convert frames to tensor and normalize | ||
frames = torch.tensor(frames).permute(0, 3, 1, 2).float() / 255.0 # Shape: (N, 3, H, W) | ||
|
||
# # Optional resizing | ||
# if self.cfg.data.do_resize: | ||
# frames = F.interpolate(frames, size=(self.cfg.data.resize_height, self.cfg.data.resize_width), | ||
# mode='bilinear', align_corners=False) | ||
|
||
# # Optional normalization | ||
# if self.cfg.data.do_rgb_normalize: | ||
# mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).to(frames.device) | ||
# std = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).to(frames.device) | ||
# frames = (frames - mean) / std | ||
|
||
return frames |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import copy | ||
from torchvision import models | ||
from efficientnet_pytorch import EfficientNet | ||
|
||
class ImageEncoder(nn.Module): | ||
def __init__(self, cfg): | ||
super().__init__() | ||
self.obs_encoder_type = cfg.model.obs_encoder.type | ||
self.do_rgb_normalize = cfg.model.do_rgb_normalize | ||
self.do_resize = cfg.model.do_resize | ||
self.encoder_feat_dim = cfg.model.encoder_feat_dim | ||
|
||
if self.do_rgb_normalize: | ||
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) | ||
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) | ||
|
||
# Build the observation encoder | ||
if self.obs_encoder_type.startswith("efficientnet"): | ||
self.obs_encoder = EfficientNet.from_name(self.obs_encoder_type, in_channels=3) | ||
self.num_obs_features = self.obs_encoder._fc.in_features | ||
self.obs_encoder._fc = nn.Identity() | ||
elif self.obs_encoder_type.startswith("resnet"): | ||
model_constructor = getattr(models, self.obs_encoder_type) | ||
self.obs_encoder = model_constructor(pretrained=False) | ||
self.num_obs_features = self.obs_encoder.fc.in_features | ||
self.obs_encoder.fc = nn.Identity() | ||
elif self.obs_encoder_type.startswith("vit"): | ||
model_constructor = getattr(models, self.obs_encoder_type) | ||
self.obs_encoder = model_constructor(pretrained=False) | ||
self.num_obs_features = self.obs_encoder.hidden_dim | ||
self.obs_encoder.heads = nn.Identity() | ||
else: | ||
raise NotImplementedError(f"Observation encoder type {self.obs_encoder_type} not implemented") | ||
|
||
# Compress observation encodings to encoder_feat_dim | ||
if self.num_obs_features != self.encoder_feat_dim: | ||
self.compress_obs_enc = nn.Linear(self.num_obs_features, self.encoder_feat_dim) | ||
else: | ||
self.compress_obs_enc = nn.Identity() | ||
|
||
def forward(self, obs): | ||
""" | ||
Args: | ||
obs: (B*N, 3, H, W) | ||
Returns: | ||
embeddings: (B*N, encoder_feat_dim) | ||
""" | ||
# Pre-processing | ||
if self.do_rgb_normalize: | ||
obs = (obs - self.mean) / self.std | ||
if self.do_resize: | ||
obs = F.interpolate(obs, size=(224, 224), mode='bilinear', align_corners=False) | ||
|
||
# Observation Encoding | ||
if self.obs_encoder_type.startswith("efficientnet"): | ||
obs_enc = self.obs_encoder.extract_features(obs) | ||
obs_enc = self.obs_encoder._avg_pooling(obs_enc) | ||
obs_enc = obs_enc.flatten(start_dim=1) | ||
elif self.obs_encoder_type.startswith("resnet"): | ||
x = self.obs_encoder.conv1(obs) | ||
x = self.obs_encoder.bn1(x) | ||
x = self.obs_encoder.relu(x) | ||
x = self.obs_encoder.maxpool(x) | ||
x = self.obs_encoder.layer1(x) | ||
x = self.obs_encoder.layer2(x) | ||
x = self.obs_encoder.layer3(x) | ||
x = self.obs_encoder.layer4(x) | ||
x = self.obs_encoder.avgpool(x) | ||
obs_enc = torch.flatten(x, 1) | ||
elif self.obs_encoder_type.startswith("vit"): | ||
obs_enc = self.obs_encoder(obs) # Returns class token embedding | ||
else: | ||
raise NotImplementedError(f"Observation encoder type {self.obs_encoder_type} not implemented") | ||
|
||
# Compress embeddings | ||
obs_enc = self.compress_obs_enc(obs_enc) | ||
|
||
return obs_enc | ||
|
||
class VideoJEPA(nn.Module): | ||
def __init__(self, cfg): | ||
super().__init__() | ||
|
||
# Initialize the online encoder | ||
self.online_encoder = ImageEncoder(cfg) | ||
|
||
# Initialize the target encoder as a copy of the online encoder | ||
self.target_encoder = copy.deepcopy(self.online_encoder) | ||
# Freeze the target encoder parameters (they will be updated via EMA) | ||
for param in self.target_encoder.parameters(): | ||
param.requires_grad = False | ||
|
||
@torch.no_grad() | ||
def update_target_encoder(self, decay=0.998): | ||
# EMA update for the target encoder | ||
for param_o, param_t in zip(self.online_encoder.parameters(), self.target_encoder.parameters()): | ||
param_t.data.mul_(decay).add_(param_o.data * (1 - decay)) | ||
|
||
def forward(self, input_obs, target_obs): | ||
""" | ||
Args: | ||
input_obs: (B, N_in, 3, H, W) tensor of past frames | ||
target_obs: (B, N_target, 3, H, W) tensor of future frames | ||
Returns: | ||
online_embeddings: (B, N_in, encoder_feat_dim) | ||
target_embeddings: (B, N_target, encoder_feat_dim) | ||
""" | ||
# Process input observations through online encoder | ||
B, N_in, C, H, W = input_obs.shape | ||
input_obs_flat = input_obs.view(B * N_in, C, H, W) | ||
online_embeddings_flat = self.online_encoder(input_obs_flat) | ||
online_embeddings = online_embeddings_flat.view(B, N_in, -1) | ||
|
||
# Process target observations through target encoder with stop gradient | ||
with torch.no_grad(): | ||
B, N_target, C, H, W = target_obs.shape | ||
target_obs_flat = target_obs.view(B * N_target, C, H, W) | ||
target_embeddings_flat = self.target_encoder(target_obs_flat) | ||
target_embeddings = target_embeddings_flat.view(B, N_target, -1) | ||
|
||
return online_embeddings, target_embeddings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# data/datamodule.py | ||
|
||
import pytorch_lightning as pl | ||
from torch.utils.data import DataLoader | ||
from data.citywalk_dataset import CityWalkSampler | ||
from data.citywalk_pretrain_dataset import VideoJEPADataset | ||
|
||
class CityWalkPretrainDataModule(pl.LightningDataModule): | ||
def __init__(self, cfg): | ||
super().__init__() | ||
self.cfg = cfg | ||
self.batch_size = cfg.training.batch_size | ||
self.num_workers = cfg.data.num_workers | ||
|
||
def setup(self, stage=None): | ||
if stage == 'fit' or stage is None: | ||
self.train_dataset = VideoJEPADataset(self.cfg, mode='train') | ||
self.val_dataset = VideoJEPADataset(self.cfg, mode='val') | ||
|
||
if stage == 'test' or stage is None: | ||
self.test_dataset = VideoJEPADataset(self.cfg, mode='test') | ||
|
||
def train_dataloader(self): | ||
return DataLoader(self.train_dataset, batch_size=self.batch_size, | ||
num_workers=self.num_workers, sampler=CityWalkSampler(self.train_dataset)) | ||
|
||
def val_dataloader(self): | ||
return DataLoader(self.val_dataset, batch_size=self.batch_size, | ||
num_workers=self.num_workers, shuffle=False) | ||
|
||
def test_dataloader(self): | ||
return DataLoader(self.test_dataset, batch_size=self.batch_size, | ||
num_workers=self.num_workers, shuffle=False) |
Oops, something went wrong.