betterbot/tool_registry.py
Andre K e68c84424f
Some checks failed
Deploy BetterBot / deploy (push) Failing after 3s
Deploy BetterBot / notify (push) Successful in 3s
feat: fork from CodeAnywhere framework
Replace standalone Telegram bot with full CodeAnywhere framework fork.
BetterBot shares all framework code and customizes only:
- instance.py: BetterBot identity, system prompt, feature flags
- tools/site_editing/: list_files, read_file, write_file with auto git push
- .env: model defaults and site directory paths
- compose/: Docker setup with betterlifesg + memoraiz mounts
- deploy script: RackNerd with Infisical secrets
2026-04-19 08:01:27 +08:00

211 lines
7.4 KiB
Python

"""Unified tool registry for the agent backend.
A ToolSet groups related tools (e.g. "vikunja", "karakeep") with:
- a factory that builds Copilot SDK Tool objects given per-user context
- a system prompt fragment injected when the tool set is active
- an optional capability label (e.g. "tasks", "bookmarks")
The registry collects tool sets and resolves the right tools + prompt
fragments for a given user at request time.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable
from copilot import define_tool
from copilot.tools import Tool, ToolInvocation
from pydantic import create_model
logger = logging.getLogger(__name__)
# ── ToolSet ──────────────────────────────────────────────────────
ToolFactory = Callable[[dict[str, Any]], list[Tool]]
@dataclass
class ToolSet:
"""A named collection of tools with a system prompt fragment.
Attributes:
name: Unique identifier (e.g. "vikunja", "karakeep").
description: Human-readable description displayed during onboarding.
capability: Abstract capability label (e.g. "tasks", "bookmarks").
Multiple tool sets can share a capability, but a user
should only have one active per capability.
system_prompt_fragment: Text appended to the system message when
this tool set is active for the user.
build_tools: Factory function that receives a user_context dict
(credentials, config) and returns Copilot SDK Tool
instances. The dict keys are tool-set-specific.
required_keys: Context keys that must be present for this tool set
to be usable (e.g. ["vikunja_api_url", "vikunja_api_key"]).
"""
name: str
description: str
capability: str
system_prompt_fragment: str
build_tools: ToolFactory
required_keys: list[str] = field(default_factory=list)
# ── ToolRegistry ─────────────────────────────────────────────────
class ToolRegistry:
"""Central registry of available tool sets."""
def __init__(self) -> None:
self._toolsets: dict[str, ToolSet] = {}
def register(self, toolset: ToolSet) -> None:
if toolset.name in self._toolsets:
logger.warning("Replacing existing tool set %r", toolset.name)
self._toolsets[toolset.name] = toolset
logger.info(
"Registered tool set %r (%s, %d required keys)",
toolset.name,
toolset.capability,
len(toolset.required_keys),
)
@property
def available(self) -> dict[str, ToolSet]:
return dict(self._toolsets)
def get_tools(
self,
active_names: list[str],
user_context: dict[str, Any],
) -> list[Tool]:
"""Build Copilot SDK tools for the requested tool sets.
Skips tool sets whose required context keys are missing.
"""
tools: list[Tool] = []
for name in active_names:
ts = self._toolsets.get(name)
if ts is None:
logger.warning("Requested unknown tool set %r — skipped", name)
continue
missing = [k for k in ts.required_keys if not user_context.get(k)]
if missing:
logger.warning(
"Tool set %r skipped: missing context keys %s",
name,
missing,
)
continue
try:
tools.extend(ts.build_tools(user_context))
except Exception:
logger.exception("Failed to build tools for %r", name)
return tools
def get_system_prompt_fragments(self, active_names: list[str]) -> list[str]:
"""Return system prompt fragments for the given tool sets."""
fragments: list[str] = []
for name in active_names:
ts = self._toolsets.get(name)
if ts and ts.system_prompt_fragment:
fragments.append(ts.system_prompt_fragment)
return fragments
# ── OpenAI schema bridge ─────────────────────────────────────────
def _json_type_to_python(json_type: str) -> type:
"""Map JSON Schema type strings to Python types for Pydantic."""
mapping: dict[str, type] = {
"string": str,
"integer": int,
"number": float,
"boolean": bool,
"array": list,
"object": dict,
}
return mapping.get(json_type, str)
def openai_tools_to_copilot(
schemas: list[dict[str, Any]],
dispatcher: Callable[..., Awaitable[str]],
context_kwargs: dict[str, Any] | None = None,
) -> list[Tool]:
"""Convert OpenAI function-calling tool schemas + a dispatcher into
Copilot SDK Tool objects.
Args:
schemas: List of OpenAI tool dicts ({"type":"function","function":{...}}).
dispatcher: Async callable with signature
``async def dispatcher(name, arguments, **context_kwargs) -> str``.
It receives the tool name, parsed argument dict, and any extra
keyword arguments from *context_kwargs*.
context_kwargs: Extra keyword arguments forwarded to every dispatcher
call (e.g. vikunja client, memory store).
Returns:
List of Copilot SDK Tool objects ready to pass to create_session().
"""
extra = context_kwargs or {}
tools: list[Tool] = []
for schema in schemas:
func = schema.get("function", {})
name: str = func.get("name", "")
description: str = func.get("description", "")
params_spec: dict = func.get("parameters", {})
properties: dict = params_spec.get("properties", {})
required_fields: list[str] = params_spec.get("required", [])
if not name:
continue
# Build Pydantic model fields dynamically
fields: dict[str, Any] = {}
for prop_name, prop_def in properties.items():
py_type = _json_type_to_python(prop_def.get("type", "string"))
# All non-required fields are optional with None default
if prop_name in required_fields:
fields[prop_name] = (py_type, ...)
else:
fields[prop_name] = (py_type | None, None)
# Create a unique Pydantic model class
model_name = f"Params_{name}"
params_model = create_model(model_name, **fields) # type: ignore[call-overload]
# Capture loop variables in closure
_name = name
_extra = extra
async def _handler(
params: Any,
invocation: ToolInvocation,
*,
_tool_name: str = _name,
_ctx: dict[str, Any] = _extra,
) -> str:
args = params.model_dump(exclude_none=True)
return await dispatcher(_tool_name, args, **_ctx)
tool = define_tool(
name=name,
description=description,
handler=_handler,
params_type=params_model,
)
tools.append(tool)
return tools
# ── Module-level singleton ───────────────────────────────────────
registry = ToolRegistry()