Skip to content

Commit 84c3579

Browse files
authored
Merge pull request martius-lab#12 from martius-lab/gymnasium_update
Gymnasium update
2 parents a8099c3 + 8e95a42 commit 84c3579

25 files changed

+491
-488
lines changed

.github/workflows/python-app.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ jobs:
7070
python3 -m pip install --upgrade pip
7171
pip3 install -e .
7272
pip3 install -r requirements.txt
73-
pip3 install myosuite==2.1.3
73+
pip3 install myosuite==2.5.0
7474
pip3 install pytest
7575
7676
- name: Run Test environment

deprl/custom_distributed.py

+2
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ def distribute(
264264
env=environment, parallel=parallel, sequential=sequential
265265
)
266266

267+
if "header" in tonic_conf:
268+
exec(tonic_conf["header"])
267269
dummy_environment = build_env_from_dict(build_dict)
268270
max_episode_steps = dummy_environment._max_episode_steps
269271
del dummy_environment

deprl/dep_controller.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from collections import deque
44

5-
import gym
5+
import gymnasium as gym
66
import torch
77

88
torch.set_default_dtype(torch.float32)

deprl/env_wrappers/gym_wrapper.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,7 @@ def muscle_activity(self):
5757

5858
@property
5959
def _max_episode_steps(self):
60-
return self.unwrapped.max_episode_steps
60+
if hasattr(self.unwrapped, "max_episode_steps"):
61+
return self.unwrapped.max_episode_steps
62+
else:
63+
return self.unwrapped.horizon

deprl/env_wrappers/scone_wrapper.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,14 @@ def _inner_step(self, action):
7272
done = self.unwrapped._get_done()
7373
self.unwrapped.time += self.step_size
7474
self.unwrapped.total_reward += reward
75-
return obs, reward, done, {}
75+
truncated = (
76+
self.unwrapped.time / self.step_size
77+
) < self._max_episode_steps
78+
return obs, reward, done, truncated, {}
79+
80+
def reset(self, *args, **kwargs):
81+
obs = super().reset()
82+
return obs, obs
7683

7784
@property
7885
def _max_episode_steps(self):

deprl/env_wrappers/wrappers.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22

3-
import gym
3+
import gymnasium as gym
44
import numpy as np
55

66
import deprl # noqa
@@ -89,6 +89,8 @@ def __init__(self, *args, **kwargs):
8989

9090
def reset(self, **kwargs):
9191
observation = super().reset(**kwargs)
92+
if len(observation) == 2 and type(observation) is tuple:
93+
observation = observation[0]
9294
if not np.any(np.isnan(observation)):
9395
self.last_observation = observation.copy()
9496
else:
@@ -97,10 +99,16 @@ def reset(self, **kwargs):
9799

98100
def step(self, action):
99101
try:
100-
observation, reward, done, info = self._inner_step(action)
102+
(
103+
observation,
104+
reward,
105+
terminated,
106+
truncated,
107+
info,
108+
) = self._inner_step(action)
101109
if np.any(np.isnan(observation)):
102110
raise self.error("NaN detected! Resetting.")
103-
111+
done = terminated or truncated
104112
except self.error as e:
105113
logger.log(f"Simulator exception thrown: {e}")
106114
observation = self.last_observation

deprl/log.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import os
33
import time
44

5-
import wandb
65
import yaml
76

7+
import wandb
88
from deprl.vendor.tonic import utils
99

1010

deprl/vendor/tonic/environments/builders.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,49 @@
33
import os
44
from types import SimpleNamespace
55

6-
import gym.wrappers
76
import numpy as np
7+
from gymnasium import wrappers
8+
9+
try:
10+
from myosuite.utils import gym
11+
except ModuleNotFoundError:
12+
pass
13+
814

915
from deprl.vendor.tonic import environments
1016
from deprl.vendor.tonic.utils import logger
1117

1218

1319
def gym_environment(*args, **kwargs):
1420
"""Returns a wrapped Gym environment."""
21+
if "header" in kwargs:
22+
kwargs.pop("header")
1523

1624
def _builder(*args, **kwargs):
1725
return gym.make(*args, **kwargs)
1826

19-
return build_environment(_builder, *args, **kwargs)
27+
return build_environment(_builder, *args, **kwargs, header=None)
2028

2129

2230
def bullet_environment(*args, **kwargs):
2331
"""Returns a wrapped PyBullet environment."""
32+
if "header" in kwargs:
33+
kwargs.pop("header")
2434

2535
def _builder(*args, **kwargs):
2636
import pybullet_envs # noqa
2737

2838
return gym.make(*args, **kwargs)
2939

30-
return build_environment(_builder, *args, **kwargs)
40+
return build_environment(_builder, *args, **kwargs, header=None)
3141

3242

3343
def control_suite_environment(*args, **kwargs):
3444
"""Returns a wrapped Control Suite environment."""
3545

