首次提交

This commit is contained in:
lhr
2026-01-19 22:04:54 +08:00
parent 1fa5a4947a
commit 12ef0292b7
16 changed files with 1147 additions and 0 deletions

10
.env.example Normal file
View File

@@ -0,0 +1,10 @@
# Database Configuration
DATABASE_URL=postgresql+asyncpg://user:password@localhost:5432/dbname
# LLM API Configuration
LLM_API_KEY=your_api_key_here
LLM_API_BASE=https://api.openai.com/v1
LLM_DEFAULT_MODEL=openai/gpt-4o
# App Configuration
DEFAULT_APP_NAME=adk_chat_app

0
app/__init__.py Normal file
View File

143
app/agent_factory.py Normal file
View File

@@ -0,0 +1,143 @@
import logging
from typing import Optional, List, Callable, Dict, Any, Union
from google.adk.agents.llm_agent import LlmAgent, Agent
from google.adk.models.lite_llm import LiteLlm
from google.adk.planners import BuiltInPlanner
from google.adk.code_executors import BuiltInCodeExecutor
from google.genai import types
from app.config import settings
from app.models import AgentConfig, LLMConfig, GenerationConfig
from app.services.tool_service import get_tool_service
from app.tools import TOOL_REGISTRY
logger = logging.getLogger(__name__)
# ==========================
# 🏭 Agent Factory
# ==========================
def _create_llm_model(
model_name: str,
llm_config: Optional[LLMConfig] = None
) -> Union[str, LiteLlm]:
"""
Create the LLM model instance or return model identifier string.
Logic:
1. If model name implies Gemini (official Google model), return the string directly.
ADK's LlmAgent handles 'gemini-...' strings using the native google-genai SDK.
2. For other models (GPT, Claude, etc.), wrap in LiteLlm adapter.
"""
# Normalize comparison
name_lower = model_name.lower()
# Check if likely native Gemini (no provider prefix)
# Exclude obvious proxy prefixes that require LiteLLM (e.g. openai/gemini...)
provider_prefixes = ["openai/", "azure/", "anthropic/", "bedrock/", "mistral/"]
has_provider_prefix = any(p in name_lower for p in provider_prefixes)
is_gemini = "gemini" in name_lower and not has_provider_prefix
# Check if custom configuration forces a specific path
# If API Base is provided and distinct from default Google, usually implies using LiteLLM/Proxy
# even for Gemini models (e.g. via OpenAI-compatible endpoint).
has_custom_base = llm_config and llm_config.api_base and "googleapis.com" not in llm_config.api_base
if is_gemini and not has_custom_base:
logger.info(f"Using Native Gemini Model: {model_name}")
return model_name
# Fallback / Non-Gemini -> Use LiteLLM
api_key = (llm_config.api_key if llm_config and llm_config.api_key else settings.LLM_API_KEY)
api_base = (llm_config.api_base if llm_config and llm_config.api_base else settings.LLM_API_BASE)
logger.info(f"Using LiteLLM for: {model_name} (Base: {api_base})")
return LiteLlm(
model=model_name,
api_base=api_base,
api_key=api_key
)
def create_agent(config: AgentConfig) -> Agent:
"""
Create a fully configured ADK Agent based on the provided AgentConfig.
"""
logger.info(f"Creating Agent: {config.name} ({config.model})")
# 1. Model Initialization
# Returns either a str (for native Gemini) or LiteLlm object
llm = _create_llm_model(config.model, config.llm_config)
# 2. Tools Selection
selected_tools = []
tool_service = get_tool_service()
for tool_name in config.tools:
# A. Check Legacy/Hardcoded Registry
if tool_name in TOOL_REGISTRY:
selected_tools.append(TOOL_REGISTRY[tool_name])
continue
# B. Check Local Tools (tools/ folder)
local_tool = tool_service.load_local_tool(tool_name)
if local_tool:
selected_tools.append(local_tool)
continue
# C. Check MCP Servers
mcp_tool = tool_service.get_mcp_toolset(tool_name)
if mcp_tool:
selected_tools.append(mcp_tool)
continue
logger.warning(f"Tool '{tool_name}' not found (checked Registry, Local, MCP). Skipping.")
# 3. Code Execution
code_executor = None
if config.enable_code_execution:
logger.info("Enabling BuiltInCodeExecutor")
code_executor = BuiltInCodeExecutor()
# 4. Planner / Thinking
# Only applicable for models that support it (mostly Gemini)
planner = None
if config.thinking_config:
logger.info(f"Enabling BuiltInPlanner with budget {config.thinking_config.thinking_budget}")
t_config = types.ThinkingConfig(
include_thoughts=config.thinking_config.include_thoughts,
thinking_budget=config.thinking_config.thinking_budget
)
planner = BuiltInPlanner(thinking_config=t_config)
# 5. Generation Config
gen_config = None
if config.generation_config:
g_params = {}
if config.generation_config.temperature is not None:
g_params["temperature"] = config.generation_config.temperature
if config.generation_config.max_output_tokens is not None:
g_params["max_output_tokens"] = config.generation_config.max_output_tokens
if config.generation_config.top_p is not None:
g_params["top_p"] = config.generation_config.top_p
if g_params:
gen_config = types.GenerateContentConfig(**g_params)
# 6. Assemble LlmAgent
return LlmAgent(
name=config.name,
model=llm,
description=config.description or "",
instruction=config.instruction,
tools=selected_tools,
code_executor=code_executor,
planner=planner,
generate_content_config=gen_config
)

