Skip to content
5 changes: 5 additions & 0 deletions src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import attrs
import sqlalchemy as sa
from sqlalchemy.sql.expression import true
from sqlalchemy.sql.selectable import GenerativeSelect

from datachain import json
from datachain.client import Client
Expand Down Expand Up @@ -88,6 +89,10 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
def cleanup_for_tests(self):
"""Cleanup for tests."""

def normalize_limit_offset(self, query: GenerativeSelect) -> GenerativeSelect:
"""Return query adjusted for warehouse-specific LIMIT/OFFSET semantics."""
return query

def _to_jsonable(self, obj: Any) -> Any:
"""Recursively convert Python/Pydantic structures into JSON-serializable
objects.
Expand Down
165 changes: 152 additions & 13 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from datetime import datetime, timezone
from functools import wraps
from types import GeneratorType
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeVar, cast
from uuid import uuid4

import attrs
Expand All @@ -25,7 +25,7 @@
from sqlalchemy.sql import func as f
from sqlalchemy.sql.elements import ColumnClause, ColumnElement, Label
from sqlalchemy.sql.expression import label
from sqlalchemy.sql.selectable import Select, TableClause
from sqlalchemy.sql.selectable import GenerativeSelect, Select, TableClause
from sqlalchemy.sql.visitors import replacement_traverse

from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
Expand Down Expand Up @@ -88,7 +88,6 @@

from sqlalchemy.sql.elements import KeyedColumnElement
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.selectable import GenerativeSelect
from typing_extensions import ParamSpec, Self

from datachain.catalog import Catalog
Expand Down Expand Up @@ -171,19 +170,76 @@ def step_result(
)


class QueryState:
"""Tracks which terminal SQL clauses have been applied to the in-flight
query during ``DatasetQuery.apply_steps``.

A subsequent clause inspects these flags to decide whether the prior
output must be wrapped in a subquery (a *state boundary*) before the
new clause is applied — otherwise constructs like ``WHERE`` after
``GROUP BY`` or ``ORDER BY`` after ``LIMIT`` would affect the
underlying rows instead of the materialized state output.
"""

def __init__(self, warehouse: "AbstractWarehouse") -> None:
self.warehouse = warehouse
self.reset()

def reset(self) -> None:
self.grouped = False
self.limited = False
self.offset_applied = False
self.distinct = False

@property
def has_finalizing_clause(self) -> bool:
"""True if a prior clause finalized the current row set."""
return self.grouped or self.limited or self.offset_applied or self.distinct

def emit_boundary(self, query_generator: "QueryGenerator") -> "QueryGenerator":
"""Materialize the in-flight query as a subquery and reset flags.

After this call, the returned generator represents a fresh state:
any prior ``LIMIT``/``OFFSET``/``GROUP BY``/``DISTINCT`` is folded
into the inner subquery, so the caller can apply its clause to the
outer level without affecting underlying rows.
"""
query: GenerativeSelect = query_generator.select()
# Some warehouses need dialect-specific LIMIT/OFFSET normalization
# before the query is wrapped.
query = self.warehouse.normalize_limit_offset(query)
sub = query.subquery()
new_query = sqlalchemy.select(*sub.c).select_from(sub)

def q(*columns):
return new_query.with_only_columns(*columns)

self.reset()
return QueryGenerator(func=q, columns=tuple(new_query.selected_columns))


@frozen
class Step(ABC):
"""A query processing step (filtering, mutation, etc.)"""

resets_query_state_after_apply: ClassVar[bool] = True

@abstractmethod
def apply(
self,
query_generator: QueryGenerator,
temp_tables: list[str],
*args,
state: "QueryState",
Comment thread
shcheklein marked this conversation as resolved.
**kwargs,
) -> "StepResult":
"""Apply the processing step."""
"""Apply the processing step.

