From 12ef0292b7e87a96ebb9d2d39c2d4bcd3a27255d Mon Sep 17 00:00:00 2001 From: lhr Date: Mon, 19 Jan 2026 22:04:54 +0800 Subject: [PATCH] =?UTF-8?q?=E9=A6=96=E6=AC=A1=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 10 + app/__init__.py | 0 app/agent_factory.py | 143 ++++++++++++++ app/config.py | 30 +++ app/main.py | 81 ++++++++ app/models.py | 128 +++++++++++++ app/routers/__init__.py | 0 app/routers/agents.py | 68 +++++++ app/routers/sessions.py | 317 ++++++++++++++++++++++++++++++++ app/routers/tools.py | 43 +++++ app/services/__init__.py | 5 + app/services/agent_service.py | 89 +++++++++ app/services/session_service.py | 40 ++++ app/services/tool_service.py | 146 +++++++++++++++ app/tools.py | 29 +++ requirements.txt | 18 ++ 16 files changed, 1147 insertions(+) create mode 100644 .env.example create mode 100644 app/__init__.py create mode 100644 app/agent_factory.py create mode 100644 app/config.py create mode 100644 app/main.py create mode 100644 app/models.py create mode 100644 app/routers/__init__.py create mode 100644 app/routers/agents.py create mode 100644 app/routers/sessions.py create mode 100644 app/routers/tools.py create mode 100644 app/services/__init__.py create mode 100644 app/services/agent_service.py create mode 100644 app/services/session_service.py create mode 100644 app/services/tool_service.py create mode 100644 app/tools.py create mode 100644 requirements.txt diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..01355ee --- /dev/null +++ b/.env.example @@ -0,0 +1,10 @@ +# Database Configuration +DATABASE_URL=postgresql+asyncpg://user:password@localhost:5432/dbname + +# LLM API Configuration +LLM_API_KEY=your_api_key_here +LLM_API_BASE=https://api.openai.com/v1 +LLM_DEFAULT_MODEL=openai/gpt-4o + +# App Configuration +DEFAULT_APP_NAME=adk_chat_app diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/agent_factory.py b/app/agent_factory.py new file mode 100644 index 0000000..b4c73e6 --- /dev/null +++ b/app/agent_factory.py @@ -0,0 +1,143 @@ +import logging +from typing import Optional, List, Callable, Dict, Any, Union + +from google.adk.agents.llm_agent import LlmAgent, Agent +from google.adk.models.lite_llm import LiteLlm +from google.adk.planners import BuiltInPlanner +from google.adk.code_executors import BuiltInCodeExecutor +from google.genai import types + +from app.config import settings +from app.models import AgentConfig, LLMConfig, GenerationConfig +from app.services.tool_service import get_tool_service + +from app.tools import TOOL_REGISTRY + + + +logger = logging.getLogger(__name__) + + +# ========================== +# 🏭 Agent Factory +# ========================== + +def _create_llm_model( + model_name: str, + llm_config: Optional[LLMConfig] = None +) -> Union[str, LiteLlm]: + """ + Create the LLM model instance or return model identifier string. + + Logic: + 1. If model name implies Gemini (official Google model), return the string directly. + ADK's LlmAgent handles 'gemini-...' strings using the native google-genai SDK. + 2. For other models (GPT, Claude, etc.), wrap in LiteLlm adapter. + """ + # Normalize comparison + name_lower = model_name.lower() + + # Check if likely native Gemini (no provider prefix) + # Exclude obvious proxy prefixes that require LiteLLM (e.g. openai/gemini...) + provider_prefixes = ["openai/", "azure/", "anthropic/", "bedrock/", "mistral/"] + has_provider_prefix = any(p in name_lower for p in provider_prefixes) + + is_gemini = "gemini" in name_lower and not has_provider_prefix + + # Check if custom configuration forces a specific path + # If API Base is provided and distinct from default Google, usually implies using LiteLLM/Proxy + # even for Gemini models (e.g. via OpenAI-compatible endpoint). + has_custom_base = llm_config and llm_config.api_base and "googleapis.com" not in llm_config.api_base + + if is_gemini and not has_custom_base: + logger.info(f"Using Native Gemini Model: {model_name}") + return model_name + + # Fallback / Non-Gemini -> Use LiteLLM + api_key = (llm_config.api_key if llm_config and llm_config.api_key else settings.LLM_API_KEY) + api_base = (llm_config.api_base if llm_config and llm_config.api_base else settings.LLM_API_BASE) + + logger.info(f"Using LiteLLM for: {model_name} (Base: {api_base})") + return LiteLlm( + model=model_name, + api_base=api_base, + api_key=api_key + ) + + +def create_agent(config: AgentConfig) -> Agent: + """ + Create a fully configured ADK Agent based on the provided AgentConfig. + """ + logger.info(f"Creating Agent: {config.name} ({config.model})") + + # 1. Model Initialization + # Returns either a str (for native Gemini) or LiteLlm object + llm = _create_llm_model(config.model, config.llm_config) + + # 2. Tools Selection + selected_tools = [] + tool_service = get_tool_service() + + for tool_name in config.tools: + # A. Check Legacy/Hardcoded Registry + if tool_name in TOOL_REGISTRY: + selected_tools.append(TOOL_REGISTRY[tool_name]) + continue + + # B. Check Local Tools (tools/ folder) + local_tool = tool_service.load_local_tool(tool_name) + if local_tool: + selected_tools.append(local_tool) + continue + + # C. Check MCP Servers + mcp_tool = tool_service.get_mcp_toolset(tool_name) + if mcp_tool: + selected_tools.append(mcp_tool) + continue + + logger.warning(f"Tool '{tool_name}' not found (checked Registry, Local, MCP). Skipping.") + + # 3. Code Execution + code_executor = None + if config.enable_code_execution: + logger.info("Enabling BuiltInCodeExecutor") + code_executor = BuiltInCodeExecutor() + + # 4. Planner / Thinking + # Only applicable for models that support it (mostly Gemini) + planner = None + if config.thinking_config: + logger.info(f"Enabling BuiltInPlanner with budget {config.thinking_config.thinking_budget}") + t_config = types.ThinkingConfig( + include_thoughts=config.thinking_config.include_thoughts, + thinking_budget=config.thinking_config.thinking_budget + ) + planner = BuiltInPlanner(thinking_config=t_config) + + # 5. Generation Config + gen_config = None + if config.generation_config: + g_params = {} + if config.generation_config.temperature is not None: + g_params["temperature"] = config.generation_config.temperature + if config.generation_config.max_output_tokens is not None: + g_params["max_output_tokens"] = config.generation_config.max_output_tokens + if config.generation_config.top_p is not None: + g_params["top_p"] = config.generation_config.top_p + + if g_params: + gen_config = types.GenerateContentConfig(**g_params) + + # 6. Assemble LlmAgent + return LlmAgent( + name=config.name, + model=llm, + description=config.description or "", + instruction=config.instruction, + tools=selected_tools, + code_executor=code_executor, + planner=planner, + generate_content_config=gen_config + ) \ No newline at end of file diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000..4c8b363 --- /dev/null +++ b/app/config.py @@ -0,0 +1,30 @@ +from pydantic_settings import BaseSettings +from functools import lru_cache + + +class Settings(BaseSettings): + """Application settings loaded from environment variables.""" + + # Database Configuration + DATABASE_URL: str = "postgresql+asyncpg://myuser:mypassword@127.0.0.1:5432/mydatabase" + + # LLM API Configuration + LLM_API_KEY: str + LLM_API_BASE: str = "https://api.openai.com/v1" + LLM_DEFAULT_MODEL: str = "openai/gpt-4o" + + # App Configuration + DEFAULT_APP_NAME: str = "adk_chat_app" + + class Config: + env_file = ".env" + extra = "ignore" + + +@lru_cache +def get_settings() -> Settings: + """Get cached settings instance.""" + return Settings() + + +settings = get_settings() \ No newline at end of file diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..7e84c2b --- /dev/null +++ b/app/main.py @@ -0,0 +1,81 @@ +import json +import logging +from typing import AsyncGenerator +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException, Depends +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse + +from google.adk.runners import Runner +from google.adk.agents.run_config import RunConfig, StreamingMode +from google.genai import types + +from app.config import settings + +from app.services import ( + get_session_service, + init_session_service, + close_session_service +) +from app.agent_factory import create_agent + +# Routers +from app.routers.agents import router as agents_router +from app.routers.sessions import router as sessions_router +from app.routers.tools import router as tools_router + + +# --- Logging Configuration --- + +# --- Logging Configuration --- +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +# --- Application Lifespan --- +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan manager for startup/shutdown events.""" + logger.info("Starting up ADK Chat Backend...") + await init_session_service() + logger.info("Application started successfully") + yield + logger.info("Shutting down ADK Chat Backend...") + await close_session_service() + logger.info("Application shutdown complete") + + +app = FastAPI( + title="ADK Enterprise Chat Backend", + description="LLM Chat backend powered by Google ADK with PostgreSQL persistence", + version="2.0.0", + lifespan=lifespan +) + + +# --- CORS Configuration --- +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# --- Register Routers --- +app.include_router(agents_router) +app.include_router(sessions_router) +app.include_router(tools_router) + + +# --- Helper --- +def get_streaming_mode(mode_str: str) -> StreamingMode: + return StreamingMode.SSE if mode_str.lower() == "sse" else StreamingMode.NONE + + + diff --git a/app/models.py b/app/models.py new file mode 100644 index 0000000..e1de4d5 --- /dev/null +++ b/app/models.py @@ -0,0 +1,128 @@ +from pydantic import BaseModel, Field +from typing import List, Optional, Any + + +from typing import List, Optional, Any, Dict, Literal, Union + +# --- MCP Tool Configuration Models --- + +class StdioConfig(BaseModel): + """Configuration for Stdio-based MCP server.""" + command: str = Field(..., description="Executable command (e.g., 'npx', 'python')") + args: List[str] = Field(default_factory=list, description="Command arguments") + env: Optional[Dict[str, str]] = Field(None, description="Environment variables") + +class SSEConfig(BaseModel): + """Configuration for SSE-based MCP server.""" + url: str = Field(..., description="Server URL (e.g., 'http://localhost:8000/sse')") + +class MCPServerConfig(BaseModel): + """Configuration for an MCP Server connection.""" + name: str = Field(..., description="Unique name for the toolset") + type: Literal["stdio", "sse"] = Field("stdio", description="Connection type") + stdio_config: Optional[StdioConfig] = Field(None, description="Stdio configuration") + sse_config: Optional[SSEConfig] = Field(None, description="SSE configuration") + tool_filter: Optional[List[str]] = Field(None, description="Optional whitelist of tools to enable") + +# --- LLM Configuration Models --- + +class LLMConfig(BaseModel): + """Custom LLM API configuration for per-request override.""" + api_key: Optional[str] = Field(None, description="API key for the LLM service") + api_base: Optional[str] = Field(None, description="Base URL for the LLM API") + model: Optional[str] = Field(None, description="Model name (e.g., 'openai/gpt-4o')") + + +class GenerationConfig(BaseModel): + """LLM generation parameters for fine-tuning responses.""" + temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="Randomness (0.0-2.0)") + max_output_tokens: Optional[int] = Field(None, gt=0, description="Max response length") + top_p: Optional[float] = Field(None, ge=0.0, le=1.0, description="Nucleus sampling") + top_k: Optional[int] = Field(None, gt=0, description="Top-k sampling") + + +# --- Request Models --- + +class CreateSessionRequest(BaseModel): + """Request to create a new chat session.""" + user_id: str + app_name: Optional[str] = None + + + +class ThinkingConfigModel(BaseModel): + """Configuration for BuiltInPlanner thinking features.""" + include_thoughts: bool = Field(True, description="Whether to include internal thoughts in the response") + thinking_budget: int = Field(1024, gt=0, description="Token budget for thinking") + + +class AgentConfig(BaseModel): + """Comprehensive configuration for creating an ADK Agent.""" + name: str = Field(..., description="Unique name for the agent") + description: Optional[str] = Field(None, description="Description of agent capabilities") + model: str = Field(..., description="Model identifier (e.g., 'gemini-2.5-flash')") + instruction: str = Field(..., description="System instruction / persona for the agent") + tools: List[str] = Field(default_factory=list, description="List of tool names to enable") + enable_code_execution: bool = Field(False, description="Enable built-in code executor") + + thinking_config: Optional[ThinkingConfigModel] = Field(None, description="Configuration for reasoning/planning") + + # Nested configurations + llm_config: Optional[LLMConfig] = Field(None, description="Custom LLM API connection details") + generation_config: Optional[GenerationConfig] = Field(None, description="Generation parameters") + + + +class AgentDefinition(AgentConfig): + """Persisted Agent definition with unique ID.""" + id: str = Field(..., description="Unique identifier for the agent") + created_at: float = Field(default_factory=lambda: __import__("time").time(), description="Creation timestamp") + updated_at: float = Field(default_factory=lambda: __import__("time").time(), description="Last update timestamp") + + +class ChatTurnRequest(BaseModel): + """Request for a single chat turn with a specific agent.""" + agent_id: str = Field(..., description="ID of the agent to use for this turn") + message: str = Field(..., description="User message content") + # Optional overrides for this specific turn + streaming_mode: Optional[str] = Field("sse", description="Streaming mode: 'sse', 'none'") + + +class ChatRequest(BaseModel): + """ + LEGACY: Request for chat completion. + Used by the old /chat/stream endpoint. + """ + session_id: str + user_id: str + message: str + app_name: Optional[str] = None + + agent_config: Optional[AgentConfig] = Field(None, description="Ephemeral agent config") + max_llm_calls: Optional[int] = Field(500, gt=0) + streaming_mode: Optional[str] = "sse" + + +class SessionResponse(BaseModel): + """Response for session operations.""" + id: str + app_name: str + user_id: str + updated_at: Any = None + + +class HistoryEvent(BaseModel): + """A single event in chat history.""" + type: str + role: str + content: str + agent_name: Optional[str] = None + timestamp: Optional[Any] = None + invocation_id: Optional[str] = None + + +class SessionDetailResponse(SessionResponse): + """Session details including chat history.""" + history: List[HistoryEvent] = [] + events_count: int = 0 + created_at: Any = None \ No newline at end of file diff --git a/app/routers/__init__.py b/app/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/routers/agents.py b/app/routers/agents.py new file mode 100644 index 0000000..501e5fa --- /dev/null +++ b/app/routers/agents.py @@ -0,0 +1,68 @@ +from fastapi import APIRouter, HTTPException, Depends +from typing import List + +from app.models import AgentConfig, AgentDefinition +from app.services.agent_service import get_agent_service, AgentService + +# Initialize Router +router = APIRouter( + prefix="/agents", + tags=["Agents"] +) + +# ========================== +# 🤖 Agent Management API +# ========================== + +@router.post("", response_model=AgentDefinition) +async def create_agent_def( + config: AgentConfig, + service: AgentService = Depends(get_agent_service) +): + """Create a new Agent definition.""" + return service.create_agent(config) + + +@router.get("", response_model=List[AgentDefinition]) +async def list_agents( + service: AgentService = Depends(get_agent_service) +): + """List all available Agents.""" + return service.list_agents() + + +@router.get("/{agent_id}", response_model=AgentDefinition) +async def get_agent_def( + agent_id: str, + service: AgentService = Depends(get_agent_service) +): + """Get specific Agent definition.""" + agent = service.get_agent(agent_id) + if not agent: + raise HTTPException(status_code=404, detail="Agent not found") + return agent + + +@router.put("/{agent_id}", response_model=AgentDefinition) +async def update_agent_def( + agent_id: str, + config: AgentConfig, + service: AgentService = Depends(get_agent_service) +): + """Update an Agent definition.""" + agent = service.update_agent(agent_id, config) + if not agent: + raise HTTPException(status_code=404, detail="Agent not found") + return agent + + +@router.delete("/{agent_id}") +async def delete_agent_def( + agent_id: str, + service: AgentService = Depends(get_agent_service) +): + """Delete an Agent definition.""" + success = service.delete_agent(agent_id) + if not success: + raise HTTPException(status_code=404, detail="Agent not found") + return {"status": "deleted", "agent_id": agent_id} diff --git a/app/routers/sessions.py b/app/routers/sessions.py new file mode 100644 index 0000000..e8a3a2a --- /dev/null +++ b/app/routers/sessions.py @@ -0,0 +1,317 @@ +from typing import List, AsyncGenerator +import json +import logging + +from fastapi import APIRouter, HTTPException, Depends +from fastapi.responses import StreamingResponse + +from google.adk.runners import Runner +from google.adk.agents.run_config import RunConfig, StreamingMode +from google.genai import types + +from app.config import settings +from app.models import ( + CreateSessionRequest, + ChatTurnRequest, + SessionResponse, + SessionDetailResponse, + HistoryEvent, + AgentConfig +) +from app.services import get_session_service +from app.services.agent_service import get_agent_service, AgentService +from app.agent_factory import create_agent + +# Initialize Router +router = APIRouter( + prefix="/sessions", + tags=["Sessions"] +) + +logger = logging.getLogger(__name__) + + +# --- Helper --- +def get_streaming_mode(mode_str: str) -> StreamingMode: + """Convert string streaming mode to StreamingMode enum.""" + mode_map = { + "none": StreamingMode.NONE, + "sse": StreamingMode.SSE, + "bidi": StreamingMode.BIDI, + } + return mode_map.get(mode_str.lower(), StreamingMode.SSE) + + +# ========================== +# 📋 Session CRUD +# ========================== + +@router.post("", response_model=SessionResponse) +async def create_session( + req: CreateSessionRequest, + service=Depends(get_session_service) +): + """Create a new chat session.""" + app_name = req.app_name or settings.DEFAULT_APP_NAME + session = await service.create_session(app_name=app_name, user_id=req.user_id) + return SessionResponse( + id=session.id, + app_name=session.app_name, + user_id=session.user_id, + updated_at=getattr(session, 'last_update_time', getattr(session, 'updated_at', None)) + ) + + +@router.get("", response_model=List[SessionResponse]) +async def list_sessions( + user_id: str, + app_name: str = None, + service=Depends(get_session_service) +): + """List all sessions for a user.""" + app_name = app_name or settings.DEFAULT_APP_NAME + response = await service.list_sessions(app_name=app_name, user_id=user_id) + + return [ + SessionResponse( + id=s.id, + app_name=s.app_name, + user_id=s.user_id, + updated_at=getattr(s, 'last_update_time', getattr(s, 'updated_at', None)) + ) for s in response.sessions + ] + + +@router.get("/{session_id}", response_model=SessionDetailResponse) +async def get_session_history( + session_id: str, + user_id: str, + app_name: str = None, + service=Depends(get_session_service) +): + """Get session details and chat history.""" + app_name = app_name or settings.DEFAULT_APP_NAME + session = await service.get_session( + app_name=app_name, + user_id=user_id, + session_id=session_id + ) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + # Convert history events + history_events = [] + events = getattr(session, 'events', []) + + # Filter events based on Rewind actions + # ADK preserves all events in the log. We must retroactively remove rewound events for the view. + valid_events = [] + for e in events: + # Check for rewind action + actions = getattr(e, 'actions', None) + rewind_target_id = getattr(actions, 'rewind_before_invocation_id', None) if actions else None + + if rewind_target_id: + # Find the index where the truncated invocation began + truncate_idx = -1 + for i, ve in enumerate(valid_events): + if getattr(ve, 'invocation_id', None) == rewind_target_id: + truncate_idx = i + break + + if truncate_idx != -1: + logger.debug(f"Rewinding history to before {rewind_target_id}") + valid_events = valid_events[:truncate_idx] + else: + valid_events.append(e) + + for e in valid_events: + content_str = "" + if hasattr(e, 'content') and e.content: + if isinstance(e.content, str): + content_str = e.content + elif hasattr(e.content, 'parts'): + parts = e.content.parts + for part in parts: + if hasattr(part, 'text') and part.text: + content_str += part.text + + # Determine basic role (user vs model) + # ADK usually sets author to the agent name or "user" + author = getattr(e, 'author', 'unknown') + role = "user" if author == "user" else "model" + + # Agent Name is the specific author if it's a model response + agent_name = author if role == "model" else None + + # Extract timestamp + timestamp = getattr(e, 'timestamp', getattr(e, 'created_at', None)) + invocation_id = getattr(e, 'invocation_id', None) + + history_events.append(HistoryEvent( + type="message", + role=role, + agent_name=agent_name, + content=content_str, + timestamp=timestamp, + invocation_id=invocation_id + )) + + return SessionDetailResponse( + id=session.id, + app_name=session.app_name, + user_id=session.user_id, + updated_at=getattr(session, 'last_update_time', getattr(session, 'updated_at', None)), + created_at=getattr(session, 'created_at', getattr(session, 'create_time', None)), + history=history_events, + events_count=len(history_events) + ) + + +@router.delete("/{session_id}") +async def delete_session( + session_id: str, + user_id: str, + app_name: str = None, + service=Depends(get_session_service) +): + """Delete a session.""" + app_name = app_name or settings.DEFAULT_APP_NAME + await service.delete_session(app_name=app_name, user_id=user_id, session_id=session_id) + return {"status": "deleted", "session_id": session_id} + + +@router.post("/{session_id}/rewind") +async def rewind_session_state( + session_id: str, + user_id: str, + invocation_id: str, + app_name: str = None, + session_service=Depends(get_session_service), + agent_service: AgentService = Depends(get_agent_service) +): + """ + Rewind session state to before a specific invocation. + Undoes changes from the specified invocation and subsequent ones. + """ + app_name = app_name or settings.DEFAULT_APP_NAME + + # 1. Need a Runner instance to perform rewind. + # The Runner requires an agent, but for rewind (session/log operation), + # the specific agent configuration might not be critical if default is used, + # assuming session persistence is agent-agnostic. + # We'll use the default assistant to initialize the runner. + agent_def = agent_service.get_agent("default-assistant") + if not agent_def: + # Fallback if default deleted + agent_def = AgentConfig( + name="rewinder", + model=settings.LLM_DEFAULT_MODEL, + instruction="", + tools=[] + ) + + agent = create_agent(agent_def) + + runner = Runner( + agent=agent, + app_name=app_name, + session_service=session_service + ) + + try: + await runner.rewind_async( + user_id=user_id, + session_id=session_id, + rewind_before_invocation_id=invocation_id + ) + return {"status": "success", "rewound_before": invocation_id} + except Exception as e: + logger.error(f"Rewind failed: {e}") + # ADK might raise specific errors, generic catch for now + raise HTTPException(status_code=500, detail=f"Rewind failed: {str(e)}") + + +# ========================== +# 💬 Session Chat +# ========================== + +@router.post("/{session_id}/chat") +async def chat_with_agent( + session_id: str, + req: ChatTurnRequest, + user_id: str, + app_name: str = None, + session_service=Depends(get_session_service), + agent_service: AgentService = Depends(get_agent_service) +): + """ + Chat with a specific Agent in a Session. + Decoupled: Agent is loaded from AgentService, Session is loaded from SessionService. + """ + app_name = app_name or settings.DEFAULT_APP_NAME + + # 1. Load Session + session = await session_service.get_session( + app_name=app_name, + user_id=user_id, + session_id=session_id + ) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + # 2. Get Agent Definition + agent_def = agent_service.get_agent(req.agent_id) + if not agent_def: + raise HTTPException(status_code=404, detail=f"Agent '{req.agent_id}' not found") + + # 3. Create Runtime Agent + agent = create_agent(agent_def) + + # 4. Preparation for Run + streaming_mode = get_streaming_mode(req.streaming_mode or "sse") + run_config = RunConfig( + streaming_mode=streaming_mode, + max_llm_calls=500, # or configurable + ) + + runner = Runner( + agent=agent, + app_name=app_name, + session_service=session_service + ) + + # 5. Stream Generator + async def event_generator() -> AsyncGenerator[str, None]: + new_msg = types.Content( + role="user", + parts=[types.Part(text=req.message)] + ) + try: + async for event in runner.run_async( + session_id=session.id, + user_id=user_id, + new_message=new_msg, + run_config=run_config, + ): + if event.content: + text_content = "" + if hasattr(event.content, 'parts'): + for part in event.content.parts: + if hasattr(part, 'text') and part.text: + text_content += part.text + elif isinstance(event.content, str): + text_content = event.content + + if text_content: + payload = {"type": "content", "text": text_content, "role": "model"} + yield f"data: {json.dumps(payload)}\n\n" + + yield f"data: {json.dumps({'type': 'done'})}\n\n" + + except Exception as e: + logger.exception("Error during agent chat stream") + err_payload = {"type": "error", "text": str(e)} + yield f"data: {json.dumps(err_payload)}\n\n" + + return StreamingResponse(event_generator(), media_type="text/event-stream") diff --git a/app/routers/tools.py b/app/routers/tools.py new file mode 100644 index 0000000..900ea0d --- /dev/null +++ b/app/routers/tools.py @@ -0,0 +1,43 @@ + +from fastapi import APIRouter, Depends, HTTPException +from typing import List, Dict + +from app.services.tool_service import ToolService, get_tool_service +from app.models import MCPServerConfig + +router = APIRouter(prefix="/tools", tags=["tools"]) + +@router.get("/local") +async def list_local_tools( + service: ToolService = Depends(get_tool_service) +) -> List[str]: + """List available local tools (scan 'tools/' directory).""" + return service.get_local_tools() + +@router.get("/mcp", response_model=List[MCPServerConfig]) +async def list_mcp_servers( + service: ToolService = Depends(get_tool_service) +): + """List configured MCP servers.""" + return service.list_mcp_servers() + +@router.post("/mcp", response_model=MCPServerConfig) +async def add_mcp_server( + config: MCPServerConfig, + service: ToolService = Depends(get_tool_service) +): + """Register a new MCP server configuration.""" + # Check if exists? Overwrite behavior is default in dict + service.add_mcp_server(config) + return config + +@router.delete("/mcp/{name}") +async def remove_mcp_server( + name: str, + service: ToolService = Depends(get_tool_service) +): + """Remove a configured MCP server.""" + deleted = service.remove_mcp_server(name) + if not deleted: + raise HTTPException(status_code=404, detail=f"MCP Server '{name}' not found") + return {"status": "deleted", "name": name} diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000..8e8e20b --- /dev/null +++ b/app/services/__init__.py @@ -0,0 +1,5 @@ +from .session_service import ( + init_session_service, + close_session_service, + get_session_service +) diff --git a/app/services/agent_service.py b/app/services/agent_service.py new file mode 100644 index 0000000..5cba43f --- /dev/null +++ b/app/services/agent_service.py @@ -0,0 +1,89 @@ +import uuid +import time +from typing import Dict, List, Optional +import logging + +from app.config import settings +from app.models import AgentConfig, AgentDefinition + +logger = logging.getLogger(__name__) + +class AgentService: + """ + Service for managing Agent Definitions. + Currently uses in-memory storage. + Can be extended to support Database storage. + """ + def __init__(self): + # In-memory store: {agent_id: AgentDefinition} + self._store: Dict[str, AgentDefinition] = {} + + # Populate with a default agent for convenience + self._create_default_agent() + + def _create_default_agent(self): + default_id = "default-assistant" + if default_id not in self._store: + self._store[default_id] = AgentDefinition( + id=default_id, + name="Assistant", + description="A helpful default AI assistant", + model=settings.LLM_DEFAULT_MODEL, + instruction="You are a helpful assistant.", + tools=[], + created_at=time.time(), + updated_at=time.time() + ) + + def create_agent(self, config: AgentConfig) -> AgentDefinition: + """Create and store a new agent definition.""" + agent_id = str(uuid.uuid4()) + definition = AgentDefinition( + id=agent_id, + **config.model_dump(), + created_at=time.time(), + updated_at=time.time() + ) + self._store[agent_id] = definition + logger.info(f"Created new agent definition: {agent_id}") + return definition + + def get_agent(self, agent_id: str) -> Optional[AgentDefinition]: + """Retrieve an agent definition by ID.""" + return self._store.get(agent_id) + + def update_agent(self, agent_id: str, config: AgentConfig) -> Optional[AgentDefinition]: + """Update an existing agent definition.""" + if agent_id not in self._store: + return None + + # Preserve original creation time and ID + original = self._store[agent_id] + updated = AgentDefinition( + id=agent_id, + **config.model_dump(), + created_at=original.created_at, + updated_at=time.time() + ) + self._store[agent_id] = updated + logger.info(f"Updated agent definition: {agent_id}") + return updated + + def list_agents(self) -> List[AgentDefinition]: + """List all stored agents.""" + return list(self._store.values()) + + def delete_agent(self, agent_id: str) -> bool: + """Delete an agent definition.""" + if agent_id in self._store: + del self._store[agent_id] + logger.info(f"Deleted agent definition: {agent_id}") + return True + return False + + +# Singleton instance +_agent_service = AgentService() + +def get_agent_service() -> AgentService: + return _agent_service diff --git a/app/services/session_service.py b/app/services/session_service.py new file mode 100644 index 0000000..e7a73cd --- /dev/null +++ b/app/services/session_service.py @@ -0,0 +1,40 @@ +import logging +from google.adk.sessions import DatabaseSessionService +from app.config import settings + +logger = logging.getLogger(__name__) + +# Global service instance +_session_service: DatabaseSessionService | None = None + + +async def init_session_service() -> None: + """Initialize the database session service on application startup.""" + global _session_service + if _session_service is None: + logger.info(f"Initializing DatabaseSessionService with {settings.DATABASE_URL}...") + _session_service = DatabaseSessionService(db_url=settings.DATABASE_URL) + logger.info("DatabaseSessionService initialized successfully") + + +async def close_session_service() -> None: + """Cleanup session service resources on application shutdown.""" + global _session_service + if _session_service: + _session_service = None + logger.info("DatabaseSessionService closed") + + +def get_session_service() -> DatabaseSessionService: + """ + Get the session service instance for dependency injection. + + Raises: + RuntimeError: If service is not initialized. + """ + if _session_service is None: + raise RuntimeError("Session service not initialized. Application may not have started properly.") + return _session_service + + + diff --git a/app/services/tool_service.py b/app/services/tool_service.py new file mode 100644 index 0000000..ae32275 --- /dev/null +++ b/app/services/tool_service.py @@ -0,0 +1,146 @@ + +import os +import logging +import importlib.util +import inspect +from typing import List, Dict, Optional, Any, Callable +from pathlib import Path + +# ADK imports +from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset +from google.adk.tools.mcp_tool.mcp_session_manager import ( + StreamableHTTPServerParams, + StdioConnectionParams, + StdioServerParameters +) + +from app.models import MCPServerConfig + +logger = logging.getLogger(__name__) + +class ToolService: + """ + Service for managing available tools (Local and MCP). + """ + def __init__(self, tools_dir: str = "tools"): + self.tools_dir = Path(tools_dir) + self._mcp_registry: Dict[str, MCPServerConfig] = {} + + # Ensure tools directory exists + if not self.tools_dir.exists(): + self.tools_dir.mkdir(parents=True, exist_ok=True) + + # --- Local Tools Scanning --- + + def get_local_tools(self) -> List[str]: + """ + Scan the 'tools/' directory for valid tool implementations. + Returns a list of tool names (folder names). + """ + tools = [] + if not self.tools_dir.exists(): + return [] + + for item in self.tools_dir.iterdir(): + if item.is_dir(): + # Check if it has a valid entry point + if (item / "tool.py").exists() or (item / "main.py").exists(): + tools.append(item.name) + return tools + + def load_local_tool(self, tool_name: str) -> Optional[Callable]: + """ + Dynamically load a local tool by name. + Expects a function named 'tool' or the directory name in tool.py/main.py. + """ + tool_path = self.tools_dir / tool_name + + # Try finding the module + module_file = None + if (tool_path / "tool.py").exists(): + module_file = tool_path / "tool.py" + elif (tool_path / "main.py").exists(): + module_file = tool_path / "main.py" + + if not module_file: + logger.warning(f"Tool implementation not found for {tool_name}") + return None + + try: + spec = importlib.util.spec_from_file_location(f"tools.{tool_name}", module_file) + if spec and spec.loader: + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Look for a callable object named 'tool' or scan for functions + if hasattr(module, "tool") and callable(module.tool): + return module.tool + + # Heuristic: Find first function that looks like a tool? + # For now, require 'tool' export or function with same name as folder + if hasattr(module, tool_name) and callable(getattr(module, tool_name)): + return getattr(module, tool_name) + + logger.warning(f"No callable 'tool' or '{tool_name}' found in {module_file}") + except Exception as e: + logger.error(f"Failed to load tool {tool_name}: {e}") + + return None + + # --- MCP Registry --- + + def list_mcp_servers(self) -> List[MCPServerConfig]: + """Return all configured MCP servers.""" + return list(self._mcp_registry.values()) + + def add_mcp_server(self, config: MCPServerConfig): + """Register a new MCP server.""" + self._mcp_registry[config.name] = config + logger.info(f"Registered MCP Server: {config.name}") + + def remove_mcp_server(self, name: str) -> bool: + """Remove an MCP server by name.""" + if name in self._mcp_registry: + del self._mcp_registry[name] + logger.info(f"Removed MCP Server: {name}") + return True + return False + + def get_mcp_toolset(self, name: str) -> Optional[MCPToolset]: + """ + Create an ADK MCPToolset instance from the configuration. + """ + config = self._mcp_registry.get(name) + if not config: + return None + + try: + if config.type == "sse": + if not config.sse_config: + raise ValueError("Missing SSE config") + params = StreamableHTTPServerParams(url=config.sse_config.url) + return MCPToolset(connection_params=params) + + elif config.type == "stdio": + if not config.stdio_config: + raise ValueError("Missing Stdio config") + + server_params = StdioServerParameters( + command=config.stdio_config.command, + args=config.stdio_config.args, + env=config.stdio_config.env + ) + params = StdioConnectionParams(server_params=server_params) + return MCPToolset(connection_params=params, tool_filter=config.tool_filter) + + except Exception as e: + logger.error(f"Failed to create MCP Toolset '{name}': {e}") + return None + + return None + +# Singleton instance +_tool_service = ToolService() + +def get_tool_service() -> ToolService: + return _tool_service diff --git a/app/tools.py b/app/tools.py new file mode 100644 index 0000000..6de029a --- /dev/null +++ b/app/tools.py @@ -0,0 +1,29 @@ +import logging +from typing import Dict, Callable, Any + +logger = logging.getLogger(__name__) + +def dummy_tool(query: str) -> dict: + """ + Dummy Tool: A placeholder tool for demonstration. + + Args: + query: Any string input. + + Returns: + A mock structured response. + """ + logger.info(f"[Tool] dummy_tool called with: {query}") + return { + "status": "success", + "message": f"Processed query: {query}", + "data": "dummy_value" + } + +# ========================== +# 🛠️ Tool Registry +# ========================== +# Map tool names to their function implementation +TOOL_REGISTRY: Dict[str, Callable] = { + "dummy_tool": dummy_tool, +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ce638e7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,18 @@ +# Core Framework +fastapi>=0.115.0 +uvicorn[standard]>=0.32.0 +pydantic>=2.0.0 +pydantic-settings>=2.6.0 + +# Google ADK +google-adk>=1.2.0 + +# Database +asyncpg>=0.30.0 +sqlalchemy[asyncio]>=2.0.0 + +# HTTP Client +httpx>=0.28.0 + +# Environment +python-dotenv>=1.0.0