Skip to content

Commit

Permalink
Merge pull request UM-ARM-Lab#41 from StoneT2000/bug-fixes
Browse files Browse the repository at this point in the history
Fix some bugs in recently raised issues (capsule support, default torch device bug, and mujoco dependency)

Closes UM-ARM-Lab#40 , closes UM-ARM-Lab#39 , and closes UM-ARM-Lab#35
  • Loading branch information
LemonPi authored Aug 20, 2024
2 parents 7d4e197 + cb2c95a commit 403e1a6
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 3 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
python -m pip install .[test]
python -m pip install flake8 pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install mujoco
- name: Clone mujoco_menagerie repository into the tests/ folder
run: |
git clone https://github.com/google-deepmind/mujoco_menagerie
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ classifiers = [# Optional
dependencies = [
'absl-py',
'lxml',
"mujoco",
'numpy<2', # pybullet requires numpy<2 for testing; for future versions this may be relaxed
'pyyaml',
'torch',
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_kinematics/urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _convert_transform(origin):
if origin is None:
return tf.Transform3d()
else:
rpy = torch.tensor(origin.rpy, dtype=torch.float32)
rpy = torch.tensor(origin.rpy, dtype=torch.float32, device="cpu")
return tf.Transform3d(rot=tf.quaternion_from_euler(rpy, "sxyz"), pos=origin.xyz)


Expand Down
12 changes: 11 additions & 1 deletion src/pytorch_kinematics/urdf_parser_py/urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ def __init__(self, radius=0.0, length=0.0):
xmlr.Attribute('length', float)
])

class Capsule(xmlr.Object):
def __init__(self, radius=0.0, length=0.0):
self.radius = radius
self.length = length

xmlr.reflect(Capsule, tag='capsule', params=[
xmlr.Attribute('radius', float),
xmlr.Attribute('length', float)
])

class Sphere(xmlr.Object):
def __init__(self, radius=0.0):
Expand Down Expand Up @@ -130,7 +139,8 @@ def __init__(self):
'box': Box,
'cylinder': Cylinder,
'sphere': Sphere,
'mesh': Mesh
'mesh': Mesh,
'capsule': Capsule
})

def from_xml(self, node, path):
Expand Down

0 comments on commit 403e1a6

Please sign in to comment.