Skip to content

feat(pt_expt): support .pt training checkpoints in DeepEval#5423

Open
wanghan-iapcm wants to merge 3 commits intodeepmodeling:masterfrom
wanghan-iapcm:feat-pt-expt-load-pt-checkpoint
Open

feat(pt_expt): support .pt training checkpoints in DeepEval#5423
wanghan-iapcm wants to merge 3 commits intodeepmodeling:masterfrom
wanghan-iapcm:feat-pt-expt-load-pt-checkpoint

Conversation

@wanghan-iapcm
Copy link
Copy Markdown
Collaborator

@wanghan-iapcm wanghan-iapcm commented Apr 26, 2026

Summary

  • dp --pt-expt test -m foo.pt previously rejected .pt files (only .pt2 / .pte were supported); dp --pt test -m foo.pt on a pt_expt-trained checkpoint silently loaded random weights because the dpmodel .w/.b naming doesn't match the legacy pt backend's .matrix/.bias.
  • This PR makes .pt training checkpoints first-class for inference under the pt_expt backend.

Changes

  • Backend.detect_backend_by_model sniffs .pt content and routes by parameter naming: .w/.b → pt-expt, .matrix/.bias → pt. Bogus .pt falls back to suffix dispatch (pt). Backwards compatible with all existing pt-trained .pt checkpoints.
  • pt_expt.DeepEval._load_pt reconstructs the model from _extra_state[\"model_params\"], loads the state-dict via ModelWrapper, and exposes an eager forward_common_lower runner with the same signature as the AOTI/exported module so the existing eval() path is unchanged. Spin-aware (7-arg) and non-spin (6-arg) variants. Multi-task .pt selects a head and remaps keys. Populates metadata (default_fparam, dim_fparam/aparam, …) so eval helpers behave the same as the .pt2/.pte path.
  • pt_expt.get_model learns get_spin_model (mirrors dpmodel) so spin checkpoints can be reconstructed from model_params (previously it silently returned a non-spin EnergyModel).
  • Dispatch: pt_expt's DeepEval ctor now explicitly accepts .pt2/.pte/.pt and raises an actionable ValueError for anything else (was: implicit fallthrough to .pte loader → cryptic torch error).

Tests

source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py (21 tests):

  • Dispatch sniffing — pt_expt-style .pt routes to pt-expt; pt-style .pt routes to pt; bogus .pt falls back to suffix.
  • Single-task .pt — metadata accessors, DeepPot(.pt).eval(...) parity vs direct forward at 1e-10, .pth rejection.
  • Multi-task .pt — head selection parity, missing-head error, no-default-no-head error.
  • Spin .pt — metadata flags, eager-reference parity, missing-spin-arg error.
  • Spin + fparam .pt — default fparam matches explicit; varying fparam changes output.
  • Spin + aparam .pt — aparam takes effect; missing-aparam raises.
  • Spin multi-task .pt — each head matches its own eager reference; distinct heads produce distinct outputs.
  • Cross-format .pt.pte consistency at 1e-10 for vanilla spin (atomic=True), default fparam (atomic=True), and aparam (atomic=True).

Test plan

  • pytest source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py -v
  • pytest source/tests/pt_expt/infer/ -v (regression: existing .pt2/.pte paths)
  • Train a small example with dp --pt-expt train, then dp --pt-expt test -m model.ckpt-N.pt produces identical metrics to dp --pt-expt test -m frozen.pt2

Known limitations

  • pt_expt training itself still has no multi-task or multi-task-spin path; _load_pt handles such checkpoints, but a user can't currently produce one via dp --pt-expt train. Tests construct them synthetically.
  • _load_pt's exported_module is a Python closure (eager), not a real torch.nn.Module. Sufficient for dp test, but eval_descriptor / eval_typeebd / eval_fitting_last_layer won't work from a .pt (only from .pt2/.pte).
  • NoPBC .pt.pte consistency not separately asserted (same eager code path as PBC).

Summary by CodeRabbit

  • New Features

    • Support for loading .pt training checkpoints and detecting the correct backend by inspecting checkpoint contents.
    • Multi-task model head selection when loading checkpoints.
    • New spin-energy model construction and inference with consistent exported behavior for spin vs non-spin.
  • Tests

    • End-to-end tests for checkpoint detection, multi-head routing, eager inference consistency, and spin-model numerical agreement.

`dp --pt-expt test -m foo.pt` previously rejected `.pt` files (only
`.pt2` / `.pte` were supported), and `dp --pt test -m foo.pt` on a
pt_expt-trained checkpoint silently loaded random weights because the
state-dict layout (dpmodel `.w`/`.b` keys) doesn't match the legacy
pt backend's expectations.

- `Backend.detect_backend_by_model` sniffs `.pt` content so files with
  `.w`/`.b` keys (pt_expt) route to the pt_expt DeepEval and files with
  `.matrix`/`.bias` keys (pt) keep routing to pt.
- `pt_expt.DeepEval._load_pt` reconstructs the model from
  `_extra_state["model_params"]`, loads the state-dict via
  `ModelWrapper`, and exposes an eager `forward_common_lower` runner
  with the same signature as the AOTI/exported module so the existing
  `eval()` path is unchanged. Spin-aware and non-spin variants;
  multi-task `.pt` selects a head and remaps keys.
- `pt_expt.get_model` learns `get_spin_model` (mirrors dpmodel) so spin
  checkpoints can be reconstructed from `model_params`.
- Tests cover dispatch sniffing, single-task / multi-task / spin /
  spin-multi-task `.pt` parity vs eager forward, fparam / aparam, and
  `.pt` vs `.pte` cross-format consistency at 1e-10.
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 98aee78a86

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread deepmd/backend/backend.py
try:
import torch

sd = torch.load(filename, map_location="cpu", weights_only=False)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid unpickling arbitrary .pt during backend detection

detect_backend_by_model now deserializes .pt files with torch.load(..., weights_only=False) before selecting a backend, which allows pickle payload execution from attacker-controlled files. In practice, simply running dp test -m untrusted.pt can execute code during format sniffing, even before model loading proceeds. Detection only needs state-dict key names, so this should use a safe load mode (for example weights_only=True) and fail closed when the payload is not a plain checkpoint dict.

Useful? React with 👍 / 👎.

get_model,
)