30
app/config.py Normal file
View File

@@ -0,0 +1,30 @@
from pydantic_settings import BaseSettings
from functools import lru_cache
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
# Database Configuration
DATABASE_URL: str = "postgresql+asyncpg://myuser:mypassword@127.0.0.1:5432/mydatabase"
# LLM API Configuration
LLM_API_KEY: str
LLM_API_BASE: str = "https://api.openai.com/v1"
LLM_DEFAULT_MODEL: str = "openai/gpt-4o"
# App Configuration
DEFAULT_APP_NAME: str = "adk_chat_app"
class Config:
env_file = ".env"
extra = "ignore"
@lru_cache
def get_settings() -> Settings:
"""Get cached settings instance."""
return Settings()
settings = get_settings()

81
app/main.py Normal file
View File

@@ -0,0 +1,81 @@
import json
import logging
from typing import AsyncGenerator
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from google.adk.runners import Runner
from google.adk.agents.run_config import RunConfig, StreamingMode
from google.genai import types
from app.config import settings
from app.services import (
get_session_service,
init_session_service,
close_session_service
)
from app.agent_factory import create_agent
# Routers
from app.routers.agents import router as agents_router
from app.routers.sessions import router as sessions_router
from app.routers.tools import router as tools_router
# --- Logging Configuration ---
# --- Logging Configuration ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# --- Application Lifespan ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager for startup/shutdown events."""
logger.info("Starting up ADK Chat Backend...")
await init_session_service()
logger.info("Application started successfully")
yield
logger.info("Shutting down ADK Chat Backend...")
await close_session_service()
logger.info("Application shutdown complete")
app = FastAPI(
title="ADK Enterprise Chat Backend",
description="LLM Chat backend powered by Google ADK with PostgreSQL persistence",
version="2.0.0",
lifespan=lifespan
)
# --- CORS Configuration ---
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Register Routers ---
app.include_router(agents_router)
app.include_router(sessions_router)
app.include_router(tools_router)
# --- Helper ---
def get_streaming_mode(mode_str: str) -> StreamingMode:
return StreamingMode.SSE if mode_str.lower() == "sse" else StreamingMode.NONE

128
app/models.py Normal file
View File

