feat(pt_expt): support .pt training checkpoints in DeepEval#5423
feat(pt_expt): support .pt training checkpoints in DeepEval#5423wanghan-iapcm wants to merge 3 commits intodeepmodeling:masterfrom
Conversation
`dp --pt-expt test -m foo.pt` previously rejected `.pt` files (only `.pt2` / `.pte` were supported), and `dp --pt test -m foo.pt` on a pt_expt-trained checkpoint silently loaded random weights because the state-dict layout (dpmodel `.w`/`.b` keys) doesn't match the legacy pt backend's expectations. - `Backend.detect_backend_by_model` sniffs `.pt` content so files with `.w`/`.b` keys (pt_expt) route to the pt_expt DeepEval and files with `.matrix`/`.bias` keys (pt) keep routing to pt. - `pt_expt.DeepEval._load_pt` reconstructs the model from `_extra_state["model_params"]`, loads the state-dict via `ModelWrapper`, and exposes an eager `forward_common_lower` runner with the same signature as the AOTI/exported module so the existing `eval()` path is unchanged. Spin-aware and non-spin variants; multi-task `.pt` selects a head and remaps keys. - `pt_expt.get_model` learns `get_spin_model` (mirrors dpmodel) so spin checkpoints can be reconstructed from `model_params`. - Tests cover dispatch sniffing, single-task / multi-task / spin / spin-multi-task `.pt` parity vs eager forward, fparam / aparam, and `.pt` vs `.pte` cross-format consistency at 1e-10.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 98aee78a86
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| try: | ||
| import torch | ||
|
|
||
| sd = torch.load(filename, map_location="cpu", weights_only=False) |
There was a problem hiding this comment.
Avoid unpickling arbitrary
.pt during backend detection
detect_backend_by_model now deserializes .pt files with torch.load(..., weights_only=False) before selecting a backend, which allows pickle payload execution from attacker-controlled files. In practice, simply running dp test -m untrusted.pt can execute code during format sniffing, even before model loading proceeds. Detection only needs state-dict key names, so this should use a safe load mode (for example weights_only=True) and fail closed when the payload is not a plain checkpoint dict.
Useful? React with 👍 / 👎.
| get_model, | ||
| ) | ||
|
|
||
| state_dict = torch.load(model_file, map_location=DEVICE, weights_only=False) |
There was a problem hiding this comment.
Load pt_expt checkpoints without unsafe pickle execution
The new .pt inference path in pt_expt.DeepEval also uses torch.load(..., weights_only=False), so evaluating an untrusted checkpoint via dp --pt-expt test -m file.pt can execute arbitrary Python objects at deserialize time. This path only needs tensor weights and _extra_state metadata for ModelWrapper.load_state_dict, so full pickle loading is unnecessary and introduces a security regression compared with safe checkpoint loading.
Useful? React with 👍 / 👎.
📝 WalkthroughWalkthroughAdds Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant BackendDetection as Backend Detection
participant TorchLoad as torch.load()
participant DeepEval as DeepEval Loader
participant ModelBuilder as Model Construction
participant Inference as Inference Engine
User->>BackendDetection: detect_backend_by_model(.pt file)
BackendDetection->>TorchLoad: torch.load(checkpoint)
TorchLoad-->>BackendDetection: checkpoint dict
BackendDetection->>BackendDetection: inspect state_dict key suffixes
BackendDetection-->>User: return backend (pt or pt-expt)
User->>DeepEval: DeepEval(model_file=.pt, head=?)
DeepEval->>TorchLoad: torch.load(checkpoint)
TorchLoad-->>DeepEval: checkpoint dict
DeepEval->>DeepEval: extract model dict, select head, remap keys
DeepEval->>ModelBuilder: get_model(config)
ModelBuilder-->>DeepEval: model instance (standard or spin)
DeepEval->>DeepEval: load weights, build exported_module
User->>Inference: exported_module(forward args)
Inference-->>User: energy, forces, virial, (atomic/spin outputs)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (5)
deepmd/pt_expt/infer/deep_eval.py (2)
252-260:str.replaceis unbounded; hoistprefixout of the loop.Two minor robustness/readability issues in the head-renaming loop:
key.replace(prefix, "model.Default.")rewrites every occurrence ofprefixin the key. If a head name happens to appear deeper in a key path (or in any state-dict key derived from a user-supplied identifier), keys silently get double-rewritten. Slice the prefix instead.prefix = f"model.{head}."is recomputed on every iteration.♻️ Proposed fix
# Restrict state_dict to the chosen head and rename to "Default". head_state = {"_extra_state": state_dict["_extra_state"]} + prefix = f"model.{head}." for key, value in state_dict.items(): - prefix = f"model.{head}." if key.startswith(prefix): - head_state[key.replace(prefix, "model.Default.")] = ( + new_key = "model.Default." + key[len(prefix) :] + head_state[new_key] = ( value.clone() if torch.is_tensor(value) else value ) state_dict = head_state🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/infer/deep_eval.py` around lines 252 - 260, The loop that renames head keys is unsafe because it recomputes prefix each iteration and uses key.replace(prefix, ...) which can replace multiple occurrences; hoist prefix = f"model.{head}." out of the for loop and when a key startswith(prefix) build the new key by slicing off the prefix (e.g. "model.Default." + key[len(prefix):]) instead of using key.replace; keep cloning tensor values as done currently and assign into head_state, leaving all other logic (state_dict, head_state, torch.is_tensor) unchanged.
223-225: InconsistentDEVICEimport: should bedeepmd.pt_expt.utils.env.This file already imports
DEVICEfromdeepmd.pt_expt.utils.env(line 813, 982). Pulling it fromdeepmd.pt.utils.envhere is inconsistent and creates an unnecessary dependency frompt_exptonpt. If the two backends ever diverge on device defaults this becomes a subtle bug.♻️ Proposed fix
- from deepmd.pt.utils.env import ( + from deepmd.pt_expt.utils.env import ( DEVICE, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/infer/deep_eval.py` around lines 223 - 225, The import of DEVICE at the top of deep_eval.py incorrectly references deepmd.pt.utils.env; change it to import DEVICE from deepmd.pt_expt.utils.env so it matches the other imports in this module (see existing imports of DEVICE around the file) and avoids creating a dependency on the pt backend—update the single import statement that currently references deepmd.pt.utils.env to reference deepmd.pt_expt.utils.env instead.source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py (2)
475-495: Preferstrict=Truein cross-format consistency zips.
dp.eval(...)returns a fixed-arity tuple matching the request defs, and the name lists in these consistency loops (here, lines 567-580, and 659-672) hard-code 7 entries to mirror the spin-with-atomic case. Withstrict=False, ifdp.evalever changes arity (e.g., a new output is added) the loop silently truncates and consistency for new fields is no longer asserted.strict=Truewould force the tests to be updated.- for name, a, b in zip( + for name, a, b in zip( ( "energy", ... "mask_mag", ), out_pt, out_pte, - strict=False, + strict=True, ):(Same applies at lines 567-580 and 659-672.)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py` around lines 475 - 495, The zip used to iterate over ("energy", "force", "virial", "atom_energy", "atom_virial", "force_mag", "mask_mag") and the outputs out_pt and out_pte currently passes strict=False, which can silently drop any future-added outputs; change the zip call(s) that compare name, a, b (the one using the hard-coded 7-entry tuple with out_pt and out_pte) to use strict=True so mismatched arity raises an error—update all equivalent consistency loops that compare out_pt and out_pte.
401-407:os.rmdirwill fail on residual files; prefershutil.rmtree.
_make_spin_filesonly puts.ptand.pteintotmpdir, so today this works. But if a future change writes any auxiliary file (e.g., a sidecar.jsonfromdeserialize_to_file),os.rmdirwill raiseOSErrorand leak the directory.shutil.rmtree(cls.files["tmpdir"], ignore_errors=True)handles both this and any partial-creation cleanup uniformly.♻️ Proposed fix
+ import shutil + `@classmethod` def tearDownClass(cls) -> None: - for ext in (".pt", ".pte"): - path = cls.files[ext] - if os.path.exists(path): - os.unlink(path) - os.rmdir(cls.files["tmpdir"]) + shutil.rmtree(cls.files["tmpdir"], ignore_errors=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py` around lines 401 - 407, The tearDownClass cleanup uses os.rmdir which will fail if any residual files remain; update the class teardown (tearDownClass) to use shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) instead of os.rmdir so the tmpdir is removed recursively and safely; be sure to import shutil at the top of the test module and keep the existing per-extension unlink loop (from _make_spin_files related cleanup) intact so both explicit files and any auxiliary sidecar files are cleaned up.deepmd/backend/backend.py (1)
122-124: Simplify backend lookup.
Backend.get_backends()already returns adict[str, type[Backend]], so the linear scan can be a single dict lookup:♻️ Proposed refactor
- for key, backend in Backend.get_backends().items(): - if key == target_name: - return backend + backend = Backend.get_backends().get(target_name) + if backend is not None: + return backend🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/backend/backend.py` around lines 122 - 124, The loop in backend.py that iterates over Backend.get_backends().items() to find a matching key is unnecessary because Backend.get_backends() already returns a dict; replace the linear scan in the lookup logic with a direct dictionary access (e.g., lookup = Backend.get_backends().get(target_name)) and return that result (or handle a missing key appropriately) instead of the for loop. Ensure you update the code paths that expect a backend when not found (raise or return None consistently) and keep references to Backend.get_backends() and target_name to locate the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/backend/backend.py`:
- Around line 108-127: The torch.load call in detect_backend_by_model currently
uses weights_only=False which can execute pickled code and cause RCE; change the
call in that block (the torch.load(...) invocation) to torch.load(filename,
map_location="cpu", weights_only=True) so we only load tensor weights when
sniffing backend, keeping the existing try/except fallback to suffix-based
detection; leave the later real model load in deepmd/pt_expt/infer/deep_eval.py
(which uses weights_only=False) unchanged.
In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py`:
- Around line 251-309: The test builds two heads with identical random seeds so
weight selection isn’t actually validated; change _build_model_and_params to
accept a seed parameter (forward it into DescrptSeA and EnergyFittingNet), call
_build_model_and_params twice with distinct seeds when creating
cls.model_a/params_a and cls.model_b/params_b in setUpClass, and add an explicit
distinct-outputs assertion (like test_distinct_heads_produce_distinct_outputs)
that verifies the two heads produce different energies/forces for the same input
to ensure head selection actually loads different weights (update any callers of
_build_model_and_params accordingly).
---
Nitpick comments:
In `@deepmd/backend/backend.py`:
- Around line 122-124: The loop in backend.py that iterates over
Backend.get_backends().items() to find a matching key is unnecessary because
Backend.get_backends() already returns a dict; replace the linear scan in the
lookup logic with a direct dictionary access (e.g., lookup =
Backend.get_backends().get(target_name)) and return that result (or handle a
missing key appropriately) instead of the for loop. Ensure you update the code
paths that expect a backend when not found (raise or return None consistently)
and keep references to Backend.get_backends() and target_name to locate the
change.
In `@deepmd/pt_expt/infer/deep_eval.py`:
- Around line 252-260: The loop that renames head keys is unsafe because it
recomputes prefix each iteration and uses key.replace(prefix, ...) which can
replace multiple occurrences; hoist prefix = f"model.{head}." out of the for
loop and when a key startswith(prefix) build the new key by slicing off the
prefix (e.g. "model.Default." + key[len(prefix):]) instead of using key.replace;
keep cloning tensor values as done currently and assign into head_state, leaving
all other logic (state_dict, head_state, torch.is_tensor) unchanged.
- Around line 223-225: The import of DEVICE at the top of deep_eval.py
incorrectly references deepmd.pt.utils.env; change it to import DEVICE from
deepmd.pt_expt.utils.env so it matches the other imports in this module (see
existing imports of DEVICE around the file) and avoids creating a dependency on
the pt backend—update the single import statement that currently references
deepmd.pt.utils.env to reference deepmd.pt_expt.utils.env instead.
In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py`:
- Around line 475-495: The zip used to iterate over ("energy", "force",
"virial", "atom_energy", "atom_virial", "force_mag", "mask_mag") and the outputs
out_pt and out_pte currently passes strict=False, which can silently drop any
future-added outputs; change the zip call(s) that compare name, a, b (the one
using the hard-coded 7-entry tuple with out_pt and out_pte) to use strict=True
so mismatched arity raises an error—update all equivalent consistency loops that
compare out_pt and out_pte.
- Around line 401-407: The tearDownClass cleanup uses os.rmdir which will fail
if any residual files remain; update the class teardown (tearDownClass) to use
shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) instead of os.rmdir so
the tmpdir is removed recursively and safely; be sure to import shutil at the
top of the test module and keep the existing per-extension unlink loop (from
_make_spin_files related cleanup) intact so both explicit files and any
auxiliary sidecar files are cleaned up.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: a018a9f1-5dda-4e71-86d2-9ffdf38f75a1
📒 Files selected for processing (4)
deepmd/backend/backend.pydeepmd/pt_expt/infer/deep_eval.pydeepmd/pt_expt/model/get_model.pysource/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py
| if filename_lower.endswith(".pt"): | ||
| try: | ||
| import torch | ||
|
|
||
| sd = torch.load(filename, map_location="cpu", weights_only=False) | ||
| if isinstance(sd, dict) and "model" in sd: | ||
| sd = sd["model"] | ||
| keys = list(sd.keys()) if hasattr(sd, "keys") else [] | ||
| has_pt_expt = any(k.endswith(".w") or k.endswith(".b") for k in keys) | ||
| has_pt = any(k.endswith(".matrix") or k.endswith(".bias") for k in keys) | ||
| if has_pt_expt and not has_pt: | ||
| target_name = "pt-expt" | ||
| else: | ||
| target_name = "pt" | ||
| for key, backend in Backend.get_backends().items(): | ||
| if key == target_name: | ||
| return backend | ||
| except Exception: | ||
| # Fall through to suffix matching if sniffing fails. | ||
| pass |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
torch.load weights_only=True state_dict keys behavior PyTorch 2.6
💡 Result:
In PyTorch 2.6, torch.load(checkpoint_path, weights_only=True) deserializes the saved object using a restricted unpickler that only allows tensors, primitive types (int, float, str, bool, None), and dictionaries. It does not allow arbitrary classes or objects unless explicitly allowlisted via torch.serialization.add_safe_globals. For a typical model checkpoint saved as torch.save({'model_state_dict': model.state_dict, ...}, path), torch.load with weights_only=True returns the full dictionary with top-level keys like 'model_state_dict', 'optimizer_state_dict', 'epoch', etc. Each state_dict value is an OrderedDict containing the model's parameter keys (e.g., 'conv1.weight', 'fc.bias') mapped to tensors. This works because the structure uses only allowed types: dict -> OrderedDict (treated as dict) -> str keys -> torch.Tensor values. Primitive metadata like epoch (int/float) is also allowed. Examples from official tutorials explicitly show: checkpoint = torch.load(PATH, weights_only=True) model.load_state_dict(checkpoint['model_state_dict']) If the checkpoint contains non-allowlisted objects (e.g., scheduler.state_dict with collections.Counter, or optimizer instances), it fails with _pickle.UnpicklingError, requiring allowlisting or weights_only=False (unsafe for untrusted files). In PyTorch 2.6+, weights_only=True is the default, promoting safer loading of weights-only checkpoints.
Citations:
- 1: https://docs.pytorch.org/docs/2.6/generated/torch.load.html
- 2: https://docs.pytorch.org/docs/2.6/notes/serialization.html
- 3: https://pytorch.org/tutorials/_sources/beginner/saving_loading_models.rst.txt
- 4: https://docs.pytorch.org/docs/2.8/_sources/notes/serialization.rst.txt
🏁 Script executed:
cd deepmd && head -150 backend/backend.py | tail -60Repository: deepmodeling/deepmd-kit
Length of output: 2431
🏁 Script executed:
# Check the full function and context around lines 108-127
sed -n '95,135p' deepmd/backend/backend.pyRepository: deepmodeling/deepmd-kit
Length of output: 1891
🏁 Script executed:
# Find torch version requirements in the repo
find . -name "*.txt" -o -name "*.toml" -o -name "*.cfg" -o -name "*.yml" | head -20 | xargs grep -l "torch" 2>/dev/null | head -5Repository: deepmodeling/deepmd-kit
Length of output: 207
🏁 Script executed:
# Check if there are other calls to torch.load in the codebase
rg "torch\.load" --type py -A 2 -B 2Repository: deepmodeling/deepmd-kit
Length of output: 15397
🏁 Script executed:
# Verify Backend.get_backends() structure
rg "def get_backends" --type py -A 10Repository: deepmodeling/deepmd-kit
Length of output: 1194
Use weights_only=True for backend sniffing.
detect_backend_by_model is called during dispatch (e.g. dp test -m foo.pt) before the user explicitly trusts the checkpoint. Loading with weights_only=False executes pickle, enabling RCE on a malicious .pt file. Since this branch only inspects state_dict.keys() to disambiguate between pt and pt-expt backends, a stricter load is sufficient and safer:
Suggested change
- sd = torch.load(filename, map_location="cpu", weights_only=False)
+ # Only state_dict keys are inspected; weights_only=True avoids
+ # executing pickle code on untrusted checkpoints during
+ # backend dispatch.
+ sd = torch.load(filename, map_location="cpu", weights_only=True)If weights_only=True fails (e.g., on checkpoints with non-standard _extra_state objects), the except Exception: pass block gracefully falls through to suffix-based backend detection. The actual model loading in deepmd/pt_expt/infer/deep_eval.py (line 230) already uses weights_only=False once the backend is selected—the right place to accept that risk.
🧰 Tools
🪛 Ruff (0.15.11)
[error] 125-127: try-except-pass detected, consider logging the exception
(S110)
[warning] 125-125: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/backend/backend.py` around lines 108 - 127, The torch.load call in
detect_backend_by_model currently uses weights_only=False which can execute
pickled code and cause RCE; change the call in that block (the torch.load(...)
invocation) to torch.load(filename, map_location="cpu", weights_only=True) so we
only load tensor weights when sniffing backend, keeping the existing try/except
fallback to suffix-based detection; leave the later real model load in
deepmd/pt_expt/infer/deep_eval.py (which uses weights_only=False) unchanged.
| @classmethod | ||
| def setUpClass(cls) -> None: | ||
| # Build two single-task models with the same architecture but | ||
| # different seeds, then save a multi-task-style checkpoint. | ||
| cls.model_a, params_a = _build_model_and_params(rcut=4.0) | ||
| cls.model_b, params_b = _build_model_and_params(rcut=4.0) | ||
|
|
||
| # Multi-task model_params layout used by pt_expt training. | ||
| model_params = {"model_dict": {"head_a": params_a, "head_b": params_b}} | ||
|
|
||
| wrapper = ModelWrapper( | ||
| {"head_a": cls.model_a, "head_b": cls.model_b}, | ||
| model_params=model_params, | ||
| ) | ||
| cls.pt_path = tempfile.NamedTemporaryFile(suffix=".pt", delete=False).name | ||
| torch.save({"model": wrapper.state_dict()}, cls.pt_path) | ||
|
|
||
| @classmethod | ||
| def tearDownClass(cls) -> None: | ||
| if os.path.exists(cls.pt_path): | ||
| os.unlink(cls.pt_path) | ||
|
|
||
| def test_select_head_matches_single_task_forward(self) -> None: | ||
| rng = np.random.default_rng(GLOBAL_SEED + 1) | ||
| natoms = 4 | ||
| coords = rng.random((1, natoms, 3)) * 8.0 | ||
| cells = np.eye(3).reshape(1, 9) * 10.0 | ||
| atom_types = np.array([i % 2 for i in range(natoms)], dtype=np.int32) | ||
|
|
||
| for head, src in (("head_a", self.model_a), ("head_b", self.model_b)): | ||
| # Build a DeepPot wrapping this DeepEval for end-to-end eval. | ||
| dp = DeepPot(self.pt_path, head=head) | ||
| de = dp.deep_eval | ||
| e, f, v = dp.eval(coords, cells, atom_types, atomic=False) | ||
|
|
||
| coord_t = torch.tensor( | ||
| coords, dtype=torch.float64, device=DEVICE | ||
| ).requires_grad_(True) | ||
| atype_t = torch.tensor( | ||
| atom_types.reshape(1, -1), dtype=torch.int64, device=DEVICE | ||
| ) | ||
| cell_t = torch.tensor(cells, dtype=torch.float64, device=DEVICE) | ||
| ref = src.forward(coord_t, atype_t, cell_t, do_atomic_virial=False) | ||
|
|
||
| np.testing.assert_allclose( | ||
| e, | ||
| ref["energy"].detach().cpu().numpy(), | ||
| rtol=1e-10, | ||
| atol=1e-10, | ||
| err_msg=f"head={head}, energy", | ||
| ) | ||
| np.testing.assert_allclose( | ||
| f, | ||
| ref["force"].detach().cpu().numpy(), | ||
| rtol=1e-10, | ||
| atol=1e-10, | ||
| err_msg=f"head={head}, force", | ||
| ) | ||
| self.assertEqual(de.get_type_map(), src.get_type_map()) |
There was a problem hiding this comment.
Heads model_a and model_b are weight-identical — head selection is not actually validated.
_build_model_and_params() constructs DescrptSeA(..., seed=GLOBAL_SEED) and EnergyFittingNet(..., seed=GLOBAL_SEED), so both calls produce models with the same initial weights. Consequently test_select_head_matches_single_task_forward would still pass even if _load_pt accidentally loaded head_b's weights when head="head_a" was requested (or vice versa). The test asserts dispatch plumbing only, not correctness of weight selection.
The spin counterpart TestPtExptLoadPtSpinMultiTask (lines 704-716) gets this right by using distinct seeds (42 / 7) plus an explicit test_distinct_heads_produce_distinct_outputs sanity check. Suggest mirroring that pattern here:
🐛 Proposed fix
- cls.model_a, params_a = _build_model_and_params(rcut=4.0)
- cls.model_b, params_b = _build_model_and_params(rcut=4.0)
+ cls.model_a, params_a = _build_model_and_params(rcut=4.0, seed=42)
+ cls.model_b, params_b = _build_model_and_params(rcut=4.0, seed=7)…and add a distinct-outputs assertion:
+ def test_distinct_heads_produce_distinct_outputs(self) -> None:
+ rng = np.random.default_rng(GLOBAL_SEED + 2)
+ natoms = 4
+ coords = rng.random((1, natoms, 3)) * 8.0
+ cells = np.eye(3).reshape(1, 9) * 10.0
+ atom_types = np.array([i % 2 for i in range(natoms)], dtype=np.int32)
+ e_a = DeepPot(self.pt_path, head="head_a").eval(
+ coords, cells, atom_types, atomic=False
+ )[0]
+ e_b = DeepPot(self.pt_path, head="head_b").eval(
+ coords, cells, atom_types, atomic=False
+ )[0]
+ self.assertFalse(np.allclose(e_a, e_b))(_build_model_and_params will need a seed parameter forwarded to DescrptSeA / EnergyFittingNet.)
🧰 Tools
🪛 Ruff (0.15.11)
[warning] 284-284: Unpacked variable v is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py` around lines 251
- 309, The test builds two heads with identical random seeds so weight selection
isn’t actually validated; change _build_model_and_params to accept a seed
parameter (forward it into DescrptSeA and EnergyFittingNet), call
_build_model_and_params twice with distinct seeds when creating
cls.model_a/params_a and cls.model_b/params_b in setUpClass, and add an explicit
distinct-outputs assertion (like test_distinct_heads_produce_distinct_outputs)
that verifies the two heads produce different energies/forces for the same input
to ensure head selection actually loads different weights (update any callers of
_build_model_and_params accordingly).
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5423 +/- ##
==========================================
+ Coverage 82.36% 82.38% +0.02%
==========================================
Files 824 824
Lines 87109 87209 +100
Branches 4197 4197
==========================================
+ Hits 71743 71848 +105
+ Misses 14091 14087 -4
+ Partials 1275 1274 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Real training-produced `.pt` checkpoints have `model.{head}.original_model.X`
for the trained weights and `model.{head}.compiled_forward_lower.*`
for the compiled-graph constants. Previously `_load_pt` did a strict
`load_state_dict` against a plain `get_model(model_params)` and failed.
Fix: strip the `original_model.` infix and drop all
`compiled_forward_lower.*` keys before loading. Works for both
single-task and multi-task layouts. Tests synthesise the wrapped
layout directly to avoid a real `torch.compile` invocation.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py (1)
580-585: Useshutil.rmtreefor the tempdir cleanup.
os.rmdironly succeeds if the directory is empty; if any future change adds an extra artifact (e.g. a.lockfile from torch save, or a partial write on a failing test),tearDownClasswill raise and mask the actual test failure. Switching toshutil.rmtree(cls.files["tmpdir"], ignore_errors=True)makes cleanup robust without changing behavior in the happy path.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py` around lines 580 - 585, tearDownClass currently uses os.rmdir to remove cls.files["tmpdir"], which will fail if the directory is not empty; replace the os.rmdir call with shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) to make cleanup robust and avoid masking test failures, and ensure shutil is imported at top of the test module if not already; reference the tearDownClass method and the cls.files["tmpdir"] usage when making the change.deepmd/pt_expt/infer/deep_eval.py (1)
251-261: Constrain head-prefix replacement to the leading occurrence.
key.replace(prefix, "model.Default.")rewrites every occurrence ofmodel.{head}.in the key, not just the leading one. The loop already gated onstartswith(prefix), so this is harmless for current key shapes, but it's a defensive landmine if a head name (e.g."head") ever appears later in the key (e.g. nested module names). Safer to slice or pincount=1:♻️ Proposed fix
- head_state = {"_extra_state": state_dict["_extra_state"]} - for key, value in state_dict.items(): - prefix = f"model.{head}." - if key.startswith(prefix): - head_state[key.replace(prefix, "model.Default.")] = ( - value.clone() if torch.is_tensor(value) else value - ) + prefix = f"model.{head}." + head_state = {"_extra_state": state_dict["_extra_state"]} + for key, value in state_dict.items(): + if key.startswith(prefix): + new_key = "model.Default." + key[len(prefix):] + head_state[new_key] = ( + value.clone() if torch.is_tensor(value) else value + )Also moves
prefixout of the per-iteration body.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/infer/deep_eval.py` around lines 251 - 261, The replacement of the head prefix in the loop may replace non-leading occurrences; compute prefix = f"model.{head}." once before the loop, and when a key startswith(prefix) produce the new key by only removing the leading prefix (e.g., new_key = "model.Default." + key[len(prefix):] or use replace with count=1) before inserting into head_state, preserving the clone behavior for tensors and leaving "_extra_state" handling as-is (affects variables: head_params, state_dict, head_state, prefix).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py`:
- Line 426: The tuple unpack from dp.eval yields unused variables that trigger
RUF059; update the unpackings to prefix unused names with an underscore (e.g.
change v → _v, av → _av, mm → _mm, ae → _ae) or assert against them where the
test should exercise those values (e.g. in
test_eval_pbc_atomic_matches_reference keep or assert on av when atomic=True),
ensuring all four occurrences (the dp.eval unpack at dp.eval(...), and the
unpack sites noted) are adjusted so ruff no longer reports unused-variable
warnings.
---
Nitpick comments:
In `@deepmd/pt_expt/infer/deep_eval.py`:
- Around line 251-261: The replacement of the head prefix in the loop may
replace non-leading occurrences; compute prefix = f"model.{head}." once before
the loop, and when a key startswith(prefix) produce the new key by only removing
the leading prefix (e.g., new_key = "model.Default." + key[len(prefix):] or use
replace with count=1) before inserting into head_state, preserving the clone
behavior for tensors and leaving "_extra_state" handling as-is (affects
variables: head_params, state_dict, head_state, prefix).
In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py`:
- Around line 580-585: tearDownClass currently uses os.rmdir to remove
cls.files["tmpdir"], which will fail if the directory is not empty; replace the
os.rmdir call with shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) to
make cleanup robust and avoid masking test failures, and ensure shutil is
imported at top of the test module if not already; reference the tearDownClass
method and the cls.files["tmpdir"] usage when making the change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: d919d944-10aa-4a7a-b592-de89cb024aa8
📒 Files selected for processing (2)
deepmd/pt_expt/infer/deep_eval.pysource/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py
| # Build a DeepPot wrapping this DeepEval for end-to-end eval. | ||
| dp = DeepPot(self.pt_path, head=head) | ||
| de = dp.deep_eval | ||
| e, f, v = dp.eval(coords, cells, atom_types, atomic=False) |
There was a problem hiding this comment.
Silence RUF059 on the unused unpacked outputs.
Ruff flags unused names in tuple unpacks at lines 426 (v), 472 (v), 613 (av, mm), and 918 (ae, av, mm). Per the repo's coding guidelines, ruff check . must pass before commit, so prefix the unused names with _ (or assert against them — for test_eval_pbc_atomic_matches_reference it would actually be worth covering av since atomic=True).
♻️ Example fix at line 613
- e, f, v, ae, av, fm, mm = dp.eval(
+ e, f, v, ae, _av, fm, _mm = dp.eval(
self.COORD, self.BOX, self.ATYPE, atomic=True, spin=self.SPIN
)As per coding guidelines: "Install linter and run ruff check . before committing changes or the CI will fail".
Also applies to: 472-472, 613-613, 918-918
🧰 Tools
🪛 Ruff (0.15.11)
[warning] 426-426: Unpacked variable v is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py` at line 426, The
tuple unpack from dp.eval yields unused variables that trigger RUF059; update
the unpackings to prefix unused names with an underscore (e.g. change v → _v, av
→ _av, mm → _mm, ae → _ae) or assert against them where the test should exercise
those values (e.g. in test_eval_pbc_atomic_matches_reference keep or assert on
av when atomic=True), ensuring all four occurrences (the dp.eval unpack at
dp.eval(...), and the unpack sites noted) are adjusted so ruff no longer reports
unused-variable warnings.
The exported `.pte` and eager `.pt` paths produce identical energy / force / virial / atom_energy / force_mag / mask_mag outputs for spin models, but per-atom virial diverges. The reduced virial (which is the sum of per-atom virials including the virtual-atom contribution) still matches, so the divergence is in the per-extended-atom split, not the totals. Pin this as a known limitation; revisit once the export-time spin atom-virial path is reconciled with the eager path.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py (1)
579-585:os.rmdiris fragile for the spin scratch dir.
tearDownClassonly unlinks.pt/.ptethenos.rmdir(tmpdir). Ifdeserialize_to_fileever drops a sidecar (cache, journal,.pte/-as-directory layout, etc.) theos.rmdircall raisesOSErrorand leaks the temp tree.shutil.rmtree(cls.files["tmpdir"], ignore_errors=True)is both shorter and robust to layout changes.♻️ Proposed refactor
+import shutil @@ `@classmethod` def tearDownClass(cls) -> None: - for ext in (".pt", ".pte"): - path = cls.files[ext] - if os.path.exists(path): - os.unlink(path) - os.rmdir(cls.files["tmpdir"]) + shutil.rmtree(cls.files["tmpdir"], ignore_errors=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py` around lines 579 - 585, tearDownClass currently unlinks only .pt/.pte files then calls os.rmdir(cls.files["tmpdir"]), which will raise OSError and leak the temp directory if any extra sidecar files or nested dirs exist (e.g., created by deserialize_to_file); replace the fragile os.rmdir call with shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) and add the shutil import so the tearDownClass cleanup always removes the entire tmpdir regardless of layout while remaining tolerant of errors.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py`:
- Around line 579-585: tearDownClass currently unlinks only .pt/.pte files then
calls os.rmdir(cls.files["tmpdir"]), which will raise OSError and leak the temp
directory if any extra sidecar files or nested dirs exist (e.g., created by
deserialize_to_file); replace the fragile os.rmdir call with
shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) and add the shutil import
so the tearDownClass cleanup always removes the entire tmpdir regardless of
layout while remaining tolerant of errors.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 7cd4b238-4f71-4f89-8118-9021b50980c0
📒 Files selected for processing (1)
source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds first-class support for loading .pt training checkpoints for inference in the pt_expt backend, including backend auto-detection for shared .pt suffixes.
Changes:
- Add
.ptsniffing inBackend.detect_backend_by_modelto route.ptfiles toptvspt-exptbased on state-dict key naming. - Implement
.ptcheckpoint loading inpt_expt.DeepEval(including multitask head selection, compiled-wrapper key cleanup, and eager runner shims). - Add a comprehensive pt_expt inference test suite covering routing, spin, multitask, aparam/fparam behavior, and
.pt↔.pteconsistency.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py | Adds end-to-end tests for .pt dispatch + pt_expt .pt inference correctness across single/multi-task and spin variants. |
| deepmd/pt_expt/model/get_model.py | Adds get_spin_model and updates get_model to construct spin models correctly from config. |
| deepmd/pt_expt/infer/deep_eval.py | Extends pt_expt inference to accept .pt checkpoints and reconstruct eager runners compatible with existing eval paths. |
| deepmd/backend/backend.py | Implements .pt content sniffing to disambiguate backend routing between pt and pt-expt. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| try: | ||
| import torch | ||
|
|
||
| sd = torch.load(filename, map_location="cpu", weights_only=False) |
There was a problem hiding this comment.
torch.load(..., weights_only=False) on a user-provided path expands the attack surface for arbitrary code execution via pickle, especially since this runs during backend detection (before the user has opted into a given backend). Prefer weights_only=True for sniffing (it should still load tensor-only/dict checkpoints), and if it fails, fall back to suffix dispatch as you already do. If you must support older/odd checkpoints that require full unpickling, consider a guarded fallback with a clear warning or an explicit opt-in.
| sd = torch.load(filename, map_location="cpu", weights_only=False) | |
| sd = torch.load(filename, map_location="cpu", weights_only=True) |
| deepcopy, | ||
| ) | ||
|
|
||
| from deepmd.pt.utils.env import ( |
There was a problem hiding this comment.
This loader is part of the pt_expt backend but imports DEVICE from deepmd.pt.utils.env (legacy pt backend). That can put tensors on an unintended device (or diverge from pt_expt’s device policy), causing incorrect placement / failures. Import DEVICE from deepmd.pt_expt.utils.env (or reuse an already-defined pt_expt device constant) to keep checkpoint loading consistent with the pt_expt runtime.
| from deepmd.pt.utils.env import ( | |
| from deepmd.pt_expt.utils.env import ( |
| head_state[key.replace(prefix, "model.Default.")] = ( | ||
| value.clone() if torch.is_tensor(value) else value | ||
| ) |
There was a problem hiding this comment.
Cloning every tensor when selecting a multitask head can significantly increase memory/time for large checkpoints, and it’s not needed here since you’re only rebinding references into a new dict (not mutating tensors in-place). Assign the tensor directly unless there’s a specific downstream mutation that requires isolation.
| head_state[key.replace(prefix, "model.Default.")] = ( | |
| value.clone() if torch.is_tensor(value) else value | |
| ) | |
| head_state[key.replace(prefix, "model.Default.")] = value |
| model_params = deepcopy(state_dict["_extra_state"]["model_params"]) | ||
|
|
There was a problem hiding this comment.
If a .pt file is missing _extra_state / model_params (e.g., a non-pt_expt checkpoint, a partially-saved checkpoint, or a hand-edited artifact), this will raise a KeyError with a non-actionable message. Add an explicit check and raise a ValueError that explains what structure is expected for a pt_expt training checkpoint and how to proceed (e.g., ‘use the pt backend’ / ‘export to .pte/.pt2’ / ‘retrain with pt_expt’).
| model_params = deepcopy(state_dict["_extra_state"]["model_params"]) | |
| checkpoint_extra_state: Any | None = None | |
| if isinstance(state_dict, dict): | |
| checkpoint_extra_state = state_dict.get("_extra_state") | |
| if not ( | |
| isinstance(state_dict, dict) | |
| and isinstance(checkpoint_extra_state, dict) | |
| and "model_params" in checkpoint_extra_state | |
| ): | |
| raise ValueError( | |
| f"Invalid .pt file '{model_file}': expected a pt_expt training " | |
| "checkpoint containing '_extra_state' with nested " | |
| "'model_params'. The provided file does not have the expected " | |
| "checkpoint structure. If this is a different PyTorch " | |
| "checkpoint, load it with the 'pt' backend instead. If this is " | |
| "an exported model, use a '.pte' or '.pt2' artifact. Otherwise, " | |
| "retrain or re-export the model with pt_expt to create a " | |
| "compatible training checkpoint." | |
| ) | |
| model_params = deepcopy(checkpoint_extra_state["model_params"]) |
| for key, backend in Backend.get_backends().items(): | ||
| if key == target_name: | ||
| return backend |
There was a problem hiding this comment.
This manual lookup loop can be simplified to a direct backend lookup (e.g., Backend.get_backend(target_name)), which is clearer and avoids re-iterating the registry.
| for key, backend in Backend.get_backends().items(): | |
| if key == target_name: | |
| return backend | |
| return Backend.get_backend(target_name) |
| def tearDownClass(cls) -> None: | ||
| for ext in (".pt", ".pte"): | ||
| path = cls.files[ext] | ||
| if os.path.exists(path): | ||
| os.unlink(path) | ||
| os.rmdir(cls.files["tmpdir"]) |
There was a problem hiding this comment.
os.rmdir will fail if anything unexpected ends up in the temp directory (e.g., platform-specific artifacts, or future test additions). Using shutil.rmtree(tmpdir, ignore_errors=True) (or equivalent) makes teardown more robust and reduces test flakiness.
Summary
dp --pt-expt test -m foo.ptpreviously rejected.ptfiles (only.pt2/.ptewere supported);dp --pt test -m foo.pton a pt_expt-trained checkpoint silently loaded random weights because the dpmodel.w/.bnaming doesn't match the legacy pt backend's.matrix/.bias..pttraining checkpoints first-class for inference under the pt_expt backend.Changes
Backend.detect_backend_by_modelsniffs.ptcontent and routes by parameter naming:.w/.b→ pt-expt,.matrix/.bias→ pt. Bogus.ptfalls back to suffix dispatch (pt). Backwards compatible with all existing pt-trained.ptcheckpoints.pt_expt.DeepEval._load_ptreconstructs the model from_extra_state[\"model_params\"], loads the state-dict viaModelWrapper, and exposes an eagerforward_common_lowerrunner with the same signature as the AOTI/exported module so the existingeval()path is unchanged. Spin-aware (7-arg) and non-spin (6-arg) variants. Multi-task.ptselects a head and remaps keys. Populatesmetadata(default_fparam, dim_fparam/aparam, …) so eval helpers behave the same as the.pt2/.ptepath.pt_expt.get_modellearnsget_spin_model(mirrors dpmodel) so spin checkpoints can be reconstructed frommodel_params(previously it silently returned a non-spinEnergyModel)..pt2/.pte/.ptand raises an actionableValueErrorfor anything else (was: implicit fallthrough to.pteloader → cryptic torch error).Tests
source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py(21 tests):.ptroutes to pt-expt; pt-style.ptroutes to pt; bogus.ptfalls back to suffix..pt— metadata accessors,DeepPot(.pt).eval(...)parity vs direct forward at 1e-10,.pthrejection..pt— head selection parity, missing-head error, no-default-no-head error..pt— metadata flags, eager-reference parity, missing-spin-arg error..pt— default fparam matches explicit; varying fparam changes output..pt— aparam takes effect; missing-aparam raises..pt— each head matches its own eager reference; distinct heads produce distinct outputs..pt↔.pteconsistency at 1e-10 for vanilla spin (atomic=True), default fparam (atomic=True), and aparam (atomic=True).Test plan
pytest source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py -vpytest source/tests/pt_expt/infer/ -v(regression: existing.pt2/.ptepaths)dp --pt-expt train, thendp --pt-expt test -m model.ckpt-N.ptproduces identical metrics todp --pt-expt test -m frozen.pt2Known limitations
_load_pthandles such checkpoints, but a user can't currently produce one viadp --pt-expt train. Tests construct them synthetically._load_pt'sexported_moduleis a Python closure (eager), not a realtorch.nn.Module. Sufficient fordp test, buteval_descriptor/eval_typeebd/eval_fitting_last_layerwon't work from a.pt(only from.pt2/.pte)..pt↔.pteconsistency not separately asserted (same eager code path as PBC).Summary by CodeRabbit
New Features
Tests