Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,5 @@ ui-component/notebook-storage/notebooks.json
ee/ui-component/package-lock.json
ee/ee_tokens/gdrive_token_dev_user.pickle
core/tests/integration/test_data/version_test_1.txt

migrations
44 changes: 40 additions & 4 deletions core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,6 +1391,43 @@ async def list_graphs(
raise HTTPException(status_code=500, detail=str(e))


@app.get("/graph/{name}/visualization", response_model=Dict[str, Any])
@telemetry.track(operation_type="get_graph_visualization", metadata_resolver=telemetry.get_graph_metadata)
async def get_graph_visualization(
name: str,
auth: AuthContext = Depends(verify_token),
folder_name: Optional[Union[str, List[str]]] = None,
end_user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Get graph visualization data.

This endpoint retrieves the nodes and links data needed for graph visualization.
It works with both local and API-based graph services.

Args:
name: Name of the graph to visualize
auth: Authentication context
folder_name: Optional folder to scope the operation to
end_user_id: Optional end-user ID to scope the operation to

Returns:
Dict: Visualization data containing nodes and links arrays
"""
try:
return await document_service.get_graph_visualization_data(
name=name,
auth=auth,
folder_name=folder_name,
end_user_id=end_user_id,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
logger.error(f"Error getting graph visualization data: {e}")
raise HTTPException(status_code=500, detail=str(e))


@app.post("/graph/{name}/update", response_model=Graph)
@telemetry.track(operation_type="update_graph", metadata_resolver=telemetry.update_graph_metadata)
async def update_graph(
Expand Down Expand Up @@ -1425,9 +1462,8 @@ async def update_graph(

# Create system filters for folder and user scoping
system_filters = {}
if request.folder_name is not None:
normalized_folder_name = normalize_folder_name(request.folder_name)
system_filters["folder_name"] = normalized_folder_name
if request.folder_name:
system_filters["folder_name"] = request.folder_name
if request.end_user_id:
system_filters["end_user_id"] = request.end_user_id

Expand Down Expand Up @@ -1907,7 +1943,7 @@ async def list_chat_conversations(

Args:
auth: Authentication context containing user and app identifiers.
limit: Maximum number of conversations to return.
limit: Maximum number of conversations to return (1-500)

Returns:
A list of dictionaries describing each conversation, ordered by most
Expand Down
25 changes: 19 additions & 6 deletions core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,13 @@ class Settings(BaseSettings):
RULES_BATCH_SIZE: int = 4096

# Graph configuration
GRAPH_MODE: Literal["local", "api"] = "local"
GRAPH_PROVIDER: Literal["litellm"] = "litellm"
GRAPH_MODEL: str
ENABLE_ENTITY_RESOLUTION: bool = True
GRAPH_MODEL: Optional[str] = None
ENABLE_ENTITY_RESOLUTION: Optional[bool] = None
# Graph API configuration
MORPHIK_GRAPH_API_KEY: Optional[str] = None
MORPHIK_GRAPH_BASE_URL: Optional[str] = None

# Reranker configuration
USE_RERANKING: bool
Expand Down Expand Up @@ -325,10 +329,19 @@ def get_settings() -> Settings:
}

# load graph config
graph_config = {
"GRAPH_PROVIDER": "litellm",
"ENABLE_ENTITY_RESOLUTION": config["graph"].get("enable_entity_resolution", True),
}
graph_config = (
{
"GRAPH_MODE": "local",
"GRAPH_PROVIDER": "litellm",
"ENABLE_ENTITY_RESOLUTION": config["graph"].get("enable_entity_resolution", True),
}
if config["graph"].get("mode", "local") == "local"
else {
"GRAPH_MODE": "api",
"MORPHIK_GRAPH_BASE_URL": config["graph"].get("base_url", "https://graph-api.morphik.ai"),
"MORPHIK_GRAPH_API_KEY": os.environ.get("MORPHIK_GRAPH_API_KEY", None),
}
)

# Set the model key for LiteLLM
if "model" not in config["graph"]:
Expand Down
53 changes: 49 additions & 4 deletions core/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from core.parser.base_parser import BaseParser
from core.reranker.base_reranker import BaseReranker
from core.services.graph_service import GraphService
from core.services.morphik_graph_service import MorphikGraphService
from core.services.rules_processor import RulesProcessor
from core.storage.base_storage import BaseStorage
from core.vector_store.base_vector_store import BaseVectorStore
Expand All @@ -50,6 +51,8 @@
CHARS_PER_TOKEN = 4
TOKENS_PER_PAGE = 630

settings = get_settings()


class DocumentService:
async def _ensure_folder_exists(
Expand Down Expand Up @@ -135,10 +138,20 @@ def __init__(
# Initialize the graph service only if completion_model is provided
# (e.g., not needed for ingestion worker)
if completion_model is not None:
self.graph_service = GraphService(
db=database,
embedding_model=embedding_model,
completion_model=completion_model,
self.graph_service = (
GraphService(
db=database,
embedding_model=embedding_model,
completion_model=completion_model,
)
if settings.GRAPH_MODE == "local"
else MorphikGraphService(
db=database,
embedding_model=embedding_model,
completion_model=completion_model,
base_url=settings.MORPHIK_GRAPH_BASE_URL,
graph_api_key=settings.MORPHIK_GRAPH_API_KEY,
)
)
else:
self.graph_service = None
Expand Down Expand Up @@ -2142,3 +2155,35 @@ async def _upload_to_app_bucket(
) -> tuple[str, str]:
bucket_override = await self._get_bucket_for_app(auth.app_id)
return await self.storage.upload_from_base64(content_base64, key, content_type, bucket=bucket_override or "")

async def get_graph_visualization_data(
self,
name: str,
auth: AuthContext,
folder_name: Optional[Union[str, List[str]]] = None,
end_user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Get graph visualization data.

Args:
name: Name of the graph to visualize
auth: Authentication context
folder_name: Optional folder name for scoping
end_user_id: Optional end user ID for scoping

Returns:
Dict containing nodes and links for visualization
"""
# Create system filters for folder and user scoping
system_filters = {}
if folder_name:
system_filters["folder_name"] = folder_name
if end_user_id:
system_filters["end_user_id"] = end_user_id

# Delegate to the GraphService
return await self.graph_service.get_graph_visualization_data(
graph_name=name,
auth=auth,
system_filters=system_filters,
)
68 changes: 68 additions & 0 deletions core/services/graph_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,3 +1462,71 @@ async def _generate_completion(
}

return response

async def get_graph_visualization_data(
self,
graph_name: str,
auth: AuthContext,
system_filters: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Get graph visualization data for local graphs.

Args:
graph_name: Name of the graph to visualize
auth: Authentication context
system_filters: Optional system filters for graph retrieval

Returns:
Dict containing nodes and links for visualization
"""
# Initialize system_filters if None
if system_filters is None:
system_filters = {}

graph = await self.db.get_graph(graph_name, auth, system_filters=system_filters)
if not graph:
logger.warning(f"Graph '{graph_name}' not found or not accessible")
return {"nodes": [], "links": []}

# Transform entities to nodes format
nodes = []
for entity in graph.entities:
nodes.append(
{
"id": entity.id,
"label": entity.label,
"type": entity.type,
"properties": entity.properties,
"color": self._get_node_color(entity.type),
}
)

# Transform relationships to links format
links = []
entity_id_set = {entity.id for entity in graph.entities}
for relationship in graph.relationships:
# Only include relationships where both source and target exist
if relationship.source_id in entity_id_set and relationship.target_id in entity_id_set:
links.append(
{"source": relationship.source_id, "target": relationship.target_id, "type": relationship.type}
)

return {"nodes": nodes, "links": links}

def _get_node_color(self, node_type: str) -> str:
"""Get color for a node type to match the UI color scheme."""
color_map = {
"person": "#4f46e5", # Indigo
"organization": "#06b6d4", # Cyan
"location": "#10b981", # Emerald
"date": "#f59e0b", # Amber
"concept": "#8b5cf6", # Violet
"event": "#ec4899", # Pink
"product": "#ef4444", # Red
"entity": "#4f46e5", # Indigo (for generic entities)
"attribute": "#f59e0b", # Amber
"relationship": "#ec4899", # Pink
"high_level_element": "#10b981", # Emerald
"semantic_unit": "#8b5cf6", # Violet
}
return color_map.get(node_type.lower(), "#6b7280") # Gray as default
93 changes: 93 additions & 0 deletions core/services/morphik_graph_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,97 @@ async def retrieve(
# Depending on requirements, either re-raise or return an error message / empty string
raise # Re-raise the exception to be handled by the caller

async def get_graph_visualization_data(
self,
graph_name: str,
auth: AuthContext,
system_filters: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Get graph visualization data from the external graph API.

Args:
graph_name: Name of the graph to visualize
auth: Authentication context
system_filters: Optional system filters for graph retrieval

Returns:
Dict containing nodes and links for visualization
"""
graph = await self._find_graph(graph_name, auth, system_filters)
graph_id = graph.id

request_data = {"graph_id": graph_id}
try:
api_response = await self._make_api_request(
method="POST",
endpoint="/visualization",
auth=auth,
json_data=request_data,
)
logger.info(f"Visualization API call for graph_id {graph_id} successful.")

# The API should return a structure like:
# {
# "nodes": [{"id": "...", "label": "...", "type": "...", "properties": {...}}, ...],
# "links": [{"source": "...", "target": "...", "type": "..."}, ...]
# }

if isinstance(api_response, dict):
# Ensure we have the expected structure
nodes = api_response.get("nodes", [])
links = api_response.get("links", [])

# Transform to match the expected format for the UI
formatted_nodes = []
for node in nodes:
formatted_nodes.append(
{
"id": node.get("id", ""),
"label": node.get("label", ""),
"type": node.get("type", "unknown"),
"properties": node.get("properties", {}),
"color": self._get_node_color(node.get("type", "unknown")),
}
)

formatted_links = []
for link in links:
formatted_links.append(
{
"source": link.get("source", ""),
"target": link.get("target", ""),
"type": link.get("type", ""),
}
)

return {"nodes": formatted_nodes, "links": formatted_links}
else:
logger.warning(f"Unexpected response format from visualization API: {type(api_response)}")
return {"nodes": [], "links": []}

except Exception as e:
logger.error(f"Failed to call visualization API for graph_id {graph_id}: {e}")
# Return empty visualization data on error
return {"nodes": [], "links": []}

def _get_node_color(self, node_type: str) -> str:
"""Get color for a node type to match the UI color scheme."""
color_map = {
"person": "#4f46e5", # Indigo
"organization": "#06b6d4", # Cyan
"location": "#10b981", # Emerald
"date": "#f59e0b", # Amber
"concept": "#8b5cf6", # Violet
"event": "#ec4899", # Pink
"product": "#ef4444", # Red
"entity": "#4f46e5", # Indigo (for generic entities)
"attribute": "#f59e0b", # Amber
"relationship": "#ec4899", # Pink
"high_level_element": "#10b981", # Emerald
"semantic_unit": "#8b5cf6", # Violet
}
return color_map.get(node_type.lower(), "#6b7280") # Gray as default

async def query_with_graph(
self,
query: str,
Expand All @@ -391,6 +482,8 @@ async def query_with_graph(
system_filters: Optional[Dict[str, Any]] = None, # For graph retrieval in self.retrieve
folder_name: Optional[Union[str, List[str]]] = None, # For document_service and CompletionRequest
end_user_id: Optional[str] = None, # For document_service and CompletionRequest
hop_depth: Optional[int] = None, # maintain signature
include_paths: Optional[bool] = None, # maintain signature
) -> CompletionResponse:
"""Generate completion using combined context from an external graph API and standard document retrieval.

Expand Down
Loading