@@ -0,0 +1,128 @@
from pydantic import BaseModel, Field
from typing import List, Optional, Any
from typing import List, Optional, Any, Dict, Literal, Union
# --- MCP Tool Configuration Models ---
class StdioConfig(BaseModel):
"""Configuration for Stdio-based MCP server."""
command: str = Field(..., description="Executable command (e.g., 'npx', 'python')")
args: List[str] = Field(default_factory=list, description="Command arguments")
env: Optional[Dict[str, str]] = Field(None, description="Environment variables")
class SSEConfig(BaseModel):
"""Configuration for SSE-based MCP server."""
url: str = Field(..., description="Server URL (e.g., 'http://localhost:8000/sse')")
class MCPServerConfig(BaseModel):
"""Configuration for an MCP Server connection."""
name: str = Field(..., description="Unique name for the toolset")
type: Literal["stdio", "sse"] = Field("stdio", description="Connection type")
stdio_config: Optional[StdioConfig] = Field(None, description="Stdio configuration")
sse_config: Optional[SSEConfig] = Field(None, description="SSE configuration")
tool_filter: Optional[List[str]] = Field(None, description="Optional whitelist of tools to enable")
# --- LLM Configuration Models ---
class LLMConfig(BaseModel):
"""Custom LLM API configuration for per-request override."""
api_key: Optional[str] = Field(None, description="API key for the LLM service")
api_base: Optional[str] = Field(None, description="Base URL for the LLM API")
model: Optional[str] = Field(None, description="Model name (e.g., 'openai/gpt-4o')")
class GenerationConfig(BaseModel):
"""LLM generation parameters for fine-tuning responses."""
temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="Randomness (0.0-2.0)")
max_output_tokens: Optional[int] = Field(None, gt=0, description="Max response length")
top_p: Optional[float] = Field(None, ge=0.0, le=1.0, description="Nucleus sampling")
top_k: Optional[int] = Field(None, gt=0, description="Top-k sampling")
# --- Request Models ---
class CreateSessionRequest(BaseModel):
"""Request to create a new chat session."""
user_id: str
app_name: Optional[str] = None
class ThinkingConfigModel(BaseModel):
"""Configuration for BuiltInPlanner thinking features."""
include_thoughts: bool = Field(True, description="Whether to include internal thoughts in the response")
thinking_budget: int = Field(1024, gt=0, description="Token budget for thinking")
class AgentConfig(BaseModel):
"""Comprehensive configuration for creating an ADK Agent."""
name: str = Field(..., description="Unique name for the agent")
description: Optional[str] = Field(None, description="Description of agent capabilities")
model: str = Field(..., description="Model identifier (e.g., 'gemini-2.5-flash')")
instruction: str = Field(..., description="System instruction / persona for the agent")
tools: List[str] = Field(default_factory=list, description="List of tool names to enable")
enable_code_execution: bool = Field(False, description="Enable built-in code executor")
thinking_config: Optional[ThinkingConfigModel] = Field(None, description="Configuration for reasoning/planning")
# Nested configurations
llm_config: Optional[LLMConfig] = Field(None, description="Custom LLM API connection details")
generation_config: Optional[GenerationConfig] = Field(None, description="Generation parameters")
class AgentDefinition(AgentConfig):
"""Persisted Agent definition with unique ID."""
id: str = Field(..., description="Unique identifier for the agent")
created_at: float = Field(default_factory=lambda: __import__("time").time(), description="Creation timestamp")
updated_at: float = Field(default_factory=lambda: __import__("time").time(), description="Last update timestamp")
class ChatTurnRequest(BaseModel):
"""Request for a single chat turn with a specific agent."""
agent_id: str = Field(..., description="ID of the agent to use for this turn")
message: str = Field(..., description="User message content")
# Optional overrides for this specific turn
streaming_mode: Optional[str] = Field("sse", description="Streaming mode: 'sse', 'none'")
class ChatRequest(BaseModel):
"""
LEGACY: Request for chat completion.
Used by the old /chat/stream endpoint.
"""
session_id: str
user_id: str
message: str
app_name: Optional[str] = None
agent_config: Optional[AgentConfig] = Field(None, description="Ephemeral agent config")
max_llm_calls: Optional[int] = Field(500, gt=0)
streaming_mode: Optional[str] = "sse"
class SessionResponse(BaseModel):
"""Response for session operations."""
id: str
app_name: str
user_id: str
updated_at: Any = None
class HistoryEvent(BaseModel):
"""A single event in chat history."""
type: str
role: str
content: str
agent_name: Optional[str] = None
timestamp: Optional[Any] = None
invocation_id: Optional[str] = None
class SessionDetailResponse(SessionResponse):
"""Session details including chat history."""
history: List[HistoryEvent] = []
events_count: int = 0
created_at: Any = None

