from typing import List, AsyncGenerator import json import logging from fastapi import APIRouter, HTTPException, Depends from fastapi.responses import StreamingResponse from google.adk.runners import Runner from google.adk.agents.run_config import RunConfig, StreamingMode from google.genai import types from app.config import settings from app.models import ( CreateSessionRequest, ChatTurnRequest, SessionResponse, SessionDetailResponse, HistoryEvent, AgentConfig ) from app.services import get_session_service from app.services.agent_service import get_agent_service, AgentService from app.agent_factory import create_agent # Initialize Router router = APIRouter( prefix="/sessions", tags=["Sessions"] ) logger = logging.getLogger(__name__) # --- Helper --- def get_streaming_mode(mode_str: str) -> StreamingMode: """Convert string streaming mode to StreamingMode enum.""" mode_map = { "none": StreamingMode.NONE, "sse": StreamingMode.SSE, "bidi": StreamingMode.BIDI, } return mode_map.get(mode_str.lower(), StreamingMode.SSE) # ========================== # 📋 Session CRUD # ========================== @router.post("", response_model=SessionResponse) async def create_session( req: CreateSessionRequest, service=Depends(get_session_service) ): """Create a new chat session.""" app_name = req.app_name or settings.DEFAULT_APP_NAME session = await service.create_session(app_name=app_name, user_id=req.user_id) return SessionResponse( id=session.id, app_name=session.app_name, user_id=session.user_id, updated_at=getattr(session, 'last_update_time', getattr(session, 'updated_at', None)) ) @router.get("", response_model=List[SessionResponse]) async def list_sessions( user_id: str, app_name: str = None, service=Depends(get_session_service) ): """List all sessions for a user.""" app_name = app_name or settings.DEFAULT_APP_NAME response = await service.list_sessions(app_name=app_name, user_id=user_id) return [ SessionResponse( id=s.id, app_name=s.app_name, user_id=s.user_id, updated_at=getattr(s, 'last_update_time', getattr(s, 'updated_at', None)) ) for s in response.sessions ] @router.get("/{session_id}", response_model=SessionDetailResponse) async def get_session_history( session_id: str, user_id: str, app_name: str = None, service=Depends(get_session_service) ): """Get session details and chat history.""" app_name = app_name or settings.DEFAULT_APP_NAME session = await service.get_session( app_name=app_name, user_id=user_id, session_id=session_id ) if not session: raise HTTPException(status_code=404, detail="Session not found") # Convert history events history_events = [] events = getattr(session, 'events', []) # Filter events based on Rewind actions # ADK preserves all events in the log. We must retroactively remove rewound events for the view. valid_events = [] for e in events: # Check for rewind action actions = getattr(e, 'actions', None) rewind_target_id = getattr(actions, 'rewind_before_invocation_id', None) if actions else None if rewind_target_id: # Find the index where the truncated invocation began truncate_idx = -1 for i, ve in enumerate(valid_events): if getattr(ve, 'invocation_id', None) == rewind_target_id: truncate_idx = i break if truncate_idx != -1: logger.debug(f"Rewinding history to before {rewind_target_id}") valid_events = valid_events[:truncate_idx] else: valid_events.append(e) for e in valid_events: content_str = "" if hasattr(e, 'content') and e.content: if isinstance(e.content, str): content_str = e.content elif hasattr(e.content, 'parts'): parts = e.content.parts for part in parts: if hasattr(part, 'text') and part.text: content_str += part.text # Determine basic role (user vs model) # ADK usually sets author to the agent name or "user" author = getattr(e, 'author', 'unknown') role = "user" if author == "user" else "model" # Agent Name is the specific author if it's a model response agent_name = author if role == "model" else None # Extract timestamp timestamp = getattr(e, 'timestamp', getattr(e, 'created_at', None)) invocation_id = getattr(e, 'invocation_id', None) history_events.append(HistoryEvent( type="message", role=role, agent_name=agent_name, content=content_str, timestamp=timestamp, invocation_id=invocation_id )) return SessionDetailResponse( id=session.id, app_name=session.app_name, user_id=session.user_id, updated_at=getattr(session, 'last_update_time', getattr(session, 'updated_at', None)), created_at=getattr(session, 'created_at', getattr(session, 'create_time', None)), history=history_events, events_count=len(history_events) ) @router.delete("/{session_id}") async def delete_session( session_id: str, user_id: str, app_name: str = None, service=Depends(get_session_service) ): """Delete a session.""" app_name = app_name or settings.DEFAULT_APP_NAME await service.delete_session(app_name=app_name, user_id=user_id, session_id=session_id) return {"status": "deleted", "session_id": session_id} @router.post("/{session_id}/rewind") async def rewind_session_state( session_id: str, user_id: str, invocation_id: str, app_name: str = None, session_service=Depends(get_session_service), agent_service: AgentService = Depends(get_agent_service) ): """ Rewind session state to before a specific invocation. Undoes changes from the specified invocation and subsequent ones. """ app_name = app_name or settings.DEFAULT_APP_NAME # 1. Need a Runner instance to perform rewind. # The Runner requires an agent, but for rewind (session/log operation), # the specific agent configuration might not be critical if default is used, # assuming session persistence is agent-agnostic. # We'll use the default assistant to initialize the runner. agent_def = agent_service.get_agent("default-assistant") if not agent_def: # Fallback if default deleted agent_def = AgentConfig( name="rewinder", model=settings.LLM_DEFAULT_MODEL, instruction="", tools=[] ) agent = create_agent(agent_def) runner = Runner( agent=agent, app_name=app_name, session_service=session_service ) try: await runner.rewind_async( user_id=user_id, session_id=session_id, rewind_before_invocation_id=invocation_id ) return {"status": "success", "rewound_before": invocation_id} except Exception as e: logger.error(f"Rewind failed: {e}") # ADK might raise specific errors, generic catch for now raise HTTPException(status_code=500, detail=f"Rewind failed: {str(e)}") # ========================== # 💬 Session Chat # ========================== @router.post("/{session_id}/chat") async def chat_with_agent( session_id: str, req: ChatTurnRequest, user_id: str, app_name: str = None, session_service=Depends(get_session_service), agent_service: AgentService = Depends(get_agent_service) ): """ Chat with a specific Agent in a Session. Decoupled: Agent is loaded from AgentService, Session is loaded from SessionService. """ app_name = app_name or settings.DEFAULT_APP_NAME # 1. Load Session session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session_id ) if not session: raise HTTPException(status_code=404, detail="Session not found") # 2. Get Agent Definition agent_def = agent_service.get_agent(req.agent_id) if not agent_def: raise HTTPException(status_code=404, detail=f"Agent '{req.agent_id}' not found") # 3. Create Runtime Agent agent = create_agent(agent_def) # 4. Preparation for Run streaming_mode = get_streaming_mode(req.streaming_mode or "sse") run_config = RunConfig( streaming_mode=streaming_mode, max_llm_calls=500, # or configurable ) runner = Runner( agent=agent, app_name=app_name, session_service=session_service ) # 5. Stream Generator async def event_generator() -> AsyncGenerator[str, None]: new_msg = types.Content( role="user", parts=[types.Part(text=req.message)] ) try: async for event in runner.run_async( session_id=session.id, user_id=user_id, new_message=new_msg, run_config=run_config, ): if event.content: text_content = "" if hasattr(event.content, 'parts'): for part in event.content.parts: if hasattr(part, 'text') and part.text: text_content += part.text elif isinstance(event.content, str): text_content = event.content if text_content: payload = {"type": "content", "text": text_content, "role": "model"} yield f"data: {json.dumps(payload)}\n\n" yield f"data: {json.dumps({'type': 'done'})}\n\n" except Exception as e: logger.exception("Error during agent chat stream") err_payload = {"type": "error", "text": str(e)} yield f"data: {json.dumps(err_payload)}\n\n" return StreamingResponse(event_generator(), media_type="text/event-stream")