Skip to content

Commit

Permalink
update test end step
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaaaavin committed Oct 18, 2024
1 parent 78a5f81 commit ef8626a
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions pl_modules/urban_nav_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, cfg):
self.image_std = np.array([0.229, 0.224, 0.225])

# Initialize list to store test metrics
self.test_metrics = []
self.test_metrics = {'l1_loss': [], 'arrived_accuracy': []}

def forward(self, obs, cord):
return self.model(obs, cord)
Expand Down Expand Up @@ -94,10 +94,8 @@ def test_step(self, batch, batch_idx):
accuracy = correct.sum().item() / correct.numel()

# Store the metrics
self.test_metrics.append({
'l1_loss': l1_loss,
'arrived_accuracy': accuracy
})
self.test_metrics['l1_loss'].append(l1_loss)
self.test_metrics['arrived_accuracy'].append(accuracy)

# Handle visualization
self.process_visualization(
Expand All @@ -110,10 +108,15 @@ def test_step(self, batch, batch_idx):

def on_test_epoch_end(self):
# Save the test metrics to a .npy file
test_metrics_array = np.array(self.test_metrics, dtype=object)
test_metrics_path = os.path.join(self.result_dir, 'test_metrics.npy')
np.save(test_metrics_path, test_metrics_array)
print(f"Test metrics saved to {test_metrics_path}")
l1_loss_array = np.array(self.test_metrics['l1_loss'])
accuracy_array = np.array(self.test_metrics['arrived_accuracy'])
l1_loss_save_path = os.path.join(self.result_dir, 'test_l1_loss.npy')
accuracy_save_path = os.path.join(self.result_dir, 'test_arrived_accuracy.npy')
np.save(l1_loss_save_path, l1_loss_array)
np.save(accuracy_save_path, accuracy_array)
print(f"Test mean L1 loss {l1_loss_array.mean():.4f} saved to {l1_loss_save_path}")
print(f"Test mean arrived accuracy {accuracy_array.mean():.4f} saved to {accuracy_save_path}")
# print(f"Test metrics saved to {test_metrics_path}")

def on_validation_epoch_start(self):
self.vis_count = 0
Expand Down

0 comments on commit ef8626a

Please sign in to comment.