Skip to content

Commit

Permalink
added obs normalization to the sac (Denys88#186)
Browse files Browse the repository at this point in the history
Co-authored-by: Denys Makoviichuk <[email protected]>
  • Loading branch information
Denys88 and DenSumy authored Jun 26, 2022
1 parent 7af30eb commit d5290c6
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions rl_games/algos_torch/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self, base_name, params):
self.env_info['action_space'].shape,
self.replay_buffer_size,
self._device)
self.target_entropy_coef = config.get("target_entropy_coef", 0.5)
self.target_entropy_coef = config.get("target_entropy_coef", 1.0)
self.target_entropy = self.target_entropy_coef * -self.env_info['action_space'].shape[0]
print("Target entropy", self.target_entropy)

Expand Down Expand Up @@ -196,8 +196,9 @@ def get_weights(self):
return state

def save(self, fn):
state = self.get_full_state_weights()
torch_ext.save_checkpoint(fn, state)
pass
#state = self.get_full_state_weights()
#torch_ext.save_checkpoint(fn, state)

def set_weights(self, weights):
self.model.sac_network.actor.load_state_dict(weights['actor'])
Expand Down Expand Up @@ -258,7 +259,7 @@ def update_actor_and_alpha(self, obs, step):
dist = self.model.actor(obs)
action = dist.rsample()
log_prob = dist.log_prob(action).sum(-1, keepdim=True)
entropy = dist.entropy().sum(-1, keepdim=True).mean()
entropy = -log_prob.mean() #dist.entropy().sum(-1, keepdim=True).mean()
actor_Q1, actor_Q2 = self.model.critic(obs, action)
actor_Q = torch.min(actor_Q1, actor_Q2)

Expand Down Expand Up @@ -294,7 +295,6 @@ def update(self, step):

obs = self.preproc_obs(obs)
next_obs = self.preproc_obs(next_obs)

critic_loss, critic1_loss, critic2_loss = self.update_critic(obs, action, reward, next_obs, not_done, step)

actor_loss, entropy, alpha, alpha_loss = self.update_actor_and_alpha(obs, step)
Expand All @@ -307,6 +307,7 @@ def update(self, step):
def preproc_obs(self, obs):
if isinstance(obs, dict):
obs = obs['obs']
obs = self.model.norm_obs(obs)
return obs

def cast_obs(self, obs):
Expand Down

0 comments on commit d5290c6

Please sign in to comment.