state_dict = torch.load(model_file, map_location=DEVICE, weights_only=False)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Load pt_expt checkpoints without unsafe pickle execution

The new .pt inference path in pt_expt.DeepEval also uses torch.load(..., weights_only=False), so evaluating an untrusted checkpoint via dp --pt-expt test -m file.pt can execute arbitrary Python objects at deserialize time. This path only needs tensor weights and _extra_state metadata for ModelWrapper.load_state_dict, so full pickle loading is unnecessary and introduces a security regression compared with safe checkpoint loading.

Useful? React with 👍 / 👎.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 26, 2026

📝 Walkthrough

Walkthrough

Adds .pt checkpoint detection and eager loading for the pt_expt backend (including multi-head selection), introduces a spin-aware model construction path, updates backend sniffing to inspect .pt contents, and adds comprehensive tests covering detection, loading, multi-head routing, compiled-layout variants, and spin inference.

Changes

Cohort / File(s) Summary
Backend detection
deepmd/backend/backend.py
Lowercases filenames for suffix matching; for .pt files attempts torch.load to inspect checkpoint/state-dict key suffixes to choose between pt-expt and pt, returning the plugin when identified; falls back to suffix logic on exceptions.
pt_expt inference loader
deepmd/pt_expt/infer/deep_eval.py
Adds DeepEval._load_pt and extends DeepEval.__init__ to accept .pt plus optional head; eagerly loads checkpoints, selects multi-task head (default "Default"), remaps state_dict keys to runtime prefixes, constructs/loading runtime model, and exposes an eager exported_module matching downstream forward outputs (spin vs non-spin).
Model construction (spin)
deepmd/pt_expt/model/get_model.py
Adds get_spin_model and makes get_model return a spin-wrapped SpinEnergyModel when "spin" is present: augments type_map for virtual spin atoms, computes/exposes exclude-type lists, ensures descriptor defaults, and duplicates descriptor selection for certain descriptor types.
Tests (end-to-end)
source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py
New tests that synthesize .pt checkpoints (plain, compiled-layout, multi-head, spin), verify backend detection routing, validate _load_pt and multi-head behavior (including error cases), and assert numeric agreement between eager EnergyModel forward and loaded-backend inference across energy/forces/virial/atomic outputs.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant BackendDetection as Backend Detection
    participant TorchLoad as torch.load()
    participant DeepEval as DeepEval Loader
    participant ModelBuilder as Model Construction
    participant Inference as Inference Engine

    User->>BackendDetection: detect_backend_by_model(.pt file)
    BackendDetection->>TorchLoad: torch.load(checkpoint)
    TorchLoad-->>BackendDetection: checkpoint dict
    BackendDetection->>BackendDetection: inspect state_dict key suffixes
    BackendDetection-->>User: return backend (pt or pt-expt)

    User->>DeepEval: DeepEval(model_file=.pt, head=?)
    DeepEval->>TorchLoad: torch.load(checkpoint)
    TorchLoad-->>DeepEval: checkpoint dict
    DeepEval->>DeepEval: extract model dict, select head, remap keys
    DeepEval->>ModelBuilder: get_model(config)
    ModelBuilder-->>DeepEval: model instance (standard or spin)
    DeepEval->>DeepEval: load weights, build exported_module
    User->>Inference: exported_module(forward args)
    Inference-->>User: energy, forces, virial, (atomic/spin outputs)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • njzjz
  • iProzd
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 41.07% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat(pt_expt): support .pt training checkpoints in DeepEval' directly and concisely describes the main change—adding support for .pt training checkpoints to the pt_expt backend's DeepEval inference loader.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (5)
deepmd/pt_expt/infer/deep_eval.py (2)

252-260: str.replace is unbounded; hoist prefix out of the loop.

Two minor robustness/readability issues in the head-renaming loop:

  1. key.replace(prefix, "model.Default.") rewrites every occurrence of prefix in the key. If a head name happens to appear deeper in a key path (or in any state-dict key derived from a user-supplied identifier), keys silently get double-rewritten. Slice the prefix instead.
  2. prefix = f"model.{head}." is recomputed on every iteration.
♻️ Proposed fix
             # Restrict state_dict to the chosen head and rename to "Default".
             head_state = {"_extra_state": state_dict["_extra_state"]}
+            prefix = f"model.{head}."
             for key, value in state_dict.items():
-                prefix = f"model.{head}."
                 if key.startswith(prefix):
