Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 35 additions & 28 deletions ak-py/src/agentkernel/framework/adk/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,25 +140,26 @@ def _process_requests(requests: list[AgentRequest]) -> tuple[str, list[types.Par

return prompt, parts

async def _setup_session_context(self, agent: Any, session: Session, requests: list[AgentRequest]) -> tuple[str, Runner]:
async def _setup_session_context(self, agent: Any, session: Session, requests: list[AgentRequest]) -> tuple[str, Runner, AKToolContext]:
"""
Setup ADK session and tool context.
:param agent: The ADK agent.
:param session: The AgentKernel session.
:param requests: The requests.
:return: Tuple of (user_id, runner).
:return: Tuple of (user_id, runner, tool_context). The caller is responsible for entering/exiting
the returned tool_context around the runner's actual execution, since tools invoked by
the agent look up this context by id from the cache while the agent is running.
"""
app_name = "AgentKernel"
user_id = "AgentKernel"
adk_session = self._session(session)

ctx: AKToolContext = AKToolContext(Runtime.current(), agent, session, requests)
with ctx:
await adk_session.create_session(app_name=app_name, user_id=user_id, session_id=session.id)
await adk_session.update_session_state(ctx.id, agent.name, {"ak_tool_context": ctx.id})
await adk_session.create_session(app_name=app_name, user_id=user_id, session_id=session.id)
await adk_session.update_session_state(ctx.id, agent.name, {"ak_tool_context": ctx.id})

runner = Runner(agent=agent.agent, app_name=app_name, session_service=adk_session.session_service)
return user_id, runner
runner = Runner(agent=agent.agent, app_name=app_name, session_service=adk_session.session_service)
return user_id, runner, ctx

@staticmethod
async def get_response(runner: Runner, user_id: str, session_id: str, parts: list[types.Part]) -> str:
Expand All @@ -174,11 +175,15 @@ async def get_response(runner: Runner, user_id: str, session_id: str, parts: lis
response_text = ""

if hasattr(runner, "run_async"):
async for event in runner.run_async(user_id=user_id, session_id=session_id, new_message=new_message):
if event.is_final_response() and event.content and event.content.parts:
text_parts = [p.text for p in event.content.parts if hasattr(p, "text") and p.text]
response_text = " ".join(text_parts) if text_parts else ""
break
events = runner.run_async(user_id=user_id, session_id=session_id, new_message=new_message)
try:
async for event in events:
if event.is_final_response() and event.content and event.content.parts:
text_parts = [p.text for p in event.content.parts if hasattr(p, "text") and p.text]
response_text = " ".join(text_parts) if text_parts else ""
break
finally:
await events.aclose()
else:
for event in runner.run(user_id=user_id, session_id=session_id, new_message=new_message):
if event.is_final_response() and event.content and event.content.parts:
Expand All @@ -201,8 +206,9 @@ async def run(self, agent: Any, session: Session, requests: list[AgentRequest])
if not parts:
return AgentReplyText(text="Sorry. No valid content found in the requests")

user_id, runner = await self._setup_session_context(agent, session, requests)
reply = await self.get_response(runner=runner, session_id=session.id, parts=parts, user_id=user_id)
user_id, runner, ctx = await self._setup_session_context(agent, session, requests)
with ctx:
reply = await self.get_response(runner=runner, session_id=session.id, parts=parts, user_id=user_id)
return AgentReplyText(text=reply, prompt=prompt)
except Exception as e:
return AgentReplyText(text=user_facing_error_message(e), prompt=prompt)
Expand All @@ -220,24 +226,25 @@ async def stream(self, agent: Any, session: Session, requests: list[AgentRequest
if not parts:
return

user_id, runner = await self._setup_session_context(agent, session, requests)
user_id, runner, ctx = await self._setup_session_context(agent, session, requests)
new_message = types.Content(role="user", parts=parts)
run_config = RunConfig(streaming_mode=StreamingMode.SSE)

if hasattr(runner, "run_async"):
async for event in runner.run_async(
user_id=user_id,
session_id=session.id,
new_message=new_message,
run_config=run_config,
):
if not getattr(event, "partial", False):
continue
if not event.content or not event.content.parts:
continue
chunk = "".join(getattr(part, "text", "") or "" for part in event.content.parts)
if chunk:
yield chunk
with ctx:
async for event in runner.run_async(
user_id=user_id,
session_id=session.id,
new_message=new_message,
run_config=run_config,
):
if not getattr(event, "partial", False):
continue
if not event.content or not event.content.parts:
continue
chunk = "".join(getattr(part, "text", "") or "" for part in event.content.parts)
if chunk:
yield chunk


class GoogleADKAgent(AKBaseAgent):
Expand Down
Loading