Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ jobs:
env:
SECRET_OE_LICENSE: ${{ secrets.OE_LICENSE }}

- name: Run mypy
if: ${{ !contains(matrix.environment, 'examples') }}
run: pixi run -e ${{ matrix.environment }} run_mypy

- name: Run tests
if: ${{ !contains(matrix.environment, 'examples') }}
run: pixi run -e ${{ matrix.environment }} run_tests
Expand Down
51 changes: 34 additions & 17 deletions openff/interchange/_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from typing import Annotated, Any

import numpy
from pydantic_core import core_schema, CoreSchema
from annotated_types import Gt
from openff.toolkit import Quantity
from openff.units import Quantity
from pydantic import (
GetCoreSchemaHandler,
AfterValidator,
BeforeValidator,
ValidationInfo,
Expand Down Expand Up @@ -124,23 +126,38 @@ def quantity_json_serializer(
}


# Pydantic v2 likes to marry validators and serializers to types with Annotated
# https://docs.pydantic.dev/latest/concepts/validators/#annotated-validators
_Quantity = Annotated[
Quantity,
WrapValidator(quantity_validator),
WrapSerializer(quantity_json_serializer),
]
class _Quantity(Quantity):
@classmethod
def __get_pydantic_core_schema__(
cls,
source_type: Any,
handler: GetCoreSchemaHandler,
) -> CoreSchema:
python_schema = core_schema.with_info_wrap_validator_function(
function=quantity_validator,
schema=core_schema.any_schema(),
)
serialization_schema = core_schema.wrap_serializer_function_ser_schema(
function=quantity_json_serializer,
schema=core_schema.any_schema(),
)

return core_schema.json_or_python_schema(
python_schema=python_schema,
json_schema=python_schema,
serialization=serialization_schema,
)


_DimensionlessQuantity = Annotated[
Quantity,
_Quantity,
WrapValidator(quantity_validator),
AfterValidator(_is_dimensionless),
WrapSerializer(quantity_json_serializer),
]

_DistanceQuantity = Annotated[
Quantity,
_Quantity,
WrapValidator(quantity_validator),
AfterValidator(_is_distance),
WrapSerializer(quantity_json_serializer),
Expand All @@ -149,7 +166,7 @@ def quantity_json_serializer(
_LengthQuantity = _DistanceQuantity

_VelocityQuantity = Annotated[
Quantity,
_Quantity,
WrapValidator(quantity_validator),
AfterValidator(_is_velocity),
WrapSerializer(quantity_json_serializer),
Expand All @@ -163,28 +180,28 @@ def quantity_json_serializer(
]

_TemperatureQuantity = Annotated[
Quantity,
_Quantity,
WrapValidator(quantity_validator),
AfterValidator(_is_temperature),
WrapSerializer(quantity_json_serializer),
]

_DegreeQuantity = Annotated[
Quantity,
_Quantity,
WrapValidator(quantity_validator),
AfterValidator(_is_degree),
WrapSerializer(quantity_json_serializer),
]

_ElementaryChargeQuantity = Annotated[
Quantity,
_Quantity,
WrapValidator(quantity_validator),
AfterValidator(_is_elementary_charge),
WrapSerializer(quantity_json_serializer),
]

_kJMolQuantity = Annotated[
Quantity,
_Quantity,
WrapValidator(quantity_validator),
AfterValidator(_is_kj_mol),
WrapSerializer(quantity_json_serializer),
Expand All @@ -209,7 +226,7 @@ def _duck_to_nanometer(value: Any):


_PositionsQuantity = Annotated[
Quantity,
_Quantity,
WrapValidator(quantity_validator),
AfterValidator(_is_nanometer),
AfterValidator(_is_positions_shape),
Expand Down Expand Up @@ -246,7 +263,7 @@ def _unwrap_list_of_openmm_quantities(value: Any):


_BoxQuantity = Annotated[
Quantity,
_Quantity,
WrapValidator(quantity_validator),
AfterValidator(_is_distance),
AfterValidator(_is_box_shape),
Expand Down
4 changes: 2 additions & 2 deletions openff/interchange/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class _BaseModel(BaseModel):
)

def model_dump(self, **kwargs) -> dict[str, Any]:
return super().model_dump(serialize_as_any=True, **kwargs)
return super().model_dump(**kwargs)

def model_dump_json(self, **kwargs) -> str:
return super().model_dump_json(serialize_as_any=True, **kwargs)
return super().model_dump_json(**kwargs)
Loading
Loading