-                    head_state[key.replace(prefix, "model.Default.")] = (
+                    new_key = "model.Default." + key[len(prefix) :]
+                    head_state[new_key] = (
                         value.clone() if torch.is_tensor(value) else value
                     )
             state_dict = head_state
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/infer/deep_eval.py` around lines 252 - 260, The loop that
renames head keys is unsafe because it recomputes prefix each iteration and uses
key.replace(prefix, ...) which can replace multiple occurrences; hoist prefix =
f"model.{head}." out of the for loop and when a key startswith(prefix) build the
new key by slicing off the prefix (e.g. "model.Default." + key[len(prefix):])
instead of using key.replace; keep cloning tensor values as done currently and
assign into head_state, leaving all other logic (state_dict, head_state,
torch.is_tensor) unchanged.

223-225: Inconsistent DEVICE import: should be deepmd.pt_expt.utils.env.

This file already imports DEVICE from deepmd.pt_expt.utils.env (line 813, 982). Pulling it from deepmd.pt.utils.env here is inconsistent and creates an unnecessary dependency from pt_expt on pt. If the two backends ever diverge on device defaults this becomes a subtle bug.

♻️ Proposed fix
-        from deepmd.pt.utils.env import (
+        from deepmd.pt_expt.utils.env import (
             DEVICE,
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/infer/deep_eval.py` around lines 223 - 225, The import of
DEVICE at the top of deep_eval.py incorrectly references deepmd.pt.utils.env;
change it to import DEVICE from deepmd.pt_expt.utils.env so it matches the other
imports in this module (see existing imports of DEVICE around the file) and
avoids creating a dependency on the pt backend—update the single import
statement that currently references deepmd.pt.utils.env to reference
deepmd.pt_expt.utils.env instead.
source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py (2)

475-495: Prefer strict=True in cross-format consistency zips.

dp.eval(...) returns a fixed-arity tuple matching the request defs, and the name lists in these consistency loops (here, lines 567-580, and 659-672) hard-code 7 entries to mirror the spin-with-atomic case. With strict=False, if dp.eval ever changes arity (e.g., a new output is added) the loop silently truncates and consistency for new fields is no longer asserted. strict=True would force the tests to be updated.

-        for name, a, b in zip(
+        for name, a, b in zip(
             (
                 "energy",
                 ...
                 "mask_mag",
             ),
             out_pt,
             out_pte,
-            strict=False,
+            strict=True,
         ):

(Same applies at lines 567-580 and 659-672.)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py` around lines 475
- 495, The zip used to iterate over ("energy", "force", "virial", "atom_energy",
"atom_virial", "force_mag", "mask_mag") and the outputs out_pt and out_pte
currently passes strict=False, which can silently drop any future-added outputs;
change the zip call(s) that compare name, a, b (the one using the hard-coded
7-entry tuple with out_pt and out_pte) to use strict=True so mismatched arity
raises an error—update all equivalent consistency loops that compare out_pt and
out_pte.

401-407: os.rmdir will fail on residual files; prefer shutil.rmtree.

_make_spin_files only puts .pt and .pte into tmpdir, so today this works. But if a future change writes any auxiliary file (e.g., a sidecar .json from deserialize_to_file), os.rmdir will raise OSError and leak the directory. shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) handles both this and any partial-creation cleanup uniformly.

♻️ Proposed fix
+    import shutil
+
     `@classmethod`
     def tearDownClass(cls) -> None:
-        for ext in (".pt", ".pte"):
-            path = cls.files[ext]
-            if os.path.exists(path):
-                os.unlink(path)
-        os.rmdir(cls.files["tmpdir"])
+        shutil.rmtree(cls.files["tmpdir"], ignore_errors=True)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py` around lines 401
- 407, The tearDownClass cleanup uses os.rmdir which will fail if any residual
files remain; update the class teardown (tearDownClass) to use
shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) instead of os.rmdir so
the tmpdir is removed recursively and safely; be sure to import shutil at the
top of the test module and keep the existing per-extension unlink loop (from
_make_spin_files related cleanup) intact so both explicit files and any
auxiliary sidecar files are cleaned up.
deepmd/backend/backend.py (1)

122-124: Simplify backend lookup.

Backend.get_backends() already returns a dict[str, type[Backend]], so the linear scan can be a single dict lookup:

♻️ Proposed refactor
-                for key, backend in Backend.get_backends().items():
-                    if key == target_name:
-                        return backend
+                backend = Backend.get_backends().get(target_name)
+                if backend is not None:
+                    return backend
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/backend/backend.py` around lines 122 - 124, The loop in backend.py
that iterates over Backend.get_backends().items() to find a matching key is
unnecessary because Backend.get_backends() already returns a dict; replace the
linear scan in the lookup logic with a direct dictionary access (e.g., lookup =
Backend.get_backends().get(target_name)) and return that result (or handle a
missing key appropriately) instead of the for loop. Ensure you update the code
paths that expect a backend when not found (raise or return None consistently)
and keep references to Backend.get_backends() and target_name to locate the
change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/backend/backend.py`:
- Around line 108-127: The torch.load call in detect_backend_by_model currently
uses weights_only=False which can execute pickled code and cause RCE; change the
call in that block (the torch.load(...) invocation) to torch.load(filename,
map_location="cpu", weights_only=True) so we only load tensor weights when
sniffing backend, keeping the existing try/except fallback to suffix-based
detection; leave the later real model load in deepmd/pt_expt/infer/deep_eval.py
(which uses weights_only=False) unchanged.

In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py`:
- Around line 251-309: The test builds two heads with identical random seeds so
weight selection isn’t actually validated; change _build_model_and_params to
accept a seed parameter (forward it into DescrptSeA and EnergyFittingNet), call
_build_model_and_params twice with distinct seeds when creating
cls.model_a/params_a and cls.model_b/params_b in setUpClass, and add an explicit
distinct-outputs assertion (like test_distinct_heads_produce_distinct_outputs)
that verifies the two heads produce different energies/forces for the same input
to ensure head selection actually loads different weights (update any callers of
_build_model_and_params accordingly).

---

Nitpick comments:
In `@deepmd/backend/backend.py`:
- Around line 122-124: The loop in backend.py that iterates over
Backend.get_backends().items() to find a matching key is unnecessary because
Backend.get_backends() already returns a dict; replace the linear scan in the
lookup logic with a direct dictionary access (e.g., lookup =
Backend.get_backends().get(target_name)) and return that result (or handle a
missing key appropriately) instead of the for loop. Ensure you update the code
paths that expect a backend when not found (raise or return None consistently)
and keep references to Backend.get_backends() and target_name to locate the
change.

In `@deepmd/pt_expt/infer/deep_eval.py`:
- Around line 252-260: The loop that renames head keys is unsafe because it
recomputes prefix each iteration and uses key.replace(prefix, ...) which can
replace multiple occurrences; hoist prefix = f"model.{head}." out of the for
loop and when a key startswith(prefix) build the new key by slicing off the
prefix (e.g. "model.Default." + key[len(prefix):]) instead of using key.replace;
keep cloning tensor values as done currently and assign into head_state, leaving
all other logic (state_dict, head_state, torch.is_tensor) unchanged.
- Around line 223-225: The import of DEVICE at the top of deep_eval.py
incorrectly references deepmd.pt.utils.env; change it to import DEVICE from
deepmd.pt_expt.utils.env so it matches the other imports in this module (see
existing imports of DEVICE around the file) and avoids creating a dependency on
the pt backend—update the single import statement that currently references
deepmd.pt.utils.env to reference deepmd.pt_expt.utils.env instead.

In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py`:
- Around line 475-495: The zip used to iterate over ("energy", "force",
"virial", "atom_energy", "atom_virial", "force_mag", "mask_mag") and the outputs
out_pt and out_pte currently passes strict=False, which can silently drop any
future-added outputs; change the zip call(s) that compare name, a, b (the one
using the hard-coded 7-entry tuple with out_pt and out_pte) to use strict=True
so mismatched arity raises an error—update all equivalent consistency loops that
compare out_pt and out_pte.
- Around line 401-407: The tearDownClass cleanup uses os.rmdir which will fail
if any residual files remain; update the class teardown (tearDownClass) to use
shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) instead of os.rmdir so
the tmpdir is removed recursively and safely; be sure to import shutil at the
top of the test module and keep the existing per-extension unlink loop (from
_make_spin_files related cleanup) intact so both explicit files and any
auxiliary sidecar files are cleaned up.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: a018a9f1-5dda-4e71-86d2-9ffdf38f75a1

