Skip to content
Merged
19 changes: 16 additions & 3 deletions csp_gateway/server/demo/omnibus.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,14 @@ class ExampleGatewayChannels(GatewayChannels):
example_list: ts[List[ExampleData]] = None
never_ticks: ts[ExampleData] = None

s_example: ts[State[ExampleData]] = None
# State fields can be added via annotation or the `set_state` API in the module's `connect` method
example_with_state: Annotated[ts[ExampleData], State(("id", "x"))] = None
example_with_state_multiple: Annotated[
ts[ExampleData],
State(("id", "x")),
State(("id", "y"), alias="example_with_state_alternative"),
] = None
# example: Set in connect with a different schema, to show flexibility of state definition

basket: Dict[ExampleEnum, ts[ExampleData]] = None
str_basket: Dict[str, ts[ExampleData]] = None
Expand Down Expand Up @@ -179,6 +186,8 @@ def connect(self, channels: ExampleGatewayChannels):
# Channels set via `set_channel`
channels.set_channel(ExampleGatewayChannels.example, data)
channels.set_channel(ExampleGatewayChannels.example_list, data_list)
channels.set_channel(ExampleGatewayChannels.example_with_state, data)
channels.set_channel(ExampleGatewayChannels.example_with_state_multiple, data)

# Generic channel for sending data from non-csp sources
channels.add_send_channel(ExampleGatewayChannels.example)
Expand All @@ -187,8 +196,12 @@ def connect(self, channels: ExampleGatewayChannels):
channels.add_send_channel(ExampleGatewayChannels.basket, ExampleEnum.C)
channels.add_send_channel(ExampleGatewayChannels.basket)

# Rudimentary state accumulation via `set_state`
channels.set_state(ExampleGatewayChannels.example, "id")
# State accumulation via `set_state`
channels.set_state(
data,
"example",
Comment thread
timkpaine marked this conversation as resolved.
Outdated
("id",),
)

