Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ async def load_llm_bundle(
billing_tier="free",
)
return (
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
SanitizedChatLiteLLM(
model=model_string, **{**litellm_kwargs, "streaming": True}
),
agent_config,
None,
)
Expand Down Expand Up @@ -174,7 +176,9 @@ async def load_llm_bundle(
billing_tier=str(global_model.get("billing_tier", "free")).lower(),
)
return (
SanitizedChatLiteLLM(model=model_string, **litellm_kwargs),
SanitizedChatLiteLLM(
model=model_string, **{**litellm_kwargs, "streaming": True}
),
agent_config,
None,
)
154 changes: 154 additions & 0 deletions surfsense_backend/tests/unit/tasks/chat/streaming/test_llm_bundle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Contracts for chat LLM construction in streaming flows.

``stream_new_chat`` / ``stream_resume_chat`` depend on LangChain receiving
token chunks from ``ChatLiteLLM``. ``langchain-litellm`` defaults
``streaming`` to ``False``, so the shared bundle loader must opt in
explicitly for both DB-backed and global model paths.
"""

from __future__ import annotations

from types import SimpleNamespace
from typing import Any

import pytest

import app.tasks.chat.streaming.flows.shared.llm_bundle as llm_bundle

pytestmark = pytest.mark.unit


class _CapturedChatLiteLLM:
calls: list[dict[str, Any]] = []

def __init__(self, **kwargs: Any) -> None:
self.kwargs = kwargs
self.__class__.calls.append(kwargs)


@pytest.fixture(autouse=True)
def _patch_common_bundle_dependencies(monkeypatch: pytest.MonkeyPatch):
"""Keep these tests focused on the LLM constructor contract."""

_CapturedChatLiteLLM.calls = []

async def _fake_search_space(_session: Any, _search_space_id: int) -> SimpleNamespace:
return SimpleNamespace(id=42, user_id="user-1")

monkeypatch.setattr(llm_bundle, "_load_search_space", _fake_search_space)
monkeypatch.setattr(llm_bundle, "SanitizedChatLiteLLM", _CapturedChatLiteLLM)
monkeypatch.setattr(llm_bundle, "register_model_usage_metadata", lambda **_kw: None)
monkeypatch.setattr(
llm_bundle,
"has_capability",
lambda _model, capability: capability in {"chat", "vision"},
)

return None


async def test_load_llm_bundle_enables_streaming_for_db_models(
monkeypatch: pytest.MonkeyPatch,
) -> None:
connection = SimpleNamespace(
provider="openai",
api_key="sk-test",
base_url=None,
extra={"litellm_params": {"temperature": 0.1}},
)
model = SimpleNamespace(
id=7,
model_id="gpt-4o-mini",
display_name="GPT 4o Mini",
connection=connection,
)

async def _fake_db_model(_session: Any, *, model_id: int, search_space: Any) -> Any:
assert model_id == 7
assert search_space.id == 42
return model

monkeypatch.setattr(llm_bundle, "_load_db_model", _fake_db_model)
monkeypatch.setattr(
llm_bundle,
"to_litellm",
lambda _conn, _model_id: (
"openai/gpt-4o-mini",
{"api_key": "sk-test", "temperature": 0.1},
),
)

llm, agent_config, error = await llm_bundle.load_llm_bundle(
object(),
config_id=7,
search_space_id=42,
)

assert error is None
assert llm is not None
assert agent_config is not None
assert _CapturedChatLiteLLM.calls == [
{
"model": "openai/gpt-4o-mini",
"api_key": "sk-test",
"temperature": 0.1,
"streaming": True,
}
]


async def test_load_llm_bundle_enables_streaming_for_global_models(
monkeypatch: pytest.MonkeyPatch,
) -> None:
global_model = {
"id": -11,
"connection_id": -101,
"model_id": "claude-sonnet-4-5",
"display_name": "Claude Sonnet",
"billing_tier": "premium",
}
global_connection = {
"id": -101,
"provider": "anthropic",
"api_key": "sk-ant-test",
"base_url": None,
"extra": {"litellm_params": {"temperature": 0.2}},
}
monkeypatch.setattr(
llm_bundle.config,
"GLOBAL_MODELS",
[global_model],
raising=False,
)
monkeypatch.setattr(
llm_bundle.config,
"GLOBAL_CONNECTIONS",
[global_connection],
raising=False,
)
monkeypatch.setattr(
llm_bundle,
"to_litellm",
lambda _conn, _model_id: (
"anthropic/claude-sonnet-4-5",
{"api_key": "sk-ant-test", "temperature": 0.2},
),
)

llm, agent_config, error = await llm_bundle.load_llm_bundle(
object(),
config_id=-11,
search_space_id=42,
)

assert error is None
assert llm is not None
assert agent_config is not None
assert _CapturedChatLiteLLM.calls == [
{
"model": "anthropic/claude-sonnet-4-5",
"api_key": "sk-ant-test",
"temperature": 0.2,
"streaming": True,
}
]
4 changes: 2 additions & 2 deletions surfsense_web/components/assistant-ui/chat-viewport.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ export interface ChatViewportProps {
export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => (
<ThreadPrimitive.Viewport
turnAnchor="top"
autoScroll={false}
scrollToBottomOnRunStart={false}
autoScroll
scrollToBottomOnRunStart
scrollToBottomOnInitialize
scrollToBottomOnThreadSwitch
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 scroll-smooth"
Expand Down
2 changes: 1 addition & 1 deletion surfsense_web/components/assistant-ui/markdown-text.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ const MarkdownTextImpl = () => {
return (
<CitationUrlMapContext.Provider value={urlMapRef}>
<MarkdownTextPrimitive
smooth={false}
smooth
remarkPlugins={[remarkGfm, [remarkMath, { singleDollarTextMath: false }]]}
rehypePlugins={[rehypeKatex]}
className="aui-md"
Expand Down
Loading