0
app/routers/__init__.py Normal file
View File

68
app/routers/agents.py Normal file
View File

@@ -0,0 +1,68 @@
from fastapi import APIRouter, HTTPException, Depends
from typing import List
from app.models import AgentConfig, AgentDefinition
from app.services.agent_service import get_agent_service, AgentService
# Initialize Router
router = APIRouter(
prefix="/agents",
tags=["Agents"]
)
# ==========================
# 🤖 Agent Management API
# ==========================
@router.post("", response_model=AgentDefinition)
async def create_agent_def(
config: AgentConfig,
service: AgentService = Depends(get_agent_service)
):
"""Create a new Agent definition."""
return service.create_agent(config)
@router.get("", response_model=List[AgentDefinition])
async def list_agents(
service: AgentService = Depends(get_agent_service)
):
"""List all available Agents."""
return service.list_agents()
@router.get("/{agent_id}", response_model=AgentDefinition)
async def get_agent_def(
agent_id: str,
service: AgentService = Depends(get_agent_service)
):
"""Get specific Agent definition."""
agent = service.get_agent(agent_id)
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")
return agent
@router.put("/{agent_id}", response_model=AgentDefinition)
async def update_agent_def(
agent_id: str,
config: AgentConfig,
service: AgentService = Depends(get_agent_service)
):
"""Update an Agent definition."""
agent = service.update_agent(agent_id, config)
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")
return agent
@router.delete("/{agent_id}")
async def delete_agent_def(
agent_id: str,
service: AgentService = Depends(get_agent_service)
):
"""Delete an Agent definition."""
success = service.delete_agent(agent_id)
if not success:
raise HTTPException(status_code=404, detail="Agent not found")
return {"status": "deleted", "agent_id": agent_id}

317
app/routers/sessions.py Normal file
View File

