318 lines
10 KiB
Python
318 lines
10 KiB
Python
from typing import List, AsyncGenerator
|
|
import json
|
|
import logging
|
|
|
|
from fastapi import APIRouter, HTTPException, Depends
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from google.adk.runners import Runner
|
|
from google.adk.agents.run_config import RunConfig, StreamingMode
|
|
from google.genai import types
|
|
|
|
from app.config import settings
|
|
from app.models import (
|
|
CreateSessionRequest,
|
|
ChatTurnRequest,
|
|
SessionResponse,
|
|
SessionDetailResponse,
|
|
HistoryEvent,
|
|
AgentConfig
|
|
)
|
|
from app.services import get_session_service
|
|
from app.services.agent_service import get_agent_service, AgentService
|
|
from app.agent_factory import create_agent
|
|
|
|
# Initialize Router
|
|
router = APIRouter(
|
|
prefix="/sessions",
|
|
tags=["Sessions"]
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# --- Helper ---
|
|
def get_streaming_mode(mode_str: str) -> StreamingMode:
|
|
"""Convert string streaming mode to StreamingMode enum."""
|
|
mode_map = {
|
|
"none": StreamingMode.NONE,
|
|
"sse": StreamingMode.SSE,
|
|
"bidi": StreamingMode.BIDI,
|
|
}
|
|
return mode_map.get(mode_str.lower(), StreamingMode.SSE)
|
|
|
|
|
|
# ==========================
|
|
# 📋 Session CRUD
|
|
# ==========================
|
|
|
|
@router.post("", response_model=SessionResponse)
|
|
async def create_session(
|
|
req: CreateSessionRequest,
|
|
service=Depends(get_session_service)
|
|
):
|
|
"""Create a new chat session."""
|
|
app_name = req.app_name or settings.DEFAULT_APP_NAME
|
|
session = await service.create_session(app_name=app_name, user_id=req.user_id)
|
|
return SessionResponse(
|
|
id=session.id,
|
|
app_name=session.app_name,
|
|
user_id=session.user_id,
|
|
updated_at=getattr(session, 'last_update_time', getattr(session, 'updated_at', None))
|
|
)
|
|
|
|
|
|
@router.get("", response_model=List[SessionResponse])
|
|
async def list_sessions(
|
|
user_id: str,
|
|
app_name: str = None,
|
|
service=Depends(get_session_service)
|
|
):
|
|
"""List all sessions for a user."""
|
|
app_name = app_name or settings.DEFAULT_APP_NAME
|
|
response = await service.list_sessions(app_name=app_name, user_id=user_id)
|
|
|
|
return [
|
|
SessionResponse(
|
|
id=s.id,
|
|
app_name=s.app_name,
|
|
user_id=s.user_id,
|
|
updated_at=getattr(s, 'last_update_time', getattr(s, 'updated_at', None))
|
|
) for s in response.sessions
|
|
]
|
|
|
|
|
|
@router.get("/{session_id}", response_model=SessionDetailResponse)
|
|
async def get_session_history(
|
|
session_id: str,
|
|
user_id: str,
|
|
app_name: str = None,
|
|
service=Depends(get_session_service)
|
|
):
|
|
"""Get session details and chat history."""
|
|
app_name = app_name or settings.DEFAULT_APP_NAME
|
|
session = await service.get_session(
|
|
app_name=app_name,
|
|
user_id=user_id,
|
|
session_id=session_id
|
|
)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
# Convert history events
|
|
history_events = []
|
|
events = getattr(session, 'events', [])
|
|
|
|
# Filter events based on Rewind actions
|
|
# ADK preserves all events in the log. We must retroactively remove rewound events for the view.
|
|
valid_events = []
|
|
for e in events:
|
|
# Check for rewind action
|
|
actions = getattr(e, 'actions', None)
|
|
rewind_target_id = getattr(actions, 'rewind_before_invocation_id', None) if actions else None
|
|
|
|
if rewind_target_id:
|
|
# Find the index where the truncated invocation began
|
|
truncate_idx = -1
|
|
for i, ve in enumerate(valid_events):
|
|
if getattr(ve, 'invocation_id', None) == rewind_target_id:
|
|
truncate_idx = i
|
|
break
|
|
|
|
if truncate_idx != -1:
|
|
logger.debug(f"Rewinding history to before {rewind_target_id}")
|
|
valid_events = valid_events[:truncate_idx]
|
|
else:
|
|
valid_events.append(e)
|
|
|
|
for e in valid_events:
|
|
content_str = ""
|
|
if hasattr(e, 'content') and e.content:
|
|
if isinstance(e.content, str):
|
|
content_str = e.content
|
|
elif hasattr(e.content, 'parts'):
|
|
parts = e.content.parts
|
|
for part in parts:
|
|
if hasattr(part, 'text') and part.text:
|
|
content_str += part.text
|
|
|
|
# Determine basic role (user vs model)
|
|
# ADK usually sets author to the agent name or "user"
|
|
author = getattr(e, 'author', 'unknown')
|
|
role = "user" if author == "user" else "model"
|
|
|
|
# Agent Name is the specific author if it's a model response
|
|
agent_name = author if role == "model" else None
|
|
|
|
# Extract timestamp
|
|
timestamp = getattr(e, 'timestamp', getattr(e, 'created_at', None))
|
|
invocation_id = getattr(e, 'invocation_id', None)
|
|
|
|
history_events.append(HistoryEvent(
|
|
type="message",
|
|
role=role,
|
|
agent_name=agent_name,
|
|
content=content_str,
|
|
timestamp=timestamp,
|
|
invocation_id=invocation_id
|
|
))
|
|
|
|
return SessionDetailResponse(
|
|
id=session.id,
|
|
app_name=session.app_name,
|
|
user_id=session.user_id,
|
|
updated_at=getattr(session, 'last_update_time', getattr(session, 'updated_at', None)),
|
|
created_at=getattr(session, 'created_at', getattr(session, 'create_time', None)),
|
|
history=history_events,
|
|
events_count=len(history_events)
|
|
)
|
|
|
|
|
|
@router.delete("/{session_id}")
|
|
async def delete_session(
|
|
session_id: str,
|
|
user_id: str,
|
|
app_name: str = None,
|
|
service=Depends(get_session_service)
|
|
):
|
|
"""Delete a session."""
|
|
app_name = app_name or settings.DEFAULT_APP_NAME
|
|
await service.delete_session(app_name=app_name, user_id=user_id, session_id=session_id)
|
|
return {"status": "deleted", "session_id": session_id}
|
|
|
|
|
|
@router.post("/{session_id}/rewind")
|
|
async def rewind_session_state(
|
|
session_id: str,
|
|
user_id: str,
|
|
invocation_id: str,
|
|
app_name: str = None,
|
|
session_service=Depends(get_session_service),
|
|
agent_service: AgentService = Depends(get_agent_service)
|
|
):
|
|
"""
|
|
Rewind session state to before a specific invocation.
|
|
Undoes changes from the specified invocation and subsequent ones.
|
|
"""
|
|
app_name = app_name or settings.DEFAULT_APP_NAME
|
|
|
|
# 1. Need a Runner instance to perform rewind.
|
|
# The Runner requires an agent, but for rewind (session/log operation),
|
|
# the specific agent configuration might not be critical if default is used,
|
|
# assuming session persistence is agent-agnostic.
|
|
# We'll use the default assistant to initialize the runner.
|
|
agent_def = agent_service.get_agent("default-assistant")
|
|
if not agent_def:
|
|
# Fallback if default deleted
|
|
agent_def = AgentConfig(
|
|
name="rewinder",
|
|
model=settings.LLM_DEFAULT_MODEL,
|
|
instruction="",
|
|
tools=[]
|
|
)
|
|
|
|
agent = create_agent(agent_def)
|
|
|
|
runner = Runner(
|
|
agent=agent,
|
|
app_name=app_name,
|
|
session_service=session_service
|
|
)
|
|
|
|
try:
|
|
await runner.rewind_async(
|
|
user_id=user_id,
|
|
session_id=session_id,
|
|
rewind_before_invocation_id=invocation_id
|
|
)
|
|
return {"status": "success", "rewound_before": invocation_id}
|
|
except Exception as e:
|
|
logger.error(f"Rewind failed: {e}")
|
|
# ADK might raise specific errors, generic catch for now
|
|
raise HTTPException(status_code=500, detail=f"Rewind failed: {str(e)}")
|
|
|
|
|
|
# ==========================
|
|
# 💬 Session Chat
|
|
# ==========================
|
|
|
|
@router.post("/{session_id}/chat")
|
|
async def chat_with_agent(
|
|
session_id: str,
|
|
req: ChatTurnRequest,
|
|
user_id: str,
|
|
app_name: str = None,
|
|
session_service=Depends(get_session_service),
|
|
agent_service: AgentService = Depends(get_agent_service)
|
|
):
|
|
"""
|
|
Chat with a specific Agent in a Session.
|
|
Decoupled: Agent is loaded from AgentService, Session is loaded from SessionService.
|
|
"""
|
|
app_name = app_name or settings.DEFAULT_APP_NAME
|
|
|
|
# 1. Load Session
|
|
session = await session_service.get_session(
|
|
app_name=app_name,
|
|
user_id=user_id,
|
|
session_id=session_id
|
|
)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
# 2. Get Agent Definition
|
|
agent_def = agent_service.get_agent(req.agent_id)
|
|
if not agent_def:
|
|
raise HTTPException(status_code=404, detail=f"Agent '{req.agent_id}' not found")
|
|
|
|
# 3. Create Runtime Agent
|
|
agent = create_agent(agent_def)
|
|
|
|
# 4. Preparation for Run
|
|
streaming_mode = get_streaming_mode(req.streaming_mode or "sse")
|
|
run_config = RunConfig(
|
|
streaming_mode=streaming_mode,
|
|
max_llm_calls=500, # or configurable
|
|
)
|
|
|
|
runner = Runner(
|
|
agent=agent,
|
|
app_name=app_name,
|
|
session_service=session_service
|
|
)
|
|
|
|
# 5. Stream Generator
|
|
async def event_generator() -> AsyncGenerator[str, None]:
|
|
new_msg = types.Content(
|
|
role="user",
|
|
parts=[types.Part(text=req.message)]
|
|
)
|
|
try:
|
|
async for event in runner.run_async(
|
|
session_id=session.id,
|
|
user_id=user_id,
|
|
new_message=new_msg,
|
|
run_config=run_config,
|
|
):
|
|
if event.content:
|
|
text_content = ""
|
|
if hasattr(event.content, 'parts'):
|
|
for part in event.content.parts:
|
|
if hasattr(part, 'text') and part.text:
|
|
text_content += part.text
|
|
elif isinstance(event.content, str):
|
|
text_content = event.content
|
|
|
|
if text_content:
|
|
payload = {"type": "content", "text": text_content, "role": "model"}
|
|
yield f"data: {json.dumps(payload)}\n\n"
|
|
|
|
yield f"data: {json.dumps({'type': 'done'})}\n\n"
|
|
|
|
except Exception as e:
|
|
logger.exception("Error during agent chat stream")
|
|
err_payload = {"type": "error", "text": str(e)}
|
|
yield f"data: {json.dumps(err_payload)}\n\n"
|
|
|
|
return StreamingResponse(event_generator(), media_type="text/event-stream")
|