Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 181 additions & 49 deletions deepmd/pt_expt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,68 +138,141 @@ def _init_from_model_json(self, model_json_str: str) -> None:
self._dpmodel = BaseModel.deserialize(model_data)
self._is_spin = False

self.rcut = self._dpmodel.get_rcut()
self.type_map = self._dpmodel.get_type_map()
self._rcut = self._dpmodel.get_rcut()
self._type_map = self._dpmodel.get_type_map()
self._sel = list(self._dpmodel.get_sel())
self._mixed_types = bool(self._dpmodel.mixed_types())
if self._is_spin:
self._model_output_def = ModelOutputDef(
FittingOutputDef(
[
OutputVariableDef(
"energy",
shape=[1],
reducible=True,
r_differentiable=True,
c_differentiable=True,
atomic=True,
magnetic=True,
)
]
)
)
spin_fitting_defs = self._dpmodel.model_output_def().def_outp.get_data()
# Keep only physical fitting outputs; mask is derived by ModelOutputDef.
fitting_defs = [
vdef for name, vdef in spin_fitting_defs.items() if name != "mask"
]
self._model_output_def = ModelOutputDef(FittingOutputDef(fitting_defs))
else:
self._model_output_def = ModelOutputDef(self._dpmodel.atomic_output_def())

def _init_from_metadata(self) -> None:
"""Initialize DeepEval from ``metadata.json`` alone.

Used when the ``.pt2`` / ``.pte`` archive ships no ``model.json``
(e.g. for backends that do not travel through the dpmodel round-trip).
The metadata contract is the same one the C++ ``DeepPotPTExpt``
reader consumes, so anything that validates against the C++ side
automatically validates here.

``self._dpmodel`` is left as ``None`` to signal the metadata-only
mode. Inference does not need it: it runs through
``aoti_load_package`` / the exported module and uses plain
attributes (``self._rcut``, ``self._sel``, ``self._mixed_types``,
``self._model_output_def``) for all metadata-level queries.
"""
self._dpmodel = None
self._is_spin = bool(self.metadata.get("is_spin", False))
self._rcut = float(self.metadata["rcut"])
self._type_map = list(self.metadata["type_map"])
self._sel = [int(s) for s in self.metadata["sel"]]
self._mixed_types = bool(self.metadata["mixed_types"])

fitting_defs = []
for vdef in self.metadata["fitting_output_defs"]:
fitting_defs.append(
OutputVariableDef(
name=vdef["name"],
shape=list(vdef["shape"]),
reducible=vdef.get("reducible", False),
r_differentiable=vdef.get("r_differentiable", False),
c_differentiable=vdef.get("c_differentiable", False),
atomic=vdef.get("atomic", True),
category=int(
vdef.get("category", OutputVariableCategory.OUT.value)
),
r_hessian=vdef.get("r_hessian", False),
magnetic=vdef.get("magnetic", False),
intensive=vdef.get("intensive", False),
)
)
self._model_output_def = ModelOutputDef(FittingOutputDef(fitting_defs))

def _load_pte(self, model_file: str) -> None:
"""Load a .pte (torch.export) model file."""
"""Load a .pte (torch.export) model file.

``model.json`` is optional: when present it is used to reconstruct
the dpmodel instance (enabling dpmodel-level introspection such as
``eval_descriptor``); when absent we fall back to pure metadata
mode via :meth:`_init_from_metadata`. ``metadata.json`` is the
only contract the inference path actually requires.
"""
extra_files = {
"model.json": "",
"model_def_script.json": "",
"metadata.json": "",
}
exported = torch.export.load(model_file, extra_files=extra_files)
self.exported_module = exported.module()
self._init_from_model_json(extra_files["model.json"])
mds = extra_files["model_def_script.json"]
self._model_def_script = json.loads(mds) if mds else {}
md = extra_files["metadata.json"]
self.metadata = json.loads(md) if md else {}
if not md:
raise ValueError(
f"Invalid .pte file '{model_file}': missing 'metadata.json'"
)
self.metadata = json.loads(md)

model_json_str = extra_files["model.json"]
if model_json_str:
self._init_from_model_json(model_json_str)
else:
self._init_from_metadata()

