首次提交
This commit is contained in:
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
143
app/agent_factory.py
Normal file
143
app/agent_factory.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import logging
|
||||
from typing import Optional, List, Callable, Dict, Any, Union
|
||||
|
||||
from google.adk.agents.llm_agent import LlmAgent, Agent
|
||||
from google.adk.models.lite_llm import LiteLlm
|
||||
from google.adk.planners import BuiltInPlanner
|
||||
from google.adk.code_executors import BuiltInCodeExecutor
|
||||
from google.genai import types
|
||||
|
||||
from app.config import settings
|
||||
from app.models import AgentConfig, LLMConfig, GenerationConfig
|
||||
from app.services.tool_service import get_tool_service
|
||||
|
||||
from app.tools import TOOL_REGISTRY
|
||||
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ==========================
|
||||
# 🏭 Agent Factory
|
||||
# ==========================
|
||||
|
||||
def _create_llm_model(
|
||||
model_name: str,
|
||||
llm_config: Optional[LLMConfig] = None
|
||||
) -> Union[str, LiteLlm]:
|
||||
"""
|
||||
Create the LLM model instance or return model identifier string.
|
||||
|
||||
Logic:
|
||||
1. If model name implies Gemini (official Google model), return the string directly.
|
||||
ADK's LlmAgent handles 'gemini-...' strings using the native google-genai SDK.
|
||||
2. For other models (GPT, Claude, etc.), wrap in LiteLlm adapter.
|
||||
"""
|
||||
# Normalize comparison
|
||||
name_lower = model_name.lower()
|
||||
|
||||
# Check if likely native Gemini (no provider prefix)
|
||||
# Exclude obvious proxy prefixes that require LiteLLM (e.g. openai/gemini...)
|
||||
provider_prefixes = ["openai/", "azure/", "anthropic/", "bedrock/", "mistral/"]
|
||||
has_provider_prefix = any(p in name_lower for p in provider_prefixes)
|
||||
|
||||
is_gemini = "gemini" in name_lower and not has_provider_prefix
|
||||
|
||||
# Check if custom configuration forces a specific path
|
||||
# If API Base is provided and distinct from default Google, usually implies using LiteLLM/Proxy
|
||||
# even for Gemini models (e.g. via OpenAI-compatible endpoint).
|
||||
has_custom_base = llm_config and llm_config.api_base and "googleapis.com" not in llm_config.api_base
|
||||
|
||||
if is_gemini and not has_custom_base:
|
||||
logger.info(f"Using Native Gemini Model: {model_name}")
|
||||
return model_name
|
||||
|
||||
# Fallback / Non-Gemini -> Use LiteLLM
|
||||
api_key = (llm_config.api_key if llm_config and llm_config.api_key else settings.LLM_API_KEY)
|
||||
api_base = (llm_config.api_base if llm_config and llm_config.api_base else settings.LLM_API_BASE)
|
||||
|
||||
logger.info(f"Using LiteLLM for: {model_name} (Base: {api_base})")
|
||||
return LiteLlm(
|
||||
model=model_name,
|
||||
api_base=api_base,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
|
||||
def create_agent(config: AgentConfig) -> Agent:
|
||||
"""
|
||||
Create a fully configured ADK Agent based on the provided AgentConfig.
|
||||
"""
|
||||
logger.info(f"Creating Agent: {config.name} ({config.model})")
|
||||
|
||||
# 1. Model Initialization
|
||||
# Returns either a str (for native Gemini) or LiteLlm object
|
||||
llm = _create_llm_model(config.model, config.llm_config)
|
||||
|
||||
# 2. Tools Selection
|
||||
selected_tools = []
|
||||
tool_service = get_tool_service()
|
||||
|
||||
for tool_name in config.tools:
|
||||
# A. Check Legacy/Hardcoded Registry
|
||||
if tool_name in TOOL_REGISTRY:
|
||||
selected_tools.append(TOOL_REGISTRY[tool_name])
|
||||
continue
|
||||
|
||||
# B. Check Local Tools (tools/ folder)
|
||||
local_tool = tool_service.load_local_tool(tool_name)
|
||||
if local_tool:
|
||||
selected_tools.append(local_tool)
|
||||
continue
|
||||
|
||||
# C. Check MCP Servers
|
||||
mcp_tool = tool_service.get_mcp_toolset(tool_name)
|
||||
if mcp_tool:
|
||||
selected_tools.append(mcp_tool)
|
||||
continue
|
||||
|
||||
logger.warning(f"Tool '{tool_name}' not found (checked Registry, Local, MCP). Skipping.")
|
||||
|
||||
# 3. Code Execution
|
||||
code_executor = None
|
||||
if config.enable_code_execution:
|
||||
logger.info("Enabling BuiltInCodeExecutor")
|
||||
code_executor = BuiltInCodeExecutor()
|
||||
|
||||
# 4. Planner / Thinking
|
||||
# Only applicable for models that support it (mostly Gemini)
|
||||
planner = None
|
||||
if config.thinking_config:
|
||||
logger.info(f"Enabling BuiltInPlanner with budget {config.thinking_config.thinking_budget}")
|
||||
t_config = types.ThinkingConfig(
|
||||
include_thoughts=config.thinking_config.include_thoughts,
|
||||
thinking_budget=config.thinking_config.thinking_budget
|
||||
)
|
||||
planner = BuiltInPlanner(thinking_config=t_config)
|
||||
|
||||
# 5. Generation Config
|
||||
gen_config = None
|
||||
if config.generation_config:
|
||||
g_params = {}
|
||||
if config.generation_config.temperature is not None:
|
||||
g_params["temperature"] = config.generation_config.temperature
|
||||
if config.generation_config.max_output_tokens is not None:
|
||||
g_params["max_output_tokens"] = config.generation_config.max_output_tokens
|
||||
if config.generation_config.top_p is not None:
|
||||
g_params["top_p"] = config.generation_config.top_p
|
||||
|
||||
if g_params:
|
||||
gen_config = types.GenerateContentConfig(**g_params)
|
||||
|
||||
# 6. Assemble LlmAgent
|
||||
return LlmAgent(
|
||||
name=config.name,
|
||||
model=llm,
|
||||
description=config.description or "",
|
||||
instruction=config.instruction,
|
||||
tools=selected_tools,
|
||||
code_executor=code_executor,
|
||||
planner=planner,
|
||||
generate_content_config=gen_config
|
||||
)
|
||||
30
app/config.py
Normal file
30
app/config.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# Database Configuration
|
||||
DATABASE_URL: str = "postgresql+asyncpg://myuser:mypassword@127.0.0.1:5432/mydatabase"
|
||||
|
||||
# LLM API Configuration
|
||||
LLM_API_KEY: str
|
||||
LLM_API_BASE: str = "https://api.openai.com/v1"
|
||||
LLM_DEFAULT_MODEL: str = "openai/gpt-4o"
|
||||
|
||||
# App Configuration
|
||||
DEFAULT_APP_NAME: str = "adk_chat_app"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance."""
|
||||
return Settings()
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
81
app/main.py
Normal file
81
app/main.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.agents.run_config import RunConfig, StreamingMode
|
||||
from google.genai import types
|
||||
|
||||
from app.config import settings
|
||||
|
||||
from app.services import (
|
||||
get_session_service,
|
||||
init_session_service,
|
||||
close_session_service
|
||||
)
|
||||
from app.agent_factory import create_agent
|
||||
|
||||
# Routers
|
||||
from app.routers.agents import router as agents_router
|
||||
from app.routers.sessions import router as sessions_router
|
||||
from app.routers.tools import router as tools_router
|
||||
|
||||
|
||||
# --- Logging Configuration ---
|
||||
|
||||
# --- Logging Configuration ---
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --- Application Lifespan ---
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager for startup/shutdown events."""
|
||||
logger.info("Starting up ADK Chat Backend...")
|
||||
await init_session_service()
|
||||
logger.info("Application started successfully")
|
||||
yield
|
||||
logger.info("Shutting down ADK Chat Backend...")
|
||||
await close_session_service()
|
||||
logger.info("Application shutdown complete")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="ADK Enterprise Chat Backend",
|
||||
description="LLM Chat backend powered by Google ADK with PostgreSQL persistence",
|
||||
version="2.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
|
||||
# --- CORS Configuration ---
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# --- Register Routers ---
|
||||
app.include_router(agents_router)
|
||||
app.include_router(sessions_router)
|
||||
app.include_router(tools_router)
|
||||
|
||||
|
||||
# --- Helper ---
|
||||
def get_streaming_mode(mode_str: str) -> StreamingMode:
|
||||
return StreamingMode.SSE if mode_str.lower() == "sse" else StreamingMode.NONE
|
||||
|
||||
|
||||
|
||||
128
app/models.py
Normal file
128
app/models.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Any
|
||||
|
||||
|
||||
from typing import List, Optional, Any, Dict, Literal, Union
|
||||
|
||||
# --- MCP Tool Configuration Models ---
|
||||
|
||||
class StdioConfig(BaseModel):
|
||||
"""Configuration for Stdio-based MCP server."""
|
||||
command: str = Field(..., description="Executable command (e.g., 'npx', 'python')")
|
||||
args: List[str] = Field(default_factory=list, description="Command arguments")
|
||||
env: Optional[Dict[str, str]] = Field(None, description="Environment variables")
|
||||
|
||||
class SSEConfig(BaseModel):
|
||||
"""Configuration for SSE-based MCP server."""
|
||||
url: str = Field(..., description="Server URL (e.g., 'http://localhost:8000/sse')")
|
||||
|
||||
class MCPServerConfig(BaseModel):
|
||||
"""Configuration for an MCP Server connection."""
|
||||
name: str = Field(..., description="Unique name for the toolset")
|
||||
type: Literal["stdio", "sse"] = Field("stdio", description="Connection type")
|
||||
stdio_config: Optional[StdioConfig] = Field(None, description="Stdio configuration")
|
||||
sse_config: Optional[SSEConfig] = Field(None, description="SSE configuration")
|
||||
tool_filter: Optional[List[str]] = Field(None, description="Optional whitelist of tools to enable")
|
||||
|
||||
# --- LLM Configuration Models ---
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""Custom LLM API configuration for per-request override."""
|
||||
api_key: Optional[str] = Field(None, description="API key for the LLM service")
|
||||
api_base: Optional[str] = Field(None, description="Base URL for the LLM API")
|
||||
model: Optional[str] = Field(None, description="Model name (e.g., 'openai/gpt-4o')")
|
||||
|
||||
|
||||
class GenerationConfig(BaseModel):
|
||||
"""LLM generation parameters for fine-tuning responses."""
|
||||
temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="Randomness (0.0-2.0)")
|
||||
max_output_tokens: Optional[int] = Field(None, gt=0, description="Max response length")
|
||||
top_p: Optional[float] = Field(None, ge=0.0, le=1.0, description="Nucleus sampling")
|
||||
top_k: Optional[int] = Field(None, gt=0, description="Top-k sampling")
|
||||
|
||||
|
||||
# --- Request Models ---
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""Request to create a new chat session."""
|
||||
user_id: str
|
||||
app_name: Optional[str] = None
|
||||
|
||||
|
||||
|
||||
class ThinkingConfigModel(BaseModel):
|
||||
"""Configuration for BuiltInPlanner thinking features."""
|
||||
include_thoughts: bool = Field(True, description="Whether to include internal thoughts in the response")
|
||||
thinking_budget: int = Field(1024, gt=0, description="Token budget for thinking")
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
"""Comprehensive configuration for creating an ADK Agent."""
|
||||
name: str = Field(..., description="Unique name for the agent")
|
||||
description: Optional[str] = Field(None, description="Description of agent capabilities")
|
||||
model: str = Field(..., description="Model identifier (e.g., 'gemini-2.5-flash')")
|
||||
instruction: str = Field(..., description="System instruction / persona for the agent")
|
||||
tools: List[str] = Field(default_factory=list, description="List of tool names to enable")
|
||||
enable_code_execution: bool = Field(False, description="Enable built-in code executor")
|
||||
|
||||
thinking_config: Optional[ThinkingConfigModel] = Field(None, description="Configuration for reasoning/planning")
|
||||
|
||||
# Nested configurations
|
||||
llm_config: Optional[LLMConfig] = Field(None, description="Custom LLM API connection details")
|
||||
generation_config: Optional[GenerationConfig] = Field(None, description="Generation parameters")
|
||||
|
||||
|
||||
|
||||
class AgentDefinition(AgentConfig):
|
||||
"""Persisted Agent definition with unique ID."""
|
||||
id: str = Field(..., description="Unique identifier for the agent")
|
||||
created_at: float = Field(default_factory=lambda: __import__("time").time(), description="Creation timestamp")
|
||||
updated_at: float = Field(default_factory=lambda: __import__("time").time(), description="Last update timestamp")
|
||||
|
||||
|
||||
class ChatTurnRequest(BaseModel):
|
||||
"""Request for a single chat turn with a specific agent."""
|
||||
agent_id: str = Field(..., description="ID of the agent to use for this turn")
|
||||
message: str = Field(..., description="User message content")
|
||||
# Optional overrides for this specific turn
|
||||
streaming_mode: Optional[str] = Field("sse", description="Streaming mode: 'sse', 'none'")
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""
|
||||
LEGACY: Request for chat completion.
|
||||
Used by the old /chat/stream endpoint.
|
||||
"""
|
||||
session_id: str
|
||||
user_id: str
|
||||
message: str
|
||||
app_name: Optional[str] = None
|
||||
|
||||
agent_config: Optional[AgentConfig] = Field(None, description="Ephemeral agent config")
|
||||
max_llm_calls: Optional[int] = Field(500, gt=0)
|
||||
streaming_mode: Optional[str] = "sse"
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
"""Response for session operations."""
|
||||
id: str
|
||||
app_name: str
|
||||
user_id: str
|
||||
updated_at: Any = None
|
||||
|
||||
|
||||
class HistoryEvent(BaseModel):
|
||||
"""A single event in chat history."""
|
||||
type: str
|
||||
role: str
|
||||
content: str
|
||||
agent_name: Optional[str] = None
|
||||
timestamp: Optional[Any] = None
|
||||
invocation_id: Optional[str] = None
|
||||
|
||||
|
||||
class SessionDetailResponse(SessionResponse):
|
||||
"""Session details including chat history."""
|
||||
history: List[HistoryEvent] = []
|
||||
events_count: int = 0
|
||||
created_at: Any = None
|
||||
0
app/routers/__init__.py
Normal file
0
app/routers/__init__.py
Normal file
68
app/routers/agents.py
Normal file
68
app/routers/agents.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import List
|
||||
|
||||
from app.models import AgentConfig, AgentDefinition
|
||||
from app.services.agent_service import get_agent_service, AgentService
|
||||
|
||||
# Initialize Router
|
||||
router = APIRouter(
|
||||
prefix="/agents",
|
||||
tags=["Agents"]
|
||||
)
|
||||
|
||||
# ==========================
|
||||
# 🤖 Agent Management API
|
||||
# ==========================
|
||||
|
||||
@router.post("", response_model=AgentDefinition)
|
||||
async def create_agent_def(
|
||||
config: AgentConfig,
|
||||
service: AgentService = Depends(get_agent_service)
|
||||
):
|
||||
"""Create a new Agent definition."""
|
||||
return service.create_agent(config)
|
||||
|
||||
|
||||
@router.get("", response_model=List[AgentDefinition])
|
||||
async def list_agents(
|
||||
service: AgentService = Depends(get_agent_service)
|
||||
):
|
||||
"""List all available Agents."""
|
||||
return service.list_agents()
|
||||
|
||||
|
||||
@router.get("/{agent_id}", response_model=AgentDefinition)
|
||||
async def get_agent_def(
|
||||
agent_id: str,
|
||||
service: AgentService = Depends(get_agent_service)
|
||||
):
|
||||
"""Get specific Agent definition."""
|
||||
agent = service.get_agent(agent_id)
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
return agent
|
||||
|
||||
|
||||
@router.put("/{agent_id}", response_model=AgentDefinition)
|
||||
async def update_agent_def(
|
||||
agent_id: str,
|
||||
config: AgentConfig,
|
||||
service: AgentService = Depends(get_agent_service)
|
||||
):
|
||||
"""Update an Agent definition."""
|
||||
agent = service.update_agent(agent_id, config)
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
return agent
|
||||
|
||||
|
||||
@router.delete("/{agent_id}")
|
||||
async def delete_agent_def(
|
||||
agent_id: str,
|
||||
service: AgentService = Depends(get_agent_service)
|
||||
):
|
||||
"""Delete an Agent definition."""
|
||||
success = service.delete_agent(agent_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
return {"status": "deleted", "agent_id": agent_id}
|
||||
317
app/routers/sessions.py
Normal file
317
app/routers/sessions.py
Normal file
@@ -0,0 +1,317 @@
|
||||
from typing import List, AsyncGenerator
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.agents.run_config import RunConfig, StreamingMode
|
||||
from google.genai import types
|
||||
|
||||
from app.config import settings
|
||||
from app.models import (
|
||||
CreateSessionRequest,
|
||||
ChatTurnRequest,
|
||||
SessionResponse,
|
||||
SessionDetailResponse,
|
||||
HistoryEvent,
|
||||
AgentConfig
|
||||
)
|
||||
from app.services import get_session_service
|
||||
from app.services.agent_service import get_agent_service, AgentService
|
||||
from app.agent_factory import create_agent
|
||||
|
||||
# Initialize Router
|
||||
router = APIRouter(
|
||||
prefix="/sessions",
|
||||
tags=["Sessions"]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --- Helper ---
|
||||
def get_streaming_mode(mode_str: str) -> StreamingMode:
|
||||
"""Convert string streaming mode to StreamingMode enum."""
|
||||
mode_map = {
|
||||
"none": StreamingMode.NONE,
|
||||
"sse": StreamingMode.SSE,
|
||||
"bidi": StreamingMode.BIDI,
|
||||
}
|
||||
return mode_map.get(mode_str.lower(), StreamingMode.SSE)
|
||||
|
||||
|
||||
# ==========================
|
||||
# 📋 Session CRUD
|
||||
# ==========================
|
||||
|
||||
@router.post("", response_model=SessionResponse)
|
||||
async def create_session(
|
||||
req: CreateSessionRequest,
|
||||
service=Depends(get_session_service)
|
||||
):
|
||||
"""Create a new chat session."""
|
||||
app_name = req.app_name or settings.DEFAULT_APP_NAME
|
||||
session = await service.create_session(app_name=app_name, user_id=req.user_id)
|
||||
return SessionResponse(
|
||||
id=session.id,
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
updated_at=getattr(session, 'last_update_time', getattr(session, 'updated_at', None))
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=List[SessionResponse])
|
||||
async def list_sessions(
|
||||
user_id: str,
|
||||
app_name: str = None,
|
||||
service=Depends(get_session_service)
|
||||
):
|
||||
"""List all sessions for a user."""
|
||||
app_name = app_name or settings.DEFAULT_APP_NAME
|
||||
response = await service.list_sessions(app_name=app_name, user_id=user_id)
|
||||
|
||||
return [
|
||||
SessionResponse(
|
||||
id=s.id,
|
||||
app_name=s.app_name,
|
||||
user_id=s.user_id,
|
||||
updated_at=getattr(s, 'last_update_time', getattr(s, 'updated_at', None))
|
||||
) for s in response.sessions
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{session_id}", response_model=SessionDetailResponse)
|
||||
async def get_session_history(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
app_name: str = None,
|
||||
service=Depends(get_session_service)
|
||||
):
|
||||
"""Get session details and chat history."""
|
||||
app_name = app_name or settings.DEFAULT_APP_NAME
|
||||
session = await service.get_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Convert history events
|
||||
history_events = []
|
||||
events = getattr(session, 'events', [])
|
||||
|
||||
# Filter events based on Rewind actions
|
||||
# ADK preserves all events in the log. We must retroactively remove rewound events for the view.
|
||||
valid_events = []
|
||||
for e in events:
|
||||
# Check for rewind action
|
||||
actions = getattr(e, 'actions', None)
|
||||
rewind_target_id = getattr(actions, 'rewind_before_invocation_id', None) if actions else None
|
||||
|
||||
if rewind_target_id:
|
||||
# Find the index where the truncated invocation began
|
||||
truncate_idx = -1
|
||||
for i, ve in enumerate(valid_events):
|
||||
if getattr(ve, 'invocation_id', None) == rewind_target_id:
|
||||
truncate_idx = i
|
||||
break
|
||||
|
||||
if truncate_idx != -1:
|
||||
logger.debug(f"Rewinding history to before {rewind_target_id}")
|
||||
valid_events = valid_events[:truncate_idx]
|
||||
else:
|
||||
valid_events.append(e)
|
||||
|
||||
for e in valid_events:
|
||||
content_str = ""
|
||||
if hasattr(e, 'content') and e.content:
|
||||
if isinstance(e.content, str):
|
||||
content_str = e.content
|
||||
elif hasattr(e.content, 'parts'):
|
||||
parts = e.content.parts
|
||||
for part in parts:
|
||||
if hasattr(part, 'text') and part.text:
|
||||
content_str += part.text
|
||||
|
||||
# Determine basic role (user vs model)
|
||||
# ADK usually sets author to the agent name or "user"
|
||||
author = getattr(e, 'author', 'unknown')
|
||||
role = "user" if author == "user" else "model"
|
||||
|
||||
# Agent Name is the specific author if it's a model response
|
||||
agent_name = author if role == "model" else None
|
||||
|
||||
# Extract timestamp
|
||||
timestamp = getattr(e, 'timestamp', getattr(e, 'created_at', None))
|
||||
invocation_id = getattr(e, 'invocation_id', None)
|
||||
|
||||
history_events.append(HistoryEvent(
|
||||
type="message",
|
||||
role=role,
|
||||
agent_name=agent_name,
|
||||
content=content_str,
|
||||
timestamp=timestamp,
|
||||
invocation_id=invocation_id
|
||||
))
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.id,
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
updated_at=getattr(session, 'last_update_time', getattr(session, 'updated_at', None)),
|
||||
created_at=getattr(session, 'created_at', getattr(session, 'create_time', None)),
|
||||
history=history_events,
|
||||
events_count=len(history_events)
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{session_id}")
|
||||
async def delete_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
app_name: str = None,
|
||||
service=Depends(get_session_service)
|
||||
):
|
||||
"""Delete a session."""
|
||||
app_name = app_name or settings.DEFAULT_APP_NAME
|
||||
await service.delete_session(app_name=app_name, user_id=user_id, session_id=session_id)
|
||||
return {"status": "deleted", "session_id": session_id}
|
||||
|
||||
|
||||
@router.post("/{session_id}/rewind")
|
||||
async def rewind_session_state(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
invocation_id: str,
|
||||
app_name: str = None,
|
||||
session_service=Depends(get_session_service),
|
||||
agent_service: AgentService = Depends(get_agent_service)
|
||||
):
|
||||
"""
|
||||
Rewind session state to before a specific invocation.
|
||||
Undoes changes from the specified invocation and subsequent ones.
|
||||
"""
|
||||
app_name = app_name or settings.DEFAULT_APP_NAME
|
||||
|
||||
# 1. Need a Runner instance to perform rewind.
|
||||
# The Runner requires an agent, but for rewind (session/log operation),
|
||||
# the specific agent configuration might not be critical if default is used,
|
||||
# assuming session persistence is agent-agnostic.
|
||||
# We'll use the default assistant to initialize the runner.
|
||||
agent_def = agent_service.get_agent("default-assistant")
|
||||
if not agent_def:
|
||||
# Fallback if default deleted
|
||||
agent_def = AgentConfig(
|
||||
name="rewinder",
|
||||
model=settings.LLM_DEFAULT_MODEL,
|
||||
instruction="",
|
||||
tools=[]
|
||||
)
|
||||
|
||||
agent = create_agent(agent_def)
|
||||
|
||||
runner = Runner(
|
||||
agent=agent,
|
||||
app_name=app_name,
|
||||
session_service=session_service
|
||||
)
|
||||
|
||||
try:
|
||||
await runner.rewind_async(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
rewind_before_invocation_id=invocation_id
|
||||
)
|
||||
return {"status": "success", "rewound_before": invocation_id}
|
||||
except Exception as e:
|
||||
logger.error(f"Rewind failed: {e}")
|
||||
# ADK might raise specific errors, generic catch for now
|
||||
raise HTTPException(status_code=500, detail=f"Rewind failed: {str(e)}")
|
||||
|
||||
|
||||
# ==========================
|
||||
# 💬 Session Chat
|
||||
# ==========================
|
||||
|
||||
@router.post("/{session_id}/chat")
|
||||
async def chat_with_agent(
|
||||
session_id: str,
|
||||
req: ChatTurnRequest,
|
||||
user_id: str,
|
||||
app_name: str = None,
|
||||
session_service=Depends(get_session_service),
|
||||
agent_service: AgentService = Depends(get_agent_service)
|
||||
):
|
||||
"""
|
||||
Chat with a specific Agent in a Session.
|
||||
Decoupled: Agent is loaded from AgentService, Session is loaded from SessionService.
|
||||
"""
|
||||
app_name = app_name or settings.DEFAULT_APP_NAME
|
||||
|
||||
# 1. Load Session
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# 2. Get Agent Definition
|
||||
agent_def = agent_service.get_agent(req.agent_id)
|
||||
if not agent_def:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{req.agent_id}' not found")
|
||||
|
||||
# 3. Create Runtime Agent
|
||||
agent = create_agent(agent_def)
|
||||
|
||||
# 4. Preparation for Run
|
||||
streaming_mode = get_streaming_mode(req.streaming_mode or "sse")
|
||||
run_config = RunConfig(
|
||||
streaming_mode=streaming_mode,
|
||||
max_llm_calls=500, # or configurable
|
||||
)
|
||||
|
||||
runner = Runner(
|
||||
agent=agent,
|
||||
app_name=app_name,
|
||||
session_service=session_service
|
||||
)
|
||||
|
||||
# 5. Stream Generator
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
new_msg = types.Content(
|
||||
role="user",
|
||||
parts=[types.Part(text=req.message)]
|
||||
)
|
||||
try:
|
||||
async for event in runner.run_async(
|
||||
session_id=session.id,
|
||||
user_id=user_id,
|
||||
new_message=new_msg,
|
||||
run_config=run_config,
|
||||
):
|
||||
if event.content:
|
||||
text_content = ""
|
||||
if hasattr(event.content, 'parts'):
|
||||
for part in event.content.parts:
|
||||
if hasattr(part, 'text') and part.text:
|
||||
text_content += part.text
|
||||
elif isinstance(event.content, str):
|
||||
text_content = event.content
|
||||
|
||||
if text_content:
|
||||
payload = {"type": "content", "text": text_content, "role": "model"}
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
yield f"data: {json.dumps({'type': 'done'})}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error during agent chat stream")
|
||||
err_payload = {"type": "error", "text": str(e)}
|
||||
yield f"data: {json.dumps(err_payload)}\n\n"
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
43
app/routers/tools.py
Normal file
43
app/routers/tools.py
Normal file
@@ -0,0 +1,43 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import List, Dict
|
||||
|
||||
from app.services.tool_service import ToolService, get_tool_service
|
||||
from app.models import MCPServerConfig
|
||||
|
||||
router = APIRouter(prefix="/tools", tags=["tools"])
|
||||
|
||||
@router.get("/local")
|
||||
async def list_local_tools(
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
) -> List[str]:
|
||||
"""List available local tools (scan 'tools/' directory)."""
|
||||
return service.get_local_tools()
|
||||
|
||||
@router.get("/mcp", response_model=List[MCPServerConfig])
|
||||
async def list_mcp_servers(
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""List configured MCP servers."""
|
||||
return service.list_mcp_servers()
|
||||
|
||||
@router.post("/mcp", response_model=MCPServerConfig)
|
||||
async def add_mcp_server(
|
||||
config: MCPServerConfig,
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""Register a new MCP server configuration."""
|
||||
# Check if exists? Overwrite behavior is default in dict
|
||||
service.add_mcp_server(config)
|
||||
return config
|
||||
|
||||
@router.delete("/mcp/{name}")
|
||||
async def remove_mcp_server(
|
||||
name: str,
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""Remove a configured MCP server."""
|
||||
deleted = service.remove_mcp_server(name)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail=f"MCP Server '{name}' not found")
|
||||
return {"status": "deleted", "name": name}
|
||||
5
app/services/__init__.py
Normal file
5
app/services/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .session_service import (
|
||||
init_session_service,
|
||||
close_session_service,
|
||||
get_session_service
|
||||
)
|
||||
89
app/services/agent_service.py
Normal file
89
app/services/agent_service.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import uuid
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
import logging
|
||||
|
||||
from app.config import settings
|
||||
from app.models import AgentConfig, AgentDefinition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AgentService:
|
||||
"""
|
||||
Service for managing Agent Definitions.
|
||||
Currently uses in-memory storage.
|
||||
Can be extended to support Database storage.
|
||||
"""
|
||||
def __init__(self):
|
||||
# In-memory store: {agent_id: AgentDefinition}
|
||||
self._store: Dict[str, AgentDefinition] = {}
|
||||
|
||||
# Populate with a default agent for convenience
|
||||
self._create_default_agent()
|
||||
|
||||
def _create_default_agent(self):
|
||||
default_id = "default-assistant"
|
||||
if default_id not in self._store:
|
||||
self._store[default_id] = AgentDefinition(
|
||||
id=default_id,
|
||||
name="Assistant",
|
||||
description="A helpful default AI assistant",
|
||||
model=settings.LLM_DEFAULT_MODEL,
|
||||
instruction="You are a helpful assistant.",
|
||||
tools=[],
|
||||
created_at=time.time(),
|
||||
updated_at=time.time()
|
||||
)
|
||||
|
||||
def create_agent(self, config: AgentConfig) -> AgentDefinition:
|
||||
"""Create and store a new agent definition."""
|
||||
agent_id = str(uuid.uuid4())
|
||||
definition = AgentDefinition(
|
||||
id=agent_id,
|
||||
**config.model_dump(),
|
||||
created_at=time.time(),
|
||||
updated_at=time.time()
|
||||
)
|
||||
self._store[agent_id] = definition
|
||||
logger.info(f"Created new agent definition: {agent_id}")
|
||||
return definition
|
||||
|
||||
def get_agent(self, agent_id: str) -> Optional[AgentDefinition]:
|
||||
"""Retrieve an agent definition by ID."""
|
||||
return self._store.get(agent_id)
|
||||
|
||||
def update_agent(self, agent_id: str, config: AgentConfig) -> Optional[AgentDefinition]:
|
||||
"""Update an existing agent definition."""
|
||||
if agent_id not in self._store:
|
||||
return None
|
||||
|
||||
# Preserve original creation time and ID
|
||||
original = self._store[agent_id]
|
||||
updated = AgentDefinition(
|
||||
id=agent_id,
|
||||
**config.model_dump(),
|
||||
created_at=original.created_at,
|
||||
updated_at=time.time()
|
||||
)
|
||||
self._store[agent_id] = updated
|
||||
logger.info(f"Updated agent definition: {agent_id}")
|
||||
return updated
|
||||
|
||||
def list_agents(self) -> List[AgentDefinition]:
|
||||
"""List all stored agents."""
|
||||
return list(self._store.values())
|
||||
|
||||
def delete_agent(self, agent_id: str) -> bool:
|
||||
"""Delete an agent definition."""
|
||||
if agent_id in self._store:
|
||||
del self._store[agent_id]
|
||||
logger.info(f"Deleted agent definition: {agent_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_agent_service = AgentService()
|
||||
|
||||
def get_agent_service() -> AgentService:
|
||||
return _agent_service
|
||||
40
app/services/session_service.py
Normal file
40
app/services/session_service.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import logging
|
||||
from google.adk.sessions import DatabaseSessionService
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global service instance
|
||||
_session_service: DatabaseSessionService | None = None
|
||||
|
||||
|
||||
async def init_session_service() -> None:
|
||||
"""Initialize the database session service on application startup."""
|
||||
global _session_service
|
||||
if _session_service is None:
|
||||
logger.info(f"Initializing DatabaseSessionService with {settings.DATABASE_URL}...")
|
||||
_session_service = DatabaseSessionService(db_url=settings.DATABASE_URL)
|
||||
logger.info("DatabaseSessionService initialized successfully")
|
||||
|
||||
|
||||
async def close_session_service() -> None:
|
||||
"""Cleanup session service resources on application shutdown."""
|
||||
global _session_service
|
||||
if _session_service:
|
||||
_session_service = None
|
||||
logger.info("DatabaseSessionService closed")
|
||||
|
||||
|
||||
def get_session_service() -> DatabaseSessionService:
|
||||
"""
|
||||
Get the session service instance for dependency injection.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If service is not initialized.
|
||||
"""
|
||||
if _session_service is None:
|
||||
raise RuntimeError("Session service not initialized. Application may not have started properly.")
|
||||
return _session_service
|
||||
|
||||
|
||||
|
||||
146
app/services/tool_service.py
Normal file
146
app/services/tool_service.py
Normal file
@@ -0,0 +1,146 @@
|
||||
|
||||
import os
|
||||
import logging
|
||||
import importlib.util
|
||||
import inspect
|
||||
from typing import List, Dict, Optional, Any, Callable
|
||||
from pathlib import Path
|
||||
|
||||
# ADK imports
|
||||
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
|
||||
from google.adk.tools.mcp_tool.mcp_session_manager import (
|
||||
StreamableHTTPServerParams,
|
||||
StdioConnectionParams,
|
||||
StdioServerParameters
|
||||
)
|
||||
|
||||
from app.models import MCPServerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolService:
|
||||
"""
|
||||
Service for managing available tools (Local and MCP).
|
||||
"""
|
||||
def __init__(self, tools_dir: str = "tools"):
|
||||
self.tools_dir = Path(tools_dir)
|
||||
self._mcp_registry: Dict[str, MCPServerConfig] = {}
|
||||
|
||||
# Ensure tools directory exists
|
||||
if not self.tools_dir.exists():
|
||||
self.tools_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# --- Local Tools Scanning ---
|
||||
|
||||
def get_local_tools(self) -> List[str]:
|
||||
"""
|
||||
Scan the 'tools/' directory for valid tool implementations.
|
||||
Returns a list of tool names (folder names).
|
||||
"""
|
||||
tools = []
|
||||
if not self.tools_dir.exists():
|
||||
return []
|
||||
|
||||
for item in self.tools_dir.iterdir():
|
||||
if item.is_dir():
|
||||
# Check if it has a valid entry point
|
||||
if (item / "tool.py").exists() or (item / "main.py").exists():
|
||||
tools.append(item.name)
|
||||
return tools
|
||||
|
||||
def load_local_tool(self, tool_name: str) -> Optional[Callable]:
|
||||
"""
|
||||
Dynamically load a local tool by name.
|
||||
Expects a function named 'tool' or the directory name in tool.py/main.py.
|
||||
"""
|
||||
tool_path = self.tools_dir / tool_name
|
||||
|
||||
# Try finding the module
|
||||
module_file = None
|
||||
if (tool_path / "tool.py").exists():
|
||||
module_file = tool_path / "tool.py"
|
||||
elif (tool_path / "main.py").exists():
|
||||
module_file = tool_path / "main.py"
|
||||
|
||||
if not module_file:
|
||||
logger.warning(f"Tool implementation not found for {tool_name}")
|
||||
return None
|
||||
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(f"tools.{tool_name}", module_file)
|
||||
if spec and spec.loader:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Look for a callable object named 'tool' or scan for functions
|
||||
if hasattr(module, "tool") and callable(module.tool):
|
||||
return module.tool
|
||||
|
||||
# Heuristic: Find first function that looks like a tool?
|
||||
# For now, require 'tool' export or function with same name as folder
|
||||
if hasattr(module, tool_name) and callable(getattr(module, tool_name)):
|
||||
return getattr(module, tool_name)
|
||||
|
||||
logger.warning(f"No callable 'tool' or '{tool_name}' found in {module_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load tool {tool_name}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
# --- MCP Registry ---
|
||||
|
||||
def list_mcp_servers(self) -> List[MCPServerConfig]:
|
||||
"""Return all configured MCP servers."""
|
||||
return list(self._mcp_registry.values())
|
||||
|
||||
def add_mcp_server(self, config: MCPServerConfig):
|
||||
"""Register a new MCP server."""
|
||||
self._mcp_registry[config.name] = config
|
||||
logger.info(f"Registered MCP Server: {config.name}")
|
||||
|
||||
def remove_mcp_server(self, name: str) -> bool:
|
||||
"""Remove an MCP server by name."""
|
||||
if name in self._mcp_registry:
|
||||
del self._mcp_registry[name]
|
||||
logger.info(f"Removed MCP Server: {name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_mcp_toolset(self, name: str) -> Optional[MCPToolset]:
|
||||
"""
|
||||
Create an ADK MCPToolset instance from the configuration.
|
||||
"""
|
||||
config = self._mcp_registry.get(name)
|
||||
if not config:
|
||||
return None
|
||||
|
||||
try:
|
||||
if config.type == "sse":
|
||||
if not config.sse_config:
|
||||
raise ValueError("Missing SSE config")
|
||||
params = StreamableHTTPServerParams(url=config.sse_config.url)
|
||||
return MCPToolset(connection_params=params)
|
||||
|
||||
elif config.type == "stdio":
|
||||
if not config.stdio_config:
|
||||
raise ValueError("Missing Stdio config")
|
||||
|
||||
server_params = StdioServerParameters(
|
||||
command=config.stdio_config.command,
|
||||
args=config.stdio_config.args,
|
||||
env=config.stdio_config.env
|
||||
)
|
||||
params = StdioConnectionParams(server_params=server_params)
|
||||
return MCPToolset(connection_params=params, tool_filter=config.tool_filter)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create MCP Toolset '{name}': {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
# Singleton instance
|
||||
_tool_service = ToolService()
|
||||
|
||||
def get_tool_service() -> ToolService:
|
||||
return _tool_service
|
||||
29
app/tools.py
Normal file
29
app/tools.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import logging
|
||||
from typing import Dict, Callable, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def dummy_tool(query: str) -> dict:
|
||||
"""
|
||||
Dummy Tool: A placeholder tool for demonstration.
|
||||
|
||||
Args:
|
||||
query: Any string input.
|
||||
|
||||
Returns:
|
||||
A mock structured response.
|
||||
"""
|
||||
logger.info(f"[Tool] dummy_tool called with: {query}")
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Processed query: {query}",
|
||||
"data": "dummy_value"
|
||||
}
|
||||
|
||||
# ==========================
|
||||
# 🛠️ Tool Registry
|
||||
# ==========================
|
||||
# Map tool names to their function implementation
|
||||
TOOL_REGISTRY: Dict[str, Callable] = {
|
||||
"dummy_tool": dummy_tool,
|
||||
}
|
||||
Reference in New Issue
Block a user