``state`` carries the in-flight query-state flags. SQL clauses
that finalize output (``filter`` after ``GROUP BY``/``LIMIT``, etc.)
inspect and update it via ``state.emit_boundary`` to materialize
a subquery boundary when needed. Steps that don't care may
ignore it.
"""

@abstractmethod
def hash_inputs(self) -> str:
Expand Down Expand Up @@ -325,7 +381,7 @@ def apply(
self,
query_generator,
temp_tables: list[str],
*args,
state: "QueryState",
**kwargs,
) -> "StepResult":
source_query = query_generator.select()
Expand Down Expand Up @@ -1128,9 +1184,12 @@ def apply(
self,
query_generator: QueryGenerator,
temp_tables: list[str],
hash_input: str,
hash_output: str,
state: "QueryState",
*,
hash_input: str = "",
hash_output: str = "",
checkpoints_enabled: bool = True,
**kwargs,
) -> "StepResult":
query = query_generator.select()

Expand Down Expand Up @@ -1820,15 +1879,22 @@ def q(*columns):

@frozen
class SQLClause(Step, ABC):
resets_query_state_after_apply: ClassVar[bool] = False

def apply(
self,
query_generator: QueryGenerator,
temp_tables: list[str],
*args,
state: "QueryState",
**kwargs,
) -> StepResult:
if self.requires_boundary(state):
query_generator = state.emit_boundary(query_generator)

query = query_generator.select()
new_query = self.apply_sql_clause(query)
new_query = state.warehouse.normalize_limit_offset(new_query)
self.update_query_state(state)

Comment thread
shcheklein marked this conversation as resolved.
def q(*columns):
return new_query.with_only_columns(*columns)
Expand All @@ -1845,9 +1911,21 @@ def parse_cols(
def apply_sql_clause(self, query: Any) -> Any:
pass

def requires_boundary(self, state: "QueryState") -> bool:
return False

def update_query_state(self, state: "QueryState") -> None:
# SELECT / MUTATE / SELECT_EXCEPT / COUNT etc. all wrap the prior
# result in a fresh subquery, so any prior state flags no longer apply.
state.reset()


@frozen
class RegenerateSystemColumns(Step):
# This step resets state conditionally inside apply(): preserve pending flags
# for true no-ops, but reset when system-column regeneration wraps the query.
resets_query_state_after_apply: ClassVar[bool] = False

catalog: "Catalog"

def hash_inputs(self) -> str:
Expand All @@ -1857,13 +1935,15 @@ def apply(
self,
query_generator: QueryGenerator,
temp_tables: list[str],
*args,
state: "QueryState",
**kwargs,
) -> StepResult:
query = query_generator.select()
new_query = self.catalog.warehouse._regenerate_system_columns(
query, keep_existing_columns=True
)
if new_query is not query:
state.reset()

def q(*columns):
return new_query.with_only_columns(*columns)
Expand Down Expand Up @@ -1951,6 +2031,16 @@ def apply_sql_clause(self, query: Select) -> Select:
expressions = self.parse_cols(self.expressions)
return query.filter(*expressions)

def requires_boundary(self, state: "QueryState") -> bool:
# WHERE applied after GROUP BY / LIMIT / OFFSET / DISTINCT would filter
# the underlying rows, not the materialized state output.
return state.has_finalizing_clause

def update_query_state(self, state: "QueryState") -> None:
# filter does not finalize the output; state flags carry over unless a
# boundary was emitted first, in which case QueryState is already reset.
pass


@frozen
class SQLOrderBy(SQLClause):
Expand All @@ -1963,6 +2053,15 @@ def apply_sql_clause(self, query: Select) -> Select:
args = self.parse_cols(self.args)
return query.order_by(*args)

def requires_boundary(self, state: "QueryState") -> bool:
# Re-ordering after LIMIT/OFFSET would change which rows are kept.
return state.limited or state.offset_applied

def update_query_state(self, state: "QueryState") -> None:
# order_by does not finalize the output; state flags carry over unless
# a boundary was emitted first.
pass


@frozen
class SQLLimit(SQLClause):
Expand All @@ -1974,6 +2073,14 @@ def hash_inputs(self) -> str:
def apply_sql_clause(self, query: Select) -> Select:
return query.limit(self.n)

def requires_boundary(self, state: "QueryState") -> bool:
# Stacked LIMITs must materialize first. (LIMIT after OFFSET combines
# into a single ``LIMIT N OFFSET M`` clause, so no boundary needed.)
return state.limited

def update_query_state(self, state: "QueryState") -> None:
state.limited = True


@frozen
class SQLOffset(SQLClause):
Expand All @@ -1982,9 +2089,17 @@ class SQLOffset(SQLClause):
def hash_inputs(self) -> str:
return hashlib.sha256(str(self.offset).encode()).hexdigest()

def apply_sql_clause(self, query: "GenerativeSelect"):
def apply_sql_clause(self, query: GenerativeSelect) -> GenerativeSelect:
return query.offset(self.offset)

def requires_boundary(self, state: "QueryState") -> bool:
# Re-applying OFFSET requires materializing the previous offset.
# (OFFSET after LIMIT combines into ``LIMIT N OFFSET M``, no boundary.)
return state.offset_applied

def update_query_state(self, state: "QueryState") -> None:
state.offset_applied = True


@frozen
class SQLCount(SQLClause):
Expand All @@ -2009,6 +2124,15 @@ def apply_sql_clause(self, query):

return query.distinct(*self.args)

def requires_boundary(self, state: "QueryState") -> bool:
# DISTINCT after GROUP BY / LIMIT / OFFSET should de-duplicate the
# materialized rows, not the underlying ones. Re-distinct also needs
# a boundary.
return state.grouped or state.limited or state.offset_applied or state.distinct

def update_query_state(self, state: "QueryState") -> None:
state.distinct = True


@frozen
class SQLUnion(Step):
Expand Down Expand Up @@ -2044,7 +2168,7 @@ def apply(
self,
query_generator: QueryGenerator,
temp_tables: list[str],
*args,
state: "QueryState",
**kwargs,
) -> StepResult:
left_before = len(self.query1.temp_table_names)
Expand Down Expand Up @@ -2201,7 +2325,7 @@ def apply(
self,
query_generator: QueryGenerator,
temp_tables: list[str],
*args,
state: "QueryState",
**kwargs,
) -> StepResult:
q1 = self.get_query(self.query1, temp_tables)
Expand Down Expand Up @@ -2340,6 +2464,12 @@ def apply_sql_clause(self, query) -> Select:

return sqlalchemy.select(*unique_cols).select_from(subquery).group_by(*group_by)

def update_query_state(self, state: "QueryState") -> None:
# SQLGroupBy.apply_sql_clause() selects from query.subquery(), so any
# prior finalized output is already folded into that inner selectable.
state.reset()
state.grouped = True


class UnionSchemaMismatchError(ValueError):
"""Union input columns mismatch."""
Expand Down Expand Up @@ -2639,6 +2769,9 @@ def q(*columns):
raise RuntimeError("DatasetQuery has no starting dataset or steps")

_hash = hasher.hexdigest()
# Stage flags reset at the start of every apply_steps run; they only
# describe the in-flight query under construction here.
state = QueryState(self.catalog.warehouse)
for step in query.steps:
hash_input = _hash
hasher.update(step.hash().encode("utf-8"))
Expand All @@ -2648,11 +2781,17 @@ def q(*columns):
result = step.apply(
result.query_generator,
self.temp_table_names,
state,
hash_input=hash_input,
hash_output=hash_output,
checkpoints_enabled=self.checkpoints_enabled,
) # a chain of steps linked by results
self.dependencies.update(result.dependencies)
# Most non-SQLClause steps materialize a fresh relation, so any
# prior query-state flags no longer apply. Steps that preserve the
# in-flight selectable, or manage flags themselves, opt out.
if step.resets_query_state_after_apply:
state.reset()

return result.query_generator

Expand Down
Loading
Loading