📥 Commits

Reviewing files that changed from the base of the PR and between d14233e and 98aee78.

📒 Files selected for processing (4)
  • deepmd/backend/backend.py
  • deepmd/pt_expt/infer/deep_eval.py
  • deepmd/pt_expt/model/get_model.py
  • source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py

Comment thread deepmd/backend/backend.py
Comment on lines +108 to +127
if filename_lower.endswith(".pt"):
try:
import torch

sd = torch.load(filename, map_location="cpu", weights_only=False)
if isinstance(sd, dict) and "model" in sd:
sd = sd["model"]
keys = list(sd.keys()) if hasattr(sd, "keys") else []
has_pt_expt = any(k.endswith(".w") or k.endswith(".b") for k in keys)
has_pt = any(k.endswith(".matrix") or k.endswith(".bias") for k in keys)
if has_pt_expt and not has_pt:
target_name = "pt-expt"
else:
target_name = "pt"
for key, backend in Backend.get_backends().items():
if key == target_name:
return backend
except Exception:
# Fall through to suffix matching if sniffing fails.
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🌐 Web query:

torch.load weights_only=True state_dict keys behavior PyTorch 2.6

💡 Result:

In PyTorch 2.6, torch.load(checkpoint_path, weights_only=True) deserializes the saved object using a restricted unpickler that only allows tensors, primitive types (int, float, str, bool, None), and dictionaries. It does not allow arbitrary classes or objects unless explicitly allowlisted via torch.serialization.add_safe_globals. For a typical model checkpoint saved as torch.save({'model_state_dict': model.state_dict, ...}, path), torch.load with weights_only=True returns the full dictionary with top-level keys like 'model_state_dict', 'optimizer_state_dict', 'epoch', etc. Each state_dict value is an OrderedDict containing the model's parameter keys (e.g., 'conv1.weight', 'fc.bias') mapped to tensors. This works because the structure uses only allowed types: dict -> OrderedDict (treated as dict) -> str keys -> torch.Tensor values. Primitive metadata like epoch (int/float) is also allowed. Examples from official tutorials explicitly show: checkpoint = torch.load(PATH, weights_only=True) model.load_state_dict(checkpoint['model_state_dict']) If the checkpoint contains non-allowlisted objects (e.g., scheduler.state_dict with collections.Counter, or optimizer instances), it fails with _pickle.UnpicklingError, requiring allowlisting or weights_only=False (unsafe for untrusted files). In PyTorch 2.6+, weights_only=True is the default, promoting safer loading of weights-only checkpoints.

Citations:


🏁 Script executed:

cd deepmd && head -150 backend/backend.py | tail -60

Repository: deepmodeling/deepmd-kit

Length of output: 2431


🏁 Script executed:

# Check the full function and context around lines 108-127
sed -n '95,135p' deepmd/backend/backend.py

Repository: deepmodeling/deepmd-kit

Length of output: 1891


🏁 Script executed:

