首次提交
This commit is contained in:
317
app/routers/sessions.py
Normal file
317
app/routers/sessions.py
Normal file
@@ -0,0 +1,317 @@
|
||||
from typing import List, AsyncGenerator
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.agents.run_config import RunConfig, StreamingMode
|
||||
from google.genai import types
|
||||
|
||||
from app.config import settings
|
||||
from app.models import (
|
||||
CreateSessionRequest,
|
||||
ChatTurnRequest,
|
||||
SessionResponse,
|
||||
SessionDetailResponse,
|
||||
HistoryEvent,
|
||||
AgentConfig
|
||||
)
|
||||
from app.services import get_session_service
|
||||
from app.services.agent_service import get_agent_service, AgentService
|
||||
from app.agent_factory import create_agent
|
||||
|
||||
# Initialize Router
|
||||
router = APIRouter(
|
||||
prefix="/sessions",
|
||||
tags=["Sessions"]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --- Helper ---
|
||||
def get_streaming_mode(mode_str: str) -> StreamingMode:
|
||||
"""Convert string streaming mode to StreamingMode enum."""
|
||||
mode_map = {
|
||||
"none": StreamingMode.NONE,
|
||||
"sse": StreamingMode.SSE,
|
||||
"bidi": StreamingMode.BIDI,
|
||||
}
|
||||
return mode_map.get(mode_str.lower(), StreamingMode.SSE)
|
||||
|
||||
|
||||
# ==========================
|
||||
# 📋 Session CRUD
|
||||
# ==========================
|
||||
|
||||
@router.post("", response_model=SessionResponse)
|
||||
async def create_session(
|
||||
req: CreateSessionRequest,
|
||||
service=Depends(get_session_service)
|
||||
):
|
||||
"""Create a new chat session."""
|
||||
app_name = req.app_name or settings.DEFAULT_APP_NAME
|
||||
session = await service.create_session(app_name=app_name, user_id=req.user_id)
|
||||
return SessionResponse(
|
||||
id=session.id,
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
updated_at=getattr(session, 'last_update_time', getattr(session, 'updated_at', None))
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=List[SessionResponse])
|
||||
async def list_sessions(
|
||||
user_id: str,
|
||||
app_name: str = None,
|
||||
service=Depends(get_session_service)
|
||||
):
|
||||
"""List all sessions for a user."""
|
||||
app_name = app_name or settings.DEFAULT_APP_NAME
|
||||
response = await service.list_sessions(app_name=app_name, user_id=user_id)
|
||||
|
||||
return [
|
||||
SessionResponse(
|
||||
id=s.id,
|
||||
app_name=s.app_name,
|
||||
user_id=s.user_id,
|
||||
updated_at=getattr(s, 'last_update_time', getattr(s, 'updated_at', None))
|
||||
) for s in response.sessions
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{session_id}", response_model=SessionDetailResponse)
|
||||
async def get_session_history(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
app_name: str = None,
|
||||
service=Depends(get_session_service)
|
||||
):
|
||||
"""Get session details and chat history."""
|
||||
app_name = app_name or settings.DEFAULT_APP_NAME
|
||||
session = await service.get_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Convert history events
|
||||
history_events = []
|
||||
events = getattr(session, 'events', [])
|
||||
|
||||
# Filter events based on Rewind actions
|
||||
# ADK preserves all events in the log. We must retroactively remove rewound events for the view.
|
||||
valid_events = []
|
||||
for e in events:
|
||||
# Check for rewind action
|
||||
actions = getattr(e, 'actions', None)
|
||||
rewind_target_id = getattr(actions, 'rewind_before_invocation_id', None) if actions else None
|
||||
|
||||
if rewind_target_id:
|
||||
# Find the index where the truncated invocation began
|
||||
truncate_idx = -1
|
||||
for i, ve in enumerate(valid_events):
|
||||
if getattr(ve, 'invocation_id', None) == rewind_target_id:
|
||||
truncate_idx = i
|
||||
break
|
||||
|
||||
if truncate_idx != -1:
|
||||
logger.debug(f"Rewinding history to before {rewind_target_id}")
|
||||
valid_events = valid_events[:truncate_idx]
|
||||
else:
|
||||
valid_events.append(e)
|
||||
|
||||
for e in valid_events:
|
||||
content_str = ""
|
||||
if hasattr(e, 'content') and e.content:
|
||||
if isinstance(e.content, str):
|
||||
content_str = e.content
|
||||
elif hasattr(e.content, 'parts'):
|
||||
parts = e.content.parts
|
||||
for part in parts:
|
||||
if hasattr(part, 'text') and part.text:
|
||||
content_str += part.text
|
||||
|
||||
# Determine basic role (user vs model)
|
||||
# ADK usually sets author to the agent name or "user"
|
||||
author = getattr(e, 'author', 'unknown')
|
||||
role = "user" if author == "user" else "model"
|
||||
|
||||
# Agent Name is the specific author if it's a model response
|
||||
agent_name = author if role == "model" else None
|
||||
|
||||
# Extract timestamp
|
||||
timestamp = getattr(e, 'timestamp', getattr(e, 'created_at', None))
|
||||
invocation_id = getattr(e, 'invocation_id', None)
|
||||
|
||||
history_events.append(HistoryEvent(
|
||||
type="message",
|
||||
role=role,
|
||||
agent_name=agent_name,
|
||||
content=content_str,
|
||||
timestamp=timestamp,
|
||||
invocation_id=invocation_id
|
||||
))
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.id,
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
updated_at=getattr(session, 'last_update_time', getattr(session, 'updated_at', None)),
|
||||
created_at=getattr(session, 'created_at', getattr(session, 'create_time', None)),
|
||||
history=history_events,
|
||||
events_count=len(history_events)
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{session_id}")
|
||||
async def delete_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
app_name: str = None,
|
||||
service=Depends(get_session_service)
|
||||
):
|
||||
"""Delete a session."""
|
||||
app_name = app_name or settings.DEFAULT_APP_NAME
|
||||
await service.delete_session(app_name=app_name, user_id=user_id, session_id=session_id)
|
||||
return {"status": "deleted", "session_id": session_id}
|
||||
|
||||
|
||||
@router.post("/{session_id}/rewind")
|
||||
async def rewind_session_state(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
invocation_id: str,
|
||||
app_name: str = None,
|
||||
session_service=Depends(get_session_service),
|
||||
agent_service: AgentService = Depends(get_agent_service)
|
||||
):
|
||||
"""
|
||||
Rewind session state to before a specific invocation.
|
||||
Undoes changes from the specified invocation and subsequent ones.
|
||||
"""
|
||||
app_name = app_name or settings.DEFAULT_APP_NAME
|
||||
|
||||
# 1. Need a Runner instance to perform rewind.
|
||||
# The Runner requires an agent, but for rewind (session/log operation),
|
||||
# the specific agent configuration might not be critical if default is used,
|
||||
# assuming session persistence is agent-agnostic.
|
||||
# We'll use the default assistant to initialize the runner.
|
||||
agent_def = agent_service.get_agent("default-assistant")
|
||||
if not agent_def:
|
||||
# Fallback if default deleted
|
||||
agent_def = AgentConfig(
|
||||
name="rewinder",
|
||||
model=settings.LLM_DEFAULT_MODEL,
|
||||
instruction="",
|
||||
tools=[]
|
||||
)
|
||||
|
||||
agent = create_agent(agent_def)
|
||||
|
||||
runner = Runner(
|
||||
agent=agent,
|
||||
app_name=app_name,
|
||||
session_service=session_service
|
||||
)
|
||||
|
||||
try:
|
||||
await runner.rewind_async(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
rewind_before_invocation_id=invocation_id
|
||||
)
|
||||
return {"status": "success", "rewound_before": invocation_id}
|
||||
except Exception as e:
|
||||
logger.error(f"Rewind failed: {e}")
|
||||
# ADK might raise specific errors, generic catch for now
|
||||
raise HTTPException(status_code=500, detail=f"Rewind failed: {str(e)}")
|
||||
|
||||
|
||||
# ==========================
|
||||
# 💬 Session Chat
|
||||
# ==========================
|
||||
|
||||
@router.post("/{session_id}/chat")
|
||||
async def chat_with_agent(
|
||||
session_id: str,
|
||||
req: ChatTurnRequest,
|
||||
user_id: str,
|
||||
app_name: str = None,
|
||||
session_service=Depends(get_session_service),
|
||||
agent_service: AgentService = Depends(get_agent_service)
|
||||
):
|
||||
"""
|
||||
Chat with a specific Agent in a Session.
|
||||
Decoupled: Agent is loaded from AgentService, Session is loaded from SessionService.
|
||||
"""
|
||||
app_name = app_name or settings.DEFAULT_APP_NAME
|
||||
|
||||
# 1. Load Session
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# 2. Get Agent Definition
|
||||
agent_def = agent_service.get_agent(req.agent_id)
|
||||
if not agent_def:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{req.agent_id}' not found")
|
||||
|
||||
# 3. Create Runtime Agent
|
||||
agent = create_agent(agent_def)
|
||||
|
||||
# 4. Preparation for Run
|
||||
streaming_mode = get_streaming_mode(req.streaming_mode or "sse")
|
||||
run_config = RunConfig(
|
||||
streaming_mode=streaming_mode,
|
||||
max_llm_calls=500, # or configurable
|
||||
)
|
||||
|
||||
runner = Runner(
|
||||
agent=agent,
|
||||
app_name=app_name,
|
||||
session_service=session_service
|
||||
)
|
||||
|
||||
# 5. Stream Generator
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
new_msg = types.Content(
|
||||
role="user",
|
||||
parts=[types.Part(text=req.message)]
|
||||
)
|
||||
try:
|
||||
async for event in runner.run_async(
|
||||
session_id=session.id,
|
||||
user_id=user_id,
|
||||
new_message=new_msg,
|
||||
run_config=run_config,
|
||||
):
|
||||
if event.content:
|
||||
text_content = ""
|
||||
if hasattr(event.content, 'parts'):
|
||||
for part in event.content.parts:
|
||||
if hasattr(part, 'text') and part.text:
|
||||
text_content += part.text
|
||||
elif isinstance(event.content, str):
|
||||
text_content = event.content
|
||||
|
||||
if text_content:
|
||||
payload = {"type": "content", "text": text_content, "role": "model"}
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
yield f"data: {json.dumps({'type': 'done'})}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error during agent chat stream")
|
||||
err_payload = {"type": "error", "text": str(e)}
|
||||
yield f"data: {json.dumps(err_payload)}\n\n"
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
Reference in New Issue
Block a user