Skip to content

Commit efa3476

Browse files
committed
Merge branch 'main' of https://github.com/mar10/nutree
2 parents eedc034 + 3166ed7 commit efa3476

File tree

11 files changed

+544
-463
lines changed

11 files changed

+544
-463
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
33
# Ruff version.
4-
rev: v0.6.6
4+
rev: v0.9.5
55
hooks:
66
# Run the linter.
77
- id: ruff

CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
# Changelog
22

3-
## 1.0.1 (unreleased)
3+
## 1.1.0 (unreleased)
44

5+
- DEPRECATE: `TypedTree.iter_by_type()`. Use `iterator(.., kind)`instead.
6+
- New methods `TypedTree.iterator(..., kind=ANY_KIND)`,
7+
`TypedNode.iterator(..., kind=ANY_KIND)`,
8+
and `TypedTree.count_descendants(leaves_only=False, kind=ANY_KIND)`
9+
510
## 1.0.0 (2024-12-27)
611
- Add benchmarks (using [Benchman](https://github.com/mar10/benchman)).
712
- Drop support for Python 3.8

Pipfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ pytest = "*"
1717
pytest-cov = "*"
1818
PyYAML = "*"
1919
rdflib = "*"
20-
ruff = "*"
20+
ruff = "~=0.9"
2121
setuptools = ">=42.0"
2222
Sphinx = "*"
2323
sphinx_rtd_theme = "*"

Pipfile.lock

Lines changed: 450 additions & 438 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/sphinx/ug_graphs.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ Navigation methods are type-aware now::
356356
assert cause1.get_index() == 0
357357
assert cause1.get_index(any_kind=True) == 2
358358

359-
assert len(list(tree.iter_by_type("effect"))) == 3
359+
assert len(list(tree.iterator(kind="effect"))) == 3
360360

361361
Keep in mind that a tree node is unique within a tree, but may reference identical
362362
data objects, so these `clones` could exist at different locations of tree.

nutree/tree.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,9 +1022,9 @@ def _self_check(self) -> Literal[True]:
10221022
# assert node._data_id == self.calc_data_id(node.data), node
10231023
assert node._data_id in self._nodes_by_data_id, node
10241024
assert node._node_id == id(node), f"{node}: {node._node_id} != {id(node)}"
1025-
assert (
1026-
node._children is None or len(node._children) > 0
1027-
), f"{node}: {node._children}"
1025+
assert node._children is None or len(node._children) > 0, (
1026+
f"{node}: {node._children}"
1027+
)
10281028

10291029
assert len(self._node_by_id) == len(node_list)
10301030

nutree/tree_generator.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ class Randomizer(ABC):
4343
"""
4444

4545
def __init__(self, *, probability: float = 1.0) -> None:
46-
assert (
47-
isinstance(probability, float) and 0.0 <= probability <= 1.0
48-
), f"probality must be in the range [0.0 .. 1.0]: {probability}"
46+
assert isinstance(probability, float) and 0.0 <= probability <= 1.0, (
47+
f"probality must be in the range [0.0 .. 1.0]: {probability}"
48+
)
4949
self.probability = probability
5050

5151
def _skip_value(self) -> bool:
@@ -84,9 +84,9 @@ def __init__(
8484
none_value: Any = None,
8585
) -> None:
8686
super().__init__(probability=probability)
87-
assert type(min_val) is type(
88-
max_val
89-
), f"min_val and max_val must be of the same type: {min_val}, {max_val}"
87+
assert type(min_val) is type(max_val), (
88+
f"min_val and max_val must be of the same type: {min_val}, {max_val}"
89+
)
9090
self.is_float = isinstance(min_val, float)
9191
self.min = min_val
9292
self.max = max_val
@@ -129,19 +129,19 @@ def __init__(
129129
) -> None:
130130
super().__init__(probability=probability)
131131
assert isinstance(min_dt, date), f"min_dt must be a date: {min_dt}"
132-
assert isinstance(
133-
max_dt, (date, int)
134-
), f"max_dt must be a date or int: {max_dt}"
132+
assert isinstance(max_dt, (date, int)), (
133+
f"max_dt must be a date or int: {max_dt}"
134+
)
135135

136136
if isinstance(max_dt, int):
137137
self.delta_days = max_dt
138138
max_dt = min_dt + timedelta(days=self.delta_days)
139139
else:
140140
self.delta_days = (max_dt - min_dt).days
141141

142-
assert (
143-
max_dt > min_dt
144-
), f"max_dt must be greater than min_dt: {min_dt}, {max_dt}"
142+
assert max_dt > min_dt, (
143+
f"max_dt must be greater than min_dt: {min_dt}, {max_dt}"
144+
)
145145

146146
self.min = min_dt
147147
self.max = max_dt

nutree/typed_tree.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ROOT_NODE_ID,
2828
DataIdType,
2929
DeserializeMapperType,
30+
IterMethod,
3031
KeyMapType,
3132
MapperCallbackType,
3233
PredicateCallbackType,
@@ -137,12 +138,44 @@ def last_child(self, kind: str | type[ANY_KIND]) -> Self | None:
137138
return n
138139
return None
139140

141+
def iterator(
142+
self,
143+
method: IterMethod = IterMethod.PRE_ORDER,
144+
*,
145+
add_self=False,
146+
kind: str | type[ANY_KIND] = ANY_KIND,
147+
) -> Iterator[TypedNode[TData]]:
148+
"""Return an iterator that walks the tree in the specified order."""
149+
if kind is ANY_KIND:
150+
yield from super().iterator(method=method, add_self=add_self)
151+
return
152+
153+
if add_self and self.kind == kind:
154+
yield self
155+
for n in super().iterator(method=method, add_self=False):
156+
if n.kind == kind:
157+
yield n
158+
return
159+
140160
def has_children(self, kind: str | type[ANY_KIND]) -> bool:
141161
"""Return true if this node has one or more children."""
142162
if kind is ANY_KIND:
143163
return bool(self._children)
144164
return len(self.get_children(kind)) > 1
145165

166+
def count_descendants(
167+
self, *, leaves_only=False, kind: str | type[ANY_KIND] = ANY_KIND
168+
) -> int:
169+
"""Return number of descendant nodes, not counting self."""
170+
if kind is ANY_KIND:
171+
return super().count_descendants(leaves_only=leaves_only)
172+
all = not leaves_only
173+
i = 0
174+
for node in self.iterator():
175+
if (all or not node._children) and node.kind == kind:
176+
i += 1
177+
return i
178+
146179
def get_siblings(self, *, add_self=False, any_kind=False) -> list[Self]:
147180
"""Return a list of all sibling entries of self (excluding self) if any."""
148181
if any_kind:
@@ -630,13 +663,30 @@ def last_child(self, kind: str | type[ANY_KIND]) -> TypedNode[TData] | None:
630663
return self.system_root.last_child(kind=kind)
631664

632665
def iter_by_type(self, kind: str | type[ANY_KIND]) -> Iterator[TypedNode[TData]]:
666+
"""@deprecated: Use :meth:`iterator` with `kind` argument instead."""
667+
yield from self.iterator(kind=kind)
668+
669+
def iterator(
670+
self,
671+
method: IterMethod = IterMethod.PRE_ORDER,
672+
*,
673+
kind: str | type[ANY_KIND] = ANY_KIND,
674+
) -> Iterator[TypedNode[TData]]:
633675
if kind == ANY_KIND:
634-
yield from self.iterator()
635-
for n in self.iterator():
676+
yield from super().iterator(method=method)
677+
return
678+
679+
for n in super().iterator(method=method):
636680
if n._kind == kind:
637681
yield n
638682
return
639683

684+
def count_descendants(
685+
self, *, leaves_only=False, kind: str | type[ANY_KIND] = ANY_KIND
686+
) -> int:
687+
"""Return number of nodes, optionally restricted to type."""
688+
return self.system_root.count_descendants(leaves_only=leaves_only, kind=kind)
689+
640690
def save(
641691
self,
642692
target: IO[str] | str | Path,

tests/test_objects.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,9 @@ class FrozenItem:
285285

286286
assert isinstance(dict_node.data, FrozenItem)
287287
assert dict_node.data is item, "dataclass should be stored as reference"
288-
assert (
289-
dict_node.price == 12.34
290-
), "should support attribute access via forwardinging"
288+
assert dict_node.price == 12.34, (
289+
"should support attribute access via forwardinging"
290+
)
291291
with pytest.raises(AttributeError):
292292
_ = dict_node.foo
293293

tests/test_rdf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def test_typed_tree(self):
6464
assert cause1.get_index() == 0
6565
assert cause1.get_index(any_kind=True) == 2
6666

67-
assert len(list(tree.iter_by_type("effect"))) == 3
67+
assert len(list(tree.iterator(kind="effect"))) == 3
6868

6969
# tree.print()
7070
# print()

0 commit comments

Comments
 (0)