# Find torch version requirements in the repo
find . -name "*.txt" -o -name "*.toml" -o -name "*.cfg" -o -name "*.yml" | head -20 | xargs grep -l "torch" 2>/dev/null | head -5

Repository: deepmodeling/deepmd-kit

Length of output: 207


🏁 Script executed:

# Check if there are other calls to torch.load in the codebase
rg "torch\.load" --type py -A 2 -B 2

Repository: deepmodeling/deepmd-kit

Length of output: 15397


🏁 Script executed:

# Verify Backend.get_backends() structure
rg "def get_backends" --type py -A 10

Repository: deepmodeling/deepmd-kit

Length of output: 1194


Use weights_only=True for backend sniffing.

detect_backend_by_model is called during dispatch (e.g. dp test -m foo.pt) before the user explicitly trusts the checkpoint. Loading with weights_only=False executes pickle, enabling RCE on a malicious .pt file. Since this branch only inspects state_dict.keys() to disambiguate between pt and pt-expt backends, a stricter load is sufficient and safer:

Suggested change
-                sd = torch.load(filename, map_location="cpu", weights_only=False)
+                # Only state_dict keys are inspected; weights_only=True avoids
+                # executing pickle code on untrusted checkpoints during
+                # backend dispatch.
+                sd = torch.load(filename, map_location="cpu", weights_only=True)

If weights_only=True fails (e.g., on checkpoints with non-standard _extra_state objects), the except Exception: pass block gracefully falls through to suffix-based backend detection. The actual model loading in deepmd/pt_expt/infer/deep_eval.py (line 230) already uses weights_only=False once the backend is selected—the right place to accept that risk.

🧰 Tools
🪛 Ruff (0.15.11)

[error] 125-127: try-except-pass detected, consider logging the exception

(S110)


[warning] 125-125: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/backend/backend.py` around lines 108 - 127, The torch.load call in
detect_backend_by_model currently uses weights_only=False which can execute
pickled code and cause RCE; change the call in that block (the torch.load(...)
invocation) to torch.load(filename, map_location="cpu", weights_only=True) so we
only load tensor weights when sniffing backend, keeping the existing try/except
fallback to suffix-based detection; leave the later real model load in
deepmd/pt_expt/infer/deep_eval.py (which uses weights_only=False) unchanged.

Comment on lines +251 to +309
@classmethod
def setUpClass(cls) -> None:
# Build two single-task models with the same architecture but
# different seeds, then save a multi-task-style checkpoint.
cls.model_a, params_a = _build_model_and_params(rcut=4.0)
cls.model_b, params_b = _build_model_and_params(rcut=4.0)

# Multi-task model_params layout used by pt_expt training.
model_params = {"model_dict": {"head_a": params_a, "head_b": params_b}}

wrapper = ModelWrapper(
{"head_a": cls.model_a, "head_b": cls.model_b},
model_params=model_params,
)
cls.pt_path = tempfile.NamedTemporaryFile(suffix=".pt", delete=False).name
torch.save({"model": wrapper.state_dict()}, cls.pt_path)

@classmethod
def tearDownClass(cls) -> None:
if os.path.exists(cls.pt_path):
os.unlink(cls.pt_path)

def test_select_head_matches_single_task_forward(self) -> None:
rng = np.random.default_rng(GLOBAL_SEED + 1)
natoms = 4
coords = rng.random((1, natoms, 3)) * 8.0
cells = np.eye(3).reshape(1, 9) * 10.0
atom_types = np.array([i % 2 for i in range(natoms)], dtype=np.int32)

for head, src in (("head_a", self.model_a), ("head_b", self.model_b)):
# Build a DeepPot wrapping this DeepEval for end-to-end eval.
dp = DeepPot(self.pt_path, head=head)
de = dp.deep_eval
e, f, v = dp.eval(coords, cells, atom_types, atomic=False)

coord_t = torch.tensor(
coords, dtype=torch.float64, device=DEVICE
).requires_grad_(True)
atype_t = torch.tensor(
atom_types.reshape(1, -1), dtype=torch.int64, device=DEVICE
)
cell_t = torch.tensor(cells, dtype=torch.float64, device=DEVICE)
ref = src.forward(coord_t, atype_t, cell_t, do_atomic_virial=False)

np.testing.assert_allclose(
e,
ref["energy"].detach().cpu().numpy(),
rtol=1e-10,
atol=1e-10,
err_msg=f"head={head}, energy",
)
np.testing.assert_allclose(
f,
ref["force"].detach().cpu().numpy(),
rtol=1e-10,
atol=1e-10,
err_msg=f"head={head}, force",
)
self.assertEqual(de.get_type_map(), src.get_type_map())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Heads model_a and model_b are weight-identical — head selection is not actually validated.

_build_model_and_params() constructs DescrptSeA(..., seed=GLOBAL_SEED) and EnergyFittingNet(..., seed=GLOBAL_SEED), so both calls produce models with the same initial weights. Consequently test_select_head_matches_single_task_forward would still pass even if _load_pt accidentally loaded head_b's weights when head="head_a" was requested (or vice versa). The test asserts dispatch plumbing only, not correctness of weight selection.

The spin counterpart TestPtExptLoadPtSpinMultiTask (lines 704-716) gets this right by using distinct seeds (42 / 7) plus an explicit test_distinct_heads_produce_distinct_outputs sanity check. Suggest mirroring that pattern here:

🐛 Proposed fix
-        cls.model_a, params_a = _build_model_and_params(rcut=4.0)
-        cls.model_b, params_b = _build_model_and_params(rcut=4.0)
+        cls.model_a, params_a = _build_model_and_params(rcut=4.0, seed=42)
+        cls.model_b, params_b = _build_model_and_params(rcut=4.0, seed=7)

…and add a distinct-outputs assertion:

+    def test_distinct_heads_produce_distinct_outputs(self) -> None:
+        rng = np.random.default_rng(GLOBAL_SEED + 2)
+        natoms = 4
+        coords = rng.random((1, natoms, 3)) * 8.0
+        cells = np.eye(3).reshape(1, 9) * 10.0
+        atom_types = np.array([i % 2 for i in range(natoms)], dtype=np.int32)
+        e_a = DeepPot(self.pt_path, head="head_a").eval(
+            coords, cells, atom_types, atomic=False
+        )[0]
+        e_b = DeepPot(self.pt_path, head="head_b").eval(
+            coords, cells, atom_types, atomic=False
+        )[0]
+        self.assertFalse(np.allclose(e_a, e_b))

(_build_model_and_params will need a seed parameter forwarded to DescrptSeA / EnergyFittingNet.)

🧰 Tools
🪛 Ruff (0.15.11)

[warning] 284-284: Unpacked variable v is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py` around lines 251
- 309, The test builds two heads with identical random seeds so weight selection
isn’t actually validated; change _build_model_and_params to accept a seed
parameter (forward it into DescrptSeA and EnergyFittingNet), call
_build_model_and_params twice with distinct seeds when creating
cls.model_a/params_a and cls.model_b/params_b in setUpClass, and add an explicit
distinct-outputs assertion (like test_distinct_heads_produce_distinct_outputs)
that verifies the two heads produce different energies/forces for the same input
to ensure head selection actually loads different weights (update any callers of
_build_model_and_params accordingly).