46+
if "header" in kwargs:
47+
kwargs.pop("header")
48+
3649
def _builder(name, *args, **kwargs):
3750
domain, task = name.split("-")
3851
environment = ControlSuiteEnvironment(
@@ -42,9 +55,9 @@ def _builder(name, *args, **kwargs):
4255
environment.spec = SimpleNamespace(
4356
max_episode_steps=time_limit, id="ostrichrl-dmcontrol"
4457
)
45-
return gym.wrappers.TimeLimit(environment, time_limit)
58+
return wrappers.TimeLimit(environment, time_limit)
4659

47-
return build_environment(_builder, *args, **kwargs)
60+
return build_environment(_builder, *args, **kwargs, header=None)
4861

4962

5063
def build_environment(
@@ -54,6 +67,7 @@ def build_environment(
5467
time_feature=False,
5568
max_episode_steps="default",
5669
scaled_actions=True,
70+
header=None,
5771
*args,
5872
**kwargs,
5973
):
@@ -62,6 +76,8 @@ def build_environment(
6276
time_feature=True, see https://arxiv.org/pdf/1712.00378.pdf for more
6377
details.
6478
"""
79+
if header is not None:
80+
exec(header)
6581

6682
# Build the environment.
6783
environment = builder(name, *args, **kwargs)
@@ -81,7 +97,7 @@ def build_environment(
8197

8298
# Remove the TimeLimit wrapper if needed.
8399
if not terminal_timeouts:
84-
if type(environment) == gym.wrappers.TimeLimit:
100+
if type(environment) == wrappers.TimeLimit:
85101
environment = environment.env
86102

87103
# Add time as a feature if needed.

deprl/vendor/tonic/environments/wrappers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Environment wrappers."""
22

3-
import gym
3+
import gymnasium as gym
44
import numpy as np
55

66

examples/example_load_baseline_myosuite.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33

44
import time
55

6-
import gym
76
import myosuite # noqa
7+
from myosuite.utils import gym
88

99
import deprl
10+
from deprl import env_wrappers
1011

11-
# create the sconegym env
12-
env = gym.make("myoChallengeChaseTagP1-v0")
12+
env = gym.make("myoLegWalk-v0", reset_type="random")
13+
env = env_wrappers.GymWrapper(env)
1314
policy = deprl.load_baseline(env)
1415

1516
env.seed(0)
@@ -36,5 +37,3 @@
3637
)
3738
env.reset()
3839
break
39-
40-
env.close()

examples/example_only_dep_myosuite.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import time
22

3-
import gym
43
import myosuite # noqa
4+
from myosuite.utils import gym
55

66
from deprl import env_wrappers
77
from deprl.dep_controller import DEP

experiments/hyfydy/scone_walk_opensim_h0918.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ tonic:
22
after_training: ''
33
header: "import deprl, gym, sconegym"
44
agent: "deprl.custom_agents.dep_factory(3, deprl.custom_mpo_torch.TunedMPO())(replay=deprl.custom_replay_buffers.AdaptiveEnergyBuffer(return_steps=1,
5-
batch_size=256, steps_between_batches=1000, batch_iterations=30, steps_before_batches=2e5,
5+
batch_size=256, steps_between_batches=1000, batch_iterations=30, steps_before_batches=1000,
66
num_acts=18))"
77
before_training: ''
88
checkpoint: "last"
@@ -11,8 +11,8 @@ tonic:
1111
name: "sconewalk_h0918_osimv1"
1212
resume: true
1313
seed: 0
14-
parallel: 20
15-
sequential: 10
14+
parallel: 1
15+
sequential: 1
1616
test_environment: null
1717
trainer: "deprl.custom_trainer.Trainer(steps=int(5e8), epoch_steps=int(2e5), save_steps=int(1e6))"
1818

experiments/myosuite_training_files/myoChaseTag.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ tonic:
2929
reset_type='random')
3030
environment_name: deprl_baseline_chasetag
3131
full_save: 1
32-
header: import deprl, gym, myosuite
32+
header: import deprl, myosuite; from myosuite.utils import gym
3333
name: myoChasetag
3434
parallel: 20
3535
path: ./output

experiments/myosuite_training_files/myoLegWalk.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ tonic:
2828
environment: deprl.environments.Gym('myoLegWalk-v0', scaled_actions=False, reset_type='random')
2929
environment_name: deprl_baseline
3030
full_save: 1
31-
header: import deprl, gym, myosuite
31+
header: import deprl, myosuite; from myosuite.utils import gym
3232
name: myoLeg
3333
parallel: 20
3434
resume: 1

experiments/myosuite_training_files/myoRelocate.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ tonic:
2323
environment: deprl.environments.Gym('myoChallengeRelocateP1-v0', scaled_actions=False)
2424
environment_name: deprl_baseline_relocate
2525
full_save: 1
26-
header: import deprl, gym, myosuite
26+
header: import deprl, myosuite; from myosuite.utils import gym
2727
name: Relocate
2828
parallel: 20
2929
resume: 1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
DEP:
2+
bias_rate: 0.002
3+
buffer_size: 200
4+
intervention_length: 5
5+
intervention_proba: 0.0004
6+
kappa: 1169.7
7+
normalization: independent
8+
q_norm_selector: l2
9+
regularization: 32
10+
s4avg: 2
11+
sensor_delay: 1
12+
tau: 40
13+
test_episode_every: 3
14+
time_dist: 5
15+
with_learning: true
16+
env_args: {}
17+
mpo_args:
18+
hidden_size: 1024
19+
lr_actor: 3.53e-05
20+
lr_critic: 6.081e-05
21+
lr_dual: 0.00213
22+
tonic:
23+
after_training: ''
24+
agent: deprl.custom_agents.dep_factory(3, deprl.custom_mpo_torch.TunedMPO())(replay=deprl.replays.buffers.Buffer(return_steps=3,
25+
batch_size=256, steps_between_batches=1000, batch_iterations=30, steps_before_batches=2e5))
26+
before_training: ''
27+
checkpoint: last
28+
environment: deprl.environments.Gym('myoChallengeRunTrackP1-v0', scaled_actions=False)
29+
environment_name: deprl_baseline_runtrack
30+
full_save: 1
31+
header: import deprl, myosuite; from myosuite.utils import gym
32+
name: myoLeg
33+
parallel: 20
34+
resume: 1
35+
seed: 0
36+
sequential: 10
37+
test_environment: null
38+
trainer: deprl.custom_trainer.Trainer(steps=int(1e8), epoch_steps=int(2e5), save_steps=int(1e6))
39+
working_dir: ./baselines_DEPRL

0 commit comments

Comments
 (0)