Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions deepmd/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,33 @@ def detect_backend_by_model(filename: str) -> type["Backend"]:
filename : str
The model file name
"""
filename = str(filename).lower()
filename_lower = str(filename).lower()
# `.pt` is shared between the pt and pt_expt backends. They use
# different parameter naming (pt: `.matrix`/`.bias`, pt_expt:
# `.w`/`.b`), so peek at the state-dict keys to disambiguate.
if filename_lower.endswith(".pt"):
try:
import torch

sd = torch.load(filename, map_location="cpu", weights_only=False)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid unpickling arbitrary .pt during backend detection

detect_backend_by_model now deserializes .pt files with torch.load(..., weights_only=False) before selecting a backend, which allows pickle payload execution from attacker-controlled files. In practice, simply running dp test -m untrusted.pt can execute code during format sniffing, even before model loading proceeds. Detection only needs state-dict key names, so this should use a safe load mode (for example weights_only=True) and fail closed when the payload is not a plain checkpoint dict.

Useful? React with 👍 / 👎.

Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.load(..., weights_only=False) on a user-provided path expands the attack surface for arbitrary code execution via pickle, especially since this runs during backend detection (before the user has opted into a given backend). Prefer weights_only=True for sniffing (it should still load tensor-only/dict checkpoints), and if it fails, fall back to suffix dispatch as you already do. If you must support older/odd checkpoints that require full unpickling, consider a guarded fallback with a clear warning or an explicit opt-in.

Suggested change
sd = torch.load(filename, map_location="cpu", weights_only=False)
sd = torch.load(filename, map_location="cpu", weights_only=True)

Copilot uses AI. Check for mistakes.
if isinstance(sd, dict) and "model" in sd:
sd = sd["model"]
keys = list(sd.keys()) if hasattr(sd, "keys") else []
has_pt_expt = any(k.endswith(".w") or k.endswith(".b") for k in keys)
has_pt = any(k.endswith(".matrix") or k.endswith(".bias") for k in keys)
if has_pt_expt and not has_pt:
target_name = "pt-expt"
else:
target_name = "pt"
for key, backend in Backend.get_backends().items():
if key == target_name:
return backend
Comment on lines +122 to +124
Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This manual lookup loop can be simplified to a direct backend lookup (e.g., Backend.get_backend(target_name)), which is clearer and avoids re-iterating the registry.

Suggested change
for key, backend in Backend.get_backends().items():
if key == target_name:
return backend
return Backend.get_backend(target_name)

Copilot uses AI. Check for mistakes.
except Exception:
# Fall through to suffix matching if sniffing fails.
pass
Comment on lines +108 to +127
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🌐 Web query:

torch.load weights_only=True state_dict keys behavior PyTorch 2.6

💡 Result:

In PyTorch 2.6, torch.load(checkpoint_path, weights_only=True) deserializes the saved object using a restricted unpickler that only allows tensors, primitive types (int, float, str, bool, None), and dictionaries. It does not allow arbitrary classes or objects unless explicitly allowlisted via torch.serialization.add_safe_globals. For a typical model checkpoint saved as torch.save({'model_state_dict': model.state_dict, ...}, path), torch.load with weights_only=True returns the full dictionary with top-level keys like 'model_state_dict', 'optimizer_state_dict', 'epoch', etc. Each state_dict value is an OrderedDict containing the model's parameter keys (e.g., 'conv1.weight', 'fc.bias') mapped to tensors. This works because the structure uses only allowed types: dict -> OrderedDict (treated as dict) -> str keys -> torch.Tensor values. Primitive metadata like epoch (int/float) is also allowed. Examples from official tutorials explicitly show: checkpoint = torch.load(PATH, weights_only=True) model.load_state_dict(checkpoint['model_state_dict']) If the checkpoint contains non-allowlisted objects (e.g., scheduler.state_dict with collections.Counter, or optimizer instances), it fails with _pickle.UnpicklingError, requiring allowlisting or weights_only=False (unsafe for untrusted files). In PyTorch 2.6+, weights_only=True is the default, promoting safer loading of weights-only checkpoints.

Citations:


🏁 Script executed:

cd deepmd && head -150 backend/backend.py | tail -60

Repository: deepmodeling/deepmd-kit

Length of output: 2431


🏁 Script executed:

# Check the full function and context around lines 108-127
sed -n '95,135p' deepmd/backend/backend.py

Repository: deepmodeling/deepmd-kit

Length of output: 1891


🏁 Script executed:

# Find torch version requirements in the repo
find . -name "*.txt" -o -name "*.toml" -o -name "*.cfg" -o -name "*.yml" | head -20 | xargs grep -l "torch" 2>/dev/null | head -5

Repository: deepmodeling/deepmd-kit

Length of output: 207


🏁 Script executed:

# Check if there are other calls to torch.load in the codebase
rg "torch\.load" --type py -A 2 -B 2

Repository: deepmodeling/deepmd-kit

Length of output: 15397


🏁 Script executed:

# Verify Backend.get_backends() structure
rg "def get_backends" --type py -A 10

Repository: deepmodeling/deepmd-kit

Length of output: 1194


Use weights_only=True for backend sniffing.

detect_backend_by_model is called during dispatch (e.g. dp test -m foo.pt) before the user explicitly trusts the checkpoint. Loading with weights_only=False executes pickle, enabling RCE on a malicious .pt file. Since this branch only inspects state_dict.keys() to disambiguate between pt and pt-expt backends, a stricter load is sufficient and safer:

Suggested change
-                sd = torch.load(filename, map_location="cpu", weights_only=False)
+                # Only state_dict keys are inspected; weights_only=True avoids
+                # executing pickle code on untrusted checkpoints during
+                # backend dispatch.
+                sd = torch.load(filename, map_location="cpu", weights_only=True)

If weights_only=True fails (e.g., on checkpoints with non-standard _extra_state objects), the except Exception: pass block gracefully falls through to suffix-based backend detection. The actual model loading in deepmd/pt_expt/infer/deep_eval.py (line 230) already uses weights_only=False once the backend is selected—the right place to accept that risk.

🧰 Tools
🪛 Ruff (0.15.11)

[error] 125-127: try-except-pass detected, consider logging the exception

(S110)


[warning] 125-125: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/backend/backend.py` around lines 108 - 127, The torch.load call in
detect_backend_by_model currently uses weights_only=False which can execute
pickled code and cause RCE; change the call in that block (the torch.load(...)
invocation) to torch.load(filename, map_location="cpu", weights_only=True) so we
only load tensor weights when sniffing backend, keeping the existing try/except
fallback to suffix-based detection; leave the later real model load in
deepmd/pt_expt/infer/deep_eval.py (which uses weights_only=False) unchanged.

for backend in Backend.get_backends().values():
for suffix in backend.suffixes:
if filename.endswith(suffix):
if filename_lower.endswith(suffix):
return backend
raise ValueError(f"Cannot detect the backend of the model file {filename}.")

Expand Down
164 changes: 163 additions & 1 deletion deepmd/pt_expt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,16 @@ def __init__(

if self._is_pt2:
self._load_pt2(model_file)
else:
elif model_file.endswith(".pte"):
self._load_pte(model_file)
elif model_file.endswith(".pt"):
self._load_pt(model_file, head=kwargs.get("head"))
else:
raise ValueError(
f"Unsupported model file '{model_file}' for the pt_expt "
"backend: expected `.pt2` / `.pte` (deployable archives) or "
"`.pt` (training checkpoint)."
)

if isinstance(auto_batch_size, bool):
if auto_batch_size:
Expand Down Expand Up @@ -206,6 +214,160 @@ def _load_pt2(self, model_file: str) -> None:
self._pt2_runner = aoti_load_package(model_file)
self.exported_module = None

def _load_pt(self, model_file: str, head: str | None = None) -> None:
"""Load a `.pt` training checkpoint (eager mode, no torch.export)."""
from copy import (
deepcopy,
)

from deepmd.pt.utils.env import (
Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loader is part of the pt_expt backend but imports DEVICE from deepmd.pt.utils.env (legacy pt backend). That can put tensors on an unintended device (or diverge from pt_expt’s device policy), causing incorrect placement / failures. Import DEVICE from deepmd.pt_expt.utils.env (or reuse an already-defined pt_expt device constant) to keep checkpoint loading consistent with the pt_expt runtime.

Suggested change
from deepmd.pt.utils.env import (
from deepmd.pt_expt.utils.env import (

Copilot uses AI. Check for mistakes.
DEVICE,
)
from deepmd.pt_expt.model import (
get_model,
)

state_dict = torch.load(model_file, map_location=DEVICE, weights_only=False)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Load pt_expt checkpoints without unsafe pickle execution

The new .pt inference path in pt_expt.DeepEval also uses torch.load(..., weights_only=False), so evaluating an untrusted checkpoint via dp --pt-expt test -m file.pt can execute arbitrary Python objects at deserialize time. This path only needs tensor weights and _extra_state metadata for ModelWrapper.load_state_dict, so full pickle loading is unnecessary and introduces a security regression compared with safe checkpoint loading.

Useful? React with 👍 / 👎.

if "model" in state_dict:
state_dict = state_dict["model"]
model_params = deepcopy(state_dict["_extra_state"]["model_params"])

Comment on lines +233 to +234
Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a .pt file is missing _extra_state / model_params (e.g., a non-pt_expt checkpoint, a partially-saved checkpoint, or a hand-edited artifact), this will raise a KeyError with a non-actionable message. Add an explicit check and raise a ValueError that explains what structure is expected for a pt_expt training checkpoint and how to proceed (e.g., ‘use the pt backend’ / ‘export to .pte/.pt2’ / ‘retrain with pt_expt’).

Suggested change
model_params = deepcopy(state_dict["_extra_state"]["model_params"])
checkpoint_extra_state: Any | None = None
if isinstance(state_dict, dict):
checkpoint_extra_state = state_dict.get("_extra_state")
if not (
isinstance(state_dict, dict)
and isinstance(checkpoint_extra_state, dict)
and "model_params" in checkpoint_extra_state
):
raise ValueError(
f"Invalid .pt file '{model_file}': expected a pt_expt training "
"checkpoint containing '_extra_state' with nested "
"'model_params'. The provided file does not have the expected "
"checkpoint structure. If this is a different PyTorch "
"checkpoint, load it with the 'pt' backend instead. If this is "
"an exported model, use a '.pte' or '.pt2' artifact. Otherwise, "
"retrain or re-export the model with pt_expt to create a "
"compatible training checkpoint."
)
model_params = deepcopy(checkpoint_extra_state["model_params"])

Copilot uses AI. Check for mistakes.
if "model_dict" in model_params:
# Multi-task: pick the requested head (defaults to "Default" if present).
heads = list(model_params["model_dict"].keys())
if head is None:
if "Default" in heads:
head = "Default"
else:
raise ValueError(
f"Multi-task checkpoint '{model_file}' has heads "
f"{heads}; pass --head to select one."
)
if head not in heads:
raise ValueError(
f"Head '{head}' not found in checkpoint '{model_file}'. "
f"Available heads: {heads}."
)
head_params = model_params["model_dict"][head]
# Restrict state_dict to the chosen head and rename to "Default".
head_state = {"_extra_state": state_dict["_extra_state"]}
for key, value in state_dict.items():
prefix = f"model.{head}."
if key.startswith(prefix):
head_state[key.replace(prefix, "model.Default.")] = (
value.clone() if torch.is_tensor(value) else value
)
Comment on lines +257 to +259
Copy link

Copilot AI Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cloning every tensor when selecting a multitask head can significantly increase memory/time for large checkpoints, and it’s not needed here since you’re only rebinding references into a new dict (not mutating tensors in-place). Assign the tensor directly unless there’s a specific downstream mutation that requires isolation.

Suggested change
head_state[key.replace(prefix, "model.Default.")] = (
value.clone() if torch.is_tensor(value) else value
)
head_state[key.replace(prefix, "model.Default.")] = value

Copilot uses AI. Check for mistakes.
state_dict = head_state
model_params = head_params

model = get_model(deepcopy(model_params)).to(DEVICE)

# Load weights into a {"Default": model} wrapper to match the
# `model.Default.*` key prefix used in the saved state_dict.
from deepmd.pt_expt.train.wrapper import (
ModelWrapper,
)

wrapper = ModelWrapper(model)
wrapper.load_state_dict(state_dict)
model = wrapper.model["Default"].eval()

self._dpmodel = model
self._is_spin = (
model_params.get("type") == "spin_ener" or "spin" in model_params
)
self.rcut = model.get_rcut()
self.type_map = model.get_type_map()
if self._is_spin:
self._model_output_def = ModelOutputDef(
FittingOutputDef(
[
OutputVariableDef(
"energy",
shape=[1],
reducible=True,
r_differentiable=True,
c_differentiable=True,
atomic=True,
magnetic=True,
)
]
)
)
else:
self._model_output_def = ModelOutputDef(model.atomic_output_def())
self._model_def_script = model_params
# Populate metadata so eval helpers (e.g. default_fparam fallback)
# behave the same as the .pt2/.pte path. Mirrors the fields that
# `_collect_metadata` writes into metadata.json.
self.metadata = {
"type_map": model.get_type_map(),
"rcut": model.get_rcut(),
"sel": model.get_sel(),
"dim_fparam": model.get_dim_fparam(),
"dim_aparam": model.get_dim_aparam(),
"mixed_types": model.mixed_types(),
"has_default_fparam": model.has_default_fparam(),
"default_fparam": model.get_default_fparam(),
"is_spin": self._is_spin,
}
if self._is_spin:
self.metadata["ntypes_spin"] = model.spin.get_ntypes_spin()
self.metadata["use_spin"] = [bool(v) for v in model.spin.use_spin]

# Eager runner with the same signature as the .pt2/.pte exported module.
# Use forward_common_lower (not forward_lower) to match the export-time
# output keys ("energy", "energy_redu", "energy_derv_r", ...) that
# communicate_extended_output downstream consumes.
# Non-spin: (ext_coord, ext_atype, nlist, mapping, fparam, aparam)
# Spin: (ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam)
if self._is_spin:

def _eager_runner_spin(
ext_coord: torch.Tensor,
ext_atype: torch.Tensor,
ext_spin: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor | None,
fparam: torch.Tensor | None,
aparam: torch.Tensor | None,
) -> dict[str, torch.Tensor]:
ext_coord = ext_coord.detach().requires_grad_(True)
return model.forward_common_lower(
ext_coord,
ext_atype,
ext_spin,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=True,
)

self.exported_module = _eager_runner_spin
else:

def _eager_runner(
ext_coord: torch.Tensor,
ext_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor | None,
fparam: torch.Tensor | None,
aparam: torch.Tensor | None,
) -> dict[str, torch.Tensor]:
ext_coord = ext_coord.detach().requires_grad_(True)
return model.forward_common_lower(
ext_coord,
ext_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=True,
)

self.exported_module = _eager_runner

def get_rcut(self) -> float:
"""Get the cutoff radius of this model."""
return self.rcut
Expand Down
38 changes: 38 additions & 0 deletions deepmd/pt_expt/model/get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
from deepmd.pt_expt.model.property_model import (
PropertyModel,
)
from deepmd.pt_expt.model.spin_ener_model import (
SpinEnergyModel,
)
from deepmd.utils.spin import (
Spin,
)


def _get_standard_model_components(
Expand Down Expand Up @@ -162,6 +168,36 @@ def get_linear_model(model_params: dict) -> BaseModel:
)


def get_spin_model(data: dict) -> SpinEnergyModel:
"""Build a pt_expt spin energy model from a config dictionary.

