forked from StanfordVL/OmniGibson
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_data_collection.py
143 lines (123 loc) · 4.32 KB
/
test_data_collection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import tempfile
import pytest
import torch as th
import omnigibson as og
from omnigibson.envs import DataCollectionWrapper, DataPlaybackWrapper
from omnigibson.macros import gm
from omnigibson.objects import DatasetObject
def test_data_collect_and_playback():
cfg = {
"env": {
"external_sensors": [],
},
"scene": {
"type": "InteractiveTraversableScene",
"scene_model": "Rs_int",
"load_object_categories": ["floors", "breakfast_table"],
},
"robots": [
{
"type": "Fetch",
"obs_modalities": [],
}
],
# Task kwargs
"task": {
"type": "BehaviorTask",
# BehaviorTask-specific
"activity_name": "assembling_gift_baskets",
"online_object_sampling": True,
},
}
if og.sim is None:
# Make sure GPU dynamics are enabled (GPU dynamics needed for cloth) and no flatcache
gm.ENABLE_OBJECT_STATES = True
gm.USE_GPU_DYNAMICS = True
gm.ENABLE_FLATCACHE = True
gm.ENABLE_TRANSITION_RULES = False
else:
# Make sure sim is stopped
og.sim.stop()
# Create temp file to save data
_, collect_hdf5_path = tempfile.mkstemp("test_data_collection.hdf5", dir=og.tempdir)
_, playback_hdf5_path = tempfile.mkstemp("test_data_playback.hdf5", dir=og.tempdir)
# Create the environment (wrapped as a DataCollection env)
env = og.Environment(configs=cfg)
env = DataCollectionWrapper(
env=env,
output_path=collect_hdf5_path,
only_successes=False,
)
# Record 2 episodes
for i in range(2):
env.reset()
for _ in range(5):
env.step(env.robots[0].action_space.sample())
# Manually add a random object, e.g.: a banana, and place on the floor
obj = DatasetObject(name="banana", category="banana")
env.scene.add_object(obj)
obj.set_position(th.ones(3, dtype=th.float32) * 10.0)
# Take a few more steps
for _ in range(5):
env.step(env.robots[0].action_space.sample())
# Manually remove the added object
env.scene.remove_object(obj)
# Take a few more steps
for _ in range(5):
env.step(env.robots[0].action_space.sample())
# Add water particles
water = env.scene.get_system("water")
pos = th.rand(10, 3, dtype=th.float32) * 10.0
water.generate_particles(positions=pos)
# Take a few more steps
for _ in range(5):
env.step(env.robots[0].action_space.sample())
# Clear the system
env.scene.clear_system("water")
# Take a few more steps
for _ in range(5):
env.step(env.robots[0].action_space.sample())
# Save this data
env.save_data()
# Clear the sim
og.clear(
physics_dt=0.001,
rendering_dt=0.001,
sim_step_dt=0.001,
)
# Define robot sensor config and external sensors to use during playback
robot_sensor_config = {
"VisionSensor": {
"sensor_kwargs": {
"image_height": 128,
"image_width": 128,
},
},
}
external_sensors_config = [
{
"sensor_type": "VisionSensor",
"name": "external_sensor0",
"relative_prim_path": f"/robot0/root_link/external_sensor0",
"modalities": ["rgb", "seg_semantic"],
"sensor_kwargs": {
"image_height": 128,
"image_width": 128,
"focal_length": 12.0,
},
"local_position": th.tensor([-0.26549, -0.30288, 1.0 + 0.861], dtype=th.float32),
"local_orientation": th.tensor([0.36165891, -0.24745751, -0.50752921, 0.74187715], dtype=th.float32),
},
]
# Create a playback env and playback the data, collecting obs along the way
env = DataPlaybackWrapper.create_from_hdf5(
input_path=collect_hdf5_path,
output_path=playback_hdf5_path,
robot_obs_modalities=["proprio", "rgb", "depth_linear"],
robot_sensor_config=robot_sensor_config,
external_sensors_config=external_sensors_config,
n_render_iterations=1,
only_successes=False,
)
env.playback_dataset(record=True)
env.save_data()