Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def add_model_path_to_list(
return paths


def sort_data_models( # noqa: PLR0912, PLR0915
def sort_data_models( # noqa: PLR0912, PLR0914, PLR0915
unsorted_data_models: list[DataModel],
sorted_data_models: SortedDataModels | None = None,
require_update_action_models: list[str] | None = None,
Expand Down Expand Up @@ -534,6 +534,7 @@ def sort_data_models( # noqa: PLR0912, PLR0915
pass

# sort on base_class dependency
seen_orders: set[tuple[str, ...]] = set()
while True:
ordered_models: list[tuple[int, DataModel]] = []
# Build lookup dict for O(1) index access instead of O(n) list.index()
Expand Down Expand Up @@ -565,6 +566,11 @@ def sort_data_models( # noqa: PLR0912, PLR0915
sorted_unresolved_models = [m[1] for m in sorted(ordered_models, key=operator.itemgetter(0))]
if sorted_unresolved_models == unresolved_references:
break
new_order = tuple(m.path for m in sorted_unresolved_models)
if new_order in seen_orders:
unresolved_references = sorted_unresolved_models
break
seen_orders.add(new_order)
unresolved_references = sorted_unresolved_models

# circular reference
Expand Down Expand Up @@ -1624,7 +1630,9 @@ def get_discriminator_field_value(

if len(discriminator_values) == 0:
for base_class in discriminator_model.base_classes:
check_paths(base_class.reference, mapping) # ty: ignore
if not base_class.reference:
continue
check_paths(base_class.reference, mapping)

if not discriminator_values:
discriminator_values = [discriminator_model.path.split("/")[-1]]
Expand Down
74 changes: 74 additions & 0 deletions tests/parser/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
if TYPE_CHECKING:
from datamodel_code_generator.parser.schema_version import JsonSchemaFeatures

from datamodel_code_generator.imports import Imports
from datamodel_code_generator.model.base import BaseClassDataType
from datamodel_code_generator.model.pydantic_v2 import BaseModel, DataModelField
from datamodel_code_generator.model.type_alias import TypeAlias, TypeAliasTypeBackport, TypeStatement
from datamodel_code_generator.parser.base import (
Expand Down Expand Up @@ -242,6 +244,78 @@ def test_sort_data_models_unresolved_raise_recursion_error() -> None:
sort_data_models(reference, recursion_count=100000)


def test_sort_data_models_circular_base_classes_no_infinite_loop() -> None:
"""Mutual base-class references must not oscillate forever in the dependency sort."""
reference_a = Reference(path="A", original_name="A", name="A")
reference_b = Reference(path="B", original_name="B", name="B")
reference = [
BaseModel(
fields=[],
reference=reference_a,
base_classes=[reference_b],
),
BaseModel(
fields=[],
reference=reference_b,
base_classes=[reference_a],
),
]

_, resolved, require_update_action_models = sort_data_models(reference)

assert set(resolved) == {"A", "B"}
assert sorted(require_update_action_models) == ["A", "B"]


def test_apply_discriminator_type_skips_base_class_without_reference() -> None:
"""Base class slots without a Reference must not be passed to check_paths."""
ref_pet = Reference(path="#/components/schemas/Pet", original_name="Pet", name="Pet")
pet_model = BaseModel(fields=[], reference=ref_pet)
ref_pet.source = pet_model
pet_model.base_classes.append(BaseClassDataType())

ref_other = Reference(path="#/components/schemas/Other", original_name="Other", name="Other")
other_model = BaseModel(fields=[], reference=ref_other)
ref_other.source = other_model

union_inner = DataType(data_types=[DataType(reference=ref_pet), DataType(reference=ref_other)])
ref_root = Reference(path="#/components/schemas/Root", original_name="Root", name="Root")
field = DataModelField(
name="u",
data_type=union_inner,
extras={
"discriminator": {
"propertyName": "petType",
"mapping": {"dog": "#/components/schemas/Other"},
}
},
)
root = BaseModel(fields=[field], reference=ref_root)
ref_root.source = root

parser = C(
data_model_type=BaseModel,
data_model_root_type=BaseModel,
data_model_field_type=DataModelField,
base_class="BaseModel",
source="",
)
union_variant_types = tuple(union_inner.data_types)
assert len(union_variant_types) == 2
assert {dt.reference.path for dt in union_variant_types} == {ref_pet.path, ref_other.path}
assert {id(dt.reference) for dt in union_variant_types} == {id(ref_pet), id(ref_other)}
pet_base_classes = pet_model.base_classes
bare_base_slot = pet_model.base_classes[-1]

parser._Parser__apply_discriminator_type([root], Imports())

assert tuple(union_inner.data_types) == union_variant_types
assert {dt.reference.path for dt in union_inner.data_types} == {ref_pet.path, ref_other.path}
assert {id(dt.reference) for dt in union_inner.data_types} == {id(ref_pet), id(ref_other)}
assert pet_model.base_classes is pet_base_classes
assert pet_model.base_classes[-1] is bare_base_slot


@pytest.mark.parametrize(
("current_module", "reference", "val"),
[
Expand Down
Loading