Skip to content

Commit 7a10ace

Browse files
committed
add input_target coordinate input
1 parent da5049b commit 7a10ace

File tree

4 files changed

+13
-5
lines changed

4 files changed

+13
-5
lines changed

config/toy.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ training:
1010
amp: true
1111
normalize_step_length: true
1212
resume: false
13-
direction_loss_weight: 5.0
13+
direction_loss_weight: 3.0
1414
distance_loss_weight: 1.0
1515
angle_loss_weight: 1.0
1616

@@ -21,7 +21,7 @@ scheduler:
2121

2222
optimizer:
2323
name: 'adamw'
24-
lr: 1e-5
24+
lr: 1e-4
2525

2626
model:
2727
do_rgb_normalize: true
@@ -35,14 +35,14 @@ model:
3535
output_coordinate_repr: 'euclidean'
3636

3737
cord_embedding:
38-
type: 'target'
38+
type: 'input_target'
3939
num_freqs: 6
4040
include_input: true
4141

4242
encoder_feat_dim: 768
4343

4444
decoder:
45-
type: 'diff_policy'
45+
type: 'attention'
4646
len_traj_pred: 5
4747
num_heads: 4
4848
num_layers: 4

data/citywalk_dataset.py

+5
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ def __getitem__(self, index):
159159
transformed_input_positions = self.input2target(input_poses, target_pose)
160160
elif self.cfg.model.cord_embedding.type == 'target':
161161
transformed_input_positions = self.transform_target_pose(target_pose, current_pose)[np.newaxis, [0, 2]]
162+
elif self.cfg.model.cord_embedding.type == 'input_target':
163+
transformed_input_positions = np.concatenate([
164+
self.input2target(input_poses, target_pose),
165+
self.transform_target_pose(target_pose, current_pose)[np.newaxis, [0, 2]]
166+
], axis=0)
162167
else:
163168
raise NotImplementedError(f"Coordinate embedding type {self.cfg.model.cord_embedding} not implemented")
164169
waypoints_transformed = self.transform_waypoints(waypoint_poses, current_pose)

model/urban_nav.py

+3
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def __init__(self, cfg):
6767
elif self.cord_embedding_type == 'target':
6868
self.cord_embedding = PolarEmbedding(cfg)
6969
self.dim_cord_embedding = self.cord_embedding.out_dim
70+
elif self.cord_embedding_type == 'input_target':
71+
self.cord_embedding = PolarEmbedding(cfg)
72+
self.dim_cord_embedding = self.cord_embedding.out_dim * (self.context_size + 1)
7073
else:
7174
raise NotImplementedError(f"Coordinate embedding type {self.cord_embedding_type} not implemented")
7275

pl_modules/urban_nav_module.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, cfg):
2222
raise ValueError(f"Unsupported coordinate representation: {self.output_coordinate_repr}")
2323

2424
self.decoder = cfg.model.decoder.type
25-
if self.decoder not in ["diff_policy", "transformer"]:
25+
if self.decoder not in ["diff_policy", "attention"]:
2626
raise ValueError(f"Unsupported decoder: {self.decoder}")
2727

2828
# Direction loss weight (you can adjust this value in your cfg)

0 commit comments

Comments
 (0)