@wanghan-iapcm wanghan-iapcm requested a review from anyangml April 26, 2026 10:21
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 26, 2026

Codecov Report

❌ Patch coverage is 98.03922% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.38%. Comparing base (d14233e) to head (7158830).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/pt_expt/infer/deep_eval.py 98.43% 1 Missing ⚠️
deepmd/pt_expt/model/get_model.py 94.73% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5423      +/-   ##
==========================================
+ Coverage   82.36%   82.38%   +0.02%     
==========================================
  Files         824      824              
  Lines       87109    87209     +100     
  Branches     4197     4197              
==========================================
+ Hits        71743    71848     +105     
+ Misses      14091    14087       -4     
+ Partials     1275     1274       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Real training-produced `.pt` checkpoints have `model.{head}.original_model.X`
for the trained weights and `model.{head}.compiled_forward_lower.*`
for the compiled-graph constants. Previously `_load_pt` did a strict
`load_state_dict` against a plain `get_model(model_params)` and failed.

Fix: strip the `original_model.` infix and drop all
`compiled_forward_lower.*` keys before loading. Works for both
single-task and multi-task layouts. Tests synthesise the wrapped
layout directly to avoid a real `torch.compile` invocation.
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py (1)

580-585: Use shutil.rmtree for the tempdir cleanup.

os.rmdir only succeeds if the directory is empty; if any future change adds an extra artifact (e.g. a .lock file from torch save, or a partial write on a failing test), tearDownClass will raise and mask the actual test failure. Switching to shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) makes cleanup robust without changing behavior in the happy path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py` around lines 580
- 585, tearDownClass currently uses os.rmdir to remove cls.files["tmpdir"],
which will fail if the directory is not empty; replace the os.rmdir call with
shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) to make cleanup robust
and avoid masking test failures, and ensure shutil is imported at top of the
test module if not already; reference the tearDownClass method and the
cls.files["tmpdir"] usage when making the change.
deepmd/pt_expt/infer/deep_eval.py (1)

251-261: Constrain head-prefix replacement to the leading occurrence.

key.replace(prefix, "model.Default.") rewrites every occurrence of model.{head}. in the key, not just the leading one. The loop already gated on startswith(prefix), so this is harmless for current key shapes, but it's a defensive landmine if a head name (e.g. "head") ever appears later in the key (e.g. nested module names). Safer to slice or pin count=1:

♻️ Proposed fix
-            head_state = {"_extra_state": state_dict["_extra_state"]}
-            for key, value in state_dict.items():
-                prefix = f"model.{head}."
-                if key.startswith(prefix):
-                    head_state[key.replace(prefix, "model.Default.")] = (
-                        value.clone() if torch.is_tensor(value) else value
-                    )
+            prefix = f"model.{head}."
+            head_state = {"_extra_state": state_dict["_extra_state"]}
+            for key, value in state_dict.items():
+                if key.startswith(prefix):
+                    new_key = "model.Default." + key[len(prefix):]
+                    head_state[new_key] = (
+                        value.clone() if torch.is_tensor(value) else value
+                    )

Also moves prefix out of the per-iteration body.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/infer/deep_eval.py` around lines 251 - 261, The replacement of
the head prefix in the loop may replace non-leading occurrences; compute prefix
= f"model.{head}." once before the loop, and when a key startswith(prefix)
produce the new key by only removing the leading prefix (e.g., new_key =
"model.Default." + key[len(prefix):] or use replace with count=1) before
inserting into head_state, preserving the clone behavior for tensors and leaving
"_extra_state" handling as-is (affects variables: head_params, state_dict,
head_state, prefix).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py`:
- Line 426: The tuple unpack from dp.eval yields unused variables that trigger
RUF059; update the unpackings to prefix unused names with an underscore (e.g.
change v → _v, av → _av, mm → _mm, ae → _ae) or assert against them where the
test should exercise those values (e.g. in
test_eval_pbc_atomic_matches_reference keep or assert on av when atomic=True),
ensuring all four occurrences (the dp.eval unpack at dp.eval(...), and the
unpack sites noted) are adjusted so ruff no longer reports unused-variable
warnings.

---

Nitpick comments:
In `@deepmd/pt_expt/infer/deep_eval.py`:
- Around line 251-261: The replacement of the head prefix in the loop may
replace non-leading occurrences; compute prefix = f"model.{head}." once before
the loop, and when a key startswith(prefix) produce the new key by only removing
the leading prefix (e.g., new_key = "model.Default." + key[len(prefix):] or use
replace with count=1) before inserting into head_state, preserving the clone
behavior for tensors and leaving "_extra_state" handling as-is (affects
variables: head_params, state_dict, head_state, prefix).

In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py`:
- Around line 580-585: tearDownClass currently uses os.rmdir to remove
cls.files["tmpdir"], which will fail if the directory is not empty; replace the
os.rmdir call with shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) to
make cleanup robust and avoid masking test failures, and ensure shutil is
imported at top of the test module if not already; reference the tearDownClass
method and the cls.files["tmpdir"] usage when making the change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: d919d944-10aa-4a7a-b592-de89cb024aa8

