Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 95 additions & 92 deletions backend/graph_builder.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,114 @@
"""Converts SQL rows into a stable graph payload: {nodes: [], edges: []}."""
"""Builds a data-driven graph payload from SQL rows: {nodes: [], edges: []}."""

from __future__ import annotations

from typing import Any, Dict, List, Tuple

SUPPORTED_TYPES: Tuple[str, ...] = (
"customer",
"order",
"delivery",
"invoice",
"payment",
"product",
)

RELATION_LABELS: Dict[Tuple[str, str], str] = {
("customer", "order"): "placed",
("order", "delivery"): "fulfilled by",
("delivery", "invoice"): "billed via",
("order", "invoice"): "billed via",
("invoice", "payment"): "paid by",
("order", "product"): "contains",
from typing import Any, Dict, List, Optional, Tuple

ENTITY_COLUMN_MAP: Dict[str, Tuple[str, ...]] = {
"customer": ("customer_id", "customer", "sold_to_party", "business_partner"),
"order": ("order_id", "sales_order"),
"delivery": ("delivery_id", "delivery_document"),
"invoice": ("invoice_id", "billing_document", "invoice_reference"),
"payment": ("payment_id", "accounting_document", "clearing_accounting_document"),
"product": ("product_id", "product", "material"),
}


def _node_id(entity_type: str, entity_value: Any) -> str:
return f"{entity_type}_{entity_value}"
def _normalize_value(value: Any) -> Optional[str]:
if value is None:
return None
text = str(value).strip()
return text if text else None


def _make_node(entity_type: str, entity_value: str, row: Dict[str, Any]) -> Dict[str, Any]:
node_id = f"{entity_type}_{entity_value}"
return {
"id": node_id,
"type": entity_type,
"data": {
"id": entity_value,
"type": entity_type,
"label": f"{entity_type.title()} {entity_value}",
"context": row,
},
}


def _extract_entities_in_row(row: Dict[str, Any]) -> List[Tuple[str, str]]:
"""Extract entities in row column order. Returns [(entity_type, entity_value), ...]."""
column_to_entity: Dict[str, str] = {}
for entity_type, columns in ENTITY_COLUMN_MAP.items():
for col in columns:
column_to_entity[col] = entity_type

entities: List[Tuple[str, str]] = []
seen_in_row = set()

for col_name, raw_value in row.items():
col_lower = str(col_name).lower()
entity_type = column_to_entity.get(col_lower)
if not entity_type:
continue
entity_value = _normalize_value(raw_value)
if not entity_value:
continue
entity_id = f"{entity_type}_{entity_value}"
if entity_id in seen_in_row:
continue
seen_in_row.add(entity_id)
entities.append((entity_type, entity_value))

return entities


def build_graph(rows: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
graph: Dict[str, List[Dict[str, Any]]] = {"nodes": [], "edges": []}
if not rows:
return graph
return {"nodes": [], "edges": []}

unique_nodes: Dict[str, Dict[str, Any]] = {}
unique_edges: Dict[str, Dict[str, Any]] = {}

def add_node(entity_type: str, entity_value: Any, row: Dict[str, Any]) -> None:
if entity_value in (None, ""):
return
node_id = _node_id(entity_type, entity_value)
if node_id in unique_nodes:
return
label = f"{entity_type.title()} {entity_value}"
unique_nodes[node_id] = {
"id": node_id,
"type": entity_type,
for row in rows:
entities = _extract_entities_in_row(row)

for entity_type, entity_value in entities:
node_id = f"{entity_type}_{entity_value}"
if node_id not in unique_nodes:
unique_nodes[node_id] = _make_node(entity_type, entity_value, row)

# Build relationships strictly from same-row entities (in row order).
if len(entities) >= 2:
for (source_type, source_value), (target_type, target_value) in zip(entities, entities[1:]):
source_id = f"{source_type}_{source_value}"
target_id = f"{target_type}_{target_value}"
if source_id == target_id:
continue
edge_id = f"e_{source_id}_{target_id}"
if edge_id in unique_edges:
continue
unique_edges[edge_id] = {
"id": edge_id,
"source": source_id,
"target": target_id,
"label": "related",
}

if not unique_nodes and rows:
first_row = rows[0]
fallback_id = "row_0"
unique_nodes[fallback_id] = {
"id": fallback_id,
"type": "record",
"data": {
"id": entity_value,
"label": label,
"type": entity_type,
**row,
"id": fallback_id,
"type": "record",
"label": "Result Row",
"context": first_row,
},
}

def add_edge(source_type: str, source_value: Any, target_type: str, target_value: Any) -> None:
if source_value in (None, "") or target_value in (None, ""):
return
source_id = _node_id(source_type, source_value)
target_id = _node_id(target_type, target_value)
if source_id not in unique_nodes or target_id not in unique_nodes:
return
label = RELATION_LABELS.get((source_type, target_type), "related")
edge_id = f"e_{source_id}_{target_id}"
if edge_id in unique_edges:
return
unique_edges[edge_id] = {
"id": edge_id,
"source": source_id,
"target": target_id,
"label": label,
}

# Map of generic node types to the possible SAP column names that contain their IDs
ENTITY_COLUMNS: Dict[str, List[str]] = {
"customer": ["customer", "business_partner", "sold_to_party"],
"order": ["sales_order"],
"delivery": ["delivery_document"],
"invoice": ["billing_document", "invoice_reference"],
"payment": ["accounting_document"],
"product": ["product", "material"],
return {
"nodes": list(unique_nodes.values()),
"edges": list(unique_edges.values()),
}

for row in rows:
entity_values: Dict[str, Any] = {}

# 1. First extract generic entity IDs from the row using SAP column names
for entity_type, possible_cols in ENTITY_COLUMNS.items():
# also check the legacy *_id name just in case Aliases were used
possible_cols.append(f"{entity_type}_id")

for col in possible_cols:
if col in row and row[col] not in (None, ""):
entity_values[entity_type] = str(row[col])
break # Found the ID for this entity type in this row

# 2. Add nodes for any entities found in this row
for entity_type, entity_id in entity_values.items():
add_node(entity_type, entity_id, row)

# 3. Canonical O2C chain edges when IDs are present.
add_edge("customer", entity_values.get("customer"), "order", entity_values.get("order"))
add_edge("order", entity_values.get("order"), "delivery", entity_values.get("delivery"))
add_edge("delivery", entity_values.get("delivery"), "invoice", entity_values.get("invoice"))
add_edge("order", entity_values.get("order"), "invoice", entity_values.get("invoice"))
add_edge("invoice", entity_values.get("invoice"), "payment", entity_values.get("payment"))
add_edge("order", entity_values.get("order"), "product", entity_values.get("product"))

graph["nodes"] = [node for node in unique_nodes.values()]
graph["edges"] = [edge for edge in unique_edges.values()]
return graph
78 changes: 47 additions & 31 deletions backend/routers/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
1) Schema loaded from schema_enforcer (schema.json / live DB fallback)
2) Guardrails — reject non-dataset queries
3) Groq generates SQL from strict SAP schema prompt
4) Auto-correct phantom column/table names
5) SQL validation — allow only SELECT/WITH + LIMIT
6) Column validation against real SAP schema
4) SQL validation — allow only SELECT/WITH + LIMIT
5) Strict column validation against real SAP schema
7) Async PostgreSQL execution
8) On DB error → auto-retry with error context sent back to LLM
8) On DB error → one retry with error context sent back to LLM
9) Graph builder — {nodes, edges} from result rows
10) Gemini summarizes if >10 rows
11) Consistent response: {type, summary, total, data, graph}
Expand Down Expand Up @@ -149,12 +148,9 @@ def set(self, key: str, value: Any) -> None:

