Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
333 changes: 325 additions & 8 deletions csp_gateway/client/client.py

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions csp_gateway/server/demo/omnibus.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
MountPerspectiveTables,
MountRestRoutes,
MountWebSocketRoutes,
Stage,
State,
)
from csp_gateway.server.config import load_gateway
Expand Down Expand Up @@ -127,6 +128,9 @@ class ExampleGatewayChannels(GatewayChannels):
basket: Dict[ExampleEnum, ts[ExampleData]] = None
str_basket: Dict[str, ts[ExampleData]] = None

# Staging can be added via annotation or the `set_stage` API in the module's `connect` method
example_with_stage: Annotated[ts[ExampleData], Stage()] = None

# FIXME
# basket_list: Dict[ExampleEnum, ts[[ExampleData]]] = None
# NOTE: this second one is not populated with data so will 404
Expand Down Expand Up @@ -188,6 +192,7 @@ def connect(self, channels: ExampleGatewayChannels):
channels.set_channel(ExampleGatewayChannels.example_list, data_list)
channels.set_channel(ExampleGatewayChannels.example_with_state, data)
channels.set_channel(ExampleGatewayChannels.example_with_state_multiple, data)
channels.set_channel(ExampleGatewayChannels.example_with_stage, data)

# Generic channel for sending data from non-csp sources
channels.add_send_channel(ExampleGatewayChannels.example)
Expand All @@ -202,6 +207,9 @@ def connect(self, channels: ExampleGatewayChannels):
("id",),
)

# Staging via `set_stage` — enables stage_add/remove/release/list/lookup APIs
channels.set_stage(ExampleGatewayChannels.example)

# Create some data streams for dict baskets
data_a = self.subscribe(
csp.timer(interval=self.interval, value=True),
Expand Down
1 change: 1 addition & 0 deletions csp_gateway/server/gateway/csp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
on_request_node_dict_basket,
)
from .module import Module
from .stage import *
from .state import *
175 changes: 173 additions & 2 deletions csp_gateway/server/gateway/csp/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
named_wait_for_next_node,
named_wait_for_next_node_dict_basket,
)
from .state import State, build_track_state_node
from .stage import Stage, _StageManager, build_staging_node
from .state import State, _StateManager, build_track_state_node

if TYPE_CHECKING:
from csp_gateway.utils import Query
Expand Down Expand Up @@ -186,6 +187,7 @@ def __new__(mcs: Any, name: Any, bases: Any, namespace: Any, **kwargs: Any) -> A
_add_field_attributes(cls)
ts_pydantic_field_types = {}
declared_states: Dict[str, _StateSpec] = {}
declared_stages: List[str] = []
for field_name, field_type in cls.model_fields.items():
# Validate that timeseries types contain structs or list of structs
outer_type = field_type.annotation
Expand All @@ -206,8 +208,12 @@ def __new__(mcs: Any, name: Any, bases: Any, namespace: Any, **kwargs: Any) -> A
keyby=_normalize_keyby(meta._meta_keyby),
indexer=meta._meta_indexer,
)
elif isinstance(meta, Stage):
if field_name not in declared_stages:
declared_stages.append(field_name)

cls._declared_states = declared_states
cls._declared_stages = declared_stages

ts_pydantic_field_types[_CSP_ENGINE_CYCLE_TIMESTAMP_FIELD] = (Optional[datetime], None)
dynamic_pydantic_model = create_model("_snapshot_model", __base__=_SnapshotModelBaseClass, **ts_pydantic_field_types)
Expand Down Expand Up @@ -235,6 +241,9 @@ class Channels(BaseModel, metaclass=ChannelsMetaclass):
# alias -> _StateSpec(source_field, keyby, indexer)
_declared_states: Dict[str, _StateSpec] = {}

# Populated by ChannelsMetaclass from Annotated[ts[X], Stage()] markers.
_declared_stages: List[str] = []

model_config = dict(arbitrary_types_allowed=True) # (for FeedbackOutputDef)

_finalized: bool = PrivateAttr(default=False)
Expand Down Expand Up @@ -265,6 +274,11 @@ class Channels(BaseModel, metaclass=ChannelsMetaclass):
_next_requests: Dict[Tuple[str, Optional[Union[str, int]]], Any] = PrivateAttr(default_factory=dict)
_send_channels: Dict[Tuple[str, Optional[Union[str, int]]], Any] = PrivateAttr(default_factory=dict)

# Staging: channel_name -> _StageManager instance
_stages: Dict[str, _StageManager] = PrivateAttr(default_factory=dict)
# Staging: channel_name -> GenericPushAdapter (release trigger)
_stage_triggers: Dict[str, Any] = PrivateAttr(default_factory=dict)

# inside context, the module being attached
_module_being_attached: Any = PrivateAttr(None)
# inside context, the requirements of the module being attached
Expand Down Expand Up @@ -410,6 +424,8 @@ def _finalize(self) -> None:
if not self._finalized:
# Auto-wire state collections declared via annotations.
self._wire_declared_states()
# Auto-wire staging declared via annotations.
self._wire_declared_stages()

