Skip to content

Commit f689ff8

Browse files
authored
Merge pull request #25 from Point72/tbg/expose-perspective-proxy-server
Exposing more utilities so that users can interact with a PerspectiveProxyRayServer
2 parents 26a9c97 + 83acac2 commit f689ff8

5 files changed

Lines changed: 53 additions & 22 deletions

File tree

raydar/dashboard/server.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ def clear_table(self, tablename: str, schema) -> None:
5151
if tablename in self._tables:
5252
self._tables[tablename].clear()
5353

54-
def get_table(self, tablename: str) -> None:
55-
return self._schemas.get(tablename, None)
56-
5754
@app.get("/")
5855
async def site(self):
5956
return FileResponse(join(static_files_dir, "index.html"))

raydar/task_tracker/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .task_tracker import AsyncMetadataTracker, RayTaskTracker
1+
from .task_tracker import *

raydar/task_tracker/task_tracker.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,41 @@
99
from collections.abc import Iterable
1010
from packaging.version import Version
1111
from ray.serve import shutdown
12+
from ray.serve.handle import DeploymentHandle
1213
from typing import Dict, List, Optional
1314

1415
from .schema import schema as default_schema
1516

1617
logger = logging.getLogger(__name__)
1718

18-
__all__ = ("AsyncMetadataTracker", "RayTaskTracker")
19+
__all__ = ("AsyncMetadataTracker", "RayTaskTracker", "setup_proxy_server")
1920

2021

2122
def get_callback_actor_name(name: str) -> str:
2223
return f"{name}_callback_actor"
2324

2425

26+
def setup_proxy_server(proxy_server_name="proxy", proxy_server_route_prefix="/proxy", **kwargs) -> DeploymentHandle:
27+
"""Construct a webserver, and bind it to a PerspectiveProxyRayServer.
28+
29+
Args:
30+
proxy_server_name: the name passed to ray.serve.run for the PerspectiveProxyRayServer
31+
proxy_server_route_prefix: the route_prefix passed to ray.serve.run for the PerspectiveProxyRayServer
32+
**kwargs: arguments forwarded to ray.serve.run() for the webserver
33+
34+
Returns: A DeploymentHandle for the PerspectiveProxyRayServer
35+
"""
36+
from raydar.dashboard.server import PerspectiveProxyRayServer
37+
38+
webserver = ray.serve.run(**kwargs)
39+
proxy_server = ray.serve.run(
40+
PerspectiveProxyRayServer.bind(webserver),
41+
name=proxy_server_name,
42+
route_prefix=proxy_server_route_prefix,
43+
)
44+
return proxy_server
45+
46+
2547
@ray.remote(resources={"node:__internal_head__": 0.1}, num_cpus=0)
2648
class AsyncMetadataTrackerCallback:
2749
"""
@@ -107,22 +129,18 @@ def __init__(
107129
self.client = StateApiClient(address=ray.get_runtime_context().gcs_address)
108130

109131
if self.perspective_dashboard_enabled:
110-
from raydar.dashboard.server import PerspectiveProxyRayServer, PerspectiveRayServer
132+
from raydar.dashboard.server import PerspectiveRayServer
111133

112134
kwargs = dict(
113135
target=PerspectiveRayServer.bind(),
114136
name="webserver",
115137
route_prefix="/",
116138
)
139+
117140
if Version(ray.__version__) < Version("2.10"):
118141
kwargs["port"] = os.environ.get("RAYDAR_PORT", 8000)
119142

120-
self.webserver = ray.serve.run(**kwargs)
121-
self.proxy_server = ray.serve.run(
122-
PerspectiveProxyRayServer.bind(self.webserver),
123-
name="proxy",
124-
route_prefix="/proxy",
125-
)
143+
self.proxy_server = setup_proxy_server(**kwargs)
126144
self.proxy_server.remote(
127145
"new",
128146
self.perspective_table_name,
@@ -149,14 +167,6 @@ def __init__(
149167
},
150168
)
151169

152-
def get_proxy_server(self) -> ray.serve.handle.DeploymentHandle:
153-
"""A getter for this actors proxy server attribute. Can be used to create custom perspective visuals.
154-
Returns: this actors proxy_server attribute
155-
"""
156-
if self.proxy_server:
157-
return self.proxy_server
158-
raise Exception("This task_tracker has no active proxy_server.")
159-
160170
def callback(self, tasks: Iterable[ray.ObjectRef]) -> None:
161171
"""A remote function used by this actor's processor actor attribute. Will be called by a separate actor
162172
with a collection of ray object references once those ObjectReferences are not in the "RUNNING" or
@@ -287,6 +297,14 @@ def get_df(self) -> pl.DataFrame:
287297
)
288298
return self.df
289299

300+
def get_proxy_server(self) -> ray.serve.handle.DeploymentHandle:
301+
"""A getter for this actors proxy server attribute. Can be used to create custom perspective visuals.
302+
Returns: this actors proxy_server attribute
303+
"""
304+
if self.proxy_server:
305+
return self.proxy_server
306+
raise Exception("This task_tracker has no active proxy_server.")
307+
290308
def save_df(self) -> None:
291309
"""Saves the internally maintained dataframe of task related information from the ray GCS"""
292310
self.get_df()

raydar/tests/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def unittest_ray_config():
3535
config = dict(
3636
num_cpus=5,
3737
include_dashboard=True,
38-
dashboard_host="0.0.0.0",
3938
)
4039
return config
4140

raydar/tests/test_task_tracker.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import pytest
22
import ray
3+
import requests
34
import time
45

5-
from raydar import RayTaskTracker
6+
from raydar import RayTaskTracker, setup_proxy_server
67

78

89
@ray.remote
@@ -21,3 +22,19 @@ def test_construction_and_dataframe(self):
2122
time.sleep(30)
2223
df = task_tracker.get_df()
2324
assert df[["name", "state"]].row(0) == ("do_some_work", "FINISHED")
25+
26+
def test_get_proxy_server(self):
27+
from raydar.dashboard.server import PerspectiveRayServer
28+
29+
kwargs = dict(
30+
target=PerspectiveRayServer.bind(),
31+
name="webserver",
32+
route_prefix="/",
33+
)
34+
server = setup_proxy_server(**kwargs)
35+
server.remote("new", "test_table", dict(a="str", b="int", c="float", d="datetime"))
36+
time.sleep(2)
37+
server.remote("update", "test_table", [dict(a="foo", b=1, c=1.0, d=time.time())])
38+
time.sleep(2)
39+
response = requests.get("http://localhost:8000/tables")
40+
assert eval(response.text) == ["test_table"]

0 commit comments

Comments
 (0)