147 lines
5.1 KiB
Python
147 lines
5.1 KiB
Python
|
|
import os
|
|
import logging
|
|
import importlib.util
|
|
import inspect
|
|
from typing import List, Dict, Optional, Any, Callable
|
|
from pathlib import Path
|
|
|
|
# ADK imports
|
|
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
|
|
from google.adk.tools.mcp_tool.mcp_session_manager import (
|
|
StreamableHTTPServerParams,
|
|
StdioConnectionParams,
|
|
StdioServerParameters
|
|
)
|
|
|
|
from app.models import MCPServerConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ToolService:
|
|
"""
|
|
Service for managing available tools (Local and MCP).
|
|
"""
|
|
def __init__(self, tools_dir: str = "tools"):
|
|
self.tools_dir = Path(tools_dir)
|
|
self._mcp_registry: Dict[str, MCPServerConfig] = {}
|
|
|
|
# Ensure tools directory exists
|
|
if not self.tools_dir.exists():
|
|
self.tools_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# --- Local Tools Scanning ---
|
|
|
|
def get_local_tools(self) -> List[str]:
|
|
"""
|
|
Scan the 'tools/' directory for valid tool implementations.
|
|
Returns a list of tool names (folder names).
|
|
"""
|
|
tools = []
|
|
if not self.tools_dir.exists():
|
|
return []
|
|
|
|
for item in self.tools_dir.iterdir():
|
|
if item.is_dir():
|
|
# Check if it has a valid entry point
|
|
if (item / "tool.py").exists() or (item / "main.py").exists():
|
|
tools.append(item.name)
|
|
return tools
|
|
|
|
def load_local_tool(self, tool_name: str) -> Optional[Callable]:
|
|
"""
|
|
Dynamically load a local tool by name.
|
|
Expects a function named 'tool' or the directory name in tool.py/main.py.
|
|
"""
|
|
tool_path = self.tools_dir / tool_name
|
|
|
|
# Try finding the module
|
|
module_file = None
|
|
if (tool_path / "tool.py").exists():
|
|
module_file = tool_path / "tool.py"
|
|
elif (tool_path / "main.py").exists():
|
|
module_file = tool_path / "main.py"
|
|
|
|
if not module_file:
|
|
logger.warning(f"Tool implementation not found for {tool_name}")
|
|
return None
|
|
|
|
try:
|
|
spec = importlib.util.spec_from_file_location(f"tools.{tool_name}", module_file)
|
|
if spec and spec.loader:
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
|
|
# Look for a callable object named 'tool' or scan for functions
|
|
if hasattr(module, "tool") and callable(module.tool):
|
|
return module.tool
|
|
|
|
# Heuristic: Find first function that looks like a tool?
|
|
# For now, require 'tool' export or function with same name as folder
|
|
if hasattr(module, tool_name) and callable(getattr(module, tool_name)):
|
|
return getattr(module, tool_name)
|
|
|
|
logger.warning(f"No callable 'tool' or '{tool_name}' found in {module_file}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load tool {tool_name}: {e}")
|
|
|
|
return None
|
|
|
|
# --- MCP Registry ---
|
|
|
|
def list_mcp_servers(self) -> List[MCPServerConfig]:
|
|
"""Return all configured MCP servers."""
|
|
return list(self._mcp_registry.values())
|
|
|
|
def add_mcp_server(self, config: MCPServerConfig):
|
|
"""Register a new MCP server."""
|
|
self._mcp_registry[config.name] = config
|
|
logger.info(f"Registered MCP Server: {config.name}")
|
|
|
|
def remove_mcp_server(self, name: str) -> bool:
|
|
"""Remove an MCP server by name."""
|
|
if name in self._mcp_registry:
|
|
del self._mcp_registry[name]
|
|
logger.info(f"Removed MCP Server: {name}")
|
|
return True
|
|
return False
|
|
|
|
def get_mcp_toolset(self, name: str) -> Optional[MCPToolset]:
|
|
"""
|
|
Create an ADK MCPToolset instance from the configuration.
|
|
"""
|
|
config = self._mcp_registry.get(name)
|
|
if not config:
|
|
return None
|
|
|
|
try:
|
|
if config.type == "sse":
|
|
if not config.sse_config:
|
|
raise ValueError("Missing SSE config")
|
|
params = StreamableHTTPServerParams(url=config.sse_config.url)
|
|
return MCPToolset(connection_params=params)
|
|
|
|
elif config.type == "stdio":
|
|
if not config.stdio_config:
|
|
raise ValueError("Missing Stdio config")
|
|
|
|
server_params = StdioServerParameters(
|
|
command=config.stdio_config.command,
|
|
args=config.stdio_config.args,
|
|
env=config.stdio_config.env
|
|
)
|
|
params = StdioConnectionParams(server_params=server_params)
|
|
return MCPToolset(connection_params=params, tool_filter=config.tool_filter)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create MCP Toolset '{name}': {e}")
|
|
return None
|
|
|
|
return None
|
|
|
|
# Singleton instance
|
|
_tool_service = ToolService()
|
|
|
|
def get_tool_service() -> ToolService:
|
|
return _tool_service
|