# first ensure everything is provided
for (
Expand Down Expand Up @@ -529,6 +545,37 @@ def _wire_declared_states(self) -> None:
edge = self.get_channel(spec.source_field)
self._wire_state_edge(alias, edge, spec.keyby, spec.indexer)

def _wire_declared_stages(self) -> None:
"""Auto-wire staging for channels declared via Annotated[ts[X], Stage()]."""
for field_name in self.__class__._declared_stages:
if field_name in self._stages:
continue # already enabled via set_stage in connect

if not self._delayed_edge_providers.get(field_name):
continue # no provider for this channel

# Determine element type from the channel's ts type
tstype = self.get_outer_type(field_name)
if is_dict_basket(tstype):
element_type = get_dict_basket_value_type(tstype)
elif isTsType(tstype):
element_type = tstype.typ
else:
continue

# Unwrap List[T] -> T
from typing import get_args as _get_args, get_origin as _get_origin

if _get_origin(element_type) is list:
element_type = _get_args(element_type)[0]

stage, push_adapter = build_staging_node(element_type)
self._stages[field_name] = stage
self._stage_triggers[field_name] = push_adapter

# Wire the push adapter's output as an additional provider for this channel
self._delayed_edge_providers[field_name].append((None, push_adapter.out()))

def _bind_delayed_channel(self, field, list_of_edges_and_modules, indexer=None):
tstype = self.get_outer_type(field)
# Make sure a getter node exists first
Expand Down Expand Up @@ -893,7 +940,7 @@ def get_state(self, field: str, indexer: Union[str, int] = None) -> Any:
delayed = self._delayed_state_edges.get((field, indexer))
if delayed is None:
element_type = self._pending_state_element_types[field]
delayed = DelayedEdge(ts[State[element_type]])
delayed = DelayedEdge(ts[_StateManager[element_type]])
delayed.__name__ = "s_{}".format(field)
self._delayed_state_edges[field, indexer] = delayed
return delayed
Expand Down Expand Up @@ -1128,6 +1175,130 @@ def send(self, field: str, value: Any, indexer: Union[str, int] = None) -> None:
# basket, push into first item of tuple
send_channel[0].push_tick(value)

# ------------------------------------------------------------------
# Staging API
# ------------------------------------------------------------------

def set_stage(self, field: str) -> None:
"""Enable staging mode for a channel.

Once staging is enabled, the channel gains stage_add/stage_remove/
stage_release/stage_list/stage_lookup methods. Released items are
injected into the channel as a single tick via a feedback edge.

Must be called during module connect (before finalization).
"""
if field in self._stages:
return # already enabled

if not self._validate_field_name(field):
raise AttributeError("{} has no channel: {}".format(self.__class__, field))

# Determine element type from the channel's ts type
tstype = self.get_outer_type(field)
if is_dict_basket(tstype):
element_type = get_dict_basket_value_type(tstype)
elif isTsType(tstype):
element_type = tstype.typ
else:
raise TypeError(f"Cannot enable staging on non-timeseries field: {field}")

# Unwrap List[T] -> T
from typing import get_args as _get_args, get_origin as _get_origin

if _get_origin(element_type) is list:
element_type = _get_args(element_type)[0]

stage, push_adapter = build_staging_node(element_type)
self._stages[field] = stage
self._stage_triggers[field] = push_adapter

# Wire the push adapter's output as an additional provider for this channel
module = self._module_being_attached
self._delayed_edge_providers[field].append((module, push_adapter.out()))

def staged_channels(self) -> List[str]:
"""Return list of channel names that have staging enabled."""
return list(self._stages.keys())

def stage_add(
self,
field: str,
struct: Any = None,
staging_ids: Optional[List[str]] = None,
) -> List[str]:
"""Add a struct to staging area(s) for a channel.

See STAGE.md for full semantics.
"""
if field not in self._stages:
raise NoProviderException(f"No staging enabled for channel: {field}")
return self._stages[field].stage_add(struct, staging_ids)

def stage_remove(
self,
field: str,
struct: Any = None,
staging_ids: Optional[List[str]] = None,
) -> List[str]:
"""Remove struct(s) from staging area(s).

See STAGE.md for full semantics.
"""
if field not in self._stages:
raise NoProviderException(f"No staging enabled for channel: {field}")
return self._stages[field].stage_remove(struct, staging_ids)

def stage_release(
self,
field: str,
staging_ids: Optional[List[str]] = None,
) -> Dict[str, List[Any]]:
"""Release staged structs into the channel.

Released items are pushed into the csp graph individually.
Returns dict mapping staging_id -> list of released structs.
"""
if field not in self._stages:
raise NoProviderException(f"No staging enabled for channel: {field}")

stage = self._stages[field]
released = stage.stage_release(staging_ids)

# Push each released item into the graph via the push adapter
push_adapter = self._stage_triggers[field]
for items in released.values():
for item in items:
push_adapter.push_tick(item)

return released

def stage_list(
self,
field: str,
staging_id: Optional[str] = None,
) -> List[str]:
"""List staging IDs for a channel.

See STAGE.md for full semantics.
"""
if field not in self._stages:
raise NoProviderException(f"No staging enabled for channel: {field}")
return self._stages[field].stage_list(staging_id)

def stage_lookup(
self,
field: str,
staging_id: Optional[str] = None,
) -> Dict[str, List[Any]]:
"""Look up contents of staging area(s).

Returns dict mapping staging_id -> list of structs.
"""
if field not in self._stages:
raise NoProviderException(f"No staging enabled for channel: {field}")
return self._stages[field].stage_lookup(staging_id)

def _check(self, field: str, where: Dict, kind: str, indexer: Union[str, int] = None) -> None:
if (field, indexer) not in where:
# TODO should only be called once the graph is started
Expand Down
Loading
Loading