@@ -0,0 +1,317 @@
from typing import List, AsyncGenerator
import json
import logging
from fastapi import APIRouter, HTTPException, Depends
from fastapi.responses import StreamingResponse
from google.adk.runners import Runner
from google.adk.agents.run_config import RunConfig, StreamingMode
from google.genai import types
from app.config import settings
from app.models import (
CreateSessionRequest,
ChatTurnRequest,
SessionResponse,
SessionDetailResponse,
HistoryEvent,
AgentConfig
)
from app.services import get_session_service
from app.services.agent_service import get_agent_service, AgentService
from app.agent_factory import create_agent
# Initialize Router
router = APIRouter(
prefix="/sessions",
tags=["Sessions"]
)
logger = logging.getLogger(__name__)
# --- Helper ---
def get_streaming_mode(mode_str: str) -> StreamingMode:
"""Convert string streaming mode to StreamingMode enum."""
mode_map = {
"none": StreamingMode.NONE,
"sse": StreamingMode.SSE,
"bidi": StreamingMode.BIDI,
}
return mode_map.get(mode_str.lower(), StreamingMode.SSE)
# ==========================
# 📋 Session CRUD
# ==========================
@router.post("", response_model=SessionResponse)
async def create_session(
req: CreateSessionRequest,
service=Depends(get_session_service)
):
"""Create a new chat session."""
app_name = req.app_name or settings.DEFAULT_APP_NAME
session = await service.create_session(app_name=app_name, user_id=req.user_id)
return SessionResponse(
id=session.id,
app_name=session.app_name,
user_id=session.user_id,
updated_at=getattr(session, 'last_update_time', getattr(session, 'updated_at', None))
)
@router.get("", response_model=List[SessionResponse])
async def list_sessions(
user_id: str,
app_name: str = None,
service=Depends(get_session_service)
):
"""List all sessions for a user."""
app_name = app_name or settings.DEFAULT_APP_NAME
response = await service.list_sessions(app_name=app_name, user_id=user_id)
return [
SessionResponse(
id=s.id,
app_name=s.app_name,
user_id=s.user_id,
updated_at=getattr(s, 'last_update_time', getattr(s, 'updated_at', None))
) for s in response.sessions
]
@router.get("/{session_id}", response_model=SessionDetailResponse)
async def get_session_history(
session_id: str,
user_id: str,
app_name: str = None,
service=Depends(get_session_service)
):
"""Get session details and chat history."""
app_name = app_name or settings.DEFAULT_APP_NAME
session = await service.get_session(
app_name=app_name,
user_id=user_id,
session_id=session_id
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# Convert history events
history_events = []
events = getattr(session, 'events', [])
# Filter events based on Rewind actions
# ADK preserves all events in the log. We must retroactively remove rewound events for the view.
valid_events = []
for e in events:
# Check for rewind action
actions = getattr(e, 'actions', None)
rewind_target_id = getattr(actions, 'rewind_before_invocation_id', None) if actions else None
if rewind_target_id:
# Find the index where the truncated invocation began
truncate_idx = -1
for i, ve in enumerate(valid_events):
if getattr(ve, 'invocation_id', None) == rewind_target_id:
truncate_idx = i
break
if truncate_idx != -1:
logger.debug(f"Rewinding history to before {rewind_target_id}")
valid_events = valid_events[:truncate_idx]
else:
valid_events.append(e)
for e in valid_events:
content_str = ""
if hasattr(e, 'content') and e.content:
if isinstance(e.content, str):
content_str = e.content
elif hasattr(e.content, 'parts'):
parts = e.content.parts
for part in parts:
if hasattr(part, 'text') and part.text:
content_str += part.text
# Determine basic role (user vs model)
# ADK usually sets author to the agent name or "user"
author = getattr(e, 'author', 'unknown')
role = "user" if author == "user" else "model"
# Agent Name is the specific author if it's a model response
agent_name = author if role == "model" else None
# Extract timestamp
timestamp = getattr(e, 'timestamp', getattr(e, 'created_at', None))
invocation_id = getattr(e, 'invocation_id', None)
history_events.append(HistoryEvent(
type="message",
role=role,
agent_name=agent_name,
content=content_str,
timestamp=timestamp,
invocation_id=invocation_id
))
return SessionDetailResponse(
id=session.id,
app_name=session.app_name,
user_id=session.user_id,
updated_at=getattr(session, 'last_update_time', getattr(session, 'updated_at', None)),
created_at=getattr(session, 'created_at', getattr(session, 'create_time', None)),
history=history_events,
events_count=len(history_events)
)
@router.delete("/{session_id}")
async def delete_session(
session_id: str,
user_id: str,
app_name: str = None,
service=Depends(get_session_service)
):
"""Delete a session."""
app_name = app_name or settings.DEFAULT_APP_NAME
await service.delete_session(app_name=app_name, user_id=user_id, session_id=session_id)
return {"status": "deleted", "session_id": session_id}
@router.post("/{session_id}/rewind")
async def rewind_session_state(
session_id: str,
user_id: str,
invocation_id: str,
app_name: str = None,
session_service=Depends(get_session_service),
agent_service: AgentService = Depends(get_agent_service)
):
"""
Rewind session state to before a specific invocation.
Undoes changes from the specified invocation and subsequent ones.
"""
app_name = app_name or settings.DEFAULT_APP_NAME
# 1. Need a Runner instance to perform rewind.
# The Runner requires an agent, but for rewind (session/log operation),
# the specific agent configuration might not be critical if default is used,
# assuming session persistence is agent-agnostic.
# We'll use the default assistant to initialize the runner.
agent_def = agent_service.get_agent("default-assistant")
if not agent_def:
# Fallback if default deleted
agent_def = AgentConfig(
name="rewinder",
model=settings.LLM_DEFAULT_MODEL,
instruction="",
tools=[]
)
agent = create_agent(agent_def)
runner = Runner(
agent=agent,
app_name=app_name,
session_service=session_service
)
try:
await runner.rewind_async(
user_id=user_id,
session_id=session_id,
rewind_before_invocation_id=invocation_id
)
return {"status": "success", "rewound_before": invocation_id}
except Exception as e:
logger.error(f"Rewind failed: {e}")
# ADK might raise specific errors, generic catch for now
raise HTTPException(status_code=500, detail=f"Rewind failed: {str(e)}")
# ==========================
# 💬 Session Chat
# ==========================
@router.post("/{session_id}/chat")
async def chat_with_agent(
session_id: str,
req: ChatTurnRequest,
user_id: str,
app_name: str = None,
session_service=Depends(get_session_service),
agent_service: AgentService = Depends(get_agent_service)
):
"""
Chat with a specific Agent in a Session.
Decoupled: Agent is loaded from AgentService, Session is loaded from SessionService.
"""
app_name = app_name or settings.DEFAULT_APP_NAME
# 1. Load Session
session = await session_service.get_session(
app_name=app_name,
user_id=user_id,
session_id=session_id
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# 2. Get Agent Definition
agent_def = agent_service.get_agent(req.agent_id)
if not agent_def:
raise HTTPException(status_code=404, detail=f"Agent '{req.agent_id}' not found")
# 3. Create Runtime Agent
agent = create_agent(agent_def)
# 4. Preparation for Run
streaming_mode = get_streaming_mode(req.streaming_mode or "sse")
run_config = RunConfig(
streaming_mode=streaming_mode,
max_llm_calls=500, # or configurable
)
runner = Runner(
agent=agent,
app_name=app_name,
session_service=session_service
)
# 5. Stream Generator
async def event_generator() -> AsyncGenerator[str, None]:
new_msg = types.Content(
role="user",
parts=[types.Part(text=req.message)]
)
try:
async for event in runner.run_async(
session_id=session.id,
user_id=user_id,
new_message=new_msg,
run_config=run_config,
):
if event.content:
text_content = ""
if hasattr(event.content, 'parts'):
for part in event.content.parts:
if hasattr(part, 'text') and part.text:
text_content += part.text
elif isinstance(event.content, str):
text_content = event.content
if text_content:
payload = {"type": "content", "text": text_content, "role": "model"}
yield f"data: {json.dumps(payload)}\n\n"
yield f"data: {json.dumps({'type': 'done'})}\n\n"
except Exception as e:
logger.exception("Error during agent chat stream")
err_payload = {"type": "error", "text": str(e)}
yield f"data: {json.dumps(err_payload)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")

43
app/routers/tools.py Normal file
View File

@@ -0,0 +1,43 @@
from fastapi import APIRouter, Depends, HTTPException
from typing import List, Dict
from app.services.tool_service import ToolService, get_tool_service
from app.models import MCPServerConfig
router = APIRouter(prefix="/tools", tags=["tools"])
@router.get("/local")
async def list_local_tools(
service: ToolService = Depends(get_tool_service)
) -> List[str]:
"""List available local tools (scan 'tools/' directory)."""
return service.get_local_tools()
@router.get("/mcp", response_model=List[MCPServerConfig])
async def list_mcp_servers(
service: ToolService = Depends(get_tool_service)
):
"""List configured MCP servers."""
return service.list_mcp_servers()
@router.post("/mcp", response_model=MCPServerConfig)
async def add_mcp_server(
config: MCPServerConfig,
service: ToolService = Depends(get_tool_service)
):
"""Register a new MCP server configuration."""
# Check if exists? Overwrite behavior is default in dict
service.add_mcp_server(config)
return config
@router.delete("/mcp/{name}")
async def remove_mcp_server(
name: str,
service: ToolService = Depends(get_tool_service)
):
"""Remove a configured MCP server."""
deleted = service.remove_mcp_server(name)
if not deleted:
raise HTTPException(status_code=404, detail=f"MCP Server '{name}' not found")
return {"status": "deleted", "name": name}

5
app/services/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from .session_service import (
init_session_service,
close_session_service,
get_session_service
)

View File

@@ -0,0 +1,89 @@
import uuid
import time
from typing import Dict, List, Optional
import logging
from app.config import settings
from app.models import AgentConfig, AgentDefinition
logger = logging.getLogger(__name__)
class AgentService:
"""
Service for managing Agent Definitions.
Currently uses in-memory storage.
Can be extended to support Database storage.
"""
def __init__(self):
# In-memory store: {agent_id: AgentDefinition}
self._store: Dict[str, AgentDefinition] = {}
# Populate with a default agent for convenience
self._create_default_agent()
def _create_default_agent(self):
default_id = "default-assistant"
if default_id not in self._store:
self._store[default_id] = AgentDefinition(
id=default_id,
name="Assistant",
description="A helpful default AI assistant",
model=settings.LLM_DEFAULT_MODEL,
instruction="You are a helpful assistant.",
tools=[],
created_at=time.time(),
updated_at=time.time()
)
def create_agent(self, config: AgentConfig) -> AgentDefinition:
"""Create and store a new agent definition."""
agent_id = str(uuid.uuid4())
definition = AgentDefinition(
id=agent_id,
**config.model_dump(),
created_at=time.time(),
updated_at=time.time()
)
self._store[agent_id] = definition
logger.info(f"Created new agent definition: {agent_id}")
return definition
def get_agent(self, agent_id: str) -> Optional[AgentDefinition]:
"""Retrieve an agent definition by ID."""
return self._store.get(agent_id)
def update_agent(self, agent_id: str, config: AgentConfig) -> Optional[AgentDefinition]:
"""Update an existing agent definition."""
if agent_id not in self._store:
return None
# Preserve original creation time and ID
original = self._store[agent_id]
updated = AgentDefinition(
id=agent_id,
**config.model_dump(),
created_at=original.created_at,
updated_at=time.time()
)
self._store[agent_id] = updated
logger.info(f"Updated agent definition: {agent_id}")
return updated
def list_agents(self) -> List[AgentDefinition]:
"""List all stored agents."""
return list(self._store.values())
def delete_agent(self, agent_id: str) -> bool:
"""Delete an agent definition."""
if agent_id in self._store:
del self._store[agent_id]
logger.info(f"Deleted agent definition: {agent_id}")
return True
return False
# Singleton instance
_agent_service = AgentService()
def get_agent_service() -> AgentService:
return _agent_service

View File

@@ -0,0 +1,40 @@
import logging
from google.adk.sessions import DatabaseSessionService
from app.config import settings
logger = logging.getLogger(__name__)
# Global service instance
_session_service: DatabaseSessionService | None = None
async def init_session_service() -> None:
"""Initialize the database session service on application startup."""
global _session_service
if _session_service is None:
logger.info(f"Initializing DatabaseSessionService with {settings.DATABASE_URL}...")
_session_service = DatabaseSessionService(db_url=settings.DATABASE_URL)
logger.info("DatabaseSessionService initialized successfully")
async def close_session_service() -> None:
"""Cleanup session service resources on application shutdown."""
global _session_service
if _session_service:
_session_service = None
logger.info("DatabaseSessionService closed")
def get_session_service() -> DatabaseSessionService:
"""
Get the session service instance for dependency injection.
Raises:
RuntimeError: If service is not initialized.
"""
if _session_service is None:
raise RuntimeError("Session service not initialized. Application may not have started properly.")
return _session_service

View File

@@ -0,0 +1,146 @@
import os
import logging
import importlib.util
import inspect
from typing import List, Dict, Optional, Any, Callable
from pathlib import Path
# ADK imports
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
from google.adk.tools.mcp_tool.mcp_session_manager import (
StreamableHTTPServerParams,
StdioConnectionParams,
StdioServerParameters
)
from app.models import MCPServerConfig
logger = logging.getLogger(__name__)
class ToolService:
"""
Service for managing available tools (Local and MCP).
"""
def __init__(self, tools_dir: str = "tools"):
self.tools_dir = Path(tools_dir)
self._mcp_registry: Dict[str, MCPServerConfig] = {}
# Ensure tools directory exists
if not self.tools_dir.exists():
self.tools_dir.mkdir(parents=True, exist_ok=True)
# --- Local Tools Scanning ---
def get_local_tools(self) -> List[str]:
"""
Scan the 'tools/' directory for valid tool implementations.
Returns a list of tool names (folder names).
"""
tools = []
if not self.tools_dir.exists():
return []
for item in self.tools_dir.iterdir():
if item.is_dir():
# Check if it has a valid entry point
if (item / "tool.py").exists() or (item / "main.py").exists():
tools.append(item.name)
return tools
def load_local_tool(self, tool_name: str) -> Optional[Callable]:
"""
Dynamically load a local tool by name.
Expects a function named 'tool' or the directory name in tool.py/main.py.
"""
tool_path = self.tools_dir / tool_name
# Try finding the module
module_file = None
if (tool_path / "tool.py").exists():
module_file = tool_path / "tool.py"
elif (tool_path / "main.py").exists():
module_file = tool_path / "main.py"
if not module_file:
logger.warning(f"Tool implementation not found for {tool_name}")
return None
try:
spec = importlib.util.spec_from_file_location(f"tools.{tool_name}", module_file)
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Look for a callable object named 'tool' or scan for functions
if hasattr(module, "tool") and callable(module.tool):
return module.tool
# Heuristic: Find first function that looks like a tool?
# For now, require 'tool' export or function with same name as folder
if hasattr(module, tool_name) and callable(getattr(module, tool_name)):
return getattr(module, tool_name)
logger.warning(f"No callable 'tool' or '{tool_name}' found in {module_file}")
except Exception as e:
logger.error(f"Failed to load tool {tool_name}: {e}")
return None
# --- MCP Registry ---
def list_mcp_servers(self) -> List[MCPServerConfig]:
"""Return all configured MCP servers."""
return list(self._mcp_registry.values())
def add_mcp_server(self, config: MCPServerConfig):
"""Register a new MCP server."""
self._mcp_registry[config.name] = config
logger.info(f"Registered MCP Server: {config.name}")
def remove_mcp_server(self, name: str) -> bool:
"""Remove an MCP server by name."""
if name in self._mcp_registry:
del self._mcp_registry[name]
logger.info(f"Removed MCP Server: {name}")
return True
return False
def get_mcp_toolset(self, name: str) -> Optional[MCPToolset]:
"""
Create an ADK MCPToolset instance from the configuration.
"""
config = self._mcp_registry.get(name)
if not config:
return None
try:
if config.type == "sse":
if not config.sse_config:
raise ValueError("Missing SSE config")
params = StreamableHTTPServerParams(url=config.sse_config.url)
return MCPToolset(connection_params=params)
elif config.type == "stdio":
if not config.stdio_config:
raise ValueError("Missing Stdio config")
server_params = StdioServerParameters(
command=config.stdio_config.command,
args=config.stdio_config.args,
env=config.stdio_config.env
)
params = StdioConnectionParams(server_params=server_params)
return MCPToolset(connection_params=params, tool_filter=config.tool_filter)
except Exception as e:
logger.error(f"Failed to create MCP Toolset '{name}': {e}")
return None
return None
# Singleton instance
_tool_service = ToolService()
def get_tool_service() -> ToolService:
return _tool_service

29
app/tools.py Normal file
View File

@@ -0,0 +1,29 @@
import logging
from typing import Dict, Callable, Any
logger = logging.getLogger(__name__)
def dummy_tool(query: str) -> dict:
"""
Dummy Tool: A placeholder tool for demonstration.
Args:
query: Any string input.
Returns:
A mock structured response.
"""
logger.info(f"[Tool] dummy_tool called with: {query}")
return {
"status": "success",
"message": f"Processed query: {query}",
"data": "dummy_value"
}
# ==========================
# 🛠️ Tool Registry
# ==========================
# Map tool names to their function implementation
TOOL_REGISTRY: Dict[str, Callable] = {
"dummy_tool": dummy_tool,
}

18
requirements.txt Normal file
View File

@@ -0,0 +1,18 @@
# Core Framework
fastapi>=0.115.0
uvicorn[standard]>=0.32.0
pydantic>=2.0.0
pydantic-settings>=2.6.0
# Google ADK
google-adk>=1.2.0
# Database
asyncpg>=0.30.0
sqlalchemy[asyncio]>=2.0.0
# HTTP Client
httpx>=0.28.0
# Environment
python-dotenv>=1.0.0