📥 Commits

Reviewing files that changed from the base of the PR and between 98aee78 and 4bfd8f1.

📒 Files selected for processing (2)
  • deepmd/pt_expt/infer/deep_eval.py
  • source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py

# Build a DeepPot wrapping this DeepEval for end-to-end eval.
dp = DeepPot(self.pt_path, head=head)
de = dp.deep_eval
e, f, v = dp.eval(coords, cells, atom_types, atomic=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Silence RUF059 on the unused unpacked outputs.

Ruff flags unused names in tuple unpacks at lines 426 (v), 472 (v), 613 (av, mm), and 918 (ae, av, mm). Per the repo's coding guidelines, ruff check . must pass before commit, so prefix the unused names with _ (or assert against them — for test_eval_pbc_atomic_matches_reference it would actually be worth covering av since atomic=True).

♻️ Example fix at line 613
-        e, f, v, ae, av, fm, mm = dp.eval(
+        e, f, v, ae, _av, fm, _mm = dp.eval(
             self.COORD, self.BOX, self.ATYPE, atomic=True, spin=self.SPIN
         )

As per coding guidelines: "Install linter and run ruff check . before committing changes or the CI will fail".

Also applies to: 472-472, 613-613, 918-918

🧰 Tools
🪛 Ruff (0.15.11)

[warning] 426-426: Unpacked variable v is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py` at line 426, The
tuple unpack from dp.eval yields unused variables that trigger RUF059; update
the unpackings to prefix unused names with an underscore (e.g. change v → _v, av
→ _av, mm → _mm, ae → _ae) or assert against them where the test should exercise
those values (e.g. in test_eval_pbc_atomic_matches_reference keep or assert on
av when atomic=True), ensuring all four occurrences (the dp.eval unpack at
dp.eval(...), and the unpack sites noted) are adjusted so ruff no longer reports
unused-variable warnings.

The exported `.pte` and eager `.pt` paths produce identical energy /
force / virial / atom_energy / force_mag / mask_mag outputs for spin
models, but per-atom virial diverges. The reduced virial (which is the
sum of per-atom virials including the virtual-atom contribution) still
matches, so the divergence is in the per-extended-atom split, not the
totals. Pin this as a known limitation; revisit once the export-time
spin atom-virial path is reconciled with the eager path.
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py (1)

579-585: os.rmdir is fragile for the spin scratch dir.

tearDownClass only unlinks .pt/.pte then os.rmdir(tmpdir). If deserialize_to_file ever drops a sidecar (cache, journal, .pte/-as-directory layout, etc.) the os.rmdir call raises OSError and leaks the temp tree. shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) is both shorter and robust to layout changes.

♻️ Proposed refactor
+import shutil
@@
     `@classmethod`
     def tearDownClass(cls) -> None:
-        for ext in (".pt", ".pte"):
-            path = cls.files[ext]
-            if os.path.exists(path):
-                os.unlink(path)
-        os.rmdir(cls.files["tmpdir"])
+        shutil.rmtree(cls.files["tmpdir"], ignore_errors=True)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py` around lines 579
- 585, tearDownClass currently unlinks only .pt/.pte files then calls
os.rmdir(cls.files["tmpdir"]), which will raise OSError and leak the temp
directory if any extra sidecar files or nested dirs exist (e.g., created by
deserialize_to_file); replace the fragile os.rmdir call with
shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) and add the shutil import
so the tearDownClass cleanup always removes the entire tmpdir regardless of
layout while remaining tolerant of errors.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py`:
- Around line 579-585: tearDownClass currently unlinks only .pt/.pte files then
calls os.rmdir(cls.files["tmpdir"]), which will raise OSError and leak the temp
directory if any extra sidecar files or nested dirs exist (e.g., created by
deserialize_to_file); replace the fragile os.rmdir call with
shutil.rmtree(cls.files["tmpdir"], ignore_errors=True) and add the shutil import
so the tearDownClass cleanup always removes the entire tmpdir regardless of
layout while remaining tolerant of errors.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 7cd4b238-4f71-4f89-8118-9021b50980c0

📥 Commits

Reviewing files that changed from the base of the PR and between 4bfd8f1 and 7158830.

📒 Files selected for processing (1)
  • source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds first-class support for loading .pt training checkpoints for inference in the pt_expt backend, including backend auto-detection for shared .pt suffixes.

Changes:

  • Add .pt sniffing in Backend.detect_backend_by_model to route .pt files to pt vs pt-expt based on state-dict key naming.
  • Implement .pt checkpoint loading in pt_expt.DeepEval (including multitask head selection, compiled-wrapper key cleanup, and eager runner shims).
  • Add a comprehensive pt_expt inference test suite covering routing, spin, multitask, aparam/fparam behavior, and .pt.pte consistency.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.

File Description
source/tests/pt_expt/infer/test_deep_eval_pt_checkpoint.py Adds end-to-end tests for .pt dispatch + pt_expt .pt inference correctness across single/multi-task and spin variants.
deepmd/pt_expt/model/get_model.py Adds get_spin_model and updates get_model to construct spin models correctly from config.
deepmd/pt_expt/infer/deep_eval.py Extends pt_expt inference to accept .pt checkpoints and reconstruct eager runners compatible with existing eval paths.
deepmd/backend/backend.py Implements .pt content sniffing to disambiguate backend routing between pt and pt-expt.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread deepmd/backend/backend.py
try:
import torch

sd = torch.load(filename, map_location="cpu", weights_only=False)
Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.load(..., weights_only=False) on a user-provided path expands the attack surface for arbitrary code execution via pickle, especially since this runs during backend detection (before the user has opted into a given backend). Prefer weights_only=True for sniffing (it should still load tensor-only/dict checkpoints), and if it fails, fall back to suffix dispatch as you already do. If you must support older/odd checkpoints that require full unpickling, consider a guarded fallback with a clear warning or an explicit opt-in.

Suggested change
sd = torch.load(filename, map_location="cpu", weights_only=False)
sd = torch.load(filename, map_location="cpu", weights_only=True)

Copilot uses AI. Check for mistakes.
deepcopy,
)

from deepmd.pt.utils.env import (
Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loader is part of the pt_expt backend but imports DEVICE from deepmd.pt.utils.env (legacy pt backend). That can put tensors on an unintended device (or diverge from pt_expt’s device policy), causing incorrect placement / failures. Import DEVICE from deepmd.pt_expt.utils.env (or reuse an already-defined pt_expt device constant) to keep checkpoint loading consistent with the pt_expt runtime.

Suggested change
from deepmd.pt.utils.env import (
from deepmd.pt_expt.utils.env import (

Copilot uses AI. Check for mistakes.
Comment on lines +257 to +259
head_state[key.replace(prefix, "model.Default.")] = (
value.clone() if torch.is_tensor(value) else value
)
Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cloning every tensor when selecting a multitask head can significantly increase memory/time for large checkpoints, and it’s not needed here since you’re only rebinding references into a new dict (not mutating tensors in-place). Assign the tensor directly unless there’s a specific downstream mutation that requires isolation.

Suggested change
head_state[key.replace(prefix, "model.Default.")] = (
value.clone() if torch.is_tensor(value) else value
)
head_state[key.replace(prefix, "model.Default.")] = value

Copilot uses AI. Check for mistakes.
Comment on lines +233 to +234
model_params = deepcopy(state_dict["_extra_state"]["model_params"])

Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a .pt file is missing _extra_state / model_params (e.g., a non-pt_expt checkpoint, a partially-saved checkpoint, or a hand-edited artifact), this will raise a KeyError with a non-actionable message. Add an explicit check and raise a ValueError that explains what structure is expected for a pt_expt training checkpoint and how to proceed (e.g., ‘use the pt backend’ / ‘export to .pte/.pt2’ / ‘retrain with pt_expt’).

Suggested change
model_params = deepcopy(state_dict["_extra_state"]["model_params"])
checkpoint_extra_state: Any | None = None
if isinstance(state_dict, dict):
checkpoint_extra_state = state_dict.get("_extra_state")
if not (
isinstance(state_dict, dict)
and isinstance(checkpoint_extra_state, dict)
and "model_params" in checkpoint_extra_state
):
raise ValueError(
f"Invalid .pt file '{model_file}': expected a pt_expt training "
"checkpoint containing '_extra_state' with nested "
"'model_params'. The provided file does not have the expected "
"checkpoint structure. If this is a different PyTorch "
"checkpoint, load it with the 'pt' backend instead. If this is "
"an exported model, use a '.pte' or '.pt2' artifact. Otherwise, "
"retrain or re-export the model with pt_expt to create a "
"compatible training checkpoint."
)
model_params = deepcopy(checkpoint_extra_state["model_params"])

Copilot uses AI. Check for mistakes.
Comment thread deepmd/backend/backend.py
Comment on lines +122 to +124
for key, backend in Backend.get_backends().items():
if key == target_name:
return backend
Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This manual lookup loop can be simplified to a direct backend lookup (e.g., Backend.get_backend(target_name)), which is clearer and avoids re-iterating the registry.

Suggested change
for key, backend in Backend.get_backends().items():
if key == target_name:
return backend
return Backend.get_backend(target_name)

Copilot uses AI. Check for mistakes.
Comment on lines +580 to +585
def tearDownClass(cls) -> None:
for ext in (".pt", ".pte"):
path = cls.files[ext]
if os.path.exists(path):
os.unlink(path)
os.rmdir(cls.files["tmpdir"])
Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

os.rmdir will fail if anything unexpected ends up in the temp directory (e.g., platform-specific artifacts, or future test additions). Using shutil.rmtree(tmpdir, ignore_errors=True) (or equivalent) makes teardown more robust and reduces test flakiness.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

@anyangml anyangml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants