Skip to content

Commit

Permalink
updates for gymnasium compat
Browse files Browse the repository at this point in the history
  • Loading branch information
Pierre Schumacher committed Aug 13, 2024
1 parent ddb31fd commit 276c667
Show file tree
Hide file tree
Showing 11 changed files with 436 additions and 413 deletions.
9 changes: 8 additions & 1 deletion deprl/env_wrappers/scone_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,14 @@ def _inner_step(self, action):
done = self.unwrapped._get_done()
self.unwrapped.time += self.step_size
self.unwrapped.total_reward += reward
return obs, reward, done, {}
truncated = (
self.unwrapped.time / self.step_size
) < self._max_episode_steps
return obs, reward, done, truncated, {}

def reset(self, *args, **kwargs):
obs = super().reset()
return obs, obs

@property
def _max_episode_steps(self):
Expand Down
4 changes: 3 additions & 1 deletion deprl/env_wrappers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def reset(self, **kwargs):
observation = super().reset(**kwargs)[0]
observation = super().reset(**kwargs)
if len(observation) == 2 and type(observation) is tuple:
observation = observation[0]
if not np.any(np.isnan(observation)):
self.last_observation = observation.copy()
else:
Expand Down
7 changes: 6 additions & 1 deletion deprl/vendor/tonic/environments/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

import gym.wrappers
import numpy as np
from myosuite.utils import gym

try:
from myosuite.utils import gym
except ModuleNotFoundError:
pass


from deprl.vendor.tonic import environments
from deprl.vendor.tonic.utils import logger
Expand Down
6 changes: 3 additions & 3 deletions experiments/hyfydy/scone_walk_opensim_h0918.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ tonic:
after_training: ''
header: "import deprl, gym, sconegym"
agent: "deprl.custom_agents.dep_factory(3, deprl.custom_mpo_torch.TunedMPO())(replay=deprl.custom_replay_buffers.AdaptiveEnergyBuffer(return_steps=1,
batch_size=256, steps_between_batches=1000, batch_iterations=30, steps_before_batches=2e5,
batch_size=256, steps_between_batches=1000, batch_iterations=30, steps_before_batches=1000,
num_acts=18))"
before_training: ''
checkpoint: "last"
Expand All @@ -11,8 +11,8 @@ tonic:
name: "sconewalk_h0918_osimv1"
resume: true
seed: 0
parallel: 20
sequential: 10
parallel: 1
sequential: 1
test_environment: null
trainer: "deprl.custom_trainer.Trainer(steps=int(5e8), epoch_steps=int(2e5), save_steps=int(1e6))"

Expand Down
2 changes: 1 addition & 1 deletion experiments/myosuite_training_files/myoChaseTag.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ tonic:
reset_type='random')
environment_name: deprl_baseline_chasetag
full_save: 1
header: import deprl, gym, myosuite
header: import deprl, myosuite; from myosuite.utils import gym
name: myoChasetag
parallel: 20
path: ./output
Expand Down
2 changes: 1 addition & 1 deletion experiments/myosuite_training_files/myoRelocate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ tonic:
environment: deprl.environments.Gym('myoChallengeRelocateP1-v0', scaled_actions=False)
environment_name: deprl_baseline_relocate
full_save: 1
header: import deprl, gym, myosuite
header: import deprl, myosuite; from myosuite.utils import gym
name: Relocate
parallel: 20
resume: 1
Expand Down
39 changes: 39 additions & 0 deletions experiments/myosuite_training_files/myoRunTrack.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
DEP:
bias_rate: 0.002
buffer_size: 200
intervention_length: 5
intervention_proba: 0.0004
kappa: 1169.7
normalization: independent
q_norm_selector: l2
regularization: 32
s4avg: 2
sensor_delay: 1
tau: 40
test_episode_every: 3
time_dist: 5
with_learning: true
env_args: {}
mpo_args:
hidden_size: 1024
lr_actor: 3.53e-05
lr_critic: 6.081e-05
lr_dual: 0.00213
tonic:
after_training: ''
agent: deprl.custom_agents.dep_factory(3, deprl.custom_mpo_torch.TunedMPO())(replay=deprl.replays.buffers.Buffer(return_steps=3,
batch_size=256, steps_between_batches=1000, batch_iterations=30, steps_before_batches=2e5))
before_training: ''
checkpoint: last
environment: deprl.environments.Gym('myoChallengeRunTrackP1-v0', scaled_actions=False)
environment_name: deprl_baseline_runtrack
full_save: 1
header: import deprl, myosuite; from myosuite.utils import gym
name: myoLeg
parallel: 20
resume: 1
seed: 0
sequential: 10
test_environment: null
trainer: deprl.custom_trainer.Trainer(steps=int(1e8), epoch_steps=int(2e5), save_steps=int(1e6))
working_dir: ./baselines_DEPRL
741 changes: 356 additions & 385 deletions poetry.lock

Large diffs are not rendered by default.

27 changes: 12 additions & 15 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,49 +13,46 @@ gitdb==4.0.11 ; python_version >= "3.9" and python_full_version <= "3.11.5"
gitpython==3.1.43 ; python_version >= "3.9" and python_full_version <= "3.11.5"
gym==0.13.0 ; python_version >= "3.9" and python_full_version <= "3.11.5"
idna==3.7 ; python_version >= "3.9" and python_full_version <= "3.11.5"
intel-openmp==2021.4.0 ; python_version >= "3.9" and python_full_version <= "3.11.5" and platform_system == "Windows"
jinja2==3.1.4 ; python_version >= "3.9" and python_full_version <= "3.11.5"
markupsafe==2.1.5 ; python_version >= "3.9" and python_full_version <= "3.11.5"
mkl==2021.4.0 ; python_version >= "3.9" and python_full_version <= "3.11.5" and platform_system == "Windows"
mpmath==1.3.0 ; python_version >= "3.9" and python_full_version <= "3.11.5"
networkx==3.2.1 ; python_version >= "3.9" and python_full_version <= "3.11.5"
numpy==1.26.4 ; python_version >= "3.9" and python_full_version <= "3.11.5"
nvidia-cublas-cu12==12.1.3.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.11.5"
nvidia-cuda-cupti-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.11.5"
nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.11.5"
nvidia-cuda-runtime-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.11.5"
nvidia-cudnn-cu12==8.9.2.26 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.11.5"
nvidia-cudnn-cu12==9.1.0.70 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.11.5"
nvidia-cufft-cu12==11.0.2.54 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.11.5"
nvidia-curand-cu12==10.3.2.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.11.5"
nvidia-cusolver-cu12==11.4.5.107 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.11.5"
nvidia-cusparse-cu12==12.1.0.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.11.5"
nvidia-nccl-cu12==2.20.5 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.11.5"
nvidia-nvjitlink-cu12==12.5.82 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.11.5"
nvidia-nvjitlink-cu12==12.6.20 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.11.5"
nvidia-nvtx-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_full_version <= "3.11.5"
pandas==2.2.2 ; python_version >= "3.9" and python_full_version <= "3.11.5"
pathtools==0.1.2 ; python_version >= "3.9" and python_full_version <= "3.11.5"
protobuf==4.25.3 ; python_version >= "3.9" and python_full_version <= "3.11.5"
protobuf==4.25.4 ; python_version >= "3.9" and python_full_version <= "3.11.5"
psutil==6.0.0 ; python_version >= "3.9" and python_full_version <= "3.11.5"
pyglet==2.0.15 ; python_version >= "3.9" and python_full_version <= "3.11.5"
pyglet==2.0.17 ; python_version >= "3.9" and python_full_version <= "3.11.5"
pysocks==1.7.1 ; python_version >= "3.9" and python_full_version <= "3.11.5"
python-dateutil==2.9.0.post0 ; python_version >= "3.9" and python_full_version <= "3.11.5"
pytz==2024.1 ; python_version >= "3.9" and python_full_version <= "3.11.5"
pyyaml==6.0.1 ; python_version >= "3.9" and python_full_version <= "3.11.5"
pyyaml==6.0.2 ; python_version >= "3.9" and python_full_version <= "3.11.5"
requests==2.32.3 ; python_version >= "3.9" and python_full_version <= "3.11.5"
requests[socks]==2.32.3 ; python_version >= "3.9" and python_full_version <= "3.11.5"
scipy==1.13.1 ; python_version >= "3.9" and python_full_version <= "3.11.5"
sentry-sdk==2.9.0 ; python_version >= "3.9" and python_full_version <= "3.11.5"
sentry-sdk==2.13.0 ; python_version >= "3.9" and python_full_version <= "3.11.5"
setproctitle==1.3.3 ; python_version >= "3.9" and python_full_version <= "3.11.5"
setuptools==70.3.0 ; python_version >= "3.9" and python_full_version <= "3.11.5"
setuptools==72.1.0 ; python_version >= "3.9" and python_full_version <= "3.11.5"
six==1.16.0 ; python_version >= "3.9" and python_full_version <= "3.11.5"
smmap==5.0.1 ; python_version >= "3.9" and python_full_version <= "3.11.5"
soupsieve==2.5 ; python_version >= "3.9" and python_full_version <= "3.11.5"
sympy==1.13.0 ; python_version >= "3.9" and python_full_version <= "3.11.5"
tbb==2021.13.0 ; python_version >= "3.9" and python_full_version <= "3.11.5" and platform_system == "Windows"
soupsieve==2.6 ; python_version >= "3.9" and python_full_version <= "3.11.5"
sympy==1.13.2 ; python_version >= "3.9" and python_full_version <= "3.11.5"
termcolor==2.4.0 ; python_version >= "3.9" and python_full_version <= "3.11.5"
torch==2.3.1 ; python_version >= "3.9" and python_full_version <= "3.11.5"
tqdm==4.66.4 ; python_version >= "3.9" and python_full_version <= "3.11.5"
triton==2.3.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version <= "3.11.5" and python_version >= "3.9"
torch==2.4.0 ; python_version >= "3.9" and python_full_version <= "3.11.5"
tqdm==4.66.5 ; python_version >= "3.9" and python_full_version <= "3.11.5"
triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version <= "3.11.5" and python_version >= "3.9"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_full_version <= "3.11.5"
tzdata==2024.1 ; python_version >= "3.9" and python_full_version <= "3.11.5"
urllib3==2.2.2 ; python_version >= "3.9" and python_full_version <= "3.11.5"
Expand Down
5 changes: 3 additions & 2 deletions tests/test_deprl.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import shutil
import sys

import gym
import myosuite # noqa
import torch
from myosuite.utils import gym

import deprl
from deprl import main, play
Expand Down Expand Up @@ -79,7 +79,8 @@ def step(self, action):


if __name__ == "__main__":
# test_exception()
test_play()
test_train()
test_load_resume()
test_load_no_resume()
test_exception()
7 changes: 4 additions & 3 deletions tests/test_myosuite_baselines.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import gym
import myosuite # noqa
import numpy as np
import torch
from myosuite.utils import gym

import deprl

Expand All @@ -17,13 +17,14 @@ def helper_env_loop(env):
env.seed(SEED)
for ep in range(10):
ret = 0
obs = env.reset()
obs = env.reset()[0]
for i in range(2000):
action = policy.noisy_test_step(obs)
obs, reward, done, _ = env.step(action)
obs, reward, terminated, truncated, info = env.step(action)
# env.mj_render()
ret += reward
qpos.append(env.sim.data.qpos[1])
done = terminated or truncated
if done:
break
returns.append(ret)
Expand Down

0 comments on commit 276c667

Please sign in to comment.