Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/context.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ Since [`bash -l`](https://www.gnu.org/software/bash/manual/bash.html#Invoking-Ba
Files will be copied to the remote directory via SSH channels before jobs start and copied back after jobs finish.
To use SSH, one needs to provide necessary parameters in {dargs:argument}`remote_profile <machine[SSHContext]/remote_profile>`, such as {dargs:argument}`username <machine[SSHContext]/remote_profile/hostname>` and {dargs:argument}`hostname <username[SSHContext]/remote_profile/hostname>`.

By default, DPDispatcher requires {dargs:argument}`remote_root <machine/remote_root>` to already exist on the remote machine, which helps catch typos in remote paths. If you want DPDispatcher to recursively create that directory tree for you, set {dargs:argument}`create_remote_root <machine[SSHContext]/create_remote_root>` to `true`.

It's suggested to generate [SSH keys](https://help.ubuntu.com/community/SSH/OpenSSH/Keys) and transfer the public key to the remote server in advance, which is more secure than password authentication.

Note that `SSH` context is [non-login](https://www.gnu.org/software/bash/manual/html_node/Bash-Startup-Files.html), so `bash_profile` files will not be executed outside the submission script.
Expand Down
1 change: 1 addition & 0 deletions doc/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ where `machine.json` is
"context_type": "SSHContext",
"local_root": "/home/user123/workplace/22_new_project/",
"remote_root": "/home/user123/dpdispatcher_work_dir/",
"create_remote_root": true,
"remote_profile": {
"hostname": "39.106.xx.xxx",
"username": "user123",
Expand Down
53 changes: 43 additions & 10 deletions dpdispatcher/contexts/ssh_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ def __init__(
remote_root,
remote_profile,
clean_asynchronously=False,
create_remote_root=False,
*args,
**kwargs,
):
Expand All @@ -480,6 +481,7 @@ def __init__(

# self.job_uuid = None
self.clean_asynchronously = clean_asynchronously
self.create_remote_root = create_remote_root
# self.job_uuid = job_uuid
# if job_uuid:
# self.job_uuid=job_uuid
Expand All @@ -488,10 +490,7 @@ def __init__(
self.ssh_session = SSHSession(**remote_profile)
# self.temp_remote_root = os.path.join(self.ssh_session.get_session_root())
self.ssh_session.ensure_alive()
try:
self.sftp.mkdir(self.temp_remote_root)
except OSError:
pass
self._mkdir(self.temp_remote_root, recursive=self.create_remote_root)

@classmethod
def load_from_dict(cls, context_dict):
Expand All @@ -511,12 +510,14 @@ def load_from_dict(cls, context_dict):
remote_root = context_dict["remote_root"]
remote_profile = context_dict["remote_profile"]
clean_asynchronously = context_dict.get("clean_asynchronously", False)
create_remote_root = context_dict.get("create_remote_root", False)

ssh_context = cls(
local_root=local_root,
remote_root=remote_root,
remote_profile=remote_profile,
clean_asynchronously=clean_asynchronously,
create_remote_root=create_remote_root,
)
# local_root = jdata['local_root']
# ssh_session = SSHSession(**input)
Expand All @@ -541,6 +542,28 @@ def close(self):
def get_job_root(self):
return self.remote_root

def _mkdir(self, remote_dir, recursive=False):
if not remote_dir:
return

sftp = self.sftp
if not recursive:
try:
sftp.mkdir(remote_dir)
except OSError:
pass
return

path = pathlib.PurePosixPath(remote_dir)
current = path.root if path.is_absolute() else ""
parts = path.parts[1:] if path.is_absolute() else path.parts
for part in parts:
current = pathlib.PurePosixPath(current, part).as_posix()
try:
sftp.mkdir(current)
except OSError:
pass

def bind_submission(self, submission):
assert self.ssh_session is not None
assert self.ssh_session.ssh is not None
Expand Down Expand Up @@ -572,11 +595,7 @@ def bind_submission(self, submission):
# if the new directory exists and the old directory does not contain files, then move the old directory
self._rmtree(old_remote_root)

sftp = self.ssh_session.ssh.open_sftp()
try:
sftp.mkdir(self.remote_root)
except OSError:
pass
self._mkdir(self.remote_root, recursive=self.create_remote_root)

# self.job_uuid = submission.submission_hash
# dlog.debug("debug:SSHContext.bind_submission"
Expand Down Expand Up @@ -1013,8 +1032,22 @@ def machine_subfields(cls) -> List[Argument]:
list[Argument]
machine subfields
"""
doc_create_remote_root = (
"Whether DPDispatcher should recursively create the configured SSH remote_root "
"when parent directories do not already exist. Keep this disabled by default "
"to avoid silently creating directories for a mistyped path."
)
doc_remote_profile = "SSH connection settings for the remote machine, including authentication, timeouts, and optional proxy/jump-host behavior."
remote_profile_format = SSHSession.arginfo()
remote_profile_format.name = "remote_profile"
remote_profile_format.doc = doc_remote_profile
return [remote_profile_format]
return [
Argument(
"create_remote_root",
bool,
optional=True,
default=False,
doc=doc_create_remote_root,
),
remote_profile_format,
]
4 changes: 4 additions & 0 deletions dpdispatcher/machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ def serialize(self, if_empty_remote_profile=False):
machine_dict["context_type"] = self.context.__class__.__name__
machine_dict["local_root"] = self.context.init_local_root
machine_dict["remote_root"] = self.context.init_remote_root
if hasattr(self.context, "clean_asynchronously"):
machine_dict["clean_asynchronously"] = self.context.clean_asynchronously
if hasattr(self.context, "create_remote_root"):
machine_dict["create_remote_root"] = self.context.create_remote_root
Comment on lines +178 to +181
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify if LocalContext defines clean_asynchronously attribute
rg -n "clean_asynchronously" dpdispatcher/contexts/local_context.py

# Check if the test passes with current implementation
rg -nA5 "def test_machine_argcheck" tests/test_argcheck.py

Repository: deepmodeling/dpdispatcher

Length of output: 297


🏁 Script executed:

#!/bin/bash
# Get the full test_machine_argcheck test
sed -n '12,35p' tests/test_argcheck.py

Repository: deepmodeling/dpdispatcher

Length of output: 838


🏁 Script executed:

#!/bin/bash
# Get the serialize method in machine.py around lines 178-181 and show more context
sed -n '160,200p' dpdispatcher/machine.py

Repository: deepmodeling/dpdispatcher

Length of output: 1824


Serialization issue: LocalContext machines will not include clean_asynchronously in serialized output, but the test expects it.

The serialize() method conditionally includes clean_asynchronously only via hasattr(self.context, "clean_asynchronously"). However, LocalContext does not define this attribute, so for LocalContext machines the field will not be included in the serialized dict. Yet test_machine_argcheck expects "clean_asynchronously": False in the output for a LocalContext machine (line 29). This will cause the test to fail.

Either:

  1. LocalContext needs to define clean_asynchronously as an attribute, or
  2. The test expectation is incorrect

Additionally, the serialize() method lacks type hints. Per the coding guidelines, add proper type annotations: def serialize(self, if_empty_remote_profile: bool = False) -> dict:

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@dpdispatcher/machine.py` around lines 178 - 181, serialize() currently only
adds clean_asynchronously when the context object has that attribute, which
omits the key for LocalContext and breaks test_machine_argcheck; update
serialize(self, if_empty_remote_profile: bool = False) -> dict to always include
"clean_asynchronously" (use getattr(self.context, "clean_asynchronously", False)
or similar) and also include create_remote_root consistently (use getattr with a
sensible default), and add the requested type annotation to the serialize method
signature so it returns a dict with the optional if_empty_remote_profile boolean
parameter; this keeps tests expecting "clean_asynchronously": False working and
satisfies the typing guideline.

if not if_empty_remote_profile:
machine_dict["remote_profile"] = self.context.remote_profile
else:
Expand Down
1 change: 1 addition & 0 deletions examples/machine/ssh_proxy_command.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"context_type": "SSHContext",
"local_root": "./",
"remote_root": "/home/user/work",
"create_remote_root": true,
"remote_profile": {
"hostname": "internal-server.company.com",
"username": "user",
Expand Down
63 changes: 63 additions & 0 deletions tests/test_argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,69 @@ def test_machine_argcheck(self):
}
self.assertDictEqual(norm_dict, expected_dict)

def test_ssh_machine_argcheck(self):
from .context import SSHContext

original_init = SSHContext.__init__

def fake_init(
self,
local_root,
remote_root,
remote_profile,
clean_asynchronously=False,
create_remote_root=False,
*args,
**kwargs,
):
self.init_local_root = local_root
self.init_remote_root = remote_root
self.remote_profile = remote_profile
self.clean_asynchronously = clean_asynchronously
self.create_remote_root = create_remote_root

SSHContext.__init__ = fake_init
try:
norm_dict = Machine.load_from_dict(
{
"batch_type": "slurm",
"context_type": "ssh",
"local_root": "./",
"remote_root": "/some/path",
"remote_profile": {
"hostname": "host",
"username": "user",
},
"create_remote_root": True,
}
).serialize()
finally:
SSHContext.__init__ = original_init

expected_dict = {
"batch_type": "Slurm",
"context_type": "SSHContext",
"local_root": "./",
"remote_root": "/some/path",
"remote_profile": {
"hostname": "host",
"username": "user",
"port": 22,
"key_filename": None,
"passphrase": None,
"timeout": 10,
"totp_secret": None,
"tar_compress": True,
"look_for_keys": True,
"execute_command": None,
"proxy_command": None,
},
"clean_asynchronously": False,
"create_remote_root": True,
"retry_count": 3,
}
self.assertDictEqual(norm_dict, expected_dict)

def test_resources_argcheck(self):
norm_dict = Resources.load_from_dict(
{
Expand Down
93 changes: 93 additions & 0 deletions tests/test_ssh_create_remote_root.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os
import sys
import unittest
from unittest.mock import MagicMock

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
__package__ = "tests"

from .context import SSHContext, setUpModule # noqa: F401


class TestSSHCreateRemoteRoot(unittest.TestCase):
def test_recursive_mkdir_disabled_by_default(self):
calls = []
context = SSHContext.__new__(SSHContext)
context.ssh_session = MagicMock()
context.ssh_session.sftp = MagicMock()
context.ssh_session.sftp.mkdir.side_effect = lambda path: calls.append(path)

context._mkdir("/data/home/user/work", recursive=False)

self.assertEqual(calls, ["/data/home/user/work"])

def test_recursive_mkdir_creates_missing_parents(self):
calls = []
context = SSHContext.__new__(SSHContext)
context.ssh_session = MagicMock()
context.ssh_session.sftp = MagicMock()

def mkdir(path):
calls.append(path)
if path in {"/data", "/data/home/user/work"}:
raise OSError("already exists")

context.ssh_session.sftp.mkdir.side_effect = mkdir

context._mkdir("/data/home/user/work", recursive=True)

self.assertEqual(
calls,
[
"/data",
"/data/home",
"/data/home/user",
"/data/home/user/work",
],
)

def test_machine_roundtrip_keeps_create_remote_root(self):
machine_dict = {
"batch_type": "Shell",
"context_type": "SSHContext",
"local_root": "./",
"remote_root": "/some/path",
"clean_asynchronously": False,
"create_remote_root": True,
"remote_profile": {
"hostname": "example.com",
"username": "alice",
},
}

from .context import Machine

original_init = SSHContext.__init__

def fake_init(
self,
local_root,
remote_root,
remote_profile,
clean_asynchronously=False,
create_remote_root=False,
*args,
**kwargs,
):
self.init_local_root = local_root
self.init_remote_root = remote_root
self.remote_profile = remote_profile
self.clean_asynchronously = clean_asynchronously
self.create_remote_root = create_remote_root

SSHContext.__init__ = fake_init
try:
machine = Machine.load_from_dict(machine_dict)
serialized = machine.serialize()
finally:
SSHContext.__init__ = original_init

self.assertTrue(serialized["create_remote_root"])
self.assertFalse(serialized["clean_asynchronously"])
self.assertEqual(serialized["remote_root"], "/some/path")
self.assertEqual(serialized["remote_profile"]["hostname"], "example.com")
Loading