def _load_pt2(self, model_file: str) -> None:
"""Load a .pt2 (AOTInductor) model file."""
"""Load a .pt2 (AOTInductor) model file.

``model.json`` is optional — it only enables the dpmodel
round-trip (used by ``eval_descriptor``, ``eval_typeebd``, etc.).
Pure AOTI inference (``DeepPot.eval`` / ``dp test`` / ASE
calculator) only needs ``metadata.json``, matching the contract
the C++ ``DeepPotPTExpt`` reader enforces.

Archive entries are located under ``model/extra/`` so that the
PyTorch 2.11 ``load_pt2`` loader accepts the archive without the
"outdated pt2 file" fallback warning.
"""
import zipfile

from torch._inductor import (
aoti_load_package,
)

from deepmd.pt_expt.utils.serialization import (
PT2_EXTRA_PREFIX,
)

md_entry = PT2_EXTRA_PREFIX + "metadata.json"
model_json_entry = PT2_EXTRA_PREFIX + "model.json"
mds_entry = PT2_EXTRA_PREFIX + "model_def_script.json"

# Read metadata from the .pt2 ZIP archive
with zipfile.ZipFile(model_file, "r") as zf:
names = zf.namelist()
if "extra/model.json" not in names:
if md_entry not in names:
raise ValueError(
f"Invalid .pt2 file '{model_file}': missing 'extra/model.json'"
f"Invalid .pt2 file '{model_file}': missing '{md_entry}'"
)
model_json_str = zf.read("extra/model.json").decode("utf-8")
md = zf.read(md_entry).decode("utf-8")
model_json_str = ""
if model_json_entry in names:
model_json_str = zf.read(model_json_entry).decode("utf-8")
mds = ""
if "extra/model_def_script.json" in names:
mds = zf.read("extra/model_def_script.json").decode("utf-8")
md = ""
if "extra/metadata.json" in names:
md = zf.read("extra/metadata.json").decode("utf-8")
if mds_entry in names:
mds = zf.read(mds_entry).decode("utf-8")

self._init_from_model_json(model_json_str)
self.metadata = json.loads(md)
self._model_def_script = json.loads(mds) if mds else {}
self.metadata = json.loads(md) if md else {}
if model_json_str:
self._init_from_model_json(model_json_str)
else:
self._init_from_metadata()

# Load the AOTInductor model package (.pt2 ZIP archive).
# Uses torch._inductor.aoti_load_package (private API, stable since PyTorch 2.6).
Expand All @@ -208,28 +281,41 @@ def _load_pt2(self, model_file: str) -> None:

def get_rcut(self) -> float:
"""Get the cutoff radius of this model."""
return self.rcut
return self._rcut

def get_ntypes(self) -> int:
"""Get the number of atom types of this model."""
return len(self.type_map)
return len(self._type_map)

def get_type_map(self) -> list[str]:
"""Get the type map (element name of the atom types) of this model."""
return self.type_map
return self._type_map

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this DP."""
return self._dpmodel.get_dim_fparam()
if self._dpmodel is not None:
return self._dpmodel.get_dim_fparam()
return int(self.metadata["dim_fparam"])

def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this DP."""
return self._dpmodel.get_dim_aparam()
if self._dpmodel is not None:
return self._dpmodel.get_dim_aparam()
return int(self.metadata["dim_aparam"])