Mirrors :func:`deepmd.dpmodel.model.model.get_spin_model`: expands the
type map and descriptor sel for virtual spin atoms, then wraps the
backbone EnergyModel as a :class:`SpinEnergyModel`.
"""
data = copy.deepcopy(data)
data["type_map"] += [item + "_spin" for item in data["type_map"]]
spin = Spin(
use_spin=data["spin"]["use_spin"],
virtual_scale=data["spin"]["virtual_scale"],
)
pair_exclude_types = spin.get_pair_exclude_types(
exclude_types=data.get("pair_exclude_types", None)
)
data["pair_exclude_types"] = pair_exclude_types
data["descriptor"]["exclude_types"] = pair_exclude_types
atom_exclude_types = spin.get_atom_exclude_types(
exclude_types=data.get("atom_exclude_types", None)
)
data["atom_exclude_types"] = atom_exclude_types
if "env_protection" not in data["descriptor"]:
data["descriptor"]["env_protection"] = 1e-6
if data["descriptor"]["type"] in ["se_e2_a"]:
data["descriptor"]["sel"] += data["descriptor"]["sel"]
backbone_model = get_standard_model(data)
return SpinEnergyModel(backbone_model=backbone_model, spin=spin)


def get_model(data: dict) -> BaseModel:
"""Get a model from a config dictionary.

Expand All @@ -172,6 +208,8 @@ def get_model(data: dict) -> BaseModel:
"""
model_type = data.get("type", "standard")
if model_type == "standard":
if "spin" in data:
return get_spin_model(data)
return get_standard_model(data)
elif model_type == "linear_ener":
return get_linear_model(data)
Expand Down
Loading
Loading