Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onshape/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.35"
__version__ = "0.2.36"
12 changes: 10 additions & 2 deletions onshape/formats/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utility functions common to different formats."""

from __future__ import annotations

import io
import json
import re
Expand Down Expand Up @@ -64,9 +66,15 @@ def from_json(cls, json_path: Path) -> Self:
return cls(**data)


def save_xml(path: str | Path | io.StringIO, tree: ET.ElementTree | ET.Element) -> None:
def save_xml(
path: str | Path | io.StringIO,
tree: ET.ElementTree[ET.Element | None] | ET.ElementTree[ET.Element] | ET.Element,
) -> None:
if isinstance(tree, ET.ElementTree):
tree = tree.getroot()
root = tree.getroot()
if root is None:
raise ValueError("ElementTree has no root element")
tree = root
xmlstr = minidom.parseString(ET.tostring(tree)).toprettyxml(indent=" ")
xmlstr = re.sub(r"\n\s*\n", "\n", xmlstr)
if isinstance(path, io.StringIO):
Expand Down
40 changes: 40 additions & 0 deletions onshape/formats/mjcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ class CollisionParams:
friction: list[float] = field(default_factory=lambda: [0.8, 0.02, 0.01])


@dataclass
class SiteMetadata:
name: str = field()
body_name: str = field()
site_type: str | None = field(default=None)
pos: list[float] | None = field(default=None) # (x, y, z)
size: list[float] | None = field(default=None) # (size_x, size_y, size_z)


@dataclass
class ImuSensor:
body_name: str = field()
Expand All @@ -40,6 +49,14 @@ class ForceSensor:
noise: float | None = field(default=None)


@dataclass
class TouchSensor:
body_name: str = field()
site_name: str = field()
name: str | None = field(default=None)
noise: float | None = field(default=None)


@dataclass
class ExplicitFloorContacts:
contact_links: list[str] = field(default_factory=lambda: [])
Expand Down Expand Up @@ -78,8 +95,10 @@ class ConversionMetadata:
suffix: str | None = field(default=None)
freejoint: bool = field(default=True)
collision_params: CollisionParams = field(default_factory=lambda: CollisionParams())
sites: list[SiteMetadata] = field(default_factory=lambda: [])
imus: list[ImuSensor] = field(default_factory=lambda: [])
force_sensors: list[ForceSensor] = field(default_factory=lambda: [])
touch_sensors: list[TouchSensor] = field(default_factory=lambda: [])
collision_geometries: list[CollisionGeometry] = field(default_factory=lambda: [])
explicit_contacts: ExplicitFloorContacts | None = field(default_factory=ExplicitFloorContacts)
weld_constraints: list[WeldConstraint] = field(default_factory=lambda: [])
Expand Down Expand Up @@ -120,6 +139,8 @@ def convert_to_mjcf_metadata(metadata: ConversionMetadata) -> "ConversionMetadat
ForceSensor as ForceSensorRef,
ImuSensor as ImuSensorRef,
JointMetadata as JointMetadataRef,
SiteMetadata as SiteMetadataRef,
TouchSensor as TouchSensorRef,
WeldConstraint as WeldConstraintRef,
)

Expand Down Expand Up @@ -150,6 +171,16 @@ def convert_to_mjcf_metadata(metadata: ConversionMetadata) -> "ConversionMetadat
solref=metadata.collision_params.solref,
friction=metadata.collision_params.friction,
),
sites=[
SiteMetadataRef(
name=site.name,
body_name=site.body_name,
site_type=cast(urdf2mjcf.model.SiteType | None, site.site_type),
pos=site.pos,
size=site.size,
)
for site in metadata.sites
],
imus=[
ImuSensorRef(
body_name=imu.body_name,
Expand All @@ -170,6 +201,15 @@ def convert_to_mjcf_metadata(metadata: ConversionMetadata) -> "ConversionMetadat
)
for fs in metadata.force_sensors
],
touch_sensors=[
TouchSensorRef(
body_name=ts.body_name,
site_name=ts.site_name,
name=ts.name,
noise=ts.noise,
)
for ts in metadata.touch_sensors
],
collision_geometries=[
CollisionGeometryRef(
name=cg.name,
Expand Down
2 changes: 2 additions & 0 deletions onshape/onshape/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Defines the config class."""

from __future__ import annotations

import argparse
from dataclasses import dataclass, field
from pathlib import Path
Expand Down
10 changes: 5 additions & 5 deletions onshape/onshape/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def get_graph(
Returns:
The graph of the assembly, along with a set of the ignored joints.
"""
graph = nx.Graph()
graph: nx.Graph = nx.Graph()

if key_namer is None:
key_namer = KeyNamer(assembly)
Expand Down Expand Up @@ -376,7 +376,7 @@ def add_edge_safe(node_a: Key, node_b: Key, joint: Key) -> None:
),
)

return graph
return graph, set()


def get_digraph(graph: nx.Graph, override_central_node: Key | None = None) -> tuple[nx.DiGraph, Key]:
Expand All @@ -387,15 +387,15 @@ def get_digraph(graph: nx.Graph, override_central_node: Key | None = None) -> tu
override_central_node: The central node to use.

Returns:
The central node of the graph and the directed graph
The directed graph and the central node of the graph.
"""
# Gets the most central node as the "root" node.
central_node: Key
if override_central_node is not None:
central_node = override_central_node
else:
closeness_centrality = nx.closeness_centrality(graph)
central_node = max(closeness_centrality, key=closeness_centrality.get)
central_node = max(closeness_centrality.keys(), key=lambda k: closeness_centrality[k])
return nx.bfs_tree(graph, central_node), central_node


Expand Down Expand Up @@ -749,7 +749,7 @@ async def check_document(
key_namer = KeyNamer(assembly)

try:
graph = get_graph(assembly, key_namer, key_to_part_instance, key_to_mate_feature)
graph, _ = get_graph(assembly, key_namer, key_to_part_instance, key_to_mate_feature)

except ValueError as e:
raise FailedCheckError(
Expand Down
2 changes: 2 additions & 0 deletions onshape/onshape/postprocess.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Defines functions for post-processing the downloaded URDF."""

from __future__ import annotations

import asyncio
import logging
import sys
Expand Down
8 changes: 6 additions & 2 deletions onshape/passes/merge_fixed_joints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Defines functions for merging URDF parts at fixed joints."""

from __future__ import annotations

import argparse
import logging
import xml.etree.ElementTree as ET
Expand Down Expand Up @@ -238,10 +240,10 @@ def fuse_child_into_parent(root: ET.Element, joint: ET.Element, urdf_dir: Path)


def process_fixed_joints(
urdf_etree: ET.ElementTree,
urdf_etree: ET.ElementTree[ET.Element],
urdf_path: Path,
ignore_merging_fixed_joints: list[str] | None = None,
) -> ET.ElementTree:
) -> ET.ElementTree[ET.Element]:
"""Iteratively fuses all fixed joints until none remain.

This greedily fuses all child links into their parent links at fixed joints,
Expand All @@ -257,6 +259,8 @@ def process_fixed_joints(
The URDF element tree with all fixed joints fused.
"""
root = urdf_etree.getroot()
if root is None:
raise ValueError("URDF ElementTree has no root element")

ignore_set = set([] if ignore_merging_fixed_joints is None else ignore_merging_fixed_joints)
visited_set: set[str] = set()
Expand Down