"""Shared tool pipeline — single source of truth for user context, tool set activation, and integration tool building. Extracted from main.py and telegram_bot.py (T009) to eliminate code duplication. Both entry points call these functions instead of maintaining their own copies. T015/T020/T021: Added UserContext with lazy credential decryption and per-user tool pipeline functions. """ from __future__ import annotations import logging from datetime import datetime, timezone from typing import Any from config import settings from tool_registry import registry as tool_registry logger = logging.getLogger(__name__) # ── UserContext (T015) ─────────────────────────────────────────── class UserContext: """Per-request context that lazily decrypts user credentials. Credentials are decrypted on first access via ``get_credential()``, cached for the duration of one request, and discarded when the object goes out of scope. Expired credentials return ``None`` to signal re-provisioning. """ def __init__(self, user: Any, store: Any) -> None: from user_store import User # deferred to avoid circular import self.user: User = user self._store = store self._cache: dict[str, str | None] = {} def get_credential(self, service: str) -> str | None: """Return the decrypted API token for *service*, or None.""" if service in self._cache: return self._cache[service] from user_store import decrypt cred = self._store.get_credential(self.user.id, service) if cred is None: self._cache[service] = None return None # Check expiry if cred.expires_at: try: exp = datetime.fromisoformat(cred.expires_at) if exp < datetime.now(timezone.utc): logger.warning( "Credential for %s/%s expired at %s", self.user.id, service, cred.expires_at, ) self._cache[service] = None return None except ValueError: pass # ignore unparseable expiry — treat as valid try: token = decrypt(cred.encrypted_token) except Exception: logger.exception("Failed to decrypt credential for %s/%s", self.user.id, service) self._cache[service] = None return None self._cache[service] = token # Touch last_used_at self._store.touch_credential(self.user.id, service) return token # ── Shared pipeline functions ──────────────────────────────────── # Service name → context key mapping used by tool set factories. _SERVICE_CONTEXT_KEYS: dict[str, dict[str, str]] = { "vikunja": { "vikunja_api_url": "VIKUNJA_API_URL", "vikunja_api_key": "__credential__", }, "karakeep": { "karakeep_api_url": "KARAKEEP_API_URL", "karakeep_api_key": "__credential__", }, } def active_toolset_names(user: Any | None = None) -> list[str]: """Return names of tool sets whose required credentials are available. When *user* is given (a ``User`` from user_store), credentials are checked via the credential vault. When ``None``, falls back to the legacy env-var check for backwards compatibility. """ if user is None: # Legacy path — env-var based ctx = _build_legacy_context() active: list[str] = [] for name, ts in tool_registry.available.items(): if all(ctx.get(k) for k in ts.required_keys): active.append(name) return active # Per-user path from user_store import get_store store = get_store() uctx = UserContext(user, store) ctx = _build_user_context_dict(user, uctx) active = [] for name, ts in tool_registry.available.items(): if all(ctx.get(k) for k in ts.required_keys): active.append(name) return active def build_user_context(user: Any | None = None) -> dict[str, Any]: """Build the context dict consumed by tool set factories. When *user* is given, credentials come from the per-user vault. When ``None``, falls back to env-var credentials. """ if user is None: return _build_legacy_context() from user_store import get_store store = get_store() uctx = UserContext(user, store) return _build_user_context_dict(user, uctx) def build_integration_tools( user: Any | None = None, *, thread_id: str | None = None, system_prompt: str | None = None, ) -> list: """Build Copilot SDK tools for all active integrations. Always returns a list (possibly empty), never None. """ active = active_toolset_names(user) if not active: return [] ctx = build_user_context(user) if thread_id is not None: ctx["_thread_id"] = thread_id if system_prompt is not None: ctx["_system_prompt"] = system_prompt return tool_registry.get_tools(active, ctx) async def try_provision(user: Any, service: str) -> str | None: """Attempt to provision *service* for *user* with lock. Returns a human-readable status message, or None on success. Uses a per-user per-service lock to prevent concurrent provisioning. """ from provisioners.base import provisioner_registry from user_store import get_store provisioner = provisioner_registry.get(service) if provisioner is None: return f"No provisioner available for {service}." store = get_store() # Check if already provisioned existing = store.get_credential(user.id, service) if existing: return None # already has credentials lock = provisioner_registry.get_lock(user.id, service) if lock.locked(): return f"Provisioning {service} is already in progress." async with lock: # Double-check after acquiring lock existing = store.get_credential(user.id, service) if existing: return None try: result = await provisioner.provision(user, store) except Exception: logger.exception("Provisioning %s for user %s failed", service, user.id) store.log_provisioning(user.id, service, "provision_failed", '{"error": "unhandled exception"}') return f"Failed to set up {service}. The error has been logged." if not result.success: return result.error or f"Failed to set up {service}." # Mark onboarding complete on first successful provisioning if not user.onboarding_complete: store.set_onboarding_complete(user.id) user.onboarding_complete = True return None def get_provisionable_services(user: Any) -> list[dict[str, str]]: """Return list of services that could be provisioned for *user*.""" from provisioners.base import provisioner_registry from user_store import get_store store = get_store() existing = store.get_credentials(user.id) result = [] for name, prov in provisioner_registry.available.items(): status = "active" if name in existing else "available" result.append( { "service": name, "capabilities": ", ".join(prov.capabilities), "status": status, } ) return result def build_onboarding_fragment(user: Any) -> str | None: """Return a system prompt fragment for new/onboarding users (T035).""" if user.onboarding_complete: return None return ( "This is a new user who hasn't set up any services yet. " "Welcome them warmly, explain what you can help with, " "and offer to set up their accounts. Available services: " + ", ".join( f"{s['capabilities']} ({s['service']})" for s in get_provisionable_services(user) if s["status"] == "available" ) + ". Ask them which services they'd like to activate." ) def build_capability_fragment(user: Any) -> str | None: """Return a system prompt fragment showing active/available capabilities (T036).""" services = get_provisionable_services(user) if not services: return None active = [s for s in services if s["status"] == "active"] available = [s for s in services if s["status"] == "available"] parts = [] if active: parts.append("Active services: " + ", ".join(f"{s['capabilities']} ({s['service']})" for s in active)) if available: parts.append( "Available (say 'set up ' to activate): " + ", ".join(f"{s['capabilities']} ({s['service']})" for s in available) ) return "\n".join(parts) def build_provisioning_confirmation(service: str, result: Any) -> str | None: """Return a system prompt fragment confirming provisioning (T038).""" from config import settings as _settings if not result or not result.success: return None parts = [f"I just created a {service} account for this user."] if result.service_username: parts.append(f"Username: {result.service_username}.") if result.service_url: parts.append(f"Service URL: {result.service_url}") if not _settings.ALLOW_CREDENTIAL_REVEAL_IN_CHAT: parts.append("Do NOT reveal the password in chat — Telegram messages are not E2E encrypted.") return " ".join(parts) # ── Internal helpers ───────────────────────────────────────────── def _build_legacy_context() -> dict[str, Any]: """Build user context from static env-var settings (owner-only path).""" return { "vikunja_api_url": settings.VIKUNJA_API_URL, "vikunja_api_key": settings.VIKUNJA_API_KEY, "vikunja_memory_path": settings.VIKUNJA_MEMORY_PATH, "memory_owner": "default", "karakeep_api_url": settings.KARAKEEP_API_URL, "karakeep_api_key": settings.KARAKEEP_API_KEY, "_user": None, } def _build_user_context_dict(user: Any, uctx: UserContext) -> dict[str, Any]: """Build a context dict from per-user credentials.""" ctx: dict[str, Any] = { "vikunja_memory_path": settings.VIKUNJA_MEMORY_PATH, "memory_owner": user.id, "_user": user, # passed through for meta-tools } for service, key_map in _SERVICE_CONTEXT_KEYS.items(): for ctx_key, source in key_map.items(): if source == "__credential__": val = uctx.get_credential(service) else: val = getattr(settings, source, "") if val: ctx[ctx_key] = val return ctx