Skip to content

Commit

Permalink
tdmpc2 bug fix for dtype error when running anymalc-reach (#808)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-sekai authored Jan 28, 2025
1 parent 3e948a5 commit e303b5e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion examples/baselines/tdmpc2/common/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def two_hot(x, cfg):
return symlog(x)
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1)
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1)
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1).to(torch.float32)
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
soft_two_hot.scatter_(1, bin_idx.unsqueeze(1), 1 - bin_offset)
soft_two_hot.scatter_(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset)
Expand Down

0 comments on commit e303b5e

Please sign in to comment.