Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).

## [Unreleased]

### Added

- **Conditional rule applicability via `when:` block** (issue #73, PR 1/3) --
Rules in `.watchflow/rules.yaml` can now declare a structured `when:` block
with named predicates that gate rule evaluation. Supported predicates:
`contributor: first_time | trusted`, `pr_count_below: N`, and
`files_match: <glob>` (or a list of globs). All predicates must hold for the
rule to run; otherwise it is skipped and logged at debug level. Enables
stricter checks for first-time contributors and path-scoped rules without an
expression parser. Contributor context (merged PR count, first-time flag,
trusted flag) is fetched via the GitHub Search API and attached to the
enriched event data. Expression-parser support (`and`/`or`/comparisons) and
extended predicates (`risk.level`, `contributor.role`) will follow in later
PRs.

### Fixed

- **Stale PR data in CODEOWNERS checks** -- `PullRequestEnricher` now
Expand Down
3 changes: 3 additions & 0 deletions src/agents/engine_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def _convert_rules_to_descriptions(self, rules: list[Any]) -> list[RuleDescripti
event_types = [et.value if hasattr(et, "value") else str(et) for et in rule.event_types]
severity = str(rule.severity.value) if hasattr(rule.severity, "value") else str(rule.severity)
rule_id = getattr(rule, "rule_id", None)
when = getattr(rule, "when", None)
else:
# It's a dict
description = (
Expand All @@ -216,6 +217,7 @@ def _convert_rules_to_descriptions(self, rules: list[Any]) -> list[RuleDescripti
event_types = rule.get("event_types", [])
severity = rule.get("severity", "medium")
rule_id = rule.get("rule_id")
when = None

rule_description = RuleDescription(
description=description,
Expand All @@ -227,6 +229,7 @@ def _convert_rules_to_descriptions(self, rules: list[Any]) -> list[RuleDescripti
validator_name=None, # Will be selected by LLM
fallback_to_llm=True,
conditions=conditions,
when=when,
)

rule_descriptions.append(rule_description)
Expand Down
5 changes: 4 additions & 1 deletion src/agents/engine_agent/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from src.core.models import Violation # noqa: TCH001, TCH002, TC001
from src.rules.conditions.base import BaseCondition # noqa: TCH001, TCH002, TC001
from src.rules.models import Rule # noqa: TCH001, TCH002, TC001
from src.rules.models import Rule, RuleWhen # noqa: TCH001, TCH002, TC001


class EngineRequest(BaseModel):
Expand Down Expand Up @@ -111,6 +111,9 @@ class RuleDescription(BaseModel):
validator_name: str | None = Field(default=None, description="Specific validator to use")
fallback_to_llm: bool = Field(default=True, description="Whether to fallback to LLM if validator fails")
conditions: list["BaseCondition"] = Field(default_factory=list, description="Attached executable conditions") # noqa: UP037
when: RuleWhen | None = Field(
default=None, description="Optional predicate block for conditional rule applicability"
)

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down
18 changes: 13 additions & 5 deletions src/agents/engine_agent/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
get_llm_evaluation_system_prompt,
)
from src.integrations.providers import get_chat_model
from src.rules.when_evaluator import should_apply_rule

logger = logging.getLogger(__name__)

Expand All @@ -36,16 +37,23 @@ async def analyze_rule_descriptions(state: EngineState) -> dict[str, Any]:
try:
logger.info(f"🔍 Analyzing {len(state.rule_descriptions)} rule descriptions")

# Filter rules applicable to this event type
# Filter rules applicable to this event type, then apply optional `when:` predicates.
applicable_rules = []
for rule_desc in state.rule_descriptions:
if state.event_type in rule_desc.event_types:
applicable_rules.append(rule_desc)
logger.info(f"🔍 Rule '{rule_desc.description[:50]}...' is applicable to {state.event_type}")
else:
if state.event_type not in rule_desc.event_types:
logger.info(
f"🔍 Rule '{rule_desc.description[:50]}...' is NOT applicable (expects: {rule_desc.event_types})"
)
continue

applies, reason = should_apply_rule(rule_desc.when, state.event_data)
if not applies:
logger.debug(f'Rule "{rule_desc.description}" skipped: {reason}')
state.analysis_steps.append(f'Rule "{rule_desc.description}" skipped: {reason}')
continue

applicable_rules.append(rule_desc)
logger.info(f"🔍 Rule '{rule_desc.description[:50]}...' is applicable to {state.event_type}")

state.rule_descriptions = applicable_rules
state.analysis_steps.append(f"Found {len(applicable_rules)} applicable rules out of {len(state.rules)} total")
Expand Down
35 changes: 35 additions & 0 deletions src/event_processors/pull_request/enricher.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ async def enrich_event_data(self, task: Any, github_token: str) -> dict[str, Any
]
event_data["diff_summary"] = self.summarize_files(files)

# Build contributor context for rule `when:` predicates (first-time / trusted / pr_count_below).
author_login = (pr_data.get("user") or {}).get("login")
if author_login:
event_data["contributor_context"] = await self._build_contributor_context(
repo_full_name, author_login, installation_id
)

# Fetch CODEOWNERS so path-has-code-owner rule can evaluate without a local repo
codeowners_paths = [".github/CODEOWNERS", "CODEOWNERS", "docs/CODEOWNERS"]
for path in codeowners_paths:
Expand All @@ -109,6 +116,34 @@ async def enrich_event_data(self, task: Any, github_token: str) -> dict[str, Any

return event_data

async def _build_contributor_context(
self, repo_full_name: str, username: str, installation_id: int
) -> dict[str, Any]:
"""
Build contributor context used by rule `when:` predicates.

Uses the Search API to count the author's prior merged PRs in this repo.
The PR currently being evaluated is not merged yet, so it is not counted.
On API failure, returns a context with `merged_pr_count=None` and
boolean predicates set to False — the `when_evaluator` treats missing
data as fail-open and will apply the rule.
"""
merged_count: int | None = None
if hasattr(self.github_client, "search_merged_pr_count"):
try:
merged_count = await self.github_client.search_merged_pr_count(
repo_full_name, username, installation_id
)
except Exception as e:
logger.warning(f"Error fetching merged PR count for {username} in {repo_full_name}: {e}")

return {
"login": username,
"merged_pr_count": merged_count,
"is_first_time": merged_count == 0,
"trusted": bool(merged_count and merged_count > 0),
}

async def fetch_acknowledgments(self, repo: str, pr_number: int, installation_id: int) -> dict[str, Acknowledgment]:
"""Fetch and parse previous acknowledgments from PR comments."""
try:
Expand Down
39 changes: 37 additions & 2 deletions src/integrations/github/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,8 +803,7 @@ async def get_pull_request_files(self, repo: str, pr_number: int, installation_i
session = await self._get_session()
while True:
url = (
f"{config.github.api_base_url}/repos/{repo}/pulls/{pr_number}"
f"/files?per_page={per_page}&page={page}"
f"{config.github.api_base_url}/repos/{repo}/pulls/{pr_number}/files?per_page={per_page}&page={page}"
)
async with session.get(url, headers=headers) as response:
if response.status != 200:
Expand Down Expand Up @@ -1106,6 +1105,42 @@ async def get_commits_for_file(
logger.warning(f"Error getting commits for file {file_path} in {repo}: {e}")
return []

async def search_merged_pr_count(self, repo: str, username: str, installation_id: int) -> int | None:
"""
Return the number of merged PRs authored by `username` in `repo` via the GitHub
Search API. Returns None when the request fails so callers can distinguish
"no data" from "zero merged PRs".
"""
token = await self.get_installation_access_token(installation_id)
if not token:
return None

headers = {
"Authorization": f"Bearer {token}",
"Accept": "application/vnd.github.v3+json",
}
query = f"is:pr is:merged repo:{repo} author:{username}"
url = f"{config.github.api_base_url}/search/issues?q={quote(query)}&per_page=1"

try:
session = await self._get_session()
async with session.get(url, headers=headers) as response:
if response.status == 200:
data = await response.json()
return cast("int", data.get("total_count", 0))
error_text = await response.text()
logger.warning(
"search_merged_pr_count failed",
repo=repo,
username=username,
status=response.status,
response=error_text[:200],
)
return None
except Exception as e:
logger.warning("search_merged_pr_count error", repo=repo, username=username, error=str(e))
return None

async def get_user_pull_requests(
self, repo: str, username: str, installation_id: int, limit: int = 100
) -> list[dict[str, Any]]:
Expand Down
19 changes: 18 additions & 1 deletion src/rules/loaders/github_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from src.core.models import EventType
from src.integrations.github import GitHubClient, github_client
from src.rules.interface import RuleLoader
from src.rules.models import Rule, RuleAction, RuleSeverity
from src.rules.models import Rule, RuleAction, RuleSeverity, RuleWhen
from src.rules.registry import CONDITION_CLASS_TO_RULE_ID, ConditionRegistry

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -113,6 +113,22 @@ def _parse_rule(rule_data: dict[str, Any]) -> Rule:
rule_id_val = rid.value
break

# Parse optional `when:` block (structured predicates controlling rule applicability).
when_block: RuleWhen | None = None
when_data = rule_data.get("when")
if when_data is not None:
if isinstance(when_data, dict):
try:
when_block = RuleWhen(**when_data)
except Exception as e:
logger.warning(
f"Invalid `when` block in rule '{rule_data.get('description', 'unknown')}': {e} — ignoring"
)
else:
logger.warning(
f"`when` block in rule '{rule_data.get('description', 'unknown')}' is not a mapping — ignoring"
)

# Actions are optional and not mapped
actions = []
if "actions" in rule_data:
Expand All @@ -129,6 +145,7 @@ def _parse_rule(rule_data: dict[str, Any]) -> Rule:
actions=actions,
parameters=parameters,
rule_id=rule_id_val,
when=when_block,
)
return rule

Expand Down
26 changes: 26 additions & 0 deletions src/rules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,28 @@ class RuleCategory(str, Enum):
HYGIENE = "hygiene" # AI spam detection, contribution governance (AI Immune System)


class RuleWhen(BaseModel):
"""
Structured predicate block controlling whether a rule is applied to an event.

When all predicates evaluate true, the rule runs; otherwise it is skipped.
An absent or empty block means the rule always runs.
"""

contributor: str | None = Field(
default=None,
description="Contributor predicate: 'first_time' (no prior merged PRs) or 'trusted' (has merged PRs).",
)
pr_count_below: int | None = Field(
default=None,
description="Rule applies only when the author has fewer than N prior merged PRs.",
)
files_match: str | list[str] | None = Field(
default=None,
description="Glob pattern(s); rule applies only when at least one changed file matches.",
)


class RuleCondition(BaseModel):
"""
Represents a condition that must be met for a rule to be triggered.
Expand Down Expand Up @@ -60,6 +82,10 @@ class Rule(BaseModel):
conditions: list["BaseCondition"] = Field(default_factory=list) # noqa: UP037
actions: list[RuleAction] = Field(default_factory=list)
parameters: dict[str, Any] = Field(default_factory=dict) # Store parameters as-is from YAML
when: RuleWhen | None = Field(
default=None,
description="Optional predicate block. Rule is skipped when predicates do not match the event.",
)

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down
72 changes: 72 additions & 0 deletions src/rules/when_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

import fnmatch
import logging
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from src.rules.models import RuleWhen

logger = logging.getLogger(__name__)


def should_apply_rule(when: RuleWhen | None, event_data: dict[str, Any]) -> tuple[bool, str]:
"""
Return whether the rule should be evaluated for this event, and a reason if skipped.

Args:
when: Parsed RuleWhen block (or None when the rule has no predicates).
event_data: Enriched event data, expected to include `contributor_context`
and `changed_files` when the predicates reference them.

Returns:
A tuple of (applies, reason). ``applies`` is True when all named
predicates in ``when`` hold (or ``when`` is empty/None). ``reason``
is a human-readable explanation when the rule is skipped, or an
empty string when the rule applies. If a predicate is present but
its required context is missing, the rule is applied (fail-open)
and a warning is logged — skipping silently on missing data would
hide misconfiguration.
"""
if when is None:
return True, ""

contributor_ctx = event_data.get("contributor_context") or {}

if when.contributor is not None:
if not contributor_ctx:
logger.warning("when.contributor set but contributor_context missing — applying rule")
elif contributor_ctx.get("merged_pr_count") is None:
# API failure: we cannot tell whether the author is first-time or trusted.
# Fail-open (apply the rule) so a transient Search API outage does not
# silently disable stricter checks for newcomers.
logger.warning(f"when.contributor='{when.contributor}' set but merged_pr_count is unknown — applying rule")
else:
predicate = when.contributor.strip().lower()
if predicate == "first_time":
if not contributor_ctx.get("is_first_time", False):
return False, "contributor is not first-time"
elif predicate == "trusted":
if not contributor_ctx.get("trusted", False):
return False, "contributor is not trusted"
else:
logger.warning(f"Unknown contributor predicate '{when.contributor}' — ignoring")

if when.pr_count_below is not None:
if not contributor_ctx:
logger.warning("when.pr_count_below set but contributor_context missing — applying rule")
else:
merged_count = contributor_ctx.get("merged_pr_count")
if merged_count is None:
logger.warning("when.pr_count_below set but merged_pr_count is None — applying rule")
elif merged_count >= when.pr_count_below:
return False, f"contributor has {merged_count} merged PRs (threshold: {when.pr_count_below})"

if when.files_match is not None:
patterns: list[str] = [when.files_match] if isinstance(when.files_match, str) else list(when.files_match)
changed_files = event_data.get("changed_files") or []
filenames = [f.get("filename", "") for f in changed_files if isinstance(f, dict) and f.get("filename")]
if not any(fnmatch.fnmatch(name, pat) for name in filenames for pat in patterns):
return False, f"no changed files match pattern {patterns}"
Comment thread
codesensei-tushar marked this conversation as resolved.

return True, ""
Loading
Loading