-
Notifications
You must be signed in to change notification settings - Fork 609
feat(pt_expt): support .pt training checkpoints in DeepEval #5423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -101,10 +101,33 @@ def detect_backend_by_model(filename: str) -> type["Backend"]: | |||||||||
| filename : str | ||||||||||
| The model file name | ||||||||||
| """ | ||||||||||
| filename = str(filename).lower() | ||||||||||
| filename_lower = str(filename).lower() | ||||||||||
| # `.pt` is shared between the pt and pt_expt backends. They use | ||||||||||
| # different parameter naming (pt: `.matrix`/`.bias`, pt_expt: | ||||||||||
| # `.w`/`.b`), so peek at the state-dict keys to disambiguate. | ||||||||||
| if filename_lower.endswith(".pt"): | ||||||||||
| try: | ||||||||||
| import torch | ||||||||||
|
|
||||||||||
| sd = torch.load(filename, map_location="cpu", weights_only=False) | ||||||||||
|
||||||||||
| sd = torch.load(filename, map_location="cpu", weights_only=False) | |
| sd = torch.load(filename, map_location="cpu", weights_only=True) |
Copilot
AI
Apr 26, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This manual lookup loop can be simplified to a direct backend lookup (e.g., Backend.get_backend(target_name)), which is clearer and avoids re-iterating the registry.
| for key, backend in Backend.get_backends().items(): | |
| if key == target_name: | |
| return backend | |
| return Backend.get_backend(target_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🌐 Web query:
torch.load weights_only=True state_dict keys behavior PyTorch 2.6
💡 Result:
In PyTorch 2.6, torch.load(checkpoint_path, weights_only=True) deserializes the saved object using a restricted unpickler that only allows tensors, primitive types (int, float, str, bool, None), and dictionaries. It does not allow arbitrary classes or objects unless explicitly allowlisted via torch.serialization.add_safe_globals. For a typical model checkpoint saved as torch.save({'model_state_dict': model.state_dict, ...}, path), torch.load with weights_only=True returns the full dictionary with top-level keys like 'model_state_dict', 'optimizer_state_dict', 'epoch', etc. Each state_dict value is an OrderedDict containing the model's parameter keys (e.g., 'conv1.weight', 'fc.bias') mapped to tensors. This works because the structure uses only allowed types: dict -> OrderedDict (treated as dict) -> str keys -> torch.Tensor values. Primitive metadata like epoch (int/float) is also allowed. Examples from official tutorials explicitly show: checkpoint = torch.load(PATH, weights_only=True) model.load_state_dict(checkpoint['model_state_dict']) If the checkpoint contains non-allowlisted objects (e.g., scheduler.state_dict with collections.Counter, or optimizer instances), it fails with _pickle.UnpicklingError, requiring allowlisting or weights_only=False (unsafe for untrusted files). In PyTorch 2.6+, weights_only=True is the default, promoting safer loading of weights-only checkpoints.
Citations:
- 1: https://docs.pytorch.org/docs/2.6/generated/torch.load.html
- 2: https://docs.pytorch.org/docs/2.6/notes/serialization.html
- 3: https://pytorch.org/tutorials/_sources/beginner/saving_loading_models.rst.txt
- 4: https://docs.pytorch.org/docs/2.8/_sources/notes/serialization.rst.txt
🏁 Script executed:
cd deepmd && head -150 backend/backend.py | tail -60Repository: deepmodeling/deepmd-kit
Length of output: 2431
🏁 Script executed:
# Check the full function and context around lines 108-127
sed -n '95,135p' deepmd/backend/backend.pyRepository: deepmodeling/deepmd-kit
Length of output: 1891
🏁 Script executed:
# Find torch version requirements in the repo
find . -name "*.txt" -o -name "*.toml" -o -name "*.cfg" -o -name "*.yml" | head -20 | xargs grep -l "torch" 2>/dev/null | head -5Repository: deepmodeling/deepmd-kit
Length of output: 207
🏁 Script executed:
# Check if there are other calls to torch.load in the codebase
rg "torch\.load" --type py -A 2 -B 2Repository: deepmodeling/deepmd-kit
Length of output: 15397
🏁 Script executed:
# Verify Backend.get_backends() structure
rg "def get_backends" --type py -A 10Repository: deepmodeling/deepmd-kit
Length of output: 1194
Use weights_only=True for backend sniffing.
detect_backend_by_model is called during dispatch (e.g. dp test -m foo.pt) before the user explicitly trusts the checkpoint. Loading with weights_only=False executes pickle, enabling RCE on a malicious .pt file. Since this branch only inspects state_dict.keys() to disambiguate between pt and pt-expt backends, a stricter load is sufficient and safer:
Suggested change
- sd = torch.load(filename, map_location="cpu", weights_only=False)
+ # Only state_dict keys are inspected; weights_only=True avoids
+ # executing pickle code on untrusted checkpoints during
+ # backend dispatch.
+ sd = torch.load(filename, map_location="cpu", weights_only=True)If weights_only=True fails (e.g., on checkpoints with non-standard _extra_state objects), the except Exception: pass block gracefully falls through to suffix-based backend detection. The actual model loading in deepmd/pt_expt/infer/deep_eval.py (line 230) already uses weights_only=False once the backend is selected—the right place to accept that risk.
🧰 Tools
🪛 Ruff (0.15.11)
[error] 125-127: try-except-pass detected, consider logging the exception
(S110)
[warning] 125-125: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/backend/backend.py` around lines 108 - 127, The torch.load call in
detect_backend_by_model currently uses weights_only=False which can execute
pickled code and cause RCE; change the call in that block (the torch.load(...)
invocation) to torch.load(filename, map_location="cpu", weights_only=True) so we
only load tensor weights when sniffing backend, keeping the existing try/except
fallback to suffix-based detection; leave the later real model load in
deepmd/pt_expt/infer/deep_eval.py (which uses weights_only=False) unchanged.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -99,8 +99,16 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| if self._is_pt2: | ||||||||||||||||||||||||||||||||||||||||||||
| self._load_pt2(model_file) | ||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||
| elif model_file.endswith(".pte"): | ||||||||||||||||||||||||||||||||||||||||||||
| self._load_pte(model_file) | ||||||||||||||||||||||||||||||||||||||||||||
| elif model_file.endswith(".pt"): | ||||||||||||||||||||||||||||||||||||||||||||
| self._load_pt(model_file, head=kwargs.get("head")) | ||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||
| f"Unsupported model file '{model_file}' for the pt_expt " | ||||||||||||||||||||||||||||||||||||||||||||
| "backend: expected `.pt2` / `.pte` (deployable archives) or " | ||||||||||||||||||||||||||||||||||||||||||||
| "`.pt` (training checkpoint)." | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(auto_batch_size, bool): | ||||||||||||||||||||||||||||||||||||||||||||
| if auto_batch_size: | ||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -206,6 +214,178 @@ def _load_pt2(self, model_file: str) -> None: | |||||||||||||||||||||||||||||||||||||||||||
| self._pt2_runner = aoti_load_package(model_file) | ||||||||||||||||||||||||||||||||||||||||||||
| self.exported_module = None | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def _load_pt(self, model_file: str, head: str | None = None) -> None: | ||||||||||||||||||||||||||||||||||||||||||||
| """Load a `.pt` training checkpoint (eager mode, no torch.export).""" | ||||||||||||||||||||||||||||||||||||||||||||
| from copy import ( | ||||||||||||||||||||||||||||||||||||||||||||
| deepcopy, | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| from deepmd.pt.utils.env import ( | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
| from deepmd.pt.utils.env import ( | |
| from deepmd.pt_expt.utils.env import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Load pt_expt checkpoints without unsafe pickle execution
The new .pt inference path in pt_expt.DeepEval also uses torch.load(..., weights_only=False), so evaluating an untrusted checkpoint via dp --pt-expt test -m file.pt can execute arbitrary Python objects at deserialize time. This path only needs tensor weights and _extra_state metadata for ModelWrapper.load_state_dict, so full pickle loading is unnecessary and introduces a security regression compared with safe checkpoint loading.
Useful? React with 👍 / 👎.
Copilot
AI
Apr 26, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a .pt file is missing _extra_state / model_params (e.g., a non-pt_expt checkpoint, a partially-saved checkpoint, or a hand-edited artifact), this will raise a KeyError with a non-actionable message. Add an explicit check and raise a ValueError that explains what structure is expected for a pt_expt training checkpoint and how to proceed (e.g., ‘use the pt backend’ / ‘export to .pte/.pt2’ / ‘retrain with pt_expt’).
| model_params = deepcopy(state_dict["_extra_state"]["model_params"]) | |
| checkpoint_extra_state: Any | None = None | |
| if isinstance(state_dict, dict): | |
| checkpoint_extra_state = state_dict.get("_extra_state") | |
| if not ( | |
| isinstance(state_dict, dict) | |
| and isinstance(checkpoint_extra_state, dict) | |
| and "model_params" in checkpoint_extra_state | |
| ): | |
| raise ValueError( | |
| f"Invalid .pt file '{model_file}': expected a pt_expt training " | |
| "checkpoint containing '_extra_state' with nested " | |
| "'model_params'. The provided file does not have the expected " | |
| "checkpoint structure. If this is a different PyTorch " | |
| "checkpoint, load it with the 'pt' backend instead. If this is " | |
| "an exported model, use a '.pte' or '.pt2' artifact. Otherwise, " | |
| "retrain or re-export the model with pt_expt to create a " | |
| "compatible training checkpoint." | |
| ) | |
| model_params = deepcopy(checkpoint_extra_state["model_params"]) |
Copilot
AI
Apr 26, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cloning every tensor when selecting a multitask head can significantly increase memory/time for large checkpoints, and it’s not needed here since you’re only rebinding references into a new dict (not mutating tensors in-place). Assign the tensor directly unless there’s a specific downstream mutation that requires isolation.
| head_state[key.replace(prefix, "model.Default.")] = ( | |
| value.clone() if torch.is_tensor(value) else value | |
| ) | |
| head_state[key.replace(prefix, "model.Default.")] = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.ptduring backend detectiondetect_backend_by_modelnow deserializes.ptfiles withtorch.load(..., weights_only=False)before selecting a backend, which allows pickle payload execution from attacker-controlled files. In practice, simply runningdp test -m untrusted.ptcan execute code during format sniffing, even before model loading proceeds. Detection only needs state-dict key names, so this should use a safe load mode (for exampleweights_only=True) and fail closed when the payload is not a plain checkpoint dict.Useful? React with 👍 / 👎.