# Create some data streams for dict baskets
data_a = self.subscribe(
Expand Down
214 changes: 155 additions & 59 deletions csp_gateway/server/gateway/csp/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DefaultDict,
Dict,
List,
NamedTuple,
Optional,
Set,
Tuple,
Expand Down Expand Up @@ -60,6 +61,24 @@
log = getLogger(__name__)


class _StateSpec(NamedTuple):
"""Describes a state collection attached to a Channels instance.

``source_field`` is the channel name whose edge feeds the state, or None
when the state was registered via ``set_state`` with a raw edge.
"""

source_field: Optional[str]
keyby: Union[str, Tuple[str, ...]]
indexer: Optional[Union[str, int]] = None


def _normalize_keyby(keyby: Union[str, Tuple[str, ...], list]) -> Tuple[str, ...]:
if isinstance(keyby, (list, tuple)):
return tuple(keyby)
return (keyby,)


class _SnapshotModelBaseClass(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid", coerce_numbers_to_str=True)

Expand Down Expand Up @@ -166,13 +185,30 @@ def __new__(mcs: Any, name: Any, bases: Any, namespace: Any, **kwargs: Any) -> A

_add_field_attributes(cls)
ts_pydantic_field_types = {}
declared_states: Dict[str, _StateSpec] = {}
for field_name, field_type in cls.model_fields.items():
# Validate that timeseries types contain structs or list of structs
outer_type = field_type.annotation
ts_pydantic_field_type = _get_ts_pydantic_field_type(outer_type)
if ts_pydantic_field_type is not None:
ts_pydantic_field_types[field_name] = ts_pydantic_field_type

# Collect State(...) annotation markers from Annotated metadata.
for meta in getattr(field_type, "metadata", ()): # pydantic 2 FieldInfo.metadata
if isinstance(meta, State):
alias = meta._meta_alias or field_name
if alias in declared_states:
raise ValueError(
f"Duplicate state alias '{alias}' on {name}: already declared on field '{declared_states[alias].source_field}'"
)
declared_states[alias] = _StateSpec(
source_field=field_name,
keyby=_normalize_keyby(meta._meta_keyby),
indexer=meta._meta_indexer,
)

cls._declared_states = declared_states

ts_pydantic_field_types[_CSP_ENGINE_CYCLE_TIMESTAMP_FIELD] = (Optional[datetime], None)
dynamic_pydantic_model = create_model("_snapshot_model", __base__=_SnapshotModelBaseClass, **ts_pydantic_field_types)
cls._snapshot_model = dynamic_pydantic_model
Expand All @@ -188,11 +224,16 @@ class Channels(BaseModel, metaclass=ChannelsMetaclass):
The names of the channels match the names that are used via the different APIs, i.e. REST, WebSockets, Perspective, etc.
It is expected that developers interact with channels through these APIs (or via get_channel/set_channel in csp).

Channels that begin with ``s_`` are "state" channels, meaning that they represent a collection of messages, typically
the last message grouped by some key (i.e. security id). These are not meant to be interacted with directly, but rather
through the "state" part of the REST API.
State collections are declared via ``Annotated[ts[X], State(keyby=..., indexer=..., alias=...)]``
on a channel field (auto-wired from the channel's edge), or registered at module
connect time via ``set_state(edge, alias, keyby, indexer=None)``. State collections
Comment thread
timkpaine marked this conversation as resolved.
Outdated
are exposed through the state part of the REST API and ``state``/``query`` helpers.
"""

# Populated by ChannelsMetaclass from Annotated[ts[X], State(...)] markers.
# alias -> _StateSpec(source_field, keyby, indexer)
_declared_states: Dict[str, _StateSpec] = {}

model_config = dict(arbitrary_types_allowed=True) # (for FeedbackOutputDef)

_finalized: bool = PrivateAttr(default=False)
Expand All @@ -206,7 +247,12 @@ class Channels(BaseModel, metaclass=ChannelsMetaclass):
_override_blocks: Dict[Any, Optional[datetime]] = PrivateAttr(default_factory=dict)
_feedbacks: Dict[int, FeedbackOutputDef] = PrivateAttr(default_factory=dict)

# alias -> _StateSpec (annotation-declared and set_state-registered combined)
_states: Dict[str, _StateSpec] = PrivateAttr(default_factory=dict)
# (alias, indexer) -> ConcurrentFutureAdapter trigger
_state_requests: Dict[Tuple[str, Optional[Union[str, int]]], Any] = PrivateAttr(default_factory=dict)
# (alias, indexer) -> bound state Edge
_state_edges: Dict[Tuple[str, Optional[Union[str, int]]], Any] = PrivateAttr(default_factory=dict)
_last_requests: Dict[Tuple[str, Optional[Union[str, int]]], Any] = PrivateAttr(default_factory=dict)
_next_requests: Dict[Tuple[str, Optional[Union[str, int]]], Any] = PrivateAttr(default_factory=dict)
_send_channels: Dict[Tuple[str, Optional[Union[str, int]]], Any] = PrivateAttr(default_factory=dict)
Expand All @@ -233,6 +279,20 @@ class Channels(BaseModel, metaclass=ChannelsMetaclass):
# Might be public in the future.
_null_ts: List[Tuple[str, Optional[str]]] = PrivateAttr(default_factory=list)

def model_post_init(self, __context: Any) -> None:
# Seed instance state registry with class-level declarations from annotations.
for alias, spec in self.__class__._declared_states.items():
self._states[alias] = spec

@classmethod
def state_aliases(cls) -> List[str]:
"""Return the list of state aliases declared on the class via annotations."""
return list(cls._declared_states.keys())

def all_state_aliases(self) -> List[str]:
"""Return all known state aliases (declared + dynamically registered)."""
return list(self._states.keys())

def dynamic_keys(self) -> Optional[Dict[str, List[Any]]]:
"""Define dynamic dictionary keys by field, driven by data from the channels."""
...
Expand Down Expand Up @@ -340,6 +400,9 @@ def _finalize(self) -> None:
who_requires_id[requires].add(id(module))

if not self._finalized:
# Auto-wire state collections declared via annotations.
self._wire_declared_states()

# first ensure everything is provided
for (
field,
Expand Down Expand Up @@ -435,6 +498,29 @@ def _finalize(self) -> None:
self._finalized = True
log.debug(f"Feedback count: {self._feedback_count}")

def _wire_declared_states(self) -> None:
for alias, spec in list(self._states.items()):
if spec.source_field is None:
continue # registered via set_state, already wired
if (alias, spec.indexer) in self._state_edges:
continue

# Skip if no module provides this channel (e.g. all setters disabled)
if not self._delayed_edge_providers.get(spec.source_field):
continue

tstype = self.get_outer_type(spec.source_field)
if is_dict_basket(tstype):
if spec.indexer is None:
raise NotImplementedError(
f"Annotation-declared state '{alias}' on dict basket '{spec.source_field}' "
f"requires an indexer (set indexer=... in State(...))"
)
edge = self.get_channel(spec.source_field, indexer=spec.indexer)
else:
edge = self.get_channel(spec.source_field)
self._wire_state_edge(alias, edge, spec.keyby, spec.indexer)

def _bind_delayed_channel(self, field, list_of_edges_and_modules, indexer=None):
tstype = self.get_outer_type(field)
# Make sure a getter node exists first
Expand Down Expand Up @@ -602,10 +688,6 @@ def get_channel(

return getattr(self, field)

@classmethod
def is_state_field(cls, field):
return field.startswith("s_")

@classmethod
def get_outer_type(cls, field):
return cls.model_fields[field].annotation
Expand All @@ -619,9 +701,6 @@ def set_channel(
# add to graph
self._add_field_to_graph(field, self._module_being_attached, True, indexer)

# TODO fix ugly state field stuff
is_state_field = self.is_state_field(field)

tstype = self.get_outer_type(field)
if is_dict_basket(tstype):
_is_dict_basket = True
Expand Down Expand Up @@ -668,7 +747,7 @@ def set_channel(
else:
edge_tstypes = [edge.tstype] # type: ignore[union-attr]

if not all(edge_tstype == gateway_tstype for edge_tstype in edge_tstypes) and not is_state_field:
if not all(edge_tstype == gateway_tstype for edge_tstype in edge_tstypes):
raise TypeError("Edge type incorrect for {}: should be {}, found {}".format(field, gateway_tstype, edge_tstypes[0]))

module = self._module_being_attached
Expand All @@ -687,60 +766,81 @@ def set_channel(
self._set_last(field)
self._set_next(field)

def _ensure_state_field(self, field: str) -> str:
if not field.startswith("s_"):
return "s_{}".format(field)
return field

def set_state(
self,
edge: Edge,
field: str,
keyby: Union[str, Tuple[str, ...]],
indexer: Union[str, int] = None,
indexer: Optional[Union[str, int]] = None,
) -> None:
# grab state version of field
state_field = self._ensure_state_field(field)

# Bail if already setup
if (state_field, indexer) in self._state_requests:
return

# First ensure edge is constructed
edge = self.get_channel(field, indexer=indexer)

# And ensure the state edge is constructed
self.get_state(state_field, indexer=indexer)
"""Register a state collection from a raw edge under ``field``.

if isinstance(edge, Edge):
# instantiate state node
if get_origin(edge.tstype.typ) is list:
edge_type_name = get_args(edge.tstype.typ)[0].__name__
state_edge = build_track_state_node(csp.unroll(edge), keyby)
else:
edge_type_name = edge.tstype.typ.__name__
state_edge = build_track_state_node(edge, keyby)
``field`` is the name the state will be exposed as (parallel to
``set_channel(field, edge, indexer)``). If ``field`` is already
registered, this is a no-op when ``keyby`` and ``indexer`` match;
otherwise a ``ValueError`` is raised.
"""
if not isinstance(edge, Edge):
raise TypeError("set_state expects a csp Edge as the first argument; got {}".format(type(edge)))

keyby = _normalize_keyby(keyby)
existing = self._states.get(field)
if existing is not None:
if existing.keyby != keyby or existing.indexer != indexer:
raise ValueError(
f"State '{field}' already registered with "
f"keyby={existing.keyby!r}, indexer={existing.indexer!r}; "
f"cannot redefine with keyby={keyby!r}, indexer={indexer!r}"
)
if (field, indexer) in self._state_edges:
return # already wired

state_edge.nodedef.__name__ = "State[{}]".format(edge_type_name)
self._states[field] = _StateSpec(source_field=None, keyby=keyby, indexer=indexer)
self._wire_state_edge(field, edge, keyby, indexer)

# register for use inside other csp nodes
self.set_channel(state_field, state_edge, indexer=indexer)
def _wire_state_edge(
self,
field: str,
edge: Edge,
keyby: Union[str, Tuple[str, ...]],
indexer: Optional[Union[str, int]],
) -> None:
if get_origin(edge.tstype.typ) is list:
edge_type_name = get_args(edge.tstype.typ)[0].__name__
state_edge = build_track_state_node(csp.unroll(edge), keyby)
else:
edge_type_name = edge.tstype.typ.__name__
state_edge = build_track_state_node(edge, keyby)

# setup ad-hoc querying
trigger = ConcurrentFutureAdapter(name="RequestState<{}>".format(edge_type_name))
state_edge.nodedef.__name__ = "State[{}]".format(edge_type_name)

named_on_request_node("QueryState<{}>".format(edge_type_name))(state_edge, trigger.out())
trigger = ConcurrentFutureAdapter(name="RequestState<{}>".format(edge_type_name))
named_on_request_node("QueryState<{}>".format(edge_type_name))(state_edge, trigger.out())

# register the trigger
self._state_requests[state_field, indexer] = trigger
else:
# TODO
raise NotImplementedError()
self._state_requests[field, indexer] = trigger
self._state_edges[field, indexer] = state_edge

def get_state(self, field: str, indexer: Union[str, int] = None) -> Any:
# grab state version of field
state_field = self._ensure_state_field(field)

return self.get_channel(state_field, indexer=indexer)
"""Return the underlying state Edge for ``field`` (csp-graph use)."""
if field not in self._states:
raise NoProviderException("Unknown state: {}".format(field))
if (field, indexer) not in self._state_edges:
spec = self._states[field]
if spec.source_field is None:
raise NoProviderException("State '{}' (indexer={}) has not been wired yet".format(field, indexer))
# Lazily wire annotation-declared state on first access
tstype = self.get_outer_type(spec.source_field)
if is_dict_basket(tstype):
if spec.indexer is None:
raise NotImplementedError(
f"Annotation-declared state '{field}' on dict basket '{spec.source_field}' "
f"requires an indexer (set indexer=... in State(...))"
)
edge = self.get_channel(spec.source_field, indexer=spec.indexer)
else:
edge = self.get_channel(spec.source_field)
self._wire_state_edge(field, edge, spec.keyby, spec.indexer)
return self._state_edges[field, indexer]

def _set_last(self, field: str, indexer: Union[str, int] = None) -> None:
# Bail if already setup
Expand Down Expand Up @@ -923,17 +1023,13 @@ def next(self, field: str, indexer: Union[str, int] = None, *, timeout=None) ->
return result

def state(self, field: str, indexer: Union[str, int] = None, *, timeout=None) -> Any:
# grab state version of field
state_field = self._ensure_state_field(field)

self._check(state_field, self._state_requests, "state", indexer=indexer)
self._check(field, self._state_requests, "state", indexer=indexer)

# TODO checks for state tracking
# TODO make sure not called from inside graph context
state_edge = self._state_requests[state_field, indexer]
trigger = self._state_requests[field, indexer]

# trigger request for state into graph
future = state_edge.push_tick()
future = trigger.push_tick()

# wait for result
return future.result(timeout=timeout)
Expand Down
Loading
Loading