SQL_PROMPT_TEMPLATE = """You are a PostgreSQL expert working with an SAP Order-to-Cash database.

{column_mapping}

STRICT RULES — VIOLATION WILL CAUSE QUERY FAILURE:
1. Use ONLY the tables and columns from the SCHEMA below. DO NOT invent column names.
2. DO NOT use: customer_id, order_id, product_id, invoice_id, payment_id, delivery_id
— these columns do NOT exist. Use the MAPPING above instead.
2. Do NOT invent aliases for missing columns. Use exact schema names only.
3. Always use explicit column names. No SELECT *.
4. All ID/text values in WHERE clauses MUST be wrapped in single quotes:
e.g. sold_to_party = '320000085' NOT sold_to_party = 320000085
Expand All @@ -175,18 +171,16 @@ def set(self, key: str, value: Any) -> None:

Validate and fix the SQL below so it correctly answers: "{question}"

{column_mapping}

SCHEMA:
{schema}

SQL TO VALIDATE:
{sql}

VALIDATION CHECKLIST:
1. Every table name MUST exist in the SCHEMA. Replace any phantom table with the correct SAP table.
2. Every column name MUST exist in the SCHEMA. Remove or replace any hallucinated column.
3. Use MAPPING above to fix wrong column names (e.g. customer_id → customer, order_id → sales_order).
1. Every table name MUST exist in the SCHEMA.
2. Every column name MUST exist in the SCHEMA.
3. Do NOT invent columns or tables and do NOT use any mapping.
4. All literal values in WHERE clauses must be single-quoted.
5. Apply LIMIT 20 if missing.
6. Use proper LEFT JOINs via the JOIN HINTS for rich results.
Expand Down Expand Up @@ -241,12 +235,12 @@ def format_small_result(rows: List[Dict[str, Any]]) -> str:
def _error_response(msg: str, detail: str = "") -> Dict[str, Any]:
return {
"type": "error",
"message": detail or msg,
"message": msg if msg == "Invalid column in query" else (detail or msg),
}


def _empty_response(message: str = "No data found") -> Dict[str, Any]:
return {"type": "empty", "message": message}
return {"type": "empty", "message": message, "graph": {"nodes": [], "edges": []}}


def normalize_question(question: str) -> Tuple[str, Dict[str, str]]:
Expand Down Expand Up @@ -302,8 +296,7 @@ def _groq_call(question: str, schema_str: str) -> Dict[str, Any]:
if not GROQ_API_KEY:
raise RuntimeError("GROQ_API_KEY not set")
client = Groq(api_key=GROQ_API_KEY)
mapping_str = schema_enforcer.build_column_mapping_prompt()
prompt = SQL_PROMPT_TEMPLATE.format(schema=schema_str, column_mapping=mapping_str)
prompt = SQL_PROMPT_TEMPLATE.format(schema=schema_str)
resp = client.chat.completions.create(
model="llama-3.1-8b-instant",
response_format={"type": "json_object"},
Expand All @@ -321,9 +314,8 @@ def _groq_call(question: str, schema_str: str) -> Dict[str, Any]:
def _groq_validate_sql_call(question: str, sql: str, schema_str: str) -> Dict[str, Any]:
from groq import Groq # pyre-ignore[21]
client = Groq(api_key=GROQ_API_KEY)
mapping_str = schema_enforcer.build_column_mapping_prompt()
prompt = SQL_VALIDATOR_PROMPT_TEMPLATE.format(
schema=schema_str, question=question, sql=sql, column_mapping=mapping_str
schema=schema_str, question=question, sql=sql
)
resp = client.chat.completions.create(
model="llama-3.1-8b-instant",
Expand Down Expand Up @@ -358,6 +350,21 @@ def _groq_retry_call(question: str, failed_sql: str, error_msg: str) -> Dict[str
return json.loads(content) # type: ignore[return-value]


def _gemini_sql_call(question: str, schema_str: str) -> Dict[str, Any]:
import google.generativeai as genai # pyre-ignore[21]
if not GEMINI_API_KEY:
raise RuntimeError("GEMINI_API_KEY not set")
genai.configure(api_key=GEMINI_API_KEY)
model = genai.GenerativeModel(
model_name="gemini-2.0-flash",
generation_config=genai.types.GenerationConfig(temperature=0.0, max_output_tokens=400), # pyre-ignore[16]
)
prompt = SQL_PROMPT_TEMPLATE.format(schema=schema_str)
result = model.generate_content([prompt, f"Question: {question}"])
text = (result.text or "{}").strip()
return json.loads(text) # type: ignore[return-value]


def _gemini_call(question: str, sql: str, rows: List[Dict[str, Any]]) -> str:
import google.generativeai as genai # pyre-ignore[21]
if not GEMINI_API_KEY:
Expand Down Expand Up @@ -448,19 +455,27 @@ async def generate_sql_from_llm(question: str, retry_context: Optional[Tuple[str
except asyncio.TimeoutError:
raise ValueError("LLM timeout — Groq did not respond within 14 seconds")
except Exception as exc:
if is_llm_rate_error(str(exc)):
if is_llm_rate_error(str(exc)) and not retry_context:
logger.warning("[GROQ] Rate-limited. Falling back to Gemini SQL generation.")
await asyncio.sleep(1.0)
try:
result = await asyncio.wait_for(
asyncio.to_thread(_gemini_sql_call, normalized_question, schema_str), # pyre-ignore[6]
timeout=14.0,
)
sql_validated = str(result.get("sql", "")).strip()
logger.info("[GEMINI] Fallback SQL: %s", _trunc(sql_validated))
except Exception as gemini_exc:
logger.warning("[GEMINI] SQL fallback failed: %s", gemini_exc)
raise ValueError("LLM limit exceeded")
elif is_llm_rate_error(str(exc)):
raise ValueError("LLM limit exceeded")
logger.warning("[GROQ] Error: %s", exc)
return None

# ── Step 6: Auto-correct phantom column/table names ───────────────────────
sql_corrected, corrections = schema_enforcer.autocorrect_sql(sql_validated)
if corrections:
c_list = list(corrections)
logger.info("[AUTOCORRECT] %d corrections applied: %s", len(c_list), ", ".join(c_list))
if "sql_validated" not in locals():
return None

# ── Step 4 (syntax): Ensure SELECT/WITH + LIMIT ───────────────────────────
sql_clean = sanitize_and_validate_sql(sql_corrected)
sql_clean = sanitize_and_validate_sql(sql_validated)

# ── Step 5: Column validation against real SAP schema ─────────────────────
if sql_clean:
Expand Down Expand Up @@ -513,7 +528,7 @@ async def summarize(question: str, sql: str, rows: List[Dict[str, Any]]) -> str:


async def build_query_response(question: str) -> Dict[str, Any]:
"""Full pipeline with auto-correct, schema enforcement, and retry-on-error."""
"""Full pipeline with strict schema enforcement and single retry-on-error."""
try:
logger.info("[DEBUG] build_query_response started for: %s", question)

Expand All @@ -525,7 +540,7 @@ async def build_query_response(question: str) -> Dict[str, Any]:
logger.info("[DEBUG] SQL generated: %s", sql)

if not sql:
return _error_response("Could not generate valid SQL for this query.")
return _error_response("Invalid column in query")

# ── Step 7: Execute SQL ───────────────────────────────────────────────
db_result = await run_sql(sql)
Expand All @@ -540,10 +555,11 @@ async def build_query_response(question: str) -> Dict[str, Any]:
failed_sql: str = str(failed_sql_raw)

logger.warning("[RETRY] DB error '%s'. Retrying SQL generation...", _trunc(error_msg, 80))

await asyncio.sleep(1.0)
sql_retry = await generate_sql_from_llm(question, retry_context=(failed_sql, error_msg))
if sql_retry and sql_retry != sql:
logger.info("[RETRY] Retrying with corrected SQL: %s", _trunc(sql_retry))
await asyncio.sleep(1.0)
db_result = await run_sql(sql_retry)
sql = sql_retry

Expand Down
Loading