@property
def model_type(self) -> type["DeepEvalWrapper"]:
"""The the evaluator of the model type."""
model_output_type = self._dpmodel.model_output_type()
"""The evaluator of the model type."""
if self._dpmodel is not None:
model_output_type = self._dpmodel.model_output_type()
else:
# Metadata-only mode: derive the output-type set from the
# fitting_output_defs names. `model_output_type()` on a
# dpmodel is the same set — just the base output names, not
# their derived `*_redu` / `*_derv_*` twins.
model_output_type = [
d.name for d in self._model_output_def.def_outp.get_data().values()
]
if "energy" in model_output_type:
return DeepPot
elif "dos" in model_output_type:
Expand All @@ -250,7 +336,12 @@ def get_sel_type(self) -> list[int]:
to the result of the model.
If returning an empty list, all atom types are selected.
"""
return self._dpmodel.get_sel_type()
if self._dpmodel is not None:
return self._dpmodel.get_sel_type()
# Metadata-only mode: read the `sel_type` field populated by
# `_collect_metadata`. Missing field → `[]` (every type
# selected), matching the dpmodel default for energy models.
return [int(t) for t in self.metadata.get("sel_type", [])]

def get_numb_dos(self) -> int:
"""Get the number of DOS."""
Expand All @@ -262,13 +353,15 @@ def get_has_efield(self) -> bool:

def get_has_spin(self) -> bool:
"""Check if the model has spin atom types."""
return getattr(self, "_is_spin", False)
return self._is_spin

def get_use_spin(self) -> list[bool]:
"""Get the per-type spin usage of this model."""
if getattr(self, "_is_spin", False):
if not self._is_spin:
return []
if self._dpmodel is not None:
return self._dpmodel.spin.use_spin.tolist()
return []
return [bool(v) for v in self.metadata.get("use_spin", [])]

def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model. Only used in old implement."""
Expand Down Expand Up @@ -422,9 +515,9 @@ def _build_nlist_native(
"""
nframes = coords.shape[0]
natoms = coords.shape[1]
rcut = self.rcut
sel = self._dpmodel.get_sel()
mixed_types = self._dpmodel.mixed_types()
rcut = self._rcut
sel = self._sel
mixed_types = self._mixed_types

if cells is not None:
box_input = cells.reshape(nframes, 3, 3)
Expand Down Expand Up @@ -535,8 +628,8 @@ def _build_nlist_ase_single(
nlist : np.ndarray, shape (nloc, nsel)
mapping : np.ndarray, shape (nall,)
"""
sel = self._dpmodel.get_sel()
mixed_types = self._dpmodel.mixed_types()
sel = self._sel
mixed_types = self._mixed_types
nsel = sum(sel)

natoms = positions.shape[0]
Expand Down Expand Up @@ -579,7 +672,7 @@ def _build_nlist_ase_single(
ghost_remap[out_mask] = np.arange(nloc, nloc + nghost, dtype=np.int64)

# Build nlist: vectorized CSR-to-dense conversion
rcut = self.rcut
rcut = self._rcut
counts = np.diff(first_neigh)
max_nn = int(counts.max()) if counts.size > 0 else 0

Expand Down Expand Up @@ -995,13 +1088,44 @@ def get_model(self) -> torch.nn.Module:
return self.exported_module

def _is_spin_model(self) -> bool:
"""Check if the underlying dpmodel is a SpinModel."""
"""Check if the underlying model is a SpinModel.

Primary path: the :attr:`_is_spin` attribute set by the loaders
— this works for both ``model.json`` and metadata-only archives
(a spin ``.pt2`` carries ``is_spin=true`` in its metadata).

Legacy path: ``isinstance(_dpmodel, SpinModel)`` — retained for
tests that construct a non-spin archive and then swap
:attr:`_dpmodel` to a :class:`SpinModel` instance after load.
"""
if self._is_spin:
return True
if self._dpmodel is None:
return False
from deepmd.dpmodel.model.spin_model import (
SpinModel,
)

return isinstance(self._dpmodel, SpinModel)

def _require_dpmodel(self, feature: str) -> None:
"""Guard for features that need a deserialised dpmodel instance.

``eval_descriptor`` / ``eval_typeebd`` / ``eval_fitting_last_layer``
all introspect the dpmodel's internal sub-modules, which requires
``model.json`` to have been present at load time. Archives
shipped without ``model.json`` (metadata-only mode) can still run
the main ``eval`` inference path but cannot expose these hooks.
"""
if self._dpmodel is None:
raise NotImplementedError(
f"{feature} requires the dpmodel instance, which is only "
"available when the .pt2 / .pte archive contains "
"'model.json'. The loaded archive is metadata-only; "
"re-export with the full dpmodel serialisation to enable "
"this feature."
)

def eval_typeebd(self) -> np.ndarray:
"""Evaluate type embedding.

Expand All @@ -1014,7 +1138,11 @@ def eval_typeebd(self) -> np.ndarray:
------
KeyError
If the model has no type embedding networks.
NotImplementedError
If the archive was loaded in metadata-only mode.
"""
self._require_dpmodel("eval_typeebd")

from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP

model = self._dpmodel
Expand Down Expand Up @@ -1058,6 +1186,8 @@ def eval_descriptor(
np.ndarray
Descriptor output, shape ``(nframes, nloc, dim_descrpt)``.
"""
self._require_dpmodel("eval_descriptor")

coords = np.array(coords)
atom_types = np.array(atom_types, dtype=np.int32)
if cells is not None:
Expand Down Expand Up @@ -1124,6 +1254,8 @@ def eval_fitting_last_layer(
np.ndarray
Middle-layer output, shape ``(nframes, nloc, neuron[-1])``.
"""
self._require_dpmodel("eval_fitting_last_layer")

coords = np.array(coords)
atom_types = np.array(atom_types, dtype=np.int32)
if cells is not None:
Expand Down
Loading
Loading