feat: fork from CodeAnywhere framework
Some checks failed
Deploy BetterBot / deploy (push) Failing after 3s
Deploy BetterBot / notify (push) Successful in 3s

Replace standalone Telegram bot with full CodeAnywhere framework fork.
BetterBot shares all framework code and customizes only:
- instance.py: BetterBot identity, system prompt, feature flags
- tools/site_editing/: list_files, read_file, write_file with auto git push
- .env: model defaults and site directory paths
- compose/: Docker setup with betterlifesg + memoraiz mounts
- deploy script: RackNerd with Infisical secrets
This commit is contained in:
Andre Kamarudin 2026-04-19 08:01:27 +08:00
parent 8bd9ce3beb
commit e68c84424f
50 changed files with 16983 additions and 448 deletions

10
.env Normal file
View file

@ -0,0 +1,10 @@
# Non-secret configuration (committed to git).
# Secrets (TG_BOT_TOKEN, VERCEL_API_KEY, OWNER_TELEGRAM_CHAT_ID) come from Infisical at deploy time.
DEFAULT_MODEL=anthropic/claude-sonnet-4
OPENAI_BASE_URL=https://ai-gateway.vercel.sh/v1
DATA_DIR=/data
# BetterBot site paths (also set as env vars in compose)
SITE_DIR=/repo/betterlifesg/site
MEMORAIZ_DIR=/repo/memoraiz/frontend

View file

@ -0,0 +1,116 @@
name: Deploy BetterBot
on:
push:
branches: [master]
workflow_dispatch:
concurrency:
group: deploy-betterbot-${{ forgejo.ref }}
cancel-in-progress: false
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- name: Configure SSH
shell: bash
env:
VPS_SSH_KEY: ${{ secrets.VPS_SSH_KEY }}
run: |
set -euo pipefail
install -d -m 700 ~/.ssh
python3 - <<'PY'
import os
from pathlib import Path
key = os.environ["VPS_SSH_KEY"]
if "\\n" in key and "\n" not in key:
key = key.replace("\\n", "\n")
key = key.replace("\r\n", "\n").replace("\r", "\n").strip() + "\n"
Path.home().joinpath(".ssh", "id_ed25519").write_text(key, encoding="utf-8")
PY
chmod 600 ~/.ssh/id_ed25519
ssh-keygen -y -f ~/.ssh/id_ed25519 >/dev/null
ssh-keyscan -H 23.226.133.245 >> ~/.ssh/known_hosts
- name: Deploy on VPS
shell: bash
run: |
set -euo pipefail
ssh root@23.226.133.245 "cd /opt/src/betterbot && bash scripts/deploy-betterbot.sh"
notify:
if: ${{ always() }}
needs: [deploy]
runs-on: ubuntu-latest
steps:
- name: Send ntfy notification
shell: bash
env:
JOB_RESULTS: ${{ toJson(needs) }}
NTFY_TOKEN: ${{ secrets.NTFY_TOKEN }}
NTFY_TOPIC_URL: https://ntfy.bytesizeprotip.com/deploy
RUN_URL: ${{ forgejo.server_url }}/${{ forgejo.repository }}/actions/runs/${{ forgejo.run_number }}
run: |
set -euo pipefail
if [ -z "$NTFY_TOKEN" ]; then
echo "NTFY_TOKEN secret not configured, skipping notification."
exit 0
fi
eval "$(python3 - <<'PY'
import json
import os
import shlex
needs = json.loads(os.environ["JOB_RESULTS"])
results = {name: str(data.get("result") or "unknown") for name, data in needs.items()}
has_failure = any(result == "failure" for result in results.values())
has_cancelled = any(result == "cancelled" for result in results.values())
if has_failure:
status = "failure"
priority = "5"
tags = "rotating_light,x"
elif has_cancelled:
status = "cancelled"
priority = "4"
tags = "warning"
else:
status = "success"
priority = "2"
tags = "white_check_mark"
summary = ", ".join(f"{name}:{result}" for name, result in results.items()) or "no upstream jobs"
message = "\n".join([
f"Repo: {os.environ['FORGEJO_REPOSITORY']}",
f"Workflow: {os.environ['FORGEJO_WORKFLOW']}",
f"Status: {status}",
f"Ref: {os.environ.get('FORGEJO_REF_NAME', '')}",
f"Actor: {os.environ.get('FORGEJO_ACTOR', '')}",
f"Run: {os.environ['RUN_URL']}",
f"Jobs: {summary}",
])
values = {
"NTFY_TITLE": f"{os.environ['FORGEJO_WORKFLOW']} {status}",
"NTFY_PRIORITY": priority,
"NTFY_TAGS": tags,
"NTFY_MESSAGE": message,
}
for key, value in values.items():
print(f"{key}={shlex.quote(value)}")
PY
)"
curl --fail --show-error --silent \
-H "Authorization: Bearer $NTFY_TOKEN" \
-H "Title: $NTFY_TITLE" \
-H "Priority: $NTFY_PRIORITY" \
-H "Tags: $NTFY_TAGS" \
-H "Click: $RUN_URL" \
-d "$NTFY_MESSAGE" \
"$NTFY_TOPIC_URL"

3
.gitignore vendored
View file

@ -1,3 +1,6 @@
__pycache__/ __pycache__/
.venv/ .venv/
*.env.local *.env.local
data/
static/dist/
frontend/node_modules/

View file

@ -1,7 +1,25 @@
# ── BetterBot Dockerfile ──
# Fork of CodeAnywhere — Telegram-only, no web UI build stage needed.
FROM python:3.12-slim FROM python:3.12-slim
RUN apt-get update \
&& apt-get install -y --no-install-recommends git openssh-client \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app WORKDIR /app
RUN apt-get update && apt-get install -y --no-install-recommends git openssh-client && rm -rf /var/lib/apt/lists/*
# Site project mount points
RUN mkdir -p /site /memoraiz /data
COPY requirements.txt . COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt RUN pip install --no-cache-dir -r requirements.txt
COPY main.py .
CMD ["python", "main.py"] COPY . .
RUN python -m compileall .
RUN chmod +x /app/docker-entrypoint.sh
EXPOSE 3000
ENTRYPOINT ["/app/docker-entrypoint.sh"]
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "3000", "--log-level", "info"]

View file

@ -1,18 +1,43 @@
# BetterBot # BetterBot
A simplified Telegram bot that lets non-technical users edit the Better Life SG website by sending natural language instructions. A Telegram bot for editing the Better Life SG website and Memoraiz app frontend.
Fork of [CodeAnywhere](https://github.com/andrekamarudin/code_anywhere) — shares
the core Copilot SDK framework and only customises `instance.py` and
`tools/site_editing/`.
## How it works ## How it works
1. User sends a message to the Telegram bot (e.g. "Change the phone number to 91234567") 1. User sends a message to the Telegram bot (e.g. "Change the phone number to 91234567")
2. BetterBot uses an LLM (GPT-4.1) with file-editing tools to read and modify the website HTML 2. BetterBot uses the CodeAnywhere Copilot SDK framework with site-editing tools
3. Changes are written directly to the site files served by Caddy 3. Tools list / read / write files and automatically commit + push to git
## Projects
| Key | What it manages | Mount path |
|-----|----------------|------------|
| `betterlifesg` | Static HTML site (Tailwind CSS via CDN) | `/repo/betterlifesg/site` |
| `memoraiz` | React 19 + Vite 6 frontend | `/repo/memoraiz/frontend` |
## Stack ## Stack
- Python 3.12 + python-telegram-bot + OpenAI SDK - Python 3.12 + CodeAnywhere framework (Copilot SDK, python-telegram-bot)
- Reads/writes static HTML files mounted from `/opt/betterlifesg/site/` - Runs on RackNerd
- Runs on RackNerd at `betterbot.bytesizeprotip.com` - Forgejo repo: `andre/betterbot`
## Fork relationship
BetterBot shares all framework files with CodeAnywhere. The only
betterbot-specific files are:
| File | Purpose |
|------|---------|
| `instance.py` | Bot identity, system prompt, feature flags, tool registration |
| `tools/site_editing/` | list_files / read_file / write_file with auto git push |
| `.env` | Non-secret config (model defaults, site paths) |
| `compose/` | Docker compose for RackNerd with site-dir mounts |
| `scripts/deploy-betterbot.sh` | Deploy script targeting RackNerd + Infisical |
To sync upstream changes, copy updated framework files from `code_anywhere/`.
## Deployment ## Deployment
@ -24,3 +49,5 @@ ssh racknerd bash /opt/src/betterbot/scripts/deploy-betterbot.sh
- `/start` — Introduction and examples - `/start` — Introduction and examples
- `/reset` — Clear conversation history - `/reset` — Clear conversation history
- `/model <name>` — Switch LLM model
- `/current` — Show current model

147
background_tasks.py Normal file
View file

@ -0,0 +1,147 @@
"""Background task manager — spawns long-running Copilot SDK sessions outside the request cycle."""
import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Awaitable, Callable
from config import settings
from copilot_runtime import copilot, stream_session
from llm_costs import extract_usage_and_cost
from model_selection import ModelSelection, build_provider_config, resolve_selection
from ux import extract_final_text
logger = logging.getLogger(__name__)
BACKGROUND_TIMEOUT = 600 # 10 minutes
@dataclass
class BackgroundTask:
task_id: str
description: str
thread_id: str
started_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
result: str | None = None
error: str | None = None
usage: dict | None = None
_asyncio_task: asyncio.Task[None] | None = field(default=None, repr=False)
@property
def done(self) -> bool:
return self._asyncio_task is not None and self._asyncio_task.done()
@property
def elapsed_seconds(self) -> float:
return (datetime.now(timezone.utc) - self.started_at).total_seconds()
class BackgroundTaskManager:
"""Tracks background agent tasks per thread. One active task per thread."""
def __init__(self) -> None:
self._tasks: dict[str, BackgroundTask] = {}
self._thread_tasks: dict[str, list[str]] = {}
self._counter = 0
def _next_id(self) -> str:
self._counter += 1
return f"bg-{self._counter}"
def start(
self,
*,
thread_id: str,
description: str,
selection: ModelSelection,
system_message: str,
on_complete: Callable[["BackgroundTask"], Awaitable[None]],
) -> BackgroundTask:
task_id = self._next_id()
task = BackgroundTask(task_id=task_id, description=description, thread_id=thread_id)
# Background agents use a dedicated model/provider (not the parent session's provider)
bg_selection = resolve_selection(model=settings.BACKGROUND_MODEL)
bg_system = (
system_message.rstrip() + "\n\nYou are running as a background agent. "
"Complete the task fully and return your findings. Be thorough."
)
async def _run() -> None:
try:
async def _collect() -> str:
events: list = []
async for ev in stream_session(
copilot,
model=bg_selection.model,
provider_config=build_provider_config(bg_selection),
system_message=bg_system,
prompt=description,
):
events.append(ev)
# Extract usage before final text
task.usage = extract_usage_and_cost(bg_selection.model, bg_selection.provider, events)
return extract_final_text(events) or "Task completed but produced no output."
task.result = await asyncio.wait_for(_collect(), timeout=BACKGROUND_TIMEOUT)
except asyncio.TimeoutError:
task.error = f"Timed out after {BACKGROUND_TIMEOUT}s"
except Exception as exc:
task.error = str(exc)
logger.exception("Background task %s failed", task_id)
finally:
try:
await on_complete(task)
except Exception:
logger.exception("Callback failed for background task %s", task_id)
task._asyncio_task = asyncio.create_task(_run())
self._tasks[task_id] = task
self._thread_tasks.setdefault(thread_id, []).append(task_id)
return task
def get_active(self, thread_id: str) -> BackgroundTask | None:
for task_id in reversed(self._thread_tasks.get(thread_id, [])):
task = self._tasks.get(task_id)
if task and not task.done:
return task
return None
def get_latest(self, thread_id: str) -> BackgroundTask | None:
ids = self._thread_tasks.get(thread_id, [])
if not ids:
return None
return self._tasks.get(ids[-1])
def context_summary(self, thread_id: str) -> str | None:
"""Return background-agent context to inject into the system message, or None."""
active = self.get_active(thread_id)
if active:
return (
f"A background agent is currently running ({active.elapsed_seconds:.0f}s elapsed).\n"
f"Task: {active.description}\n"
"Its results will be posted to the chat when done."
)
latest = self.get_latest(thread_id)
if latest is None:
return None
if latest.error:
return f"The last background agent failed: {latest.error}"
if latest.result:
snippet = latest.result[:2000] + ("..." if len(latest.result) > 2000 else "")
return f"The last background agent completed.\nTask: {latest.description}\nResult:\n{snippet}"
return None
def format_status(self, thread_id: str) -> str:
active = self.get_active(thread_id)
if active:
return f"A background agent is running ({active.elapsed_seconds:.0f}s elapsed).\nTask: {active.description}"
latest = self.get_latest(thread_id)
if latest is None:
return "No background agent has run in this thread."
if latest.error:
return f"Last background agent failed: {latest.error}"
return f"Last background agent completed.\nTask: {latest.description}"

View file

@ -1,5 +1,4 @@
# Secrets from Infisical — used as .env in the compose stack dir
TG_BOT_TOKEN=CHANGE_ME TG_BOT_TOKEN=CHANGE_ME
VERCEL_AI_GATEWAY_KEY=CHANGE_ME VERCEL_API_KEY=CHANGE_ME
OPENAI_BASE_URL=https://ai-gateway.vercel.sh/v1 OWNER_TELEGRAM_CHAT_ID=CHANGE_ME
MODEL=anthropic/claude-sonnet-4
ALLOWED_USERS=876499264,417471802

View file

@ -3,14 +3,24 @@ services:
build: /opt/src/betterbot build: /opt/src/betterbot
container_name: betterbot container_name: betterbot
restart: unless-stopped restart: unless-stopped
env_file:
- defaults.env
- .env
volumes: volumes:
- /opt/src/betterlifesg:/repo/betterlifesg:rw - /opt/src/betterlifesg:/repo/betterlifesg:rw
- /opt/src/hk_memoraiz:/repo/memoraiz:rw - /opt/src/hk_memoraiz:/repo/memoraiz:rw
- /root/.ssh:/root/.ssh:ro - /root/.ssh:/root/.ssh:ro
env_file: - betterbot-data:/data
- .env
environment: environment:
- TZ=${TZ:-Asia/Singapore} - TZ=${TZ:-Asia/Singapore}
- SITE_DIR=/repo/betterlifesg/site - SITE_DIR=/repo/betterlifesg/site
- MEMORAIZ_DIR=/repo/memoraiz/frontend - MEMORAIZ_DIR=/repo/memoraiz/frontend
- GIT_SSH_COMMAND=ssh -o StrictHostKeyChecking=no - GIT_SSH_COMMAND=ssh -o StrictHostKeyChecking=no
- GIT_AUTHOR_NAME=BetterBot
- GIT_AUTHOR_EMAIL=betterbot@bytesizeprotip.com
- GIT_COMMITTER_NAME=BetterBot
- GIT_COMMITTER_EMAIL=betterbot@bytesizeprotip.com
volumes:
betterbot-data:
name: betterbot-data

66
config.py Normal file
View file

@ -0,0 +1,66 @@
import os
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
# LLM
DEFAULT_MODEL: str = "gpt-5.4"
BACKGROUND_MODEL: str = "vercel:google/gemma-4-31b-it"
OPENAI_API_KEY: str = ""
OPENAI_BASE_URL: str = "https://api.openai.com/v1"
OPENROUTER_API_KEY: str = ""
OPENROUTER_BASE_URL: str = "https://openrouter.ai/api/v1"
VERCEL_API_KEY: str = ""
VERCEL_BASE_URL: str = "https://ai-gateway.vercel.sh/v1"
HUGGINGFACE_API_KEY: str = ""
HUGGINGFACE_BASE_URL: str = "https://router.huggingface.co/hf-inference/v1"
GITHUB_TOKEN: str = ""
# Telegram
TG_BOT_TOKEN: str = ""
OWNER_TELEGRAM_CHAT_ID: str = "" # Telegram user ID of the bot owner for approval requests
ELEVENLABS_API_KEY: str = ""
ELEVENLABS_VOICE_ID: str = ""
ELEVENLABS_MODEL: str = "eleven_multilingual_v2"
# Auth
AUTH_TOKEN: str = ""
# Integrations
VIKUNJA_API_URL: str = ""
VIKUNJA_API_KEY: str = ""
VIKUNJA_MEMORY_PATH: str = ""
KARAKEEP_API_URL: str = ""
KARAKEEP_API_KEY: str = ""
# User identity & provisioning
CREDENTIAL_VAULT_KEY: str = ""
VIKUNJA_ADMIN_API_KEY: str = ""
KARAKEEP_ADMIN_API_KEY: str = ""
ALLOW_CREDENTIAL_REVEAL_IN_CHAT: bool = False
EVENT_POLL_INTERVAL_SECONDS: int = 300
# Advisor
ADVISOR_ENABLED: bool = False
ADVISOR_DEFAULT_MODEL: str = "claude-opus-4.6"
ADVISOR_MAX_USES: int = 3
ADVISOR_MAX_TOKENS: int = 700
# Paths
REPOS_DIR: str = "/repos"
DATA_DIR: str = "/data"
TG_PERSISTENCE_DIR: str = ""
# BetterBot — site directories
SITE_DIR: str = "/site"
MEMORAIZ_DIR: str = "/memoraiz"
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
settings = Settings()
if not settings.VERCEL_API_KEY:
settings.VERCEL_API_KEY = os.getenv("AI_GATEWAY_API_KEY", "")
if not settings.HUGGINGFACE_API_KEY:
settings.HUGGINGFACE_API_KEY = os.getenv("HF_ACCESS_TOKEN", "")

333
copilot_runtime.py Normal file
View file

@ -0,0 +1,333 @@
"""Copilot SDK runtime — manages CopilotClient lifecycle and session creation."""
import asyncio
import logging
import random
from typing import Any, AsyncIterator
from copilot import CopilotClient, SubprocessConfig
from copilot.generated.session_events import SessionEventType
from copilot.session import CopilotSession, PermissionHandler, SessionEvent
from copilot.tools import Tool
from config import settings
from instance import skill_directories as _default_skill_directories
logger = logging.getLogger(__name__)
# ── Retry configuration ─────────────────────────────────────────────
MAX_SESSION_RETRIES = 3
RETRY_BASE_DELAY = 3.0 # seconds
RETRY_MAX_DELAY = 30.0 # seconds
_RETRYABLE_PATTERNS = (
"failed to get response",
"operation was aborted",
"timed out",
"timeout",
"502",
"503",
"504",
"service unavailable",
"overloaded",
)
# Rate-limit errors: the SDK already retries 5 times internally, so our
# outer retry loop should NOT retry these (it would just multiply the wait).
_RATE_LIMIT_PATTERNS = ("429", "rate limit", "rate_limit", "too many requests")
def _is_rate_limited(error_msg: str) -> bool:
"""Check if an error is a rate-limit (429) from the provider."""
lower = error_msg.lower()
return any(p in lower for p in _RATE_LIMIT_PATTERNS)
def _is_retryable(error_msg: str) -> bool:
"""Check if an error message indicates a transient failure worth retrying.
Rate-limit errors are excluded the SDK already burned through 5 retries
internally, so adding more at our level just wastes time.
"""
if _is_rate_limited(error_msg):
return False
lower = error_msg.lower()
return any(p in lower for p in _RETRYABLE_PATTERNS)
def _backoff_delay(attempt: int) -> float:
"""Exponential backoff with jitter."""
delay = min(RETRY_BASE_DELAY * (2**attempt), RETRY_MAX_DELAY)
return delay * (0.5 + random.random() * 0.5)
TERMINAL_EVENT_TYPES = frozenset(
{
SessionEventType.SESSION_IDLE,
SessionEventType.SESSION_ERROR,
SessionEventType.SESSION_SHUTDOWN,
}
)
# Use PermissionHandler.approve_all from the SDK (single-user trusted environment)
def format_prompt_with_history(
history: list[dict[str, Any]],
new_message: str,
) -> str:
"""Format conversation history + new message into a single prompt string.
The history list typically includes the latest user message as the last item
(just added to the store). We split context (all but last) from the actual prompt (last).
"""
if not history:
return new_message
# All but last message as context; last is the new user message
context_messages = history[:-1]
if not context_messages:
return new_message
lines: list[str] = []
for msg in context_messages:
role = msg.get("role", "assistant")
content = msg.get("content", "")
if isinstance(content, list):
text_parts = [p.get("text", "") for p in content if isinstance(p, dict) and p.get("type") == "input_text"]
image_count = sum(1 for p in content if isinstance(p, dict) and p.get("type") == "input_image")
content = " ".join(t for t in text_parts if t)
if image_count:
content += f" [+{image_count} image(s)]" if content else f"[{image_count} image(s)]"
if not content:
continue
lines.append(f"[{role}]: {content}")
if not lines:
return new_message
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>\n\n" + new_message
class CopilotRuntime:
"""Manages a long-lived CopilotClient and creates per-request sessions."""
def __init__(self) -> None:
self._client: CopilotClient | None = None
async def start(self) -> None:
github_token = settings.GITHUB_TOKEN or "not-needed"
self._client = CopilotClient(
SubprocessConfig(
cwd=settings.REPOS_DIR,
github_token=github_token,
use_logged_in_user=False,
)
)
await self._client.start()
logger.info(
"Copilot SDK client started (cwd=%s, copilot_auth=%s)",
settings.REPOS_DIR,
"yes" if settings.GITHUB_TOKEN else "no",
)
async def stop(self) -> None:
if self._client:
try:
await self._client.stop()
except BaseException:
logger.warning("Copilot SDK client stop errors", exc_info=True)
self._client = None
logger.info("Copilot SDK client stopped")
@property
def client(self) -> CopilotClient:
if self._client is None:
raise RuntimeError("CopilotRuntime not started — call start() first")
return self._client
async def create_session(
self,
*,
model: str,
provider_config: dict[str, Any] | None,
system_message: str,
tools: list[Tool] | None = None,
streaming: bool = True,
) -> CopilotSession:
kwargs: dict[str, Any] = {
"on_permission_request": PermissionHandler.approve_all,
"model": model,
"streaming": streaming,
"working_directory": settings.REPOS_DIR,
"system_message": {"mode": "replace", "content": system_message},
"tools": tools or None,
"excluded_tools": ["task"],
"skill_directories": _default_skill_directories(),
}
if provider_config is not None:
kwargs["provider"] = provider_config
return await self.client.create_session(**kwargs)
async def stream_session(
rt: CopilotRuntime,
*,
model: str,
provider_config: dict[str, Any] | None,
system_message: str,
prompt: str,
tools: list[Tool] | None = None,
attachments: list[dict[str, Any]] | None = None,
thread_id: str | None = None,
) -> AsyncIterator[SessionEvent]:
"""Create a session, send a prompt, yield events until idle, then destroy.
Retries transparently on transient errors (model timeouts, 5xx, rate limits)
with exponential backoff + jitter.
"""
# Reset advisor counter at the start of each run
if thread_id is not None:
try:
from tools.advisor import _reset_advisor_state
_reset_advisor_state(thread_id)
except ImportError:
pass
last_error_msg = ""
for attempt in range(MAX_SESSION_RETRIES):
session = await rt.create_session(
model=model,
provider_config=provider_config,
system_message=system_message,
tools=tools,
)
queue: asyncio.Queue[SessionEvent | None] = asyncio.Queue()
def on_event(event: SessionEvent) -> None:
queue.put_nowait(event)
if event.type in TERMINAL_EVENT_TYPES:
queue.put_nowait(None) # sentinel
unsub = session.on(on_event)
hit_retryable_error = False
try:
await session.send(prompt, attachments=attachments)
while True:
item = await queue.get()
if item is None:
break
# Intercept SESSION_ERROR for retry
if item.type == SessionEventType.SESSION_ERROR:
error_msg = (item.data and item.data.message) or "Unknown session error"
last_error_msg = error_msg
retries_left = MAX_SESSION_RETRIES - attempt - 1
if _is_retryable(error_msg) and retries_left > 0:
delay = _backoff_delay(attempt)
logger.warning(
"Retryable session error (attempt %d/%d, next in %.1fs): %s",
attempt + 1,
MAX_SESSION_RETRIES,
delay,
error_msg,
)
hit_retryable_error = True
break
# Not retryable or out of retries — yield the error to caller
if retries_left == 0 and _is_retryable(error_msg):
logger.error(
"All %d session retries exhausted: %s",
MAX_SESSION_RETRIES,
error_msg,
)
yield item
return
yield item
finally:
unsub()
await session.destroy()
if hit_retryable_error:
await asyncio.sleep(_backoff_delay(attempt))
continue
# Completed normally
return
# Should not reach here, but safety net
logger.error("stream_session fell through retry loop; last error: %s", last_error_msg)
async def run_session(
rt: CopilotRuntime,
*,
model: str,
provider_config: dict[str, Any] | None,
system_message: str,
prompt: str,
tools: list[Tool] | None = None,
attachments: list[dict[str, Any]] | None = None,
thread_id: str | None = None,
) -> Any:
"""Create a session, send a prompt, wait for completion, return raw result.
Retries transparently on transient errors with exponential backoff + jitter.
"""
# Reset advisor counter at the start of each run
if thread_id is not None:
try:
from tools.advisor import _reset_advisor_state
_reset_advisor_state(thread_id)
except ImportError:
pass
last_exc: Exception | None = None
for attempt in range(MAX_SESSION_RETRIES):
session = await rt.create_session(
model=model,
provider_config=provider_config,
system_message=system_message,
tools=tools,
)
try:
result = await session.send_and_wait(prompt, attachments=attachments, timeout=300)
if result and result.data and result.data.content:
return result
# Fallback: check message history
messages = await session.get_messages()
for msg in reversed(messages):
if msg.type == SessionEventType.ASSISTANT_MESSAGE and msg.data and msg.data.content:
return msg
return result
except Exception as exc:
last_exc = exc
retries_left = MAX_SESSION_RETRIES - attempt - 1
if _is_retryable(str(exc)) and retries_left > 0:
delay = _backoff_delay(attempt)
logger.warning(
"Retryable run_session error (attempt %d/%d, next in %.1fs): %s",
attempt + 1,
MAX_SESSION_RETRIES,
delay,
exc,
)
await asyncio.sleep(delay)
continue
raise
finally:
await session.destroy()
# All retries exhausted — raise the last exception
if last_exc is not None:
raise last_exc
# Module-level singleton
copilot = CopilotRuntime()

14
docker-entrypoint.sh Normal file
View file

@ -0,0 +1,14 @@
#!/usr/bin/env sh
set -eu
if [ -f /host-git/.gitconfig ]; then
cp /host-git/.gitconfig /root/.gitconfig
chmod 644 /root/.gitconfig
fi
if [ -f /host-git/.git-credentials ]; then
cp /host-git/.git-credentials /root/.git-credentials
chmod 600 /root/.git-credentials
fi
exec "$@"

2
frontend/.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
node_modules/
dist/

15
frontend/index.html Normal file
View file

@ -0,0 +1,15 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>CodeAnywhere</title>
</head>
<body>
<div id="root"></div>
<script type="module" src="/src/main.tsx"></script>
</body>
</html>

3449
frontend/package-lock.json generated Normal file

File diff suppressed because it is too large Load diff

25
frontend/package.json Normal file
View file

@ -0,0 +1,25 @@
{
"name": "code-anywhere-frontend",
"private": true,
"type": "module",
"scripts": {
"dev": "vite",
"build": "tsc && vite build",
"preview": "vite preview"
},
"dependencies": {
"react": "^19.1.0",
"react-dom": "^19.1.0",
"@ai-sdk/react": "latest",
"ai": "latest",
"react-markdown": "^10.1.0",
"remark-gfm": "^4.0.0"
},
"devDependencies": {
"vite": "^6.3.0",
"@vitejs/plugin-react": "^4.5.0",
"typescript": "^5.8.0",
"@types/react": "^19.1.0",
"@types/react-dom": "^19.1.0"
}
}

344
frontend/src/App.tsx Normal file
View file

@ -0,0 +1,344 @@
import { useChat } from '@ai-sdk/react';
import type { UIMessage } from 'ai';
import { DefaultChatTransport } from 'ai';
import { useCallback, useEffect, useMemo, useRef, useState, type CSSProperties } from 'react';
import { Chat } from './Chat';
import { Sidebar } from './Sidebar';
import * as api from './api';
import type { ServerThread, UIConfig } from './types';
export function App() {
const [config, setConfig] = useState<UIConfig | null>(null);
const [threads, setThreads] = useState<ServerThread[]>([]);
const [activeThreadId, setActiveThreadId] = useState<string | null>(null);
const [activeThread, setActiveThread] = useState<ServerThread | null>(null);
const [toast, setToast] = useState<{ text: string; type?: 'error' } | null>(null);
const [mobileSidebar, setMobileSidebar] = useState(false);
const [searchQuery, setSearchQuery] = useState('');
const [sidebarWidth, setSidebarWidth] = useState(() => {
if (typeof window === 'undefined') return 260;
const raw = window.localStorage.getItem('codeanywhere.sidebarWidth');
const parsed = Number(raw || 260);
return Number.isFinite(parsed) ? Math.min(460, Math.max(220, parsed)) : 260;
});
const threadIdRef = useRef<string | null>(null);
const sidebarWidthRef = useRef(sidebarWidth);
threadIdRef.current = activeThreadId;
sidebarWidthRef.current = sidebarWidth;
const transport = useMemo(
() =>
new DefaultChatTransport({
api: '/api/chat',
credentials: 'same-origin',
prepareSendMessagesRequest: ({ messages }) => {
const last = messages[messages.length - 1];
const textContent = last?.parts
?.filter((p): p is { type: 'text'; text: string } => p.type === 'text')
.map((p) => p.text)
.join('\n') || '';
return {
body: {
threadId: threadIdRef.current,
content: textContent,
},
};
},
}),
[],
);
const {
messages,
sendMessage,
setMessages,
status,
error: chatError,
stop,
} = useChat({ transport });
const flash = useCallback((text: string, type?: 'error') => {
setToast({ text, type });
setTimeout(() => setToast(null), 4000);
}, []);
useEffect(() => {
(async () => {
try {
const [cfg, threadList] = await Promise.all([api.loadConfig(), api.listThreads()]);
setConfig(cfg);
setThreads(threadList);
if (threadList.length > 0) {
const first = threadList[0];
setActiveThreadId(first.id);
const full = await api.getThread(first.id);
setActiveThread(full);
setMessages(toUIMessages(full.messages || []));
}
} catch {
// 401 handled by api.ts redirect
}
})();
}, [setMessages]);
const refreshThreads = useCallback(
async (selectId?: string) => {
const list = await api.listThreads(searchQuery || undefined);
setThreads(list);
if (selectId && list.some((thread) => thread.id === selectId)) {
setActiveThreadId(selectId);
}
},
[searchQuery],
);
useEffect(() => {
const timer = setTimeout(() => {
refreshThreads(activeThreadId || undefined).catch(() => { });
}, 250);
return () => clearTimeout(timer);
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [searchQuery]);
const switchThread = useCallback(
async (id: string) => {
setActiveThreadId(id);
setMobileSidebar(false);
try {
const full = await api.getThread(id);
setActiveThread(full);
setMessages(toUIMessages(full.messages || []));
} catch (e: unknown) {
flash(e instanceof Error ? e.message : 'Failed to load thread', 'error');
}
},
[flash, setMessages],
);
const createNewThread = useCallback(
async (selection?: { model?: string; provider?: string }) => {
try {
const thread = await api.createThread({
model: selection?.model || config?.defaultModel,
provider: selection?.provider || config?.defaultProvider,
});
setActiveThreadId(thread.id);
setActiveThread(thread);
setMessages([]);
await refreshThreads(thread.id);
setMobileSidebar(false);
} catch (e: unknown) {
flash(e instanceof Error ? e.message : 'Failed to create thread', 'error');
}
},
[config, flash, refreshThreads, setMessages],
);
const handleSend = useCallback(
async (text: string) => {
let threadId = activeThreadId;
if (!threadId) {
const firstLine = text.split('\n', 1)[0].slice(0, 60).trim() || 'New chat';
const thread = await api.createThread({
title: firstLine,
model: config?.defaultModel,
provider: config?.defaultProvider,
});
threadId = thread.id;
setActiveThreadId(threadId);
setActiveThread(thread);
threadIdRef.current = threadId;
await refreshThreads(threadId);
}
sendMessage({ text });
},
[activeThreadId, config, refreshThreads, sendMessage],
);
useEffect(() => {
if (status === 'ready' && activeThreadId && messages.length > 0) {
refreshThreads(activeThreadId).catch(() => { });
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [status]);
const handleDeleteThread = useCallback(
async (id: string) => {
if (!confirm('Delete this thread?')) return;
try {
await api.deleteThread(id);
if (activeThreadId === id) {
setActiveThreadId(null);
setActiveThread(null);
setMessages([]);
}
await refreshThreads();
} catch (e: unknown) {
flash(e instanceof Error ? e.message : 'Delete failed', 'error');
}
},
[activeThreadId, flash, refreshThreads, setMessages],
);
const handleRename = useCallback(
async (id: string) => {
const name = prompt('New name:');
if (!name) return;
try {
const updated = await api.updateThread(id, { title: name });
if (activeThreadId === id) setActiveThread(updated);
await refreshThreads(id);
} catch (e: unknown) {
flash(e instanceof Error ? e.message : 'Rename failed', 'error');
}
},
[activeThreadId, flash, refreshThreads],
);
const handleMagicRename = useCallback(
async (id: string) => {
try {
flash('Summarizing with AI...');
const updated = await api.magicRename(id);
if (activeThreadId === id) setActiveThread(updated);
await refreshThreads(id);
flash('Thread renamed');
} catch (e: unknown) {
flash(e instanceof Error ? e.message : 'Magic rename failed', 'error');
}
},
[activeThreadId, flash, refreshThreads],
);
const handleMagicRenameAll = useCallback(async () => {
try {
flash('Summarizing eligible threads with gpt-5.4-nano...');
const result = await api.magicRenameAll();
await refreshThreads(activeThreadId || undefined);
if (activeThreadId) {
const full = await api.getThread(activeThreadId);
setActiveThread(full);
}
if (result.renamedCount > 0) {
flash(
result.skippedCount > 0
? `Renamed ${result.renamedCount} threads. Skipped ${result.skippedCount}.`
: `Renamed ${result.renamedCount} threads.`,
);
} else {
flash('No eligible threads were renamed.');
}
} catch (e: unknown) {
flash(e instanceof Error ? e.message : 'Magic rename all failed', 'error');
}
}, [activeThreadId, flash, refreshThreads]);
const handleSaveModel = useCallback(
async (provider: string, model: string) => {
if (!activeThreadId) return;
try {
const updated = await api.updateThread(activeThreadId, { provider, model });
setActiveThread(updated);
await refreshThreads(activeThreadId);
flash(`Routing set to ${updated.provider}:${updated.model}`);
} catch (e: unknown) {
flash(e instanceof Error ? e.message : 'Routing save failed', 'error');
}
},
[activeThreadId, flash, refreshThreads],
);
const handleSidebarResizeStart = useCallback((event: React.MouseEvent<HTMLDivElement>) => {
if (window.matchMedia('(max-width: 768px)').matches) return;
event.preventDefault();
const startX = event.clientX;
const startWidth = sidebarWidthRef.current;
let nextWidth = startWidth;
document.body.classList.add('sidebar-resizing');
const onMove = (moveEvent: MouseEvent) => {
nextWidth = Math.min(460, Math.max(220, startWidth + moveEvent.clientX - startX));
setSidebarWidth(nextWidth);
};
const onUp = () => {
document.body.classList.remove('sidebar-resizing');
window.removeEventListener('mousemove', onMove);
window.removeEventListener('mouseup', onUp);
window.localStorage.setItem('codeanywhere.sidebarWidth', String(nextWidth));
};
window.addEventListener('mousemove', onMove);
window.addEventListener('mouseup', onUp);
}, []);
const shellStyle = useMemo(
() => ({ '--sidebar-width': `${sidebarWidth}px` }) as CSSProperties,
[sidebarWidth],
);
if (!config) {
return (
<div style={{ display: 'grid', placeItems: 'center', height: '100vh', color: 'var(--muted)' }}>
Loading CodeAnywhere...
</div>
);
}
return (
<div className="app-shell" style={shellStyle}>
{mobileSidebar && <div className="mobile-backdrop" onClick={() => setMobileSidebar(false)} />}
<Sidebar
threads={threads}
activeThreadId={activeThreadId}
defaultModel={config.defaultModel}
defaultProvider={config.defaultProvider}
providerOptions={config.providerOptions}
curatedModels={config.curatedModels}
mobileOpen={mobileSidebar}
currentModel={activeThread?.model}
currentProvider={activeThread?.provider}
searchQuery={searchQuery}
onSearchChange={setSearchQuery}
onSelectThread={switchThread}
onNewThread={createNewThread}
onSaveModel={handleSaveModel}
onMagicRenameAll={handleMagicRenameAll}
/>
<div
className="sidebar-resizer"
onMouseDown={handleSidebarResizeStart}
role="separator"
aria-orientation="vertical"
aria-label="Resize sidebar"
/>
<Chat
thread={activeThread}
messages={messages}
status={status}
error={chatError}
onSend={handleSend}
onStop={stop}
onDeleteThread={activeThreadId ? () => handleDeleteThread(activeThreadId) : undefined}
onRename={activeThreadId ? () => handleRename(activeThreadId) : undefined}
onMagicRename={activeThreadId ? () => handleMagicRename(activeThreadId) : undefined}
onToggleSidebar={() => setMobileSidebar((value) => !value)}
/>
{toast && <div className={`toast ${toast.type || ''}`}>{toast.text}</div>}
</div>
);
}
function toUIMessages(serverMessages: { id: string; role: string; content: string; created_at: string; parts?: { type: string;[k: string]: unknown }[] }[]): UIMessage[] {
return serverMessages.map((m) => ({
id: m.id,
role: m.role as 'user' | 'assistant',
parts: [{ type: 'text' as const, text: m.content || '' }],
createdAt: new Date(m.created_at),
}));
}

286
frontend/src/Chat.tsx Normal file
View file

@ -0,0 +1,286 @@
import type { UIMessage } from 'ai';
import { useCallback, useEffect, useRef, useState } from 'react';
import ReactMarkdown from 'react-markdown';
import remarkGfm from 'remark-gfm';
import type { ServerThread } from './types';
interface ChatProps {
thread: ServerThread | null;
messages: UIMessage[];
status: string;
error: Error | undefined;
onSend: (text: string) => void;
onStop: () => void;
onDeleteThread?: () => void;
onRename?: () => void;
onMagicRename?: () => void;
onToggleSidebar: () => void;
}
export function Chat({
thread,
messages,
status,
error,
onSend,
onStop,
onDeleteThread,
onRename,
onMagicRename,
onToggleSidebar,
}: ChatProps) {
const [input, setInput] = useState('');
const [busyElapsed, setBusyElapsed] = useState('0s');
const scrollRef = useRef<HTMLDivElement>(null);
const textareaRef = useRef<HTMLTextAreaElement>(null);
// Auto-scroll on new content
useEffect(() => {
const el = scrollRef.current;
if (el) el.scrollTop = el.scrollHeight;
}, [messages, status]);
const handleSubmit = useCallback(
(e?: React.FormEvent) => {
e?.preventDefault();
const text = input.trim();
if (!text || status === 'submitted' || status === 'streaming') return;
onSend(text);
setInput('');
},
[input, status, onSend],
);
const handleKeyDown = useCallback(
(e: React.KeyboardEvent) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
handleSubmit();
}
},
[handleSubmit],
);
const busy = status === 'submitted' || status === 'streaming';
useEffect(() => {
if (!busy) {
setBusyElapsed('0s');
return;
}
const startedAt = Date.now();
const tick = () => {
setBusyElapsed(formatElapsedDuration((Date.now() - startedAt) / 1000));
};
tick();
const timer = window.setInterval(tick, 1000);
return () => window.clearInterval(timer);
}, [busy]);
return (
<section className="chat-area">
<div className="chat-header">
<div className="chat-header-left">
<button className="btn-ghost btn-sm mobile-toggle" onClick={onToggleSidebar}>
Threads
</button>
<div className="eyebrow">Workspace chat</div>
<h2>{thread?.title || 'Start a chat'}</h2>
{thread?.usage && (
<div className="chat-usage">
{thread.usage.total_tokens.toLocaleString()} tokens · ${Number(thread.usage.cost_usd || '0').toFixed(6)}
</div>
)}
</div>
{thread && (
<div className="chat-actions">
{onMagicRename && (
<button className="btn-ghost btn-sm" onClick={onMagicRename} disabled={busy}>
Magic rename
</button>
)}
{onRename && (
<button className="btn-ghost btn-sm" onClick={onRename} disabled={busy}>
Rename
</button>
)}
{onDeleteThread && (
<button className="btn-ghost btn-sm" onClick={onDeleteThread} disabled={busy}>
Delete
</button>
)}
</div>
)}
</div>
<div className="messages-scroll" ref={scrollRef}>
{messages.length === 0 && !busy && (
<div className="empty-state">
{thread
? 'This thread is empty. Type your first prompt below.'
: 'Select a thread or create a new chat to get started.'}
</div>
)}
{messages.map((msg) => (
<MessageBubble key={msg.id} message={msg} />
))}
{busy && (
<div className="status-indicator">
{status === 'submitted' ? 'Sending...' : 'Streaming...'} · {busyElapsed}
<button
className="btn-ghost btn-sm"
style={{ marginLeft: 12 }}
onClick={onStop}
>
Stop
</button>
</div>
)}
{error && (
<div className="message assistant" style={{ borderColor: 'rgba(255,157,181,0.35)' }}>
<div className="message-label">Error</div>
<div className="message-body" style={{ color: 'var(--warning-text)' }}>
{error.message || 'An error occurred.'}
</div>
</div>
)}
</div>
<form className="composer" onSubmit={handleSubmit}>
<div className="composer-input-row">
<textarea
ref={textareaRef}
value={input}
onChange={(e) => setInput(e.target.value)}
onKeyDown={handleKeyDown}
placeholder={thread ? 'Ask anything...' : 'Type your first prompt to create a new chat'}
disabled={busy}
/>
<button type="submit" disabled={busy || !input.trim()}>
Send
</button>
</div>
<div className="composer-actions">
<span>Shift+Enter for newline · Enter to send</span>
</div>
</form>
</section>
);
}
function MessageBubble({ message }: { message: UIMessage }) {
const isUser = message.role === 'user';
return (
<article className={`message ${message.role}`}>
<div className="message-label">
{isUser ? 'You' : 'CodeAnywhere'}
{'createdAt' in message && message.createdAt ? (
<span> · {formatTime(message.createdAt as Date)}</span>
) : null}
</div>
<div className="message-body">
{message.parts.map((part, i) => {
if (part.type === 'text') {
if (isUser) {
return <div key={i} style={{ whiteSpace: 'pre-wrap' }}>{part.text}</div>;
}
return (
<div key={i} className="markdown-body">
<ReactMarkdown remarkPlugins={[remarkGfm]}>{part.text}</ReactMarkdown>
</div>
);
}
if (part.type === 'reasoning') {
return (
<details key={i} className="reasoning-block">
<summary>Reasoning</summary>
<div className="reasoning-body">{part.text}</div>
</details>
);
}
if (part.type?.startsWith('tool-')) {
const toolPart = part as { type: string; toolCallId?: string; toolName?: string; state?: string; input?: unknown; output?: unknown };
const toolName = ('toolName' in toolPart ? toolPart.toolName : toolPart.type) || 'tool';
const state = toolPart.state || '';
return (
<details key={i} className="tool-invocation">
<summary>
<span className="tool-invocation-label">{toolName}</span>
<span className="tool-invocation-state">
{state === 'result' ? 'done' : state}
</span>
</summary>
<div className="tool-invocation-body">
{toolPart.input != null && (
<>
<strong>Input:</strong>
{'\n'}
{typeof toolPart.input === 'string' ? toolPart.input : JSON.stringify(toolPart.input, null, 2)}
</>
)}
{state === 'result' && toolPart.output != null && (
<>
{'\n\n'}
<strong>Output:</strong>
{'\n'}
{typeof toolPart.output === 'string'
? toolPart.output.length > 2000
? toolPart.output.slice(0, 2000) + '...'
: toolPart.output
: JSON.stringify(toolPart.output, null, 2)?.slice(0, 2000)}
</>
)}
</div>
</details>
);
}
if (part.type === 'file' && 'mediaType' in part && typeof part.mediaType === 'string' && part.mediaType.startsWith('image/')) {
return (
<img
key={i}
src={'url' in part ? String(part.url) : ''}
alt="Generated"
style={{ maxWidth: '100%', borderRadius: 12, marginTop: 8 }}
/>
);
}
return null;
})}
</div>
</article>
);
}
function formatTime(date: Date | string): string {
const d = typeof date === 'string' ? new Date(date) : date;
if (isNaN(d.getTime())) return '';
return new Intl.DateTimeFormat(undefined, {
hour: 'numeric',
minute: '2-digit',
}).format(d);
}
function formatElapsedDuration(totalSeconds: number): string {
const wholeSeconds = Math.max(0, Math.floor(totalSeconds));
const hours = Math.floor(wholeSeconds / 3600);
const minutes = Math.floor((wholeSeconds % 3600) / 60);
const seconds = wholeSeconds % 60;
if (hours > 0) {
return `${hours}h ${String(minutes).padStart(2, '0')}m ${String(seconds).padStart(2, '0')}s`;
}
if (minutes > 0) {
return `${minutes}m ${String(seconds).padStart(2, '0')}s`;
}
return `${seconds}s`;
}

254
frontend/src/Sidebar.tsx Normal file
View file

@ -0,0 +1,254 @@
import { useCallback, useEffect, useMemo, useState } from 'react';
import type { CuratedModelOption, ProviderOption, ServerThread } from './types';
type SortMode = 'created' | 'updated';
interface SidebarProps {
threads: ServerThread[];
activeThreadId: string | null;
defaultModel: string;
defaultProvider: string;
providerOptions: ProviderOption[];
curatedModels: Record<string, CuratedModelOption[]>;
mobileOpen: boolean;
currentModel?: string;
currentProvider?: string;
searchQuery: string;
onSearchChange: (query: string) => void;
onSelectThread: (id: string) => void;
onNewThread: (selection?: { model?: string; provider?: string }) => void;
onSaveModel: (provider: string, model: string) => void;
onMagicRenameAll: () => void;
}
export function Sidebar({
threads,
activeThreadId,
defaultModel,
defaultProvider,
providerOptions,
curatedModels,
mobileOpen,
currentModel,
currentProvider,
searchQuery,
onSearchChange,
onSelectThread,
onNewThread,
onSaveModel,
onMagicRenameAll,
}: SidebarProps) {
const [providerInput, setProviderInput] = useState(currentProvider || defaultProvider);
const [modelInput, setModelInput] = useState(currentModel || defaultModel);
const [sortMode, setSortMode] = useState<SortMode>('created');
useEffect(() => {
setProviderInput(currentProvider || defaultProvider);
}, [currentProvider, defaultProvider]);
useEffect(() => {
setModelInput(currentModel || defaultModel);
}, [currentModel, defaultModel]);
const providerLookup = useMemo(
() => new Map(providerOptions.map((option) => [option.id, option])),
[providerOptions],
);
const selectedProvider = providerInput || defaultProvider;
const selectedProviderOption = providerLookup.get(selectedProvider);
const suggestedModels = curatedModels[selectedProvider] || [];
const hasExplicitProviderRef = modelInput.includes(':');
const canApplyRouting =
!!providerInput.trim() &&
!!modelInput.trim() &&
(selectedProviderOption?.available !== false || hasExplicitProviderRef);
const initial = (title: string) => {
const match = title.match(/[A-Za-z0-9]/);
return match ? match[0].toUpperCase() : '#';
};
const formatDate = (iso: string) => {
const date = new Date(iso);
if (isNaN(date.getTime())) return '';
return new Intl.DateTimeFormat(undefined, {
month: 'short',
day: 'numeric',
hour: 'numeric',
minute: '2-digit',
}).format(date);
};
const toggleSort = useCallback(() => {
setSortMode((previous) => (previous === 'created' ? 'updated' : 'created'));
}, []);
const sortedThreads = useMemo(() => {
const key = sortMode === 'created' ? 'created_at' : 'updated_at';
return [...threads].sort((a, b) => b[key].localeCompare(a[key]));
}, [threads, sortMode]);
const magicRenameCandidateCount = useMemo(
() =>
threads.filter((thread) => {
const messageCount = Number(thread.message_count || thread.messages?.length || 0);
const titleSource = String(thread.title_source || '').trim().toLowerCase();
return messageCount > 0 && !['manual', 'magic'].includes(titleSource);
}).length,
[threads],
);
return (
<aside className={`sidebar ${mobileOpen ? 'mobile-open' : ''}`}>
<div className="sidebar-header">
<div className="sidebar-top">
<h1>CodeAnywhere</h1>
<div className="sidebar-top-actions">
<button
className="btn-ghost btn-sm"
onClick={onMagicRenameAll}
disabled={magicRenameCandidateCount === 0}
title={
magicRenameCandidateCount > 0
? `Use gpt-5.4-nano to rename ${magicRenameCandidateCount} eligible threads`
: 'Only untouched auto-titled conversations are eligible'
}
>
{magicRenameCandidateCount > 0
? `Magic rename all (${magicRenameCandidateCount})`
: 'Magic rename all'}
</button>
<button
type="button"
onClick={() =>
onNewThread({
provider: providerInput || undefined,
model: modelInput || undefined,
})
}
disabled={!canApplyRouting}
>
New chat
</button>
</div>
</div>
</div>
<div className="model-panel">
<label>
Active provider
<select
value={providerInput}
onChange={(e) => {
const nextProvider = e.target.value;
const previousCurated = curatedModels[providerInput] || [];
const nextCurated = curatedModels[nextProvider] || [];
setProviderInput(nextProvider);
setModelInput((current) => {
const currentMatchesPrevious = previousCurated.some((option) => option.id === current);
if ((!current.trim() || currentMatchesPrevious) && nextCurated.length > 0) {
return nextCurated[0].id;
}
return current;
});
}}
>
{providerOptions.map((option) => (
<option key={option.id} value={option.id}>
{option.available ? option.label : `${option.label} (not configured)`}
</option>
))}
</select>
</label>
<label>
Active model
<input
type="text"
value={modelInput}
onChange={(e) => setModelInput(e.target.value)}
placeholder={defaultModel}
/>
</label>
{suggestedModels.length > 0 && (
<details className="model-presets">
<summary className="model-presets-toggle">
Presets for {selectedProviderOption?.label || selectedProvider}
<span className="model-presets-count">{suggestedModels.length}</span>
</summary>
<div className="preset-grid">
{suggestedModels.map((option) => (
<button
key={option.ref}
type="button"
className={`preset-chip ${modelInput === option.id ? 'active' : ''}`}
onClick={() => setModelInput(option.id)}
title={option.ref}
>
<span className="preset-chip-name">{option.label}</span>
<span className="preset-chip-meta">{option.description}</span>
</button>
))}
</div>
</details>
)}
{selectedProviderOption?.available === false && !hasExplicitProviderRef && selectedProviderOption.reason && (
<div className="provider-status provider-status-error">{selectedProviderOption.reason}</div>
)}
<div className="model-hint">
{hasExplicitProviderRef
? 'Combined refs override the provider dropdown until you remove the provider prefix.'
: 'You can also paste a combined ref like vercel:openai/gpt-5.4-mini.'}
</div>
<button
className="btn-sm"
type="button"
onClick={() => onSaveModel(providerInput, modelInput)}
disabled={!canApplyRouting}
>
Save routing
</button>
</div>
<div className="threads-toolbar">
<input
className="threads-search"
type="text"
placeholder="Search threads…"
value={searchQuery}
onChange={(e) => onSearchChange(e.target.value)}
/>
<button
className="sort-toggle"
onClick={toggleSort}
title={`Sorted by ${sortMode === 'created' ? 'created' : 'last modified'}`}
>
{sortMode === 'created' ? '⏱ Created' : '✏️ Modified'}
</button>
</div>
<div className="threads-list">
{sortedThreads.length === 0 && (
<div style={{ padding: '20px', color: 'var(--muted)', fontSize: 13, textAlign: 'center' }}>
{searchQuery ? 'No matching threads.' : 'No threads yet. Start a new chat.'}
</div>
)}
{sortedThreads.map((thread) => (
<button
key={thread.id}
className={`thread-card ${thread.id === activeThreadId ? 'active' : ''}`}
onClick={() => onSelectThread(thread.id)}
>
<span className="thread-badge">{initial(thread.title)}</span>
<div>
<div className="thread-title">{thread.title}</div>
<div className="thread-meta">
{thread.provider}:{thread.model} · ${Number(thread.usage?.cost_usd || '0').toFixed(6)} · {formatDate(sortMode === 'created' ? thread.created_at : thread.updated_at)}
</div>
</div>
</button>
))}
</div>
</aside>
);
}

120
frontend/src/api.ts Normal file
View file

@ -0,0 +1,120 @@
import type { ServerThread, UIConfig } from './types';
async function apiFetch(url: string, init?: RequestInit): Promise<Response> {
const res = await fetch(url, {
credentials: 'same-origin',
...init,
headers: { 'Content-Type': 'application/json', ...(init?.headers || {}) },
});
if (res.status === 401) {
window.location.href = '/login';
throw new Error('Unauthorized');
}
if (!res.ok) {
const body = await res.text();
let detail = `Request failed (${res.status})`;
try {
const j = JSON.parse(body);
if (j.detail) detail = j.detail;
} catch { /* ignore */ }
throw new Error(detail);
}
return res;
}
async function apiJson<T = unknown>(url: string, init?: RequestInit): Promise<T> {
const res = await apiFetch(url, init);
return res.json() as Promise<T>;
}
export async function loadConfig(): Promise<UIConfig> {
return apiJson('/ui-config');
}
export async function listThreads(query?: string): Promise<ServerThread[]> {
const url = query ? `/api/threads?q=${encodeURIComponent(query)}` : '/api/threads';
const data = await apiJson<{ threads: ServerThread[] }>(url);
return data.threads;
}
export async function getThread(id: string): Promise<ServerThread> {
const data = await apiJson<{ thread: ServerThread }>(`/api/threads/${id}`);
return data.thread;
}
export async function createThread(opts: {
title?: string;
model?: string;
provider?: string;
}): Promise<ServerThread> {
const data = await apiJson<{ thread: ServerThread }>('/api/threads', {
method: 'POST',
body: JSON.stringify(opts),
});
return data.thread;
}
export async function updateThread(
id: string,
patch: { title?: string; model?: string; provider?: string },
): Promise<ServerThread> {
const data = await apiJson<{ thread: ServerThread }>(`/api/threads/${id}`, {
method: 'PATCH',
body: JSON.stringify(patch),
});
return data.thread;
}
export async function deleteThread(id: string): Promise<void> {
await apiFetch(`/api/threads/${id}`, { method: 'DELETE' });
}
export async function magicRename(id: string): Promise<ServerThread> {
const data = await apiJson<{ thread: ServerThread }>(`/api/threads/${id}/magic-rename`, {
method: 'POST',
});
return data.thread;
}
export async function magicRenameAll(): Promise<{
threads: ServerThread[];
renamedCount: number;
skippedCount: number;
}> {
return apiJson('/api/threads/magic-rename-all', { method: 'POST' });
}
export async function uploadImage(file: File): Promise<{
id: string;
name: string;
mime_type: string;
size: number;
preview_url?: string;
}> {
const dataUrl = await readFileAsDataUrl(file);
const data = await apiJson<{ upload: { id: string; name: string; mime_type: string; size: number; preview_url?: string } }>(
'/api/uploads',
{
method: 'POST',
body: JSON.stringify({
name: file.name || 'image',
mimeType: file.type || 'image/png',
dataUrl,
}),
},
);
return data.upload;
}
export async function deleteUpload(id: string): Promise<void> {
await apiFetch(`/api/uploads/${encodeURIComponent(id)}`, { method: 'DELETE' });
}
function readFileAsDataUrl(file: File): Promise<string> {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = () => resolve(reader.result as string);
reader.onerror = () => reject(new Error(`Failed to read ${file.name}`));
reader.readAsDataURL(file);
});
}

10
frontend/src/main.tsx Normal file
View file

@ -0,0 +1,10 @@
import { StrictMode } from 'react';
import { createRoot } from 'react-dom/client';
import { App } from './App';
import './styles.css';
createRoot(document.getElementById('root')!).render(
<StrictMode>
<App />
</StrictMode>,
);

809
frontend/src/styles.css Normal file
View file

@ -0,0 +1,809 @@
:root {
color-scheme: dark;
--bg: #0b1220;
--panel: rgba(17, 27, 47, 0.9);
--panel-strong: rgba(22, 36, 61, 0.92);
--line: rgba(151, 180, 255, 0.18);
--text: #eaf1ff;
--muted: #9fb4d9;
--accent: #80b3ff;
--accent-strong: #4d8dff;
--warning-text: #ffd6de;
--shadow: 0 18px 50px rgba(0, 0, 0, 0.28);
--radius: 20px;
}
* {
box-sizing: border-box;
margin: 0;
padding: 0;
}
html,
body,
#root {
height: 100%;
background: linear-gradient(180deg, #07101d 0%, #0b1220 100%);
color: var(--text);
font-family: "Segoe UI", "IBM Plex Sans", system-ui, sans-serif;
}
/* ── Layout ── */
.app-shell {
display: grid;
grid-template-columns: var(--sidebar-width, 260px) 10px minmax(0, 1fr);
height: 100vh;
}
/* ── Sidebar ── */
.sidebar {
display: grid;
grid-template-rows: auto auto auto 1fr;
border-right: 1px solid var(--line);
background: var(--panel);
}
.sidebar-header {
padding: 18px;
border-bottom: 1px solid var(--line);
background: var(--panel-strong);
}
.sidebar-resizer {
position: relative;
cursor: col-resize;
background: linear-gradient(180deg, rgba(128, 179, 255, 0.06), rgba(128, 179, 255, 0.18));
}
.sidebar-resizer::after {
content: '';
position: absolute;
inset: 0 3px;
border-radius: 999px;
background: rgba(128, 179, 255, 0.16);
opacity: 0.45;
transition: opacity 0.12s ease;
}
.sidebar-resizer:hover::after {
opacity: 1;
}
body.sidebar-resizing,
body.sidebar-resizing * {
cursor: col-resize !important;
user-select: none !important;
}
.sidebar-top {
display: flex;
align-items: center;
justify-content: space-between;
gap: 10px;
margin-bottom: 12px;
}
.sidebar-top-actions {
display: flex;
flex-wrap: wrap;
justify-content: flex-end;
gap: 8px;
}
.sidebar-top h1 {
font-size: 18px;
font-weight: 700;
}
.model-panel {
padding: 14px 18px;
border-bottom: 1px solid var(--line);
display: grid;
gap: 8px;
}
.model-panel label {
display: grid;
gap: 4px;
color: var(--muted);
font-size: 12px;
}
.model-hint {
color: var(--muted);
font-size: 11px;
line-height: 1.5;
}
.model-presets {
border: 1px solid var(--line);
border-radius: 12px;
background: rgba(8, 15, 28, 0.5);
overflow: hidden;
}
.model-presets>.preset-grid {
padding: 8px;
display: grid;
gap: 6px;
}
.model-presets-toggle {
list-style: none;
cursor: pointer;
padding: 10px 12px;
display: flex;
align-items: center;
justify-content: space-between;
color: var(--muted);
font-size: 12px;
font-weight: 600;
user-select: none;
}
.model-presets-toggle::-webkit-details-marker {
display: none;
}
.model-presets-toggle:hover {
color: var(--accent);
}
.model-presets-count {
display: inline-grid;
place-items: center;
min-width: 20px;
height: 20px;
padding: 0 6px;
border-radius: 999px;
background: rgba(128, 179, 255, 0.14);
color: var(--accent);
font-size: 11px;
font-weight: 700;
}
.preset-chip {
width: 100%;
border-radius: 14px;
padding: 10px 12px;
text-align: left;
display: grid;
gap: 4px;
background: rgba(8, 15, 28, 0.72);
border: 1px solid var(--line);
color: var(--text);
}
.preset-chip:hover:not(:disabled) {
border-color: rgba(128, 179, 255, 0.42);
}
.preset-chip.active {
border-color: rgba(128, 179, 255, 0.5);
background: rgba(21, 36, 63, 0.94);
}
.preset-chip-name {
font-size: 12px;
font-weight: 700;
}
.preset-chip-meta {
color: var(--muted);
font-size: 11px;
font-weight: 500;
line-height: 1.4;
}
.provider-status {
border-radius: 12px;
padding: 10px 12px;
font-size: 12px;
line-height: 1.5;
}
.provider-status-error {
border: 1px solid rgba(255, 157, 181, 0.3);
background: rgba(63, 20, 29, 0.55);
color: var(--warning-text);
}
.threads-list {
overflow-y: auto;
padding: 8px;
display: flex;
flex-direction: column;
gap: 4px;
}
.threads-toolbar {
display: flex;
gap: 6px;
padding: 8px 12px;
border-bottom: 1px solid var(--line);
align-items: center;
}
.threads-search {
flex: 1;
padding: 6px 10px;
border-radius: 8px;
border: 1px solid var(--line);
background: rgba(8, 15, 28, 0.9);
color: var(--text);
font: inherit;
font-size: 12px;
}
.threads-search:focus {
outline: none;
border-color: var(--accent);
}
.sort-toggle {
flex-shrink: 0;
padding: 4px 8px;
border-radius: 8px;
border: 1px solid var(--line);
background: rgba(8, 15, 28, 0.7);
color: var(--muted);
font-size: 11px;
cursor: pointer;
white-space: nowrap;
}
.sort-toggle:hover {
border-color: var(--accent);
color: var(--accent);
}
.thread-card {
width: 100%;
padding: 10px 12px;
border: 1px solid transparent;
border-radius: 14px;
background: rgba(8, 15, 28, 0.82);
color: var(--text);
text-align: left;
cursor: pointer;
display: grid;
grid-template-columns: 32px minmax(0, 1fr);
align-items: center;
gap: 10px;
transition: border-color 0.12s;
}
.thread-card:hover {
border-color: rgba(128, 179, 255, 0.2);
}
.thread-card.active {
border-color: rgba(128, 179, 255, 0.45);
background: rgba(21, 36, 63, 0.98);
}
.thread-badge {
width: 32px;
height: 32px;
display: grid;
place-items: center;
border-radius: 8px;
background: rgba(128, 179, 255, 0.14);
color: var(--accent);
font-weight: 700;
font-size: 13px;
}
.thread-card.active .thread-badge {
background: var(--accent);
color: #03101f;
}
.thread-title {
font-size: 13px;
font-weight: 600;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.thread-meta {
color: var(--muted);
font-size: 11px;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
/* ── Chat area ── */
.chat-area {
display: grid;
grid-template-rows: auto minmax(0, 1fr) auto;
height: 100%;
min-height: 0;
overflow: hidden;
}
.chat-header {
display: flex;
align-items: center;
justify-content: space-between;
gap: 12px;
padding: 14px 20px;
border-bottom: 1px solid var(--line);
background: var(--panel-strong);
}
.chat-header-left h2 {
font-size: 18px;
}
.chat-header-left .eyebrow {
color: var(--muted);
font-size: 11px;
letter-spacing: 0.1em;
text-transform: uppercase;
}
.chat-actions {
display: flex;
gap: 8px;
}
/* ── Messages ── */
.messages-scroll {
overflow-y: auto;
min-height: 0;
padding: 20px;
display: flex;
flex-direction: column;
gap: 14px;
}
.empty-state {
align-self: center;
justify-self: center;
max-width: 420px;
padding: 40px 20px;
text-align: center;
color: var(--muted);
line-height: 1.6;
}
.message {
max-width: min(760px, 100%);
padding: 14px 16px;
border-radius: 18px;
border: 1px solid var(--line);
background: rgba(17, 27, 47, 0.88);
}
.message.user {
align-self: flex-end;
background: rgba(77, 141, 255, 0.18);
border-color: rgba(128, 179, 255, 0.34);
}
.message.assistant {
align-self: flex-start;
}
.message-label {
margin-bottom: 8px;
color: var(--muted);
font-size: 11px;
text-transform: uppercase;
letter-spacing: 0.1em;
}
.message-body {
color: var(--text);
line-height: 1.6;
word-break: break-word;
font-size: 15px;
}
.status-indicator {
align-self: flex-start;
padding: 10px 14px;
border-radius: 14px;
border: 1px dashed var(--line);
background: rgba(12, 24, 43, 0.9);
color: var(--accent);
font-size: 14px;
animation: pulse 1.5s ease-in-out infinite;
}
@keyframes pulse {
0%,
100% {
opacity: 1;
}
50% {
opacity: 0.6;
}
}
/* ── Tool invocations ── */
.tool-invocation {
margin: 8px 0;
border: 1px solid rgba(94, 196, 255, 0.28);
border-radius: 14px;
background: rgba(8, 15, 28, 0.9);
overflow: hidden;
}
.tool-invocation summary {
list-style: none;
cursor: pointer;
padding: 10px 14px;
font-size: 13px;
color: var(--muted);
}
.tool-invocation summary::-webkit-details-marker {
display: none;
}
.tool-invocation-label {
color: var(--accent);
font-weight: 700;
}
.tool-invocation-state {
color: var(--muted);
font-size: 11px;
text-transform: uppercase;
letter-spacing: 0.06em;
margin-left: 8px;
}
.tool-invocation-body {
padding: 0 14px 12px;
border-top: 1px solid var(--line);
color: #c8d8f0;
font-size: 12px;
line-height: 1.6;
font-family: "Cascadia Code", Consolas, monospace;
white-space: pre-wrap;
word-break: break-word;
max-height: 300px;
overflow: auto;
}
/* ── Reasoning ── */
.reasoning-block {
margin: 8px 0;
border: 1px solid rgba(128, 179, 255, 0.28);
border-radius: 14px;
background: rgba(8, 15, 28, 0.9);
overflow: hidden;
}
.reasoning-block summary {
list-style: none;
cursor: pointer;
padding: 10px 14px;
font-size: 13px;
font-weight: 600;
color: var(--muted);
}
.reasoning-block summary::-webkit-details-marker {
display: none;
}
.reasoning-body {
padding: 0 14px 12px;
border-top: 1px solid var(--line);
color: #c8d8f0;
font-size: 13px;
line-height: 1.6;
white-space: pre-wrap;
word-break: break-word;
max-height: 300px;
overflow: auto;
}
/* ── Composer ── */
.composer {
display: grid;
gap: 10px;
padding: 16px 20px;
border-top: 1px solid var(--line);
background: var(--panel-strong);
}
.composer-input-row {
display: grid;
grid-template-columns: 1fr auto;
gap: 10px;
align-items: end;
}
.composer textarea {
width: 100%;
min-height: 80px;
max-height: 240px;
resize: vertical;
padding: 12px 14px;
border-radius: 14px;
border: 1px solid var(--line);
background: rgba(8, 15, 28, 0.9);
color: var(--text);
font: inherit;
font-size: 15px;
line-height: 1.5;
}
.composer textarea:focus {
outline: none;
border-color: var(--accent);
}
.composer-actions {
display: flex;
gap: 8px;
align-items: center;
font-size: 12px;
color: var(--muted);
}
.composer-uploads {
display: flex;
flex-wrap: wrap;
gap: 8px;
}
.composer-upload-item {
display: flex;
align-items: center;
gap: 8px;
padding: 6px 10px;
border: 1px solid var(--line);
border-radius: 10px;
background: rgba(8, 15, 28, 0.88);
font-size: 12px;
}
.composer-upload-item img {
width: 40px;
height: 40px;
object-fit: cover;
border-radius: 6px;
}
/* ── Buttons ── */
button {
border: none;
border-radius: 999px;
background: linear-gradient(135deg, var(--accent-strong), var(--accent));
color: #03101f;
padding: 8px 14px;
font: inherit;
font-size: 13px;
font-weight: 700;
cursor: pointer;
transition: opacity 0.12s;
}
button:disabled {
opacity: 0.5;
cursor: not-allowed;
}
.btn-ghost {
background: transparent;
color: var(--text);
border: 1px solid var(--line);
font-weight: 600;
}
.btn-ghost:hover:not(:disabled) {
border-color: var(--accent);
}
.btn-sm {
padding: 6px 10px;
font-size: 12px;
}
input[type="text"],
select {
width: 100%;
padding: 8px 12px;
border-radius: 10px;
border: 1px solid var(--line);
background: rgba(8, 15, 28, 0.9);
color: var(--text);
font: inherit;
font-size: 13px;
}
input[type="text"]:focus,
select:focus {
outline: none;
border-color: var(--accent);
}
/* ── Markdown body ── */
.markdown-body> :first-child {
margin-top: 0;
}
.markdown-body> :last-child {
margin-bottom: 0;
}
.markdown-body h1,
.markdown-body h2,
.markdown-body h3,
.markdown-body h4,
.markdown-body h5,
.markdown-body h6 {
margin: 0 0 12px;
font-size: inherit;
font-weight: 700;
}
.markdown-body p,
.markdown-body ul,
.markdown-body ol,
.markdown-body pre,
.markdown-body blockquote,
.markdown-body table,
.markdown-body hr {
margin: 0 0 12px;
}
.markdown-body ul,
.markdown-body ol {
padding-left: 22px;
}
.markdown-body li+li {
margin-top: 4px;
}
.markdown-body a {
color: var(--accent);
}
.markdown-body blockquote {
padding-left: 12px;
border-left: 3px solid rgba(128, 179, 255, 0.35);
color: var(--muted);
}
.markdown-body hr {
border: none;
border-top: 1px solid var(--line);
}
.markdown-body code,
.markdown-body pre {
font-family: "Cascadia Code", Consolas, monospace;
}
.markdown-body :not(pre)>code {
padding: 0.16em 0.38em;
border-radius: 8px;
background: rgba(128, 179, 255, 0.12);
font-size: 0.94em;
}
.markdown-body pre {
padding: 12px 14px;
border-radius: 14px;
border: 1px solid var(--line);
background: rgba(6, 12, 22, 0.92);
overflow: auto;
}
.markdown-body pre code {
padding: 0;
background: transparent;
font-size: 0.94em;
}
.markdown-body table {
width: 100%;
border-collapse: collapse;
font-size: 14px;
}
.markdown-body th,
.markdown-body td {
border: 1px solid var(--line);
padding: 8px 10px;
text-align: left;
}
.markdown-body img {
max-width: 100%;
border-radius: 12px;
}
/* ── Banner / toast ── */
.toast {
position: fixed;
top: 16px;
right: 16px;
z-index: 100;
max-width: 360px;
padding: 12px 16px;
border-radius: 14px;
border: 1px solid var(--line);
background: var(--panel-strong);
color: var(--accent);
font-size: 14px;
box-shadow: var(--shadow);
animation: slideIn 0.2s ease;
}
.toast.error {
border-color: rgba(255, 157, 181, 0.35);
color: var(--warning-text);
}
@keyframes slideIn {
from {
transform: translateY(-12px);
opacity: 0;
}
to {
transform: translateY(0);
opacity: 1;
}
}
/* ── Mobile ── */
@media (max-width: 768px) {
.app-shell {
grid-template-columns: 1fr;
}
.sidebar {
display: none;
}
.sidebar.mobile-open {
display: grid;
position: fixed;
inset: 0;
z-index: 50;
width: min(86vw, 320px);
}
.sidebar-resizer {
display: none;
}
.mobile-backdrop {
position: fixed;
inset: 0;
z-index: 40;
background: rgba(3, 8, 16, 0.6);
backdrop-filter: blur(4px);
}
.mobile-toggle {
display: inline-flex !important;
}
.chat-header {
flex-wrap: wrap;
}
.chat-actions {
width: 100%;
justify-content: flex-start;
}
}
.mobile-toggle {
display: none;
}

58
frontend/src/types.ts Normal file
View file

@ -0,0 +1,58 @@
export interface UsageSummary {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
reasoning_tokens: number;
cached_tokens: number;
web_search_requests?: number;
request_count: number;
cost_usd: string;
cost_breakdown?: Record<string, string>;
pricing_source?: string;
advisor_prompt_tokens?: number;
advisor_completion_tokens?: number;
advisor_total_tokens?: number;
}
export interface ServerThread {
id: string;
title: string;
model: string;
provider: string;
message_count: number;
title_source: string;
created_at: string;
updated_at: string;
usage: UsageSummary;
messages?: ServerMessage[];
}
export interface ServerMessage {
id: string;
role: 'user' | 'assistant';
content: string;
parts?: { type: string;[k: string]: unknown }[];
created_at: string;
}
export interface ProviderOption {
id: string;
label: string;
available: boolean;
reason?: string;
}
export interface CuratedModelOption {
id: string;
label: string;
description: string;
ref: string;
}
export interface UIConfig {
defaultModel: string;
defaultProvider: string;
pricingSource?: string;
providerOptions: ProviderOption[];
curatedModels: Record<string, CuratedModelOption[]>;
}

27
frontend/tsconfig.json Normal file
View file

@ -0,0 +1,27 @@
{
"compilerOptions": {
"target": "ES2020",
"useDefineForClassFields": true,
"lib": [
"ES2020",
"DOM",
"DOM.Iterable"
],
"module": "ESNext",
"skipLibCheck": true,
"moduleResolution": "bundler",
"allowImportingTsExtensions": true,
"isolatedModules": true,
"moduleDetection": "force",
"noEmit": true,
"jsx": "react-jsx",
"strict": true,
"noUnusedLocals": false,
"noUnusedParameters": false,
"noFallthroughCasesInSwitch": true,
"forceConsistentCasingInFileNames": true
},
"include": [
"src"
]
}

19
frontend/vite.config.ts Normal file
View file

@ -0,0 +1,19 @@
import react from '@vitejs/plugin-react';
import { defineConfig } from 'vite';
export default defineConfig({
plugins: [react()],
build: {
outDir: '../static/dist',
emptyOutDir: true,
},
server: {
proxy: {
'/api': 'http://localhost:3000',
'/login': 'http://localhost:3000',
'/health': 'http://localhost:3000',
'/uploads': 'http://localhost:3000',
'/ui-config': 'http://localhost:3000',
},
},
});

90
instance.py Normal file
View file

@ -0,0 +1,90 @@
"""BetterBot instance configuration.
BetterBot is a fork of CodeAnywhere tailored for editing the Better Life SG
website and Memoraiz frontend via Telegram. This file is the only
customization point everything else is shared framework code.
"""
from __future__ import annotations
# ── Identity ─────────────────────────────────────────────────────
BOT_NAME = "BetterBot"
BOT_DESCRIPTION = "Telegram bot for editing Better Life SG website and Memoraiz"
FASTAPI_TITLE = "BetterBot"
# ── System prompt ────────────────────────────────────────────────
BASE_CONTEXT = """\
You are BetterBot, a helpful assistant that manages two projects:
1. **Better Life SG website** (project: "betterlifesg")
- Static HTML site using Tailwind CSS (loaded via CDN)
- Key files: index.html, fresh-grads.html, prenatal.html, retirement.html, \
legacy.html, team.html, contact.html, images/ folder
- Brand color: teal (#00b49a)
- Changes go live immediately after writing
2. **Memoraiz app** (project: "memoraiz")
- React 19 + Vite 6 + Tailwind CSS 4 frontend
- Source code is under frontend/src/ (pages in frontend/src/pages/, \
components in frontend/src/components/)
- Changes require a rebuild to go live (handled automatically after you write)
When the user asks you to change something:
1. Ask which project if unclear (default to betterlifesg for website questions)
2. First read the relevant file(s) to understand the current state
3. Make the requested changes
4. Write the updated file back
5. Confirm what you changed
When writing a file, always write the COMPLETE file content never partial.
Keep your responses concise and friendly. Always confirm changes after making them.
Do NOT change the overall page structure unless explicitly asked.
"""
SKILLS_CONTEXT = ""
# ── Skill directories ────────────────────────────────────────────
def skill_directories() -> list[str]:
"""BetterBot does not use skill directories."""
return []
# ── Tool registration ────────────────────────────────────────────
def register_tools() -> None:
"""Register BetterBot's site-editing tools."""
from tool_registry import registry
from tools.site_editing import site_editing_toolset
registry.register(site_editing_toolset)
# ── Telegram ─────────────────────────────────────────────────────
TELEGRAM_START_MESSAGE = (
"Hi! I'm BetterBot 🤖\n\n"
"I manage two projects:\n"
"• **Better Life SG** website\n"
"• **Memoraiz** app\n\n"
"Just tell me what you'd like to change!\n\n"
"Examples:\n"
'"Change the WhatsApp number to 91234567"\n'
'"Update Hendri\'s title to Senior Consultant"\n'
'"Update the login page text in Memoraiz"\n\n'
"/reset — start a fresh conversation\n"
"/model <name> — switch LLM model\n"
"/current — show current model\n"
"Send a photo with an optional caption to ask about an image."
)
# ── Features ─────────────────────────────────────────────────────
ENABLE_WEB_UI = False
ENABLE_TELEGRAM = True
ENABLE_CHATKIT = False
ENABLE_BACKGROUND_AGENTS = False

144
learning.py Normal file
View file

@ -0,0 +1,144 @@
"""Learning extraction — identifies facts worth persisting from conversation turns."""
from __future__ import annotations
import re
from typing import Any
def extract_learnings_from_turn(
user_message: str,
assistant_message: str,
) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]:
"""Extract project-scoped and global learnings from a completed turn.
Returns:
(project_learnings, global_learnings) where each item is (fact, category).
Strategy: rule-based extraction (zero cost, no extra LLM call).
- Detect user-stated preferences, personality cues, contact info global
- Detect technical discoveries, patterns, architecture notes project
"""
project: list[tuple[str, str]] = []
global_facts: list[tuple[str, str]] = []
# ── Global learnings (user preferences, personality) ──────────────
# Explicit preference statements
pref_patterns = [
r"(?:I (?:prefer|like|want|always|never|hate|don't like|can't stand)\s+.+?)[.!?]",
r"(?:my (?:favorite|preferred|default|usual)\s+.+?(?:is|are)\s+.+?)[.!?]",
r"(?:I'm (?:a|an)\s+.+?)(?:\s+(?:at|who|that|working))?[.,!?]",
]
for pattern in pref_patterns:
for match in re.finditer(pattern, user_message, re.IGNORECASE):
fact = match.group(0).strip()
if 10 < len(fact) < 200:
global_facts.append((fact, "preference"))
# Job/role mentions
job_patterns = [
r"(?:I (?:work|am working)\s+(?:at|for|on|with)\s+.+?)[.,!?]",
r"(?:I'm\s+(?:a|an)\s+\w+(?:\s+\w+)?\s+(?:engineer|developer|designer|manager|analyst|researcher|founder|student))",
r"(?:my (?:role|title|position|job)\s+(?:is|as)\s+.+?)[.,!?]",
]
for pattern in job_patterns:
for match in re.finditer(pattern, user_message, re.IGNORECASE):
fact = match.group(0).strip()
if 10 < len(fact) < 200:
global_facts.append((fact, "profile"))
# Contact/identity info
identity_patterns = [
r"(?:my (?:email|github|twitter|handle|username|name|phone|number)\s+(?:is)\s+.+?)[.,!?]",
]
for pattern in identity_patterns:
for match in re.finditer(pattern, user_message, re.IGNORECASE):
fact = match.group(0).strip()
if 10 < len(fact) < 200:
global_facts.append((fact, "identity"))
# Family / relationships / personal facts
family_patterns = [
r"(?:I (?:have|am married to|live with)\s+(?:a\s+)?(?:wife|husband|partner|spouse)\s+\w+[^.!?]*)[.!?]",
r"(?:my (?:wife|husband|partner|spouse|daughter|son|child|kid|mother|father|parent|sibling|brother|sister)(?:'s\s+\w+)?\s+(?:is|was|name(?:d| is)?|born)\s+[^.!?]+)[.!?]",
r"(?:I (?:have|have got)\s+(?:two|three|four|five|\d+)\s+(?:daughters?|sons?|kids?|children)[^.!?]*)[.!?]",
r"[A-Z][a-z]+\s+born\s+\d{4}",
]
for pattern in family_patterns:
for match in re.finditer(pattern, user_message, re.IGNORECASE):
fact = match.group(0).strip()
if 10 < len(fact) < 300:
global_facts.append((fact, "personal"))
# Tone/personality instructions (e.g., "be more concise", "use bullet points")
tone_patterns = [
r"(?:(?:be|use|respond|reply|answer|speak|write)\s+(?:more\s+)?(?:concise|brief|detailed|verbose|short|formal|casual|friendly|professional|terse|bullet|markdown|code))[.!?]?",
r"(?:don't (?:use|add|include|give)\s+.+?(?:explanation|comment|context|preamble|prefix))[.!?]?",
]
for pattern in tone_patterns:
for match in re.finditer(pattern, user_message, re.IGNORECASE):
fact = match.group(0).strip()
if 5 < len(fact) < 200:
global_facts.append((fact, "tone"))
# ── Project learnings (technical facts, patterns) ─────────────────
# Architecture/discovery from assistant responses
arch_patterns = [
r"(?:the\s+\w[\w-]*(?:\s+\w[\w-]*)?\s+(?:uses|runs on|depends on|is backed by|is configured with|requires)\s+.+?)[.,]",
]
for pattern in arch_patterns:
for match in re.finditer(pattern, assistant_message, re.IGNORECASE):
fact = match.group(0).strip()
if 15 < len(fact) < 250:
project.append((fact, "architecture"))
# Bug/pattern from assistant
bug_patterns = [
r"(?:(?:root cause|the issue|the problem|the bug)\s+(?:is|was)\s+.+?)[.,]",
r"(?:this\s+(?:happens|occurs)\s+(?:because|due to|when)\s+.+?)[.,]",
]
for pattern in bug_patterns:
for match in re.finditer(pattern, assistant_message, re.IGNORECASE):
fact = match.group(0).strip()
if 15 < len(fact) < 250:
project.append((fact, "bug_pattern"))
# Deployment patterns from assistant
deploy_patterns = [
r"(?:deployed?\s+(?:via|through|using|to)\s+.+?)[.,]",
r"(?:the\s+(?:deploy|ci|pipeline|action)\s+(?:uses|runs|triggers)\s+.+?)[.,]",
]
for pattern in deploy_patterns:
for match in re.finditer(pattern, assistant_message, re.IGNORECASE):
fact = match.group(0).strip()
if 15 < len(fact) < 250:
project.append((fact, "deployment"))
return project, global_facts
def format_learnings_for_prompt(
project_learnings: list[dict[str, Any]],
global_learnings: list[dict[str, Any]],
) -> str | None:
"""Format learnings into a section to append to the system prompt.
Returns None if there are no learnings to inject.
"""
sections: list[str] = []
if global_learnings:
lines = ["## User Preferences & Profile"]
for item in global_learnings[-15:]: # Keep prompt concise
lines.append(f"- {item['fact']}")
sections.append("\n".join(lines))
if project_learnings:
lines = ["## Project Learnings"]
for item in project_learnings[-15:]:
lines.append(f"- {item['fact']}")
sections.append("\n".join(lines))
return "\n\n".join(sections) if sections else None

450
llm_costs.py Normal file
View file

@ -0,0 +1,450 @@
from __future__ import annotations
import logging
import time
from decimal import ROUND_HALF_UP, Decimal
from typing import Any
from config import settings
logger = logging.getLogger(__name__)
OPENROUTER_PRICING_SOURCE = "https://openrouter.ai/api/v1/models"
VERCEL_GENERATION_URL = "https://ai-gateway.vercel.sh/v1/generation"
ZERO = Decimal("0")
USD_DISPLAY_QUANT = Decimal("0.000001")
_PRICING_CACHE_TTL = 3600 # 1 hour
# In-memory cache: model_id → (pricing_dict, fetched_at)
_pricing_cache: dict[str, tuple[dict[str, Decimal], float]] = {}
# Maps bare model prefixes to their OpenRouter vendor slug so we can look up
# pricing even when the model was routed directly (e.g. openai provider).
_VENDOR_PREFIXES: tuple[tuple[str, str], ...] = (
("gpt-", "openai"),
("chatgpt-", "openai"),
("o1", "openai"),
("o3", "openai"),
("o4", "openai"),
("computer-use-", "openai"),
("claude-", "anthropic"),
("gemini-", "google"),
("deepseek-", "deepseek"),
)
def _to_openrouter_model_id(model: str, provider: str) -> str:
"""Normalise a model string into the ``vendor/model`` format OpenRouter uses."""
model = str(model or "").strip()
if "/" in model:
return model
lowered = model.lower()
for prefix, vendor in _VENDOR_PREFIXES:
if lowered.startswith(prefix):
return f"{vendor}/{model}"
return model
def _to_decimal(value: Any) -> Decimal:
try:
if value in (None, ""):
return ZERO
return Decimal(str(value))
except Exception:
return ZERO
def _to_int(value: Any) -> int:
try:
if value in (None, ""):
return 0
return int(value)
except Exception:
return 0
def fetch_vercel_generation_cost(api_call_ids: list[str]) -> Decimal | None:
"""Look up exact cost from Vercel AI Gateway for generation IDs.
Returns total cost in USD across all generations, or None if lookup fails.
"""
if not settings.VERCEL_API_KEY or not api_call_ids:
return None
try:
import httpx
except Exception:
return None
total_cost = ZERO
found_any = False
with httpx.Client(timeout=10.0) as client:
for gen_id in api_call_ids:
try:
response = client.get(
VERCEL_GENERATION_URL,
params={"id": gen_id},
headers={"Authorization": f"Bearer {settings.VERCEL_API_KEY}"},
)
response.raise_for_status()
data = response.json()
cost = data.get("total_cost")
if cost is not None:
total_cost += _to_decimal(cost)
found_any = True
else:
logger.debug("Vercel generation %s returned no total_cost: %s", gen_id, data)
except Exception:
logger.warning("Failed to fetch Vercel generation cost for %s", gen_id)
continue
return total_cost if found_any else None
def find_openrouter_pricing(model_id: str) -> dict[str, Decimal]:
try:
import httpx
except Exception:
return {}
model_id = str(model_id or "").strip()
if not model_id:
return {}
# Check in-memory cache
cached = _pricing_cache.get(model_id)
if cached is not None:
pricing_data, fetched_at = cached
if time.monotonic() - fetched_at < _PRICING_CACHE_TTL:
return pricing_data
headers = {}
if settings.OPENROUTER_API_KEY:
headers["Authorization"] = f"Bearer {settings.OPENROUTER_API_KEY}"
try:
with httpx.Client(timeout=20.0) as client:
response = client.get(OPENROUTER_PRICING_SOURCE, headers=headers)
response.raise_for_status()
payload = response.json()
except Exception:
logger.warning("Failed to fetch OpenRouter pricing for %s", model_id)
return {}
for item in payload.get("data", []):
if str(item.get("id") or "").strip() != model_id:
continue
pricing = item.get("pricing") or {}
result: dict[str, Decimal] = {}
for key in ("prompt", "completion", "input_cache_read", "web_search"):
amount = _to_decimal(pricing.get(key))
if amount:
result[key] = amount
_pricing_cache[model_id] = (result, time.monotonic())
return result
# Cache the miss too so we don't re-fetch for unknown models
_pricing_cache[model_id] = ({}, time.monotonic())
return {}
def _aggregate_sdk_usage(events: list[Any]) -> dict[str, int]:
"""Aggregate token counts from Copilot SDK events.
Prefers session.usage_info (aggregated model_metrics) when available,
otherwise sums individual assistant.usage events.
"""
from copilot.generated.session_events import SessionEventType
# Collect api_call_ids from ASSISTANT_USAGE events (used for Vercel cost lookup)
api_call_ids: list[str] = []
for event in events:
if getattr(event, "type", None) != SessionEventType.ASSISTANT_USAGE:
continue
data = getattr(event, "data", None)
call_id = getattr(data, "api_call_id", None) if data else None
if call_id:
api_call_ids.append(str(call_id))
if api_call_ids:
logger.debug("Collected %d api_call_ids (first: %s)", len(api_call_ids), api_call_ids[0])
else:
logger.debug("No api_call_ids found in %d events", len(events))
# Try session.usage_info first — it carries aggregated model_metrics
for event in reversed(events):
if getattr(event, "type", None) != SessionEventType.SESSION_USAGE_INFO:
continue
data = getattr(event, "data", None)
metrics = getattr(data, "model_metrics", None)
if not metrics:
continue
agg: dict[str, Any] = {
"input_tokens": 0,
"output_tokens": 0,
"cache_read_tokens": 0,
"cache_write_tokens": 0,
"request_count": 0,
"api_call_ids": api_call_ids,
}
for metric in metrics.values():
usage = getattr(metric, "usage", None)
if usage:
agg["input_tokens"] += _to_int(getattr(usage, "input_tokens", 0))
agg["output_tokens"] += _to_int(getattr(usage, "output_tokens", 0))
agg["cache_read_tokens"] += _to_int(getattr(usage, "cache_read_tokens", 0))
agg["cache_write_tokens"] += _to_int(getattr(usage, "cache_write_tokens", 0))
requests = getattr(metric, "requests", None)
if requests:
agg["request_count"] += _to_int(getattr(requests, "count", 0))
return agg
# Fall back to summing individual assistant.usage events
agg = {
"input_tokens": 0,
"output_tokens": 0,
"cache_read_tokens": 0,
"cache_write_tokens": 0,
"request_count": 0,
"api_call_ids": api_call_ids,
}
found = False
for event in events:
if getattr(event, "type", None) != SessionEventType.ASSISTANT_USAGE:
continue
data = getattr(event, "data", None)
if data is None:
continue
found = True
agg["input_tokens"] += _to_int(getattr(data, "input_tokens", 0))
agg["output_tokens"] += _to_int(getattr(data, "output_tokens", 0))
agg["cache_read_tokens"] += _to_int(getattr(data, "cache_read_tokens", 0))
agg["cache_write_tokens"] += _to_int(getattr(data, "cache_write_tokens", 0))
agg["request_count"] += 1
if found:
return agg
return {}
def extract_usage_and_cost(
model: str,
provider: str,
result: Any,
advisor_usage: dict[str, int] | None = None,
) -> dict[str, Any]:
# When given a list of SDK events, aggregate directly from typed events
if isinstance(result, list):
sdk_usage = _aggregate_sdk_usage(result)
else:
sdk_usage = {}
# Fall back to legacy extraction for non-list results
if not sdk_usage:
single = _find_usage_event(result) if isinstance(result, list) else result
usage = _extract_usage_payload(single)
prompt_tokens = _pick_first_int(usage, "input_tokens", "prompt_tokens", "promptTokens")
completion_tokens = _pick_first_int(
usage, "output_tokens", "completion_tokens", "completionTokens", "outputTokens"
)
reasoning_tokens = _pick_nested_int(
usage, [("output_tokens_details", "reasoning_tokens"), ("completion_tokens_details", "reasoning_tokens")]
)
cached_tokens = _pick_nested_int(
usage, [("input_tokens_details", "cached_tokens"), ("prompt_tokens_details", "cached_tokens")]
)
request_count = 1
else:
prompt_tokens = sdk_usage.get("input_tokens", 0)
completion_tokens = sdk_usage.get("output_tokens", 0)
cached_tokens = sdk_usage.get("cache_read_tokens", 0)
reasoning_tokens = 0
request_count = sdk_usage.get("request_count", 1) or 1
api_call_ids: list[str] = sdk_usage.get("api_call_ids", []) if sdk_usage else []
total_tokens = prompt_tokens + completion_tokens
summary: dict[str, Any] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"reasoning_tokens": reasoning_tokens,
"cached_tokens": cached_tokens,
"request_count": request_count,
"cost_usd": "0",
"cost_breakdown": {},
}
# Merge advisor sidecar usage when provided
if advisor_usage:
summary["advisor_prompt_tokens"] = advisor_usage.get("prompt_tokens", 0)
summary["advisor_completion_tokens"] = advisor_usage.get("completion_tokens", 0)
summary["advisor_total_tokens"] = (
summary["advisor_prompt_tokens"] + summary["advisor_completion_tokens"]
)
# Prefer Vercel generation cost lookup (exact cost from gateway)
if provider == "vercel" and api_call_ids:
vercel_cost = fetch_vercel_generation_cost(api_call_ids)
if vercel_cost is not None:
summary.update(
{
"cost_usd": _format_usd(vercel_cost),
"pricing_source": VERCEL_GENERATION_URL,
}
)
return summary
logger.debug("Vercel generation cost lookup returned None for %d IDs", len(api_call_ids))
elif provider == "vercel":
logger.debug("Vercel provider but no api_call_ids available — falling back to OpenRouter pricing")
# Fall back to OpenRouter pricing catalog (token-rate estimation)
pricing = find_openrouter_pricing(_to_openrouter_model_id(model, provider))
if not pricing:
return summary
prompt_rate = pricing.get("prompt", ZERO)
completion_rate = pricing.get("completion", ZERO)
cache_rate = pricing.get("input_cache_read", ZERO)
web_search_rate = pricing.get("web_search", ZERO)
uncached_prompt_tokens = max(prompt_tokens - cached_tokens, 0)
prompt_cost = Decimal(uncached_prompt_tokens) * prompt_rate
completion_cost = Decimal(completion_tokens) * completion_rate
cache_cost = Decimal(cached_tokens) * cache_rate
web_search_requests = _count_web_search_requests(result)
web_search_cost = Decimal(web_search_requests) * web_search_rate
total_cost = prompt_cost + completion_cost + cache_cost + web_search_cost
breakdown = {}
if prompt_cost:
breakdown["prompt"] = _format_usd(prompt_cost)
if completion_cost:
breakdown["completion"] = _format_usd(completion_cost)
if cache_cost:
breakdown["cache_read"] = _format_usd(cache_cost)
if web_search_cost:
breakdown["web_search"] = _format_usd(web_search_cost)
summary.update(
{
"web_search_requests": web_search_requests,
"cost_usd": _format_usd(total_cost),
"cost_breakdown": breakdown,
"pricing_source": OPENROUTER_PRICING_SOURCE,
}
)
return summary
def format_usage_line(usage: dict[str, Any] | None) -> str:
if not isinstance(usage, dict):
return ""
prompt_tokens = _to_int(usage.get("prompt_tokens"))
completion_tokens = _to_int(usage.get("completion_tokens"))
total_tokens = _to_int(usage.get("total_tokens"))
cost_usd = _format_display_usd(usage.get("cost_usd"))
return f"Usage: {total_tokens:,} tok ({prompt_tokens:,} in / {completion_tokens:,} out) · Cost: ${cost_usd}"
def _extract_usage_payload(result: Any) -> dict[str, Any]:
for attr in ("usage",):
value = getattr(result, attr, None)
if value:
return _to_dict(value)
data = getattr(result, "data", None)
if data is not None:
usage = getattr(data, "usage", None)
if usage:
return _to_dict(usage)
if isinstance(result, dict):
usage = result.get("usage")
if usage:
return _to_dict(usage)
return {}
def _to_dict(value: Any) -> dict[str, Any]:
if isinstance(value, dict):
return value
if hasattr(value, "model_dump"):
try:
return value.model_dump()
except Exception:
return {}
if hasattr(value, "dict"):
try:
return value.dict()
except Exception:
return {}
return {k: getattr(value, k) for k in dir(value) if not k.startswith("_") and not callable(getattr(value, k))}
def _pick_first_int(data: dict[str, Any], *keys: str) -> int:
for key in keys:
if key in data:
value = _to_int(data.get(key))
if value:
return value
return 0
def _pick_nested_int(data: dict[str, Any], paths: list[tuple[str, str]]) -> int:
for first, second in paths:
parent = data.get(first)
if isinstance(parent, dict) and second in parent:
value = _to_int(parent.get(second))
if value:
return value
return 0
def _count_web_search_requests(result: Any) -> int:
events = getattr(result, "events", None)
if not events:
return 0
count = 0
for event in events:
tool_name = getattr(getattr(event, "data", None), "tool_name", None) or getattr(
getattr(event, "data", None), "name", None
)
if str(tool_name or "").strip().lower() == "web_search":
count += 1
return count
def _format_usd(value: Decimal) -> str:
return format(value.quantize(Decimal("0.000000000001"), rounding=ROUND_HALF_UP).normalize(), "f")
def _format_display_usd(value: Any) -> str:
amount = _to_decimal(value).quantize(USD_DISPLAY_QUANT, rounding=ROUND_HALF_UP)
return format(amount, "f")
def _find_usage_event(events: list[Any]) -> Any:
for event in reversed(events):
if _extract_usage_payload(event):
return event
return events[-1] if events else None
def format_cost_value(value: Any) -> str:
return _format_display_usd(value)
def summarize_usage(usage: dict[str, Any] | None) -> dict[str, Any] | None:
if not isinstance(usage, dict):
return None
prompt_tokens = _to_int(usage.get("prompt_tokens"))
completion_tokens = _to_int(usage.get("completion_tokens"))
total_tokens = _to_int(usage.get("total_tokens")) or (prompt_tokens + completion_tokens)
request_count = _to_int(usage.get("request_count"))
cost_usd = _to_decimal(usage.get("cost_usd"))
if request_count <= 0 and total_tokens <= 0 and cost_usd <= ZERO:
return None
return {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"reasoning_tokens": _to_int(usage.get("reasoning_tokens")),
"cached_tokens": _to_int(usage.get("cached_tokens")),
"request_count": max(request_count, 0),
"cost_usd": format(cost_usd, "f"),
}

85
local_media_store.py Normal file
View file

@ -0,0 +1,85 @@
"""Disk-backed media blobs shared by the web UI, Telegram, and ChatKit uploads."""
from __future__ import annotations
import base64
import json
import shutil
import uuid
from pathlib import Path
from typing import Any
class LocalMediaStore:
def __init__(self, data_dir: str = "/data"):
self._root = Path(data_dir) / "uploads"
self._root.mkdir(parents=True, exist_ok=True)
def save_bytes(
self,
data: bytes,
*,
name: str,
mime_type: str,
file_id: str | None = None,
) -> dict[str, Any]:
normalized_name = str(name).strip() or "upload"
normalized_mime_type = str(mime_type).strip() or "application/octet-stream"
normalized_file_id = str(file_id).strip() if file_id else uuid.uuid4().hex
item_dir = self._item_dir(normalized_file_id)
item_dir.mkdir(parents=True, exist_ok=True)
(item_dir / "blob").write_bytes(data)
metadata = {
"id": normalized_file_id,
"name": normalized_name,
"mime_type": normalized_mime_type,
"size": len(data),
}
(item_dir / "meta.json").write_text(json.dumps(metadata, indent=2), encoding="utf-8")
return dict(metadata)
def get_meta(self, file_id: str) -> dict[str, Any]:
meta_path = self._item_dir(file_id) / "meta.json"
if not meta_path.exists():
raise KeyError(file_id)
try:
payload = json.loads(meta_path.read_text(encoding="utf-8"))
except json.JSONDecodeError as error:
raise KeyError(file_id) from error
if not isinstance(payload, dict):
raise KeyError(file_id)
payload.setdefault("id", str(file_id))
payload.setdefault("name", "upload")
payload.setdefault("mime_type", "application/octet-stream")
payload.setdefault("size", self._blob_path(file_id).stat().st_size if self._blob_path(file_id).exists() else 0)
return payload
def read_bytes(self, file_id: str) -> bytes:
blob_path = self._blob_path(file_id)
if not blob_path.exists():
raise KeyError(file_id)
return blob_path.read_bytes()
def delete(self, file_id: str) -> None:
item_dir = self._item_dir(file_id)
if item_dir.exists():
shutil.rmtree(item_dir)
def build_data_url(self, file_id: str) -> str:
metadata = self.get_meta(file_id)
encoded = base64.b64encode(self.read_bytes(file_id)).decode("ascii")
return f"data:{metadata['mime_type']};base64,{encoded}"
def _item_dir(self, file_id: str) -> Path:
normalized_file_id = str(file_id).strip()
if not normalized_file_id:
raise KeyError(file_id)
return self._root / normalized_file_id
def _blob_path(self, file_id: str) -> Path:
return self._item_dir(file_id) / "blob"

1750
main.py

File diff suppressed because it is too large Load diff

451
model_selection.py Normal file
View file

@ -0,0 +1,451 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Final
from config import settings
OFFICIAL_OPENAI_BASE_URLS: Final[set[str]] = {
"https://api.openai.com",
"https://api.openai.com/v1",
}
SUPPORTED_PROVIDERS: Final[tuple[str, ...]] = ("copilot", "openai", "openrouter", "vercel", "huggingface")
OPENAI_PREFIXES: Final[tuple[str, ...]] = (
"gpt-",
"chatgpt-",
"o1",
"o3",
"o4",
"computer-use-",
)
OPENROUTER_VENDOR_HINTS: Final[tuple[tuple[str, str], ...]] = (
("gpt-", "openai"),
("chatgpt-", "openai"),
("o1", "openai"),
("o3", "openai"),
("o4", "openai"),
("claude-", "anthropic"),
("gemini-", "google"),
("deepseek-", "deepseek"),
)
VERCEL_VENDOR_HINTS: Final[tuple[tuple[str, str], ...]] = (
("gpt-", "openai"),
("chatgpt-", "openai"),
("o1", "openai"),
("o3", "openai"),
("o4", "openai"),
("claude-", "anthropic"),
("gemini-", "google"),
("deepseek-", "deepseek"),
)
VENDOR_PROVIDER_ALIASES: Final[dict[str, str]] = {
"anthropic": "anthropic",
"claude": "anthropic",
"google": "google",
"gemini": "google",
"deepseek": "deepseek",
}
PROVIDER_DISPLAY_NAMES: Final[dict[str, str]] = {
"copilot": "GitHub Copilot",
"openai": "OpenAI",
"openrouter": "OpenRouter",
"vercel": "Vercel AI Gateway",
"huggingface": "HuggingFace",
}
# Model name fragments known NOT to support vision input.
_NO_VISION_FRAGMENTS: Final[tuple[str, ...]] = (
"gemma-2",
"gemma-1",
"llama",
"mistral",
"mixtral",
"phi-",
"qwen",
"deepseek",
"command-r",
"codestral",
"wizardlm",
"yi-",
"dbrx",
"o1-mini",
"o1-preview",
"o3-mini",
"gpt-3.5",
)
CURATED_MODELS: Final[dict[str, tuple[tuple[str, str], ...]]] = {
"copilot": (
("claude-sonnet-4.5", "Claude Sonnet 4.5 via Copilot"),
("claude-sonnet-4.6", "Claude Sonnet 4.6 via Copilot"),
("claude-opus-4.6", "Claude Opus 4.6 via Copilot (premium)"),
("gpt-4o", "GPT-4o via Copilot"),
("o3-mini", "o3-mini via Copilot"),
("gemini-2.5-pro", "Gemini 2.5 Pro via Copilot"),
),
"openai": (
("gpt-5.4", "default orchestrator model"),
("gpt-5.4-mini", "fast subagent / background model"),
("gpt-5.4-nano", "smallest low-latency model"),
),
"openrouter": (
("anthropic/claude-haiku-4.5", "fast Claude"),
("anthropic/claude-sonnet-4.5", "strong Claude"),
("anthropic/claude-sonnet-4.6", "newer Claude Sonnet"),
("google/gemini-2.5-flash", "fast Gemini"),
("google/gemini-2.5-pro", "strong Gemini"),
),
"vercel": (
("anthropic/claude-sonnet-4.5", "strong Claude via Vercel"),
("anthropic/claude-sonnet-4.6", "newer Claude Sonnet via Vercel"),
("anthropic/claude-haiku-4.5", "fast Claude via Vercel"),
("google/gemini-2.5-pro", "strong Gemini via Vercel"),
("google/gemini-2.5-flash", "fast Gemini via Vercel"),
("openai/gpt-4o", "OpenAI via Vercel gateway"),
("xai/grok-2", "Grok via Vercel gateway"),
("google/gemma-4-31b-it", "Gemma 4 via Vercel (background agent)"),
),
"huggingface": (
("Qwen/QwQ-32B", "QwQ 32B reasoning model"),
("Qwen/Qwen3-235B-A22B", "Qwen3 235B MoE"),
("mistralai/Mistral-Small-3.1-24B-Instruct-2503", "Mistral Small 3.1 24B"),
("google/gemma-3-27b-it", "Gemma 3 27B"),
),
}
def provider_configuration_error(provider: str) -> str | None:
normalized = normalize_provider_name(provider)
if normalized is None:
supported = ", ".join(SUPPORTED_PROVIDERS)
return f"Unsupported provider `{provider}`. Supported providers: {supported}."
if normalized == "copilot" and not settings.GITHUB_TOKEN:
return "GitHub Copilot is not configured. Set GITHUB_TOKEN in Infisical and redeploy."
if normalized == "openai" and not settings.OPENAI_API_KEY:
return "OpenAI is not configured. Set OPENAI_API_KEY in Infisical and redeploy."
if normalized == "openrouter" and not settings.OPENROUTER_API_KEY:
return "OpenRouter is not configured. Add OPENROUTER_API_KEY in Infisical and redeploy."
if normalized == "vercel" and not settings.VERCEL_API_KEY:
return "Vercel AI Gateway is not configured. Add VERCEL_API_KEY in Infisical and redeploy."
if normalized == "huggingface" and not settings.HUGGINGFACE_API_KEY:
return "HuggingFace is not configured. Add HUGGINGFACE_API_KEY in Infisical and redeploy."
return None
def provider_is_configured(provider: str) -> bool:
return provider_configuration_error(provider) is None
def build_provider_options() -> list[dict[str, Any]]:
return [
{
"id": provider,
"label": PROVIDER_DISPLAY_NAMES[provider],
"available": provider_is_configured(provider),
"reason": provider_configuration_error(provider),
}
for provider in SUPPORTED_PROVIDERS
]
def build_curated_model_options() -> dict[str, list[dict[str, str]]]:
return {
provider: [
{
"id": model_name,
"label": model_name,
"description": description,
"ref": f"{provider}:{model_name}",
}
for model_name, description in CURATED_MODELS[provider]
]
for provider in SUPPORTED_PROVIDERS
}
class ModelSelectionError(ValueError):
pass
@dataclass(frozen=True)
class ModelSelection:
provider: str
model: str
@property
def ref(self) -> str:
return f"{self.provider}:{self.model}"
@property
def supports_hosted_web_search(self) -> bool:
return self.provider == "openai" and is_official_openai_base_url(settings.OPENAI_BASE_URL)
@property
def likely_supports_vision(self) -> bool:
"""Heuristic: return False for model families known not to accept images."""
name = self.model.lower()
# Strip vendor prefix (e.g. "google/gemma-4-31b-it" → "gemma-4-31b-it")
if "/" in name:
name = name.split("/", 1)[1]
return not any(frag in name for frag in _NO_VISION_FRAGMENTS)
def default_selection() -> ModelSelection:
return resolve_selection(model=settings.DEFAULT_MODEL, provider="openai")
def resolve_selection(
*,
model: str | None,
provider: str | None = None,
current: ModelSelection | None = None,
) -> ModelSelection:
raw_model = (model or (current.model if current else settings.DEFAULT_MODEL)).strip()
raw_provider = (provider or "").strip().lower()
if not raw_model:
raise ModelSelectionError("Model cannot be empty.")
explicit_provider, explicit_model = _split_explicit_provider(raw_model)
if explicit_provider:
raw_provider = explicit_provider
raw_model = explicit_model
if not raw_provider:
inferred = _infer_provider(raw_model)
if inferred:
raw_provider = inferred
elif current:
raw_provider = current.provider
else:
raw_provider = "openai"
raw_provider, raw_model = _normalize_vendor_alias(raw_provider, raw_model)
if raw_provider not in SUPPORTED_PROVIDERS:
supported = ", ".join(SUPPORTED_PROVIDERS)
raise ModelSelectionError(f"Unsupported provider `{raw_provider}`. Supported providers: {supported}.")
normalized_model = _normalize_model_for_provider(raw_model, raw_provider)
_validate_provider_configuration(raw_provider)
_validate_model_for_provider(raw_provider, normalized_model)
return ModelSelection(provider=raw_provider, model=normalized_model)
def build_provider_config(selection: ModelSelection) -> dict[str, Any] | None:
"""Build a Copilot SDK ProviderConfig dict for the given selection.
Returns None for the copilot provider the SDK uses the built-in
GitHub Copilot model catalog when no custom provider is passed.
"""
if selection.provider == "copilot":
return None
if selection.provider == "openai":
return {
"type": "openai",
"base_url": settings.OPENAI_BASE_URL,
"api_key": settings.OPENAI_API_KEY,
}
if selection.provider == "openrouter":
return {
"type": "openai",
"base_url": settings.OPENROUTER_BASE_URL,
"api_key": settings.OPENROUTER_API_KEY,
}
if selection.provider == "vercel":
return {
"type": "openai",
"base_url": settings.VERCEL_BASE_URL,
"api_key": settings.VERCEL_API_KEY,
}
if selection.provider == "huggingface":
return {
"type": "openai",
"base_url": settings.HUGGINGFACE_BASE_URL,
"api_key": settings.HUGGINGFACE_API_KEY,
}
raise ModelSelectionError(f"Unsupported provider `{selection.provider}`.")
def format_selection(selection: ModelSelection) -> str:
return f"Provider: `{selection.provider}`\nModel: `{selection.model}`"
def normalize_provider_name(provider: str | None) -> str | None:
if provider is None:
return None
normalized = provider.strip().lower()
if not normalized:
return None
if normalized == "hf":
return "huggingface"
if normalized in VENDOR_PROVIDER_ALIASES:
return "vercel"
if normalized in SUPPORTED_PROVIDERS:
return normalized
supported = ", ".join(SUPPORTED_PROVIDERS)
raise ModelSelectionError(f"Unsupported provider `{normalized}`. Supported providers: {supported}.")
def format_known_models(*, current: ModelSelection | None = None, provider: str | None = None) -> str:
requested_provider = normalize_provider_name(provider)
provider_names = (requested_provider,) if requested_provider else SUPPORTED_PROVIDERS
lines: list[str] = []
if current is not None:
lines.append("Current selection")
lines.append(format_selection(current))
lines.append("")
lines.append("Suggested models")
for provider_name in provider_names:
lines.append("")
lines.append(PROVIDER_DISPLAY_NAMES[provider_name])
for model_name, description in CURATED_MODELS[provider_name]:
lines.append(f"- `{model_name}` - {description}")
lines.append("")
lines.append("Usage")
lines.append("- `/models` - show the curated list")
lines.append("- `/provider copilot` - switch to GitHub Copilot subscription models")
lines.append("- `/provider openrouter` - switch backend provider")
lines.append("- `/provider vercel` - switch backend provider")
lines.append("- `/provider huggingface` - switch to HuggingFace serverless inference")
lines.append("- `/model copilot:anthropic/claude-opus-4.6` - switch provider and model at once")
lines.append("- `/model huggingface:Qwen/QwQ-32B` - use HuggingFace with specific model (shorthand: `hf:`)")
lines.append("- `/model anthropic/claude-sonnet-4.5` - switch model inside the current provider")
copilot_error = provider_configuration_error("copilot")
if requested_provider in (None, "copilot") and copilot_error:
lines.append("")
lines.append(copilot_error)
openrouter_error = provider_configuration_error("openrouter")
if requested_provider in (None, "openrouter") and openrouter_error:
lines.append("")
lines.append(openrouter_error)
vercel_error = provider_configuration_error("vercel")
if requested_provider in (None, "vercel") and vercel_error:
lines.append("")
lines.append(vercel_error)
huggingface_error = provider_configuration_error("huggingface")
if requested_provider in (None, "huggingface") and huggingface_error:
lines.append("")
lines.append(huggingface_error)
return "\n".join(lines).strip()
def is_official_openai_base_url(base_url: str) -> bool:
return base_url.rstrip("/") in OFFICIAL_OPENAI_BASE_URLS
def _split_explicit_provider(model: str) -> tuple[str | None, str]:
provider, separator, remainder = model.partition(":")
if not separator:
return None, model
normalized = provider.strip().lower()
if normalized == "hf":
normalized = "huggingface"
if normalized in SUPPORTED_PROVIDERS or normalized in VENDOR_PROVIDER_ALIASES:
return normalized, remainder.strip()
return None, model
def _infer_provider(model: str) -> str | None:
"""Return the provider a model clearly belongs to, or *None* for ambiguous names."""
normalized = model.strip().lower()
if normalized.startswith(OPENAI_PREFIXES):
return "openai"
if normalized.startswith(("claude-", "gemini-", "deepseek-", "anthropic/", "google/", "deepseek/")):
return "vercel"
if normalized.startswith(("openai/", "xai/")):
return "vercel"
# Any vendor/model format (e.g. zai/glm-5.1) defaults to Vercel gateway
if "/" in normalized:
return "vercel"
return None
def _normalize_copilot_model(model: str) -> str:
"""Strip vendor prefix (e.g. anthropic/claude-opus-4.6 -> claude-opus-4.6) for Copilot models."""
if "/" in model:
return model.split("/", 1)[1]
return model
def _normalize_vendor_alias(provider: str, model: str) -> tuple[str, str]:
normalized = provider.strip().lower()
if normalized in VENDOR_PROVIDER_ALIASES:
vendor = VENDOR_PROVIDER_ALIASES[normalized]
stripped_model = model.strip()
if not stripped_model.startswith(f"{vendor}/"):
stripped_model = f"{vendor}/{stripped_model}"
return "vercel", stripped_model
return normalized, model.strip()
def _normalize_model_for_provider(model: str, provider: str) -> str:
normalized = model.strip()
lowered = normalized.lower()
if provider == "copilot":
return _normalize_copilot_model(normalized)
if provider == "openai":
if lowered.startswith("openai/"):
return normalized.split("/", 1)[1]
return normalized
if provider == "openrouter":
if "/" in normalized:
return normalized
for prefix, vendor in OPENROUTER_VENDOR_HINTS:
if lowered.startswith(prefix):
return f"{vendor}/{normalized}"
return normalized
if provider == "vercel":
if "/" in normalized:
return normalized
for prefix, vendor in VERCEL_VENDOR_HINTS:
if lowered.startswith(prefix):
return f"{vendor}/{normalized}"
return normalized
if provider == "huggingface":
# HuggingFace models are always org/model format — pass through as-is
return normalized
return normalized
def _validate_provider_configuration(provider: str) -> None:
issue = provider_configuration_error(provider)
if issue:
raise ModelSelectionError(issue)
def _validate_model_for_provider(provider: str, model: str) -> None:
lowered = model.lower()
if provider != "openai":
return
if lowered.startswith(("anthropic/", "google/", "deepseek/", "claude-", "gemini-", "deepseek-")):
raise ModelSelectionError(
"That model is not an OpenAI model. `/model` only switches models inside the active provider. For Claude or Gemini, use `/provider vercel` or `/model vercel:<vendor/model>`."
)

9
prompt_utils.py Normal file
View file

@ -0,0 +1,9 @@
import difflib
def generate_diff(old: str, new: str) -> str:
"""Return a unified diff between two prompt strings, or empty string if identical."""
old_lines = old.splitlines(keepends=True)
new_lines = new.splitlines(keepends=True)
diff = difflib.unified_diff(old_lines, new_lines, fromfile="current", tofile="proposed")
return "".join(diff)

16
provisioners/__init__.py Normal file
View file

@ -0,0 +1,16 @@
"""Service account provisioners.
Importing this module registers all available provisioners with the
global provisioner registry.
"""
from config import settings
from provisioners.base import provisioner_registry
from provisioners.karakeep import KarakeepProvisioner
from provisioners.vikunja import VikunjaProvisioner
if settings.VIKUNJA_ADMIN_API_KEY and settings.VIKUNJA_API_URL:
provisioner_registry.register(VikunjaProvisioner())
if settings.KARAKEEP_ADMIN_API_KEY and settings.KARAKEEP_API_URL:
provisioner_registry.register(KarakeepProvisioner())

91
provisioners/base.py Normal file
View file

@ -0,0 +1,91 @@
"""Provisioner base class and registry (T029)."""
from __future__ import annotations
import asyncio
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from user_store import ServiceCredential, User, UserStore
logger = logging.getLogger(__name__)
@dataclass
class ProvisionResult:
"""Result of a provisioning attempt."""
success: bool
service_username: str | None = None
service_user_id: str | None = None
service_url: str | None = None
error: str | None = None
class ServiceProvisioner(ABC):
"""Abstract base for service account provisioners."""
@property
@abstractmethod
def service_name(self) -> str:
"""Unique service key (e.g. 'vikunja', 'karakeep')."""
@property
def capabilities(self) -> list[str]:
"""Human-readable capability labels."""
return []
@property
def requires_consent(self) -> bool:
"""Whether the user must explicitly opt-in before provisioning."""
return True
@abstractmethod
async def provision(self, user: User, store: UserStore) -> ProvisionResult:
"""Create a service account for *user* and store credentials.
Implementations MUST:
- Generate credentials (username, password, API token)
- Store the token via ``store.store_credential(...)``
- Log to ``store.log_provisioning(...)``
- Attempt rollback on partial failure
"""
@abstractmethod
async def health_check(self, credential: ServiceCredential) -> bool:
"""Return True if *credential* is still valid and usable."""
# ── Provisioner registry ─────────────────────────────────────────
class ProvisionerRegistry:
"""Central registry of service provisioners."""
def __init__(self) -> None:
self._provisioners: dict[str, ServiceProvisioner] = {}
self._locks: dict[str, asyncio.Lock] = {}
def register(self, provisioner: ServiceProvisioner) -> None:
self._provisioners[provisioner.service_name] = provisioner
logger.info("Registered provisioner: %s", provisioner.service_name)
def get(self, service_name: str) -> ServiceProvisioner | None:
return self._provisioners.get(service_name)
@property
def available(self) -> dict[str, ServiceProvisioner]:
return dict(self._provisioners)
def get_lock(self, user_id: str, service: str) -> asyncio.Lock:
"""Per-user, per-service provisioning lock to prevent races."""
key = f"{user_id}:{service}"
if key not in self._locks:
self._locks[key] = asyncio.Lock()
return self._locks[key]
provisioner_registry = ProvisionerRegistry()

75
provisioners/karakeep.py Normal file
View file

@ -0,0 +1,75 @@
"""Karakeep service account provisioner (T031).
Karakeep does NOT expose an admin user-creation API via its REST surface.
The tRPC admin routes are not accessible through /api/v1/. Therefore
Karakeep remains owner-only: the owner's API key is migrated from env
vars during bootstrap, and no automated provisioning is available for
other users.
This provisioner stub exists so the registry can report Karakeep as
'available but not provisionable' during onboarding.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
import httpx
from config import settings
from provisioners.base import ProvisionResult, ServiceProvisioner
if TYPE_CHECKING:
from user_store import ServiceCredential, User, UserStore
logger = logging.getLogger(__name__)
class KarakeepProvisioner(ServiceProvisioner):
@property
def service_name(self) -> str:
return "karakeep"
@property
def capabilities(self) -> list[str]:
return ["bookmarks"]
@property
def requires_consent(self) -> bool:
return True
async def provision(self, user: User, store: UserStore) -> ProvisionResult:
"""Karakeep does not support automatic user provisioning."""
store.log_provisioning(
user.id,
"karakeep",
"provision_failed",
'{"reason": "Karakeep admin API does not support user creation"}',
)
return ProvisionResult(
success=False,
error=(
"Karakeep doesn't support automatic account creation. "
"Ask the admin to create an account for you at "
f"{settings.KARAKEEP_API_URL.rstrip('/').replace('/api/v1', '')}, "
"then share your API key with me."
),
)
async def health_check(self, credential: ServiceCredential) -> bool:
"""Verify the stored API key still works."""
from user_store import decrypt
base_url = settings.KARAKEEP_API_URL.rstrip("/")
try:
token = decrypt(credential.encrypted_token)
async with httpx.AsyncClient(timeout=10) as client:
resp = await client.get(
f"{base_url}/users/me",
headers={"Authorization": f"Bearer {token}"},
)
return resp.status_code == 200
except Exception:
logger.exception("Karakeep health check failed")
return False

226
provisioners/vikunja.py Normal file
View file

@ -0,0 +1,226 @@
"""Vikunja service account provisioner (T030).
Creates a Vikunja account via the public registration API, then activates
the user by setting status=0 through direct DB access (docker exec into
the Vikunja DB container). This is needed because Vikunja marks new local
users as 'email confirmation required' when the mailer is enabled.
"""
from __future__ import annotations
import json
import logging
import re
import secrets
import subprocess
from typing import TYPE_CHECKING
import httpx
from config import settings
from provisioners.base import ProvisionResult, ServiceProvisioner
if TYPE_CHECKING:
from user_store import ServiceCredential, User, UserStore
logger = logging.getLogger(__name__)
_VIKUNJA_DB_CONTAINER = "vikunja-db"
_MAX_USERNAME_RETRIES = 3
def _slugify(name: str) -> str:
"""Turn a display name into a safe username fragment."""
slug = re.sub(r"[^a-zA-Z0-9]", "", name).lower()
return slug[:12] if slug else "user"
def _generate_username(display_name: str) -> str:
suffix = secrets.token_hex(3)
return f"{_slugify(display_name)}_{suffix}"
def _generate_password() -> str:
return secrets.token_urlsafe(16)
class VikunjaProvisioner(ServiceProvisioner):
@property
def service_name(self) -> str:
return "vikunja"
@property
def capabilities(self) -> list[str]:
return ["tasks"]
@property
def requires_consent(self) -> bool:
return True
async def provision(self, user: User, store: UserStore) -> ProvisionResult:
base_url = settings.VIKUNJA_API_URL.rstrip("/")
vikunja_user_id: str | None = None
for attempt in range(_MAX_USERNAME_RETRIES):
username = _generate_username(user.display_name)
password = _generate_password()
email = f"{username}@code.bytesizeprotip.com"
try:
# Step 1: Register
async with httpx.AsyncClient(timeout=15) as client:
reg_resp = await client.post(
f"{base_url}/register",
json={"username": username, "password": password, "email": email},
)
if reg_resp.status_code == 400 and "already exists" in reg_resp.text.lower():
logger.warning(
"Username %s taken, retrying (%d/%d)", username, attempt + 1, _MAX_USERNAME_RETRIES
)
continue
if reg_resp.status_code not in (200, 201):
return ProvisionResult(
success=False,
error=f"Registration failed ({reg_resp.status_code}): {reg_resp.text[:200]}",
)
reg_data = reg_resp.json()
vikunja_user_id = str(reg_data.get("id", ""))
# Step 2: Admin-side activation — set user status=0 via DB
if vikunja_user_id:
try:
_activate_vikunja_user(vikunja_user_id)
except Exception:
logger.exception("Failed to activate Vikunja user %s via DB", vikunja_user_id)
# Continue anyway — user may still be able to log in
# Step 3: Login to get JWT
login_resp = await client.post(
f"{base_url}/login",
json={"username": username, "password": password},
)
if login_resp.status_code != 200:
store.log_provisioning(
user.id,
"vikunja",
"provision_failed",
json.dumps({"step": "login", "status": login_resp.status_code}),
)
return ProvisionResult(
success=False,
service_user_id=vikunja_user_id,
error=f"Login failed after registration ({login_resp.status_code})",
)
jwt_token = login_resp.json().get("token", "")
# Step 4: Create long-lived API token
from datetime import datetime, timedelta, timezone
expires = datetime.now(timezone.utc) + timedelta(days=365)
token_resp = await client.put(
f"{base_url}/tokens",
json={
"title": "CodeAnywhere",
"expires_at": expires.strftime("%Y-%m-%dT%H:%M:%S+00:00"),
"right": 2, # read+write
},
headers={"Authorization": f"Bearer {jwt_token}"},
)
if token_resp.status_code not in (200, 201):
store.log_provisioning(
user.id,
"vikunja",
"provision_failed",
json.dumps({"step": "create_token", "status": token_resp.status_code}),
)
return ProvisionResult(
success=False,
service_user_id=vikunja_user_id,
error=f"Token creation failed ({token_resp.status_code})",
)
api_token = token_resp.json().get("token", "")
# Step 5: Store credential
store.store_credential(
user.id,
"vikunja",
api_token,
service_user_id=vikunja_user_id,
service_username=username,
expires_at=expires.isoformat(),
)
store.log_provisioning(
user.id,
"vikunja",
"provisioned",
json.dumps({"username": username, "vikunja_user_id": vikunja_user_id}),
)
logger.info("Provisioned Vikunja account %s for user %s", username, user.id)
return ProvisionResult(
success=True,
service_username=username,
service_user_id=vikunja_user_id,
service_url=settings.VIKUNJA_API_URL,
)
except httpx.HTTPError as exc:
store.log_provisioning(
user.id,
"vikunja",
"provision_failed",
json.dumps({"error": str(exc)}),
)
return ProvisionResult(success=False, error=f"HTTP error: {exc}")
# All retries exhausted
store.log_provisioning(
user.id,
"vikunja",
"provision_failed",
json.dumps({"error": "username collision after max retries"}),
)
return ProvisionResult(success=False, error="Could not find available username after retries")
async def health_check(self, credential: ServiceCredential) -> bool:
"""Verify the stored API token still works."""
from user_store import decrypt
base_url = settings.VIKUNJA_API_URL.rstrip("/")
try:
token = decrypt(credential.encrypted_token)
async with httpx.AsyncClient(timeout=10) as client:
resp = await client.get(
f"{base_url}/tasks",
headers={"Authorization": f"Bearer {token}"},
params={"per_page": 1},
)
return resp.status_code == 200
except Exception:
logger.exception("Vikunja health check failed")
return False
def _activate_vikunja_user(vikunja_user_id: str) -> None:
"""Set Vikunja user status to 0 (active) via docker exec into the DB."""
# Sanitise: vikunja_user_id must be numeric
if not vikunja_user_id.isdigit():
raise ValueError(f"Invalid Vikunja user ID: {vikunja_user_id}")
cmd = [
"docker",
"exec",
_VIKUNJA_DB_CONTAINER,
"psql",
"-U",
"vikunja",
"-d",
"vikunja",
"-c",
f"UPDATE users SET status = 0 WHERE id = {vikunja_user_id};",
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
if result.returncode != 0:
logger.error("DB activation failed: %s", result.stderr)
raise RuntimeError(f"DB activation failed: {result.stderr}")
logger.info("Activated Vikunja user %s via DB", vikunja_user_id)

15
pyproject.toml Normal file
View file

@ -0,0 +1,15 @@
[project]
name = "code-anywhere"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"github-copilot-sdk>=0.2.2",
"python-telegram-bot>=22.7",
]
[dependency-groups]
dev = [
"ipykernel>=7.2.0",
]

View file

@ -1,2 +1,9 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
github-copilot-sdk>=0.2.2
openai>=1.0.0 openai>=1.0.0
python-telegram-bot>=21.0 python-telegram-bot>=21.0
pydantic-settings>=2.0
httpx>=0.27.0
duckduckgo-search>=7.0
cryptography>=43.0

2
ruff.toml Normal file
View file

@ -0,0 +1,2 @@
line-length = 120
target-version = "py312"

View file

@ -5,27 +5,34 @@ set -euo pipefail
REPO_DIR="${REPO_DIR:-/opt/src/betterbot}" REPO_DIR="${REPO_DIR:-/opt/src/betterbot}"
STACK_DIR="${STACK_DIR:-/opt/betterbot}" STACK_DIR="${STACK_DIR:-/opt/betterbot}"
CONFIG_DIR="${CONFIG_DIR:-/opt/config}"
cd "$REPO_DIR" cd "$REPO_DIR"
git pull --ff-only origin master git pull --ff-only origin master
# ── Copy compose ────────────────────────────────────────────────────
mkdir -p "$STACK_DIR" mkdir -p "$STACK_DIR"
cp "$REPO_DIR/compose/docker-compose.yml" "$STACK_DIR/docker-compose.yml"
cp compose/docker-compose.yml "$STACK_DIR/docker-compose.yml" # ── Seed defaults from committed .env ───────────────────────────────
cp "$REPO_DIR/.env" "$STACK_DIR/defaults.env"
# Seed .env from example if it doesn't exist # ── First-run guard ─────────────────────────────────────────────────
if [ ! -f "$STACK_DIR/.env" ]; then if [ ! -f "$STACK_DIR/.env" ]; then
cp compose/.env.example "$STACK_DIR/.env" cp "$REPO_DIR/compose/.env.example" "$STACK_DIR/.env"
echo "WARNING: $STACK_DIR/.env created from template — edit it with real secrets." echo "⚠ First deploy — edit $STACK_DIR/.env with real secrets, then re-run."
exit 0
fi fi
# Fetch secrets from Infisical if available # ── Fetch secrets from Infisical ────────────────────────────────────
if [ -f /opt/config/infisical-agent.env ] && [ -f /opt/src/self_hosting/infra/scripts/infisical-env.sh ]; then INFISICAL_AGENT_ENV="$CONFIG_DIR/infisical-agent.env"
source /opt/src/self_hosting/infra/scripts/infisical-env.sh if [ -f "$INFISICAL_AGENT_ENV" ]; then
infisical_fetch racknerd-betterbot "$STACK_DIR/.env" INFRA_DIR="${INFRA_DIR:-/opt/src/self_hosting/infra}"
source "$INFRA_DIR/scripts/infisical-env.sh"
infisical_fetch racknerd-betterbot "$STACK_DIR/.env" prod || echo "Warning: Infisical fetch failed, using existing .env"
fi fi
# Configure git identity for the bot's commits in mounted repos # ── Configure git identity for site repos ───────────────────────────
for repo_dir in /opt/src/betterlifesg /opt/src/hk_memoraiz; do for repo_dir in /opt/src/betterlifesg /opt/src/hk_memoraiz; do
if [ -d "$repo_dir/.git" ]; then if [ -d "$repo_dir/.git" ]; then
git -C "$repo_dir" config user.email "betterbot@bytesizeprotip.com" git -C "$repo_dir" config user.email "betterbot@bytesizeprotip.com"
@ -33,6 +40,7 @@ for repo_dir in /opt/src/betterlifesg /opt/src/hk_memoraiz; do
fi fi
done done
# ── Deploy ──────────────────────────────────────────────────────────
cd "$STACK_DIR" cd "$STACK_DIR"
docker compose build --pull docker compose build --pull
docker compose up -d docker compose up -d

2437
static/index.html Normal file

File diff suppressed because it is too large Load diff

2012
telegram_bot.py Normal file

File diff suppressed because it is too large Load diff

311
tool_pipeline.py Normal file
View file

@ -0,0 +1,311 @@
"""Shared tool pipeline — single source of truth for user context, tool set
activation, and integration tool building.
Extracted from main.py and telegram_bot.py (T009) to eliminate code
duplication. Both entry points call these functions instead of maintaining
their own copies.
T015/T020/T021: Added UserContext with lazy credential decryption and
per-user tool pipeline functions.
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import Any
from config import settings
from tool_registry import registry as tool_registry
logger = logging.getLogger(__name__)
# ── UserContext (T015) ───────────────────────────────────────────
class UserContext:
"""Per-request context that lazily decrypts user credentials.
Credentials are decrypted on first access via ``get_credential()``,
cached for the duration of one request, and discarded when the object
goes out of scope. Expired credentials return ``None`` to signal
re-provisioning.
"""
def __init__(self, user: Any, store: Any) -> None:
from user_store import User # deferred to avoid circular import
self.user: User = user
self._store = store
self._cache: dict[str, str | None] = {}
def get_credential(self, service: str) -> str | None:
"""Return the decrypted API token for *service*, or None."""
if service in self._cache:
return self._cache[service]
from user_store import decrypt
cred = self._store.get_credential(self.user.id, service)
if cred is None:
self._cache[service] = None
return None
# Check expiry
if cred.expires_at:
try:
exp = datetime.fromisoformat(cred.expires_at)
if exp < datetime.now(timezone.utc):
logger.warning(
"Credential for %s/%s expired at %s",
self.user.id,
service,
cred.expires_at,
)
self._cache[service] = None
return None
except ValueError:
pass # ignore unparseable expiry — treat as valid
try:
token = decrypt(cred.encrypted_token)
except Exception:
logger.exception("Failed to decrypt credential for %s/%s", self.user.id, service)
self._cache[service] = None
return None
self._cache[service] = token
# Touch last_used_at
self._store.touch_credential(self.user.id, service)
return token
# ── Shared pipeline functions ────────────────────────────────────
# Service name → context key mapping used by tool set factories.
_SERVICE_CONTEXT_KEYS: dict[str, dict[str, str]] = {
"vikunja": {
"vikunja_api_url": "VIKUNJA_API_URL",
"vikunja_api_key": "__credential__",
},
"karakeep": {
"karakeep_api_url": "KARAKEEP_API_URL",
"karakeep_api_key": "__credential__",
},
}
def active_toolset_names(user: Any | None = None) -> list[str]:
"""Return names of tool sets whose required credentials are available.
When *user* is given (a ``User`` from user_store), credentials are
checked via the credential vault. When ``None``, falls back to the
legacy env-var check for backwards compatibility.
"""
if user is None:
# Legacy path — env-var based
ctx = _build_legacy_context()
active: list[str] = []
for name, ts in tool_registry.available.items():
if all(ctx.get(k) for k in ts.required_keys):
active.append(name)
return active
# Per-user path
from user_store import get_store
store = get_store()
uctx = UserContext(user, store)
ctx = _build_user_context_dict(user, uctx)
active = []
for name, ts in tool_registry.available.items():
if all(ctx.get(k) for k in ts.required_keys):
active.append(name)
return active
def build_user_context(user: Any | None = None) -> dict[str, Any]:
"""Build the context dict consumed by tool set factories.
When *user* is given, credentials come from the per-user vault.
When ``None``, falls back to env-var credentials.
"""
if user is None:
return _build_legacy_context()
from user_store import get_store
store = get_store()
uctx = UserContext(user, store)
return _build_user_context_dict(user, uctx)
def build_integration_tools(
user: Any | None = None,
*,
thread_id: str | None = None,
system_prompt: str | None = None,
) -> list:
"""Build Copilot SDK tools for all active integrations.
Always returns a list (possibly empty), never None.
"""
active = active_toolset_names(user)
if not active:
return []
ctx = build_user_context(user)
if thread_id is not None:
ctx["_thread_id"] = thread_id
if system_prompt is not None:
ctx["_system_prompt"] = system_prompt
return tool_registry.get_tools(active, ctx)
async def try_provision(user: Any, service: str) -> str | None:
"""Attempt to provision *service* for *user* with lock.
Returns a human-readable status message, or None on success.
Uses a per-user per-service lock to prevent concurrent provisioning.
"""
from provisioners.base import provisioner_registry
from user_store import get_store
provisioner = provisioner_registry.get(service)
if provisioner is None:
return f"No provisioner available for {service}."
store = get_store()
# Check if already provisioned
existing = store.get_credential(user.id, service)
if existing:
return None # already has credentials
lock = provisioner_registry.get_lock(user.id, service)
if lock.locked():
return f"Provisioning {service} is already in progress."
async with lock:
# Double-check after acquiring lock
existing = store.get_credential(user.id, service)
if existing:
return None
try:
result = await provisioner.provision(user, store)
except Exception:
logger.exception("Provisioning %s for user %s failed", service, user.id)
store.log_provisioning(user.id, service, "provision_failed", '{"error": "unhandled exception"}')
return f"Failed to set up {service}. The error has been logged."
if not result.success:
return result.error or f"Failed to set up {service}."
# Mark onboarding complete on first successful provisioning
if not user.onboarding_complete:
store.set_onboarding_complete(user.id)
user.onboarding_complete = True
return None
def get_provisionable_services(user: Any) -> list[dict[str, str]]:
"""Return list of services that could be provisioned for *user*."""
from provisioners.base import provisioner_registry
from user_store import get_store
store = get_store()
existing = store.get_credentials(user.id)
result = []
for name, prov in provisioner_registry.available.items():
status = "active" if name in existing else "available"
result.append(
{
"service": name,
"capabilities": ", ".join(prov.capabilities),
"status": status,
}
)
return result
def build_onboarding_fragment(user: Any) -> str | None:
"""Return a system prompt fragment for new/onboarding users (T035)."""
if user.onboarding_complete:
return None
return (
"This is a new user who hasn't set up any services yet. "
"Welcome them warmly, explain what you can help with, "
"and offer to set up their accounts. Available services: "
+ ", ".join(
f"{s['capabilities']} ({s['service']})"
for s in get_provisionable_services(user)
if s["status"] == "available"
)
+ ". Ask them which services they'd like to activate."
)
def build_capability_fragment(user: Any) -> str | None:
"""Return a system prompt fragment showing active/available capabilities (T036)."""
services = get_provisionable_services(user)
if not services:
return None
active = [s for s in services if s["status"] == "active"]
available = [s for s in services if s["status"] == "available"]
parts = []
if active:
parts.append("Active services: " + ", ".join(f"{s['capabilities']} ({s['service']})" for s in active))
if available:
parts.append(
"Available (say 'set up <service>' to activate): "
+ ", ".join(f"{s['capabilities']} ({s['service']})" for s in available)
)
return "\n".join(parts)
def build_provisioning_confirmation(service: str, result: Any) -> str | None:
"""Return a system prompt fragment confirming provisioning (T038)."""
from config import settings as _settings
if not result or not result.success:
return None
parts = [f"I just created a {service} account for this user."]
if result.service_username:
parts.append(f"Username: {result.service_username}.")
if result.service_url:
parts.append(f"Service URL: {result.service_url}")
if not _settings.ALLOW_CREDENTIAL_REVEAL_IN_CHAT:
parts.append("Do NOT reveal the password in chat — Telegram messages are not E2E encrypted.")
return " ".join(parts)
# ── Internal helpers ─────────────────────────────────────────────
def _build_legacy_context() -> dict[str, Any]:
"""Build user context from static env-var settings (owner-only path)."""
return {
"vikunja_api_url": settings.VIKUNJA_API_URL,
"vikunja_api_key": settings.VIKUNJA_API_KEY,
"vikunja_memory_path": settings.VIKUNJA_MEMORY_PATH,
"memory_owner": "default",
"karakeep_api_url": settings.KARAKEEP_API_URL,
"karakeep_api_key": settings.KARAKEEP_API_KEY,
"_user": None,
}
def _build_user_context_dict(user: Any, uctx: UserContext) -> dict[str, Any]:
"""Build a context dict from per-user credentials."""
ctx: dict[str, Any] = {
"vikunja_memory_path": settings.VIKUNJA_MEMORY_PATH,
"memory_owner": user.id,
"_user": user, # passed through for meta-tools
}
for service, key_map in _SERVICE_CONTEXT_KEYS.items():
for ctx_key, source in key_map.items():
if source == "__credential__":
val = uctx.get_credential(service)
else:
val = getattr(settings, source, "")
if val:
ctx[ctx_key] = val
return ctx

211
tool_registry.py Normal file
View file

@ -0,0 +1,211 @@
"""Unified tool registry for the agent backend.
A ToolSet groups related tools (e.g. "vikunja", "karakeep") with:
- a factory that builds Copilot SDK Tool objects given per-user context
- a system prompt fragment injected when the tool set is active
- an optional capability label (e.g. "tasks", "bookmarks")
The registry collects tool sets and resolves the right tools + prompt
fragments for a given user at request time.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable
from copilot import define_tool
from copilot.tools import Tool, ToolInvocation
from pydantic import create_model
logger = logging.getLogger(__name__)
# ── ToolSet ──────────────────────────────────────────────────────
ToolFactory = Callable[[dict[str, Any]], list[Tool]]
@dataclass
class ToolSet:
"""A named collection of tools with a system prompt fragment.
Attributes:
name: Unique identifier (e.g. "vikunja", "karakeep").
description: Human-readable description displayed during onboarding.
capability: Abstract capability label (e.g. "tasks", "bookmarks").
Multiple tool sets can share a capability, but a user
should only have one active per capability.
system_prompt_fragment: Text appended to the system message when
this tool set is active for the user.
build_tools: Factory function that receives a user_context dict
(credentials, config) and returns Copilot SDK Tool
instances. The dict keys are tool-set-specific.
required_keys: Context keys that must be present for this tool set
to be usable (e.g. ["vikunja_api_url", "vikunja_api_key"]).
"""
name: str
description: str
capability: str
system_prompt_fragment: str
build_tools: ToolFactory
required_keys: list[str] = field(default_factory=list)
# ── ToolRegistry ─────────────────────────────────────────────────
class ToolRegistry:
"""Central registry of available tool sets."""
def __init__(self) -> None:
self._toolsets: dict[str, ToolSet] = {}
def register(self, toolset: ToolSet) -> None:
if toolset.name in self._toolsets:
logger.warning("Replacing existing tool set %r", toolset.name)
self._toolsets[toolset.name] = toolset
logger.info(
"Registered tool set %r (%s, %d required keys)",
toolset.name,
toolset.capability,
len(toolset.required_keys),
)
@property
def available(self) -> dict[str, ToolSet]:
return dict(self._toolsets)
def get_tools(
self,
active_names: list[str],
user_context: dict[str, Any],
) -> list[Tool]:
"""Build Copilot SDK tools for the requested tool sets.
Skips tool sets whose required context keys are missing.
"""
tools: list[Tool] = []
for name in active_names:
ts = self._toolsets.get(name)
if ts is None:
logger.warning("Requested unknown tool set %r — skipped", name)
continue
missing = [k for k in ts.required_keys if not user_context.get(k)]
if missing:
logger.warning(
"Tool set %r skipped: missing context keys %s",
name,
missing,
)
continue
try:
tools.extend(ts.build_tools(user_context))
except Exception:
logger.exception("Failed to build tools for %r", name)
return tools
def get_system_prompt_fragments(self, active_names: list[str]) -> list[str]:
"""Return system prompt fragments for the given tool sets."""
fragments: list[str] = []
for name in active_names:
ts = self._toolsets.get(name)
if ts and ts.system_prompt_fragment:
fragments.append(ts.system_prompt_fragment)
return fragments
# ── OpenAI schema bridge ─────────────────────────────────────────
def _json_type_to_python(json_type: str) -> type:
"""Map JSON Schema type strings to Python types for Pydantic."""
mapping: dict[str, type] = {
"string": str,
"integer": int,
"number": float,
"boolean": bool,
"array": list,
"object": dict,
}
return mapping.get(json_type, str)
def openai_tools_to_copilot(
schemas: list[dict[str, Any]],
dispatcher: Callable[..., Awaitable[str]],
context_kwargs: dict[str, Any] | None = None,
) -> list[Tool]:
"""Convert OpenAI function-calling tool schemas + a dispatcher into
Copilot SDK Tool objects.
Args:
schemas: List of OpenAI tool dicts ({"type":"function","function":{...}}).
dispatcher: Async callable with signature
``async def dispatcher(name, arguments, **context_kwargs) -> str``.
It receives the tool name, parsed argument dict, and any extra
keyword arguments from *context_kwargs*.
context_kwargs: Extra keyword arguments forwarded to every dispatcher
call (e.g. vikunja client, memory store).
Returns:
List of Copilot SDK Tool objects ready to pass to create_session().
"""
extra = context_kwargs or {}
tools: list[Tool] = []
for schema in schemas:
func = schema.get("function", {})
name: str = func.get("name", "")
description: str = func.get("description", "")
params_spec: dict = func.get("parameters", {})
properties: dict = params_spec.get("properties", {})
required_fields: list[str] = params_spec.get("required", [])
if not name:
continue
# Build Pydantic model fields dynamically
fields: dict[str, Any] = {}
for prop_name, prop_def in properties.items():
py_type = _json_type_to_python(prop_def.get("type", "string"))
# All non-required fields are optional with None default
if prop_name in required_fields:
fields[prop_name] = (py_type, ...)
else:
fields[prop_name] = (py_type | None, None)
# Create a unique Pydantic model class
model_name = f"Params_{name}"
params_model = create_model(model_name, **fields) # type: ignore[call-overload]
# Capture loop variables in closure
_name = name
_extra = extra
async def _handler(
params: Any,
invocation: ToolInvocation,
*,
_tool_name: str = _name,
_ctx: dict[str, Any] = _extra,
) -> str:
args = params.model_dump(exclude_none=True)
return await dispatcher(_tool_name, args, **_ctx)
tool = define_tool(
name=name,
description=description,
handler=_handler,
params_type=params_model,
)
tools.append(tool)
return tools
# ── Module-level singleton ───────────────────────────────────────
registry = ToolRegistry()

9
tools/__init__.py Normal file
View file

@ -0,0 +1,9 @@
"""Unified tool package.
Importing this module registers all available tool sets with the
global tool_registry.registry singleton via instance.register_tools().
"""
from instance import register_tools
register_tools()

299
tools/advisor/__init__.py Normal file
View file

@ -0,0 +1,299 @@
"""Advisor tool — escalate hard decisions to a stronger model.
The executor can call this tool to consult a stronger advisor model.
The advisor returns guidance text only; it cannot call tools or emit
user-facing text directly.
"""
from __future__ import annotations
import logging
from typing import Any
from openai import AsyncOpenAI
from config import settings
from model_selection import ModelSelection, build_provider_config, resolve_selection
from tool_registry import ToolSet, openai_tools_to_copilot
logger = logging.getLogger(__name__)
# ── System prompt fragment (T004) ────────────────────────────────
ADVISOR_SYSTEM_PROMPT = (
"You have an `advisor` tool backed by a stronger model. "
"Call it before committing to a non-trivial design choice, "
"before deleting code you don't understand, "
"when a test fails for a non-obvious reason, "
"or when you are about to loop. "
"Do not call it for typos, lint, or routine edits."
)
# ── Tool schema (T003) ──────────────────────────────────────────
ADVISOR_TOOL_SCHEMA: list[dict[str, Any]] = [
{
"type": "function",
"function": {
"name": "advisor",
"description": (
"Consult a stronger model for a plan, correction, or stop signal. "
"Call this when you are uncertain about architecture, root cause, or next step. "
"You get back guidance text only; no tool calls are executed by the advisor."
),
"parameters": {
"type": "object",
"required": ["question"],
"properties": {
"question": {
"type": "string",
"description": "What you need help deciding.",
},
"context_summary": {
"type": "string",
"description": "Short summary of relevant state the advisor must know.",
},
"stakes": {
"type": "string",
"enum": ["low", "medium", "high"],
"description": "How critical this decision is.",
},
},
},
},
},
]
# ── Per-run usage counter (T010) ─────────────────────────────────
_usage_counter: dict[str, int] = {}
def reset_advisor_counter(thread_id: str) -> None:
"""Reset the per-run advisor usage counter for a thread."""
_usage_counter.pop(thread_id, None)
def _check_and_increment(thread_id: str) -> str | None:
"""Increment counter; return error string if limit reached, else None."""
current = _usage_counter.get(thread_id, 0)
if current >= settings.ADVISOR_MAX_USES:
return f"Advisor limit reached (max {settings.ADVISOR_MAX_USES} per run). Proceed on your own."
_usage_counter[thread_id] = current + 1
return None
# ── Advisor trace + usage accumulators (T015, T021) ──────────────
_advisor_usage: dict[str, dict[str, int]] = {}
_advisor_traces: dict[str, list[dict[str, Any]]] = {}
def get_advisor_usage(thread_id: str) -> dict[str, int] | None:
"""Return accumulated advisor token usage for a thread, or None."""
return _advisor_usage.get(thread_id)
def get_advisor_traces(thread_id: str) -> list[dict[str, Any]]:
"""Return advisor trace records for a thread."""
return _advisor_traces.get(thread_id, [])
def _reset_advisor_state(thread_id: str) -> None:
"""Clear per-run advisor state (counter, usage, traces)."""
_usage_counter.pop(thread_id, None)
_advisor_usage.pop(thread_id, None)
_advisor_traces.pop(thread_id, None)
# ── Prompt builder (T007) ────────────────────────────────────────
_ADVISOR_SYSTEM_INSTRUCTION = (
"You are an expert advisor. Provide concise, actionable guidance. "
"Do not call tools or produce user-facing text. Focus on the question."
)
_SYSTEM_PROMPT_MAX_CHARS = 500
def _build_advisor_prompt(
question: str,
context_summary: str,
stakes: str,
system_prompt_excerpt: str,
) -> list[dict[str, str]]:
"""Build the message list for the advisor one-shot call."""
# Truncate executor system prompt
trimmed = system_prompt_excerpt
if len(trimmed) > _SYSTEM_PROMPT_MAX_CHARS:
trimmed = trimmed[:_SYSTEM_PROMPT_MAX_CHARS] + "\u2026"
system_parts = [_ADVISOR_SYSTEM_INSTRUCTION]
if trimmed:
system_parts.append(f"Executor context (trimmed):\n{trimmed}")
user_parts = []
if context_summary:
user_parts.append(f"Context: {context_summary}")
user_parts.append(f"Question [{stakes} stakes]: {question}")
return [
{"role": "system", "content": "\n\n".join(system_parts)},
{"role": "user", "content": "\n\n".join(user_parts)},
]
# ── One-shot advisor completion (T008) ───────────────────────────
async def _call_advisor_model(
messages: list[dict[str, str]],
model: str,
max_tokens: int,
provider_config: dict[str, Any] | None,
temperature: float = 0.2,
) -> tuple[str, dict[str, int]]:
"""Send a one-shot, tool-less completion to the advisor model.
Returns (response_text, {"prompt_tokens": N, "completion_tokens": N}).
On any error, returns a fallback string and zero usage.
"""
try:
if provider_config is None:
# Copilot provider — use the Copilot Models API endpoint
from config import settings as _s
client = AsyncOpenAI(
base_url="https://api.githubcopilot.com",
api_key=_s.GITHUB_TOKEN,
)
else:
client = AsyncOpenAI(
base_url=provider_config.get("base_url", ""),
api_key=provider_config.get("api_key", ""),
)
response = await client.chat.completions.create(
model=model,
messages=messages, # type: ignore[arg-type]
max_tokens=max_tokens,
temperature=temperature,
stream=False,
)
text = ""
if response.choices:
text = response.choices[0].message.content or ""
usage_data: dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0}
if response.usage:
usage_data["prompt_tokens"] = response.usage.prompt_tokens or 0
usage_data["completion_tokens"] = response.usage.completion_tokens or 0
return text.strip(), usage_data
except Exception as exc:
logger.warning("Advisor call failed: %s", exc)
return "Advisor unavailable. Proceed with your best judgment.", {"prompt_tokens": 0, "completion_tokens": 0}
# ── Dispatcher (T009) ────────────────────────────────────────────
async def _handle_advisor_call(
arguments: dict[str, Any],
thread_id: str,
system_prompt_excerpt: str,
) -> str:
"""Handle an advisor tool invocation."""
question = arguments.get("question", "")
context_summary = arguments.get("context_summary", "")
stakes = arguments.get("stakes", "medium")
if not question:
return "No question provided."
# Check usage limit
limit_msg = _check_and_increment(thread_id)
if limit_msg:
return limit_msg
# Token/temperature params based on stakes
max_tokens = settings.ADVISOR_MAX_TOKENS
temperature = 0.2
if stakes == "high":
max_tokens *= 2
temperature = 0.4
# Build prompt
messages = _build_advisor_prompt(question, context_summary, stakes, system_prompt_excerpt)
# Resolve advisor model (per-thread override or default)
advisor_model_id = settings.ADVISOR_DEFAULT_MODEL
try:
advisor_selection = resolve_selection(model=advisor_model_id)
except Exception as exc:
logger.warning("Advisor model resolution failed: %s", exc)
return f"Advisor model resolution failed ({exc}). Proceed with your best judgment."
provider_config = build_provider_config(advisor_selection)
# Call advisor
response_text, usage_data = await _call_advisor_model(
messages,
advisor_selection.model,
max_tokens,
provider_config,
temperature=temperature,
)
# Record usage (T015)
existing = _advisor_usage.get(thread_id, {"prompt_tokens": 0, "completion_tokens": 0})
existing["prompt_tokens"] = existing.get("prompt_tokens", 0) + usage_data["prompt_tokens"]
existing["completion_tokens"] = existing.get("completion_tokens", 0) + usage_data["completion_tokens"]
_advisor_usage[thread_id] = existing
# Record trace (T021)
total_tokens = usage_data["prompt_tokens"] + usage_data["completion_tokens"]
_advisor_traces.setdefault(thread_id, []).append(
{
"kind": "advisor",
"question": question,
"guidance": response_text,
"model": advisor_model_id,
"tokens": total_tokens,
}
)
return response_text
# ── Tool factory (T003) ─────────────────────────────────────────
def _build_advisor_tools(user_context: dict[str, Any]) -> list:
"""Factory that creates Copilot SDK advisor tools."""
thread_id = user_context.get("_thread_id", "unknown")
system_prompt = user_context.get("_system_prompt", "")
advisor_model_override = user_context.get("advisor_model")
async def dispatcher(name: str, arguments: dict, **_kw: Any) -> str:
if name == "advisor":
# Allow per-thread model override
if advisor_model_override:
arguments.setdefault("_advisor_model_override", advisor_model_override)
return await _handle_advisor_call(arguments, thread_id, system_prompt)
return f"Unknown advisor tool: {name}"
return openai_tools_to_copilot(schemas=ADVISOR_TOOL_SCHEMA, dispatcher=dispatcher)
# ── ToolSet registration (T003) ──────────────────────────────────
advisor_toolset = ToolSet(
name="advisor",
description="Consult a stronger model on hard decisions",
capability="advisor",
system_prompt_fragment=ADVISOR_SYSTEM_PROMPT,
build_tools=_build_advisor_tools,
required_keys=[],
)

145
tools/meta/__init__.py Normal file
View file

@ -0,0 +1,145 @@
"""Credential meta-tools (T039-T041).
Provides tools for users to inspect their service access and credentials.
"""
from __future__ import annotations
from typing import Any
from tool_registry import ToolSet, openai_tools_to_copilot
META_SYSTEM_PROMPT = """\
You can help users check their service access and available integrations. \
When a user asks about their accounts, services, or credentials, use the \
list_my_services and get_my_credentials tools. \
Never reveal passwords unless the system configuration explicitly allows it.
"""
TOOLS = [
{
"type": "function",
"function": {
"name": "list_my_services",
"description": (
"List all services available to the user, showing which are active "
"(have credentials), which are available for setup, and their capabilities."
),
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
},
{
"type": "function",
"function": {
"name": "get_my_credentials",
"description": (
"Get the user's credentials for a specific service. Returns the username "
"and service URL. Password is only shown if system configuration allows it."
),
"parameters": {
"type": "object",
"properties": {
"service": {
"type": "string",
"description": "The service name (e.g. 'vikunja', 'karakeep')",
},
},
"required": ["service"],
},
},
},
]
def _build_meta_tools(user_context: dict[str, Any]) -> list:
"""Factory that creates Copilot SDK meta-tools."""
user = user_context.get("_user")
async def dispatcher(name: str, arguments: dict, **_kw: Any) -> str:
if name == "list_my_services":
return _handle_list_my_services(user)
if name == "get_my_credentials":
return _handle_get_my_credentials(user, arguments.get("service", ""))
return f"Unknown meta-tool: {name}"
return openai_tools_to_copilot(schemas=TOOLS, dispatcher=dispatcher)
def _handle_list_my_services(user: Any) -> str:
if user is None:
return "Unable to determine your identity."
from tool_pipeline import get_provisionable_services
services = get_provisionable_services(user)
if not services:
return "No services are currently configured."
lines = []
for s in services:
status = s["status"].upper()
caps = s["capabilities"] or "general"
lines.append(f"- **{s['service']}** ({caps}): {status}")
return "Your services:\n" + "\n".join(lines)
def _handle_get_my_credentials(user: Any, service: str) -> str:
if user is None:
return "Unable to determine your identity."
if not service:
return "Please specify which service you want credentials for."
from config import settings
from user_store import get_store
store = get_store()
cred = store.get_credential(user.id, service)
if cred is None:
return f"You don't have credentials for {service}. Say 'set up {service}' to get started."
parts = [f"**Service**: {service}"]
if cred.service_username:
parts.append(f"**Username**: {cred.service_username}")
# Service URL from settings
url_attr = f"{service.upper()}_API_URL"
url = getattr(settings, url_attr, "")
if url:
# Strip /api/v1 suffix for display
display_url = url.rstrip("/")
for suffix in ("/api/v1", "/api"):
if display_url.endswith(suffix):
display_url = display_url[: -len(suffix)]
break
parts.append(f"**URL**: {display_url}")
if cred.expires_at:
parts.append(f"**Expires**: {cred.expires_at}")
if settings.ALLOW_CREDENTIAL_REVEAL_IN_CHAT:
from user_store import decrypt
try:
token = decrypt(cred.encrypted_token)
parts.append(f"**API Token**: `{token}`")
parts.append("⚠️ Be careful — this token grants full access to your account.")
except Exception:
parts.append("**API Token**: (decryption failed)")
else:
parts.append("**Password/Token**: Stored securely. Access the service directly at the URL above.")
return "\n".join(parts)
meta_toolset = ToolSet(
name="meta",
description="Service access and credential information",
capability="account_management",
system_prompt_fragment=META_SYSTEM_PROMPT,
build_tools=_build_meta_tools,
required_keys=[], # Available to all authenticated users
)

View file

@ -0,0 +1,227 @@
"""Site-editing tool set for BetterBot.
Provides list_files, read_file, and write_file tools that operate on
mounted project directories (Better Life SG website, Memoraiz frontend).
After writing, changes are committed and pushed via git.
"""
from __future__ import annotations
import logging
import pathlib
import subprocess
from typing import Any
from config import settings
from tool_registry import ToolSet, openai_tools_to_copilot
logger = logging.getLogger(__name__)
# ── Project definitions ──────────────────────────────────────────
def _build_projects() -> dict[str, dict[str, Any]]:
"""Build the project map from environment-configured directories."""
site_dir = pathlib.Path(settings.SITE_DIR) if hasattr(settings, "SITE_DIR") else pathlib.Path("/site")
memoraiz_dir = pathlib.Path(settings.MEMORAIZ_DIR) if hasattr(settings, "MEMORAIZ_DIR") else pathlib.Path("/memoraiz")
projects: dict[str, dict[str, Any]] = {}
if site_dir.exists():
projects["betterlifesg"] = {
"dir": site_dir,
"label": "Better Life SG website",
"git_repo": site_dir.parent,
}
if memoraiz_dir.exists():
projects["memoraiz"] = {
"dir": memoraiz_dir,
"label": "Memoraiz app (React frontend)",
"git_repo": memoraiz_dir.parent,
}
return projects
PROJECTS = _build_projects()
PROJECT_NAMES = list(PROJECTS.keys()) or ["betterlifesg", "memoraiz"]
# ── Path safety ──────────────────────────────────────────────────
def _resolve(base: pathlib.Path, path: str) -> pathlib.Path:
"""Resolve a relative path inside a base dir, preventing path traversal."""
resolved = (base / path).resolve()
if not str(resolved).startswith(str(base.resolve())):
raise ValueError(f"Path traversal blocked: {path}")
return resolved
# ── Git push ─────────────────────────────────────────────────────
def _git_push(project_key: str, changed_file: str) -> str:
"""Commit and push changes in the project's git repo."""
proj = PROJECTS[project_key]
repo = proj["git_repo"]
if not (repo / ".git").exists():
return "(no git repo — skipped push)"
try:
subprocess.run(["git", "add", "-A"], cwd=repo, check=True, capture_output=True)
subprocess.run(
["git", "commit", "-m", f"betterbot: update {changed_file}"],
cwd=repo,
check=True,
capture_output=True,
)
result = subprocess.run(
["git", "push", "origin", "HEAD"],
cwd=repo,
check=True,
capture_output=True,
text=True,
)
logger.info("Git push for %s: %s", project_key, result.stderr.strip())
return f"Pushed {project_key} to git"
except subprocess.CalledProcessError as e:
logger.error("Git/deploy error: %s\nstdout: %s\nstderr: %s", e, e.stdout, e.stderr)
return f"Push failed: {e.stderr or e.stdout or str(e)}"
# ── Tool implementations ─────────────────────────────────────────
def handle_tool_call(name: str, args: dict) -> str:
"""Execute a site-editing tool call and return the result as a string."""
project_key = args.get("project", "betterlifesg")
if project_key not in PROJECTS:
return f"Unknown project: {project_key}. Available: {', '.join(PROJECTS.keys())}"
base = PROJECTS[project_key]["dir"]
if name == "list_files":
subdir = args.get("subdirectory", "")
target = _resolve(base, subdir) if subdir else base
files = []
for p in sorted(target.rglob("*")):
if p.is_file() and not any(
part in (".git", "node_modules", "__pycache__") for part in p.parts
):
files.append(str(p.relative_to(base)))
return "\n".join(files[:200]) if files else "(no files found)"
if name == "read_file":
path = _resolve(base, args["path"])
if not path.exists():
return f"Error: {args['path']} does not exist."
if path.suffix in (".png", ".jpg", ".jpeg", ".gif", ".webp", ".ico"):
return f"[Binary image file: {args['path']}, {path.stat().st_size} bytes]"
return path.read_text(encoding="utf-8")
if name == "write_file":
path = _resolve(base, args["path"])
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(args["content"], encoding="utf-8")
push_result = _git_push(project_key, args["path"])
return f"OK — wrote {len(args['content'])} chars to {args['path']}. {push_result}"
return f"Unknown tool: {name}"
# ── OpenAI function-calling tool schemas ─────────────────────────
TOOLS = [
{
"type": "function",
"function": {
"name": "list_files",
"description": "List all files in a project directory.",
"parameters": {
"type": "object",
"properties": {
"project": {
"type": "string",
"enum": PROJECT_NAMES,
"description": "Which project to list files for.",
},
"subdirectory": {
"type": "string",
"description": "Optional subdirectory to list, e.g. 'src/pages'. Defaults to root.",
"default": "",
},
},
"required": ["project"],
},
},
},
{
"type": "function",
"function": {
"name": "read_file",
"description": "Read the full contents of a project file.",
"parameters": {
"type": "object",
"properties": {
"project": {
"type": "string",
"enum": PROJECT_NAMES,
"description": "Which project the file belongs to.",
},
"path": {
"type": "string",
"description": "Relative path inside the project directory.",
},
},
"required": ["project", "path"],
},
},
},
{
"type": "function",
"function": {
"name": "write_file",
"description": "Write (create or overwrite) a text file in a project directory. After writing, changes are committed and pushed to git automatically.",
"parameters": {
"type": "object",
"properties": {
"project": {
"type": "string",
"enum": PROJECT_NAMES,
"description": "Which project the file belongs to.",
},
"path": {
"type": "string",
"description": "Relative path inside the project directory.",
},
"content": {
"type": "string",
"description": "The full file content to write.",
},
},
"required": ["project", "path", "content"],
},
},
},
]
SYSTEM_PROMPT = """\
You have access to site-editing tools for managing project files. \
When asked to change site content, use list_files to see what's available, \
read_file to understand the current state, then write_file to apply changes. \
Always write the COMPLETE file content never partial. \
Changes are committed and pushed to git automatically after writing.\
"""
# ── ToolSet registration ─────────────────────────────────────────
def _factory(context: dict) -> list:
"""Build Copilot SDK tools from the OpenAI schemas."""
return openai_tools_to_copilot(TOOLS, handler=handle_tool_call)
site_editing_toolset = ToolSet(
name="site_editing",
system_prompt=SYSTEM_PROMPT,
openai_schemas=TOOLS,
factory=_factory,
required_keys=[],
)

769
user_store.py Normal file
View file

@ -0,0 +1,769 @@
"""User identity store — SQLite-backed user management, credential vault,
conversation storage, and event tracking.
Provides:
- Schema migrations applied idempotently at startup
- Fernet encryption for per-user API tokens at rest
- User resolution (provider + external_id internal User)
- Credential CRUD with lazy decryption support
- Session / message / event persistence
- Owner bootstrap migration from env-var credentials
"""
from __future__ import annotations
import logging
import sqlite3
import time
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from cryptography.fernet import Fernet
from config import settings
logger = logging.getLogger(__name__)
# ── Data classes ─────────────────────────────────────────────────
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
@dataclass
class User:
id: str
display_name: str
email: str | None
created_at: str
updated_at: str
is_owner: bool
onboarding_complete: bool
telegram_approved: bool | None = None # None = pending, True = approved, False = denied
is_new: bool = False # transient — not persisted
@dataclass
class ExternalIdentity:
id: int
user_id: str
provider: str
external_id: str
metadata_json: str | None
created_at: str
@dataclass
class ServiceCredential:
id: int
user_id: str
service: str
encrypted_token: str
service_user_id: str | None
service_username: str | None
created_at: str
expires_at: str | None
last_used_at: str | None
@dataclass
class ProvisioningLogEntry:
id: int
user_id: str
service: str
action: str
detail_json: str | None
created_at: str
@dataclass
class Session:
id: str
user_id: str
surface: str
topic_id: str | None
title: str | None
created_at: str
last_active_at: str
@dataclass
class Message:
id: str
session_id: str
role: str
content: str
created_at: str
@dataclass
class Event:
id: int
user_id: str
source: str
event_type: str
summary: str
detail_json: str | None
created_at: str
consumed_at: str | None
# ── Fernet helpers (T012) ────────────────────────────────────────
def _get_fernet() -> Fernet:
key = settings.CREDENTIAL_VAULT_KEY
if not key:
raise RuntimeError("CREDENTIAL_VAULT_KEY is not set — cannot encrypt/decrypt credentials")
return Fernet(key.encode() if isinstance(key, str) else key)
def encrypt(plaintext: str) -> str:
"""Encrypt *plaintext* and return a base64-encoded TEXT string."""
f = _get_fernet()
return f.encrypt(plaintext.encode()).decode()
def decrypt(ciphertext: str) -> str:
"""Decrypt a Fernet ciphertext (base64 TEXT) back to the original string."""
f = _get_fernet()
return f.decrypt(ciphertext.encode()).decode()
# ── Schema migrations (T011) ────────────────────────────────────
_MIGRATIONS: list[tuple[int, str, str]] = [
(
1,
"initial_schema",
"""
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
display_name TEXT NOT NULL,
email TEXT UNIQUE,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
is_owner INTEGER NOT NULL DEFAULT 0,
onboarding_complete INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS external_identities (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL REFERENCES users(id),
provider TEXT NOT NULL,
external_id TEXT NOT NULL,
metadata_json TEXT,
created_at TEXT NOT NULL,
UNIQUE(provider, external_id)
);
CREATE INDEX IF NOT EXISTS idx_external_identities_lookup
ON external_identities(provider, external_id);
CREATE TABLE IF NOT EXISTS service_credentials (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL REFERENCES users(id),
service TEXT NOT NULL,
encrypted_token TEXT NOT NULL,
service_user_id TEXT,
service_username TEXT,
created_at TEXT NOT NULL,
expires_at TEXT,
last_used_at TEXT,
UNIQUE(user_id, service)
);
CREATE INDEX IF NOT EXISTS idx_service_credentials_user_id
ON service_credentials(user_id);
CREATE TABLE IF NOT EXISTS provisioning_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL REFERENCES users(id),
service TEXT NOT NULL,
action TEXT NOT NULL,
detail_json TEXT,
created_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL REFERENCES users(id),
surface TEXT NOT NULL,
topic_id TEXT,
title TEXT,
created_at TEXT NOT NULL,
last_active_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_sessions_user_id
ON sessions(user_id);
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL REFERENCES sessions(id),
role TEXT NOT NULL,
content TEXT NOT NULL,
created_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_messages_session_id
ON messages(session_id);
CREATE INDEX IF NOT EXISTS idx_messages_created_at
ON messages(created_at);
CREATE TABLE IF NOT EXISTS events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL REFERENCES users(id),
source TEXT NOT NULL,
event_type TEXT NOT NULL,
summary TEXT NOT NULL,
detail_json TEXT,
created_at TEXT NOT NULL,
consumed_at TEXT
);
CREATE INDEX IF NOT EXISTS idx_events_user_created
ON events(user_id, created_at);
""",
),
(
2,
"add_telegram_approved",
"""
ALTER TABLE users ADD COLUMN telegram_approved INTEGER;
""",
),
]
def _apply_migrations(conn: sqlite3.Connection) -> None:
"""Ensure all migrations are applied idempotently."""
conn.execute(
"""CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
name TEXT NOT NULL,
applied_at TEXT NOT NULL
)"""
)
applied = {row[0] for row in conn.execute("SELECT version FROM schema_migrations")}
for version, name, ddl in _MIGRATIONS:
if version in applied:
continue
logger.info("Applying migration %d: %s", version, name)
conn.executescript(ddl)
conn.execute(
"INSERT INTO schema_migrations (version, name, applied_at) VALUES (?, ?, ?)",
(version, name, _now_iso()),
)
conn.commit()
logger.info("Schema migrations up to date (latest: %d)", max(v for v, _, _ in _MIGRATIONS))
# ── SQLite connection (T010) ─────────────────────────────────────
_MAX_WRITE_RETRIES = 3
_RETRY_BASE_DELAY = 0.05 # seconds
def _connect(db_path: str | Path) -> sqlite3.Connection:
"""Open a SQLite connection with WAL mode and foreign keys enabled."""
conn = sqlite3.connect(str(db_path), timeout=10)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
conn.row_factory = sqlite3.Row
return conn
def _write_with_retry(conn: sqlite3.Connection, fn: Any) -> Any:
"""Execute *fn(conn)* with exponential-backoff retry on SQLITE_BUSY."""
for attempt in range(_MAX_WRITE_RETRIES):
try:
return fn(conn)
except sqlite3.OperationalError as exc:
if "database is locked" not in str(exc) or attempt == _MAX_WRITE_RETRIES - 1:
raise
delay = _RETRY_BASE_DELAY * (2**attempt)
logger.warning("SQLITE_BUSY — retrying in %.2fs (attempt %d)", delay, attempt + 1)
time.sleep(delay)
return None # unreachable
# ── UserStore ────────────────────────────────────────────────────
class UserStore:
"""Central access layer for user identity, credentials, sessions, and events."""
def __init__(self, db_path: str | Path | None = None) -> None:
if db_path is None:
db_path = Path(settings.DATA_DIR) / "users.db"
self._db_path = Path(db_path)
self._db_path.parent.mkdir(parents=True, exist_ok=True)
self._conn = _connect(self._db_path)
_apply_migrations(self._conn)
@property
def conn(self) -> sqlite3.Connection:
return self._conn
def close(self) -> None:
self._conn.close()
# ── User resolution (T013) ───────────────────────────────────
def resolve_or_create_user(
self,
provider: str,
external_id: str,
display_name: str,
email: str | None = None,
) -> User:
"""Resolve (provider, external_id) → User. Creates if not found."""
def _do(conn: sqlite3.Connection) -> User:
row = conn.execute(
"""SELECT u.id, u.display_name, u.email, u.created_at, u.updated_at,
u.is_owner, u.onboarding_complete, u.telegram_approved
FROM external_identities ei
JOIN users u ON u.id = ei.user_id
WHERE ei.provider = ? AND ei.external_id = ?""",
(provider, external_id),
).fetchone()
if row:
_raw_approved = row["telegram_approved"]
return User(
id=row["id"],
display_name=row["display_name"],
email=row["email"],
created_at=row["created_at"],
updated_at=row["updated_at"],
is_owner=bool(row["is_owner"]),
onboarding_complete=bool(row["onboarding_complete"]),
telegram_approved=None if _raw_approved is None else bool(_raw_approved),
is_new=False,
)
# Create new user
now = _now_iso()
user_id = str(uuid.uuid4())
conn.execute(
"""INSERT INTO users (id, display_name, email, created_at, updated_at,
is_owner, onboarding_complete)
VALUES (?, ?, ?, ?, ?, 0, 0)""",
(user_id, display_name, email, now, now),
)
conn.execute(
"""INSERT INTO external_identities (user_id, provider, external_id, created_at)
VALUES (?, ?, ?, ?)""",
(user_id, provider, external_id, now),
)
conn.commit()
return User(
id=user_id,
display_name=display_name,
email=email,
created_at=now,
updated_at=now,
is_owner=False,
onboarding_complete=False,
is_new=True,
)
return _write_with_retry(self._conn, _do)
def get_user(self, user_id: str) -> User | None:
row = self._conn.execute(
"""SELECT id, display_name, email, created_at, updated_at,
is_owner, onboarding_complete, telegram_approved
FROM users WHERE id = ?""",
(user_id,),
).fetchone()
if not row:
return None
_raw_approved = row["telegram_approved"]
return User(
id=row["id"],
display_name=row["display_name"],
email=row["email"],
created_at=row["created_at"],
updated_at=row["updated_at"],
is_owner=bool(row["is_owner"]),
onboarding_complete=bool(row["onboarding_complete"]),
telegram_approved=None if _raw_approved is None else bool(_raw_approved),
)
def set_onboarding_complete(self, user_id: str) -> None:
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"UPDATE users SET onboarding_complete = 1, updated_at = ? WHERE id = ?",
(_now_iso(), user_id),
)
conn.commit()
_write_with_retry(self._conn, _do)
def set_telegram_approval(self, user_id: str, approved: bool) -> None:
"""Set telegram_approved to 1 (approved) or 0 (denied)."""
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"UPDATE users SET telegram_approved = ?, updated_at = ? WHERE id = ?",
(int(approved), _now_iso(), user_id),
)
conn.commit()
_write_with_retry(self._conn, _do)
# ── Credential CRUD (T014) ───────────────────────────────────
def get_credentials(self, user_id: str) -> dict[str, ServiceCredential]:
"""Return all credentials for a user keyed by service name."""
rows = self._conn.execute(
"""SELECT id, user_id, service, encrypted_token, service_user_id,
service_username, created_at, expires_at, last_used_at
FROM service_credentials WHERE user_id = ?""",
(user_id,),
).fetchall()
return {
row["service"]: ServiceCredential(
id=row["id"],
user_id=row["user_id"],
service=row["service"],
encrypted_token=row["encrypted_token"],
service_user_id=row["service_user_id"],
service_username=row["service_username"],
created_at=row["created_at"],
expires_at=row["expires_at"],
last_used_at=row["last_used_at"],
)
for row in rows
}
def get_credential(self, user_id: str, service: str) -> ServiceCredential | None:
row = self._conn.execute(
"""SELECT id, user_id, service, encrypted_token, service_user_id,
service_username, created_at, expires_at, last_used_at
FROM service_credentials WHERE user_id = ? AND service = ?""",
(user_id, service),
).fetchone()
if not row:
return None
return ServiceCredential(
id=row["id"],
user_id=row["user_id"],
service=row["service"],
encrypted_token=row["encrypted_token"],
service_user_id=row["service_user_id"],
service_username=row["service_username"],
created_at=row["created_at"],
expires_at=row["expires_at"],
last_used_at=row["last_used_at"],
)
def store_credential(
self,
user_id: str,
service: str,
token: str,
service_user_id: str | None = None,
service_username: str | None = None,
expires_at: str | None = None,
) -> None:
"""Store (or replace) a credential. Encrypts the token before writing."""
enc_token = encrypt(token)
now = _now_iso()
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"""INSERT INTO service_credentials
(user_id, service, encrypted_token, service_user_id,
service_username, created_at, expires_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(user_id, service) DO UPDATE SET
encrypted_token = excluded.encrypted_token,
service_user_id = excluded.service_user_id,
service_username = excluded.service_username,
expires_at = excluded.expires_at,
last_used_at = NULL""",
(user_id, service, enc_token, service_user_id, service_username, now, expires_at),
)
conn.commit()
_write_with_retry(self._conn, _do)
def delete_credential(self, user_id: str, service: str) -> None:
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"DELETE FROM service_credentials WHERE user_id = ? AND service = ?",
(user_id, service),
)
conn.commit()
_write_with_retry(self._conn, _do)
def touch_credential(self, user_id: str, service: str) -> None:
"""Update last_used_at for a credential."""
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"UPDATE service_credentials SET last_used_at = ? WHERE user_id = ? AND service = ?",
(_now_iso(), user_id, service),
)
conn.commit()
_write_with_retry(self._conn, _do)
# ── Provisioning log ─────────────────────────────────────────
def log_provisioning(
self,
user_id: str,
service: str,
action: str,
detail_json: str | None = None,
) -> None:
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"""INSERT INTO provisioning_log (user_id, service, action, detail_json, created_at)
VALUES (?, ?, ?, ?, ?)""",
(user_id, service, action, detail_json, _now_iso()),
)
conn.commit()
_write_with_retry(self._conn, _do)
# ── Session + message CRUD (T023) ────────────────────────────
def get_or_create_session(
self,
user_id: str,
surface: str,
topic_id: str | None = None,
) -> Session:
"""Resolve or create a session for the given user/surface/topic."""
def _do(conn: sqlite3.Connection) -> Session:
if topic_id is not None:
row = conn.execute(
"""SELECT id, user_id, surface, topic_id, title, created_at, last_active_at
FROM sessions WHERE user_id = ? AND surface = ? AND topic_id = ?""",
(user_id, surface, topic_id),
).fetchone()
else:
row = conn.execute(
"""SELECT id, user_id, surface, topic_id, title, created_at, last_active_at
FROM sessions WHERE user_id = ? AND surface = ? AND topic_id IS NULL""",
(user_id, surface),
).fetchone()
if row:
now = _now_iso()
conn.execute(
"UPDATE sessions SET last_active_at = ? WHERE id = ?",
(now, row["id"]),
)
conn.commit()
return Session(
id=row["id"],
user_id=row["user_id"],
surface=row["surface"],
topic_id=row["topic_id"],
title=row["title"],
created_at=row["created_at"],
last_active_at=now,
)
now = _now_iso()
session_id = str(uuid.uuid4())
conn.execute(
"""INSERT INTO sessions (id, user_id, surface, topic_id, title, created_at, last_active_at)
VALUES (?, ?, ?, ?, NULL, ?, ?)""",
(session_id, user_id, surface, topic_id, now, now),
)
conn.commit()
return Session(
id=session_id,
user_id=user_id,
surface=surface,
topic_id=topic_id,
title=None,
created_at=now,
last_active_at=now,
)
return _write_with_retry(self._conn, _do)
def log_message(self, session_id: str, role: str, content: str) -> None:
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"""INSERT INTO messages (id, session_id, role, content, created_at)
VALUES (?, ?, ?, ?, ?)""",
(str(uuid.uuid4()), session_id, role, content, _now_iso()),
)
conn.commit()
_write_with_retry(self._conn, _do)
def get_session_messages(self, session_id: str) -> list[Message]:
rows = self._conn.execute(
"""SELECT id, session_id, role, content, created_at
FROM messages WHERE session_id = ? ORDER BY created_at""",
(session_id,),
).fetchall()
return [
Message(
id=row["id"],
session_id=row["session_id"],
role=row["role"],
content=row["content"],
created_at=row["created_at"],
)
for row in rows
]
# ── Event CRUD (T025) ────────────────────────────────────────
def store_event(
self,
user_id: str,
source: str,
event_type: str,
summary: str,
detail_json: str | None = None,
) -> None:
def _do(conn: sqlite3.Connection) -> None:
conn.execute(
"""INSERT INTO events (user_id, source, event_type, summary, detail_json, created_at)
VALUES (?, ?, ?, ?, ?, ?)""",
(user_id, source, event_type, summary, detail_json, _now_iso()),
)
conn.commit()
_write_with_retry(self._conn, _do)
def get_recent_events(self, user_id: str, window_hours: int = 24) -> list[Event]:
cutoff = datetime.now(timezone.utc).isoformat()
# Use simple string comparison — ISO8601 sorts lexicographically
rows = self._conn.execute(
"""SELECT id, user_id, source, event_type, summary, detail_json, created_at, consumed_at
FROM events
WHERE user_id = ? AND created_at >= datetime(?, '-' || ? || ' hours')
ORDER BY created_at DESC""",
(user_id, cutoff, window_hours),
).fetchall()
return [
Event(
id=row["id"],
user_id=row["user_id"],
source=row["source"],
event_type=row["event_type"],
summary=row["summary"],
detail_json=row["detail_json"],
created_at=row["created_at"],
consumed_at=row["consumed_at"],
)
for row in rows
]
def mark_events_consumed(self, event_ids: list[int]) -> None:
if not event_ids:
return
def _do(conn: sqlite3.Connection) -> None:
placeholders = ",".join("?" for _ in event_ids)
conn.execute(
f"UPDATE events SET consumed_at = ? WHERE id IN ({placeholders})", # noqa: S608
[_now_iso(), *event_ids],
)
conn.commit()
_write_with_retry(self._conn, _do)
# ── Owner bootstrap (T016) ───────────────────────────────────
def bootstrap_owner(self) -> User:
"""Idempotent: ensure the owner user exists, migrate env-var credentials."""
def _do(conn: sqlite3.Connection) -> User:
row = conn.execute(
"""SELECT id, display_name, email, created_at, updated_at,
is_owner, onboarding_complete
FROM users WHERE is_owner = 1""",
).fetchone()
if row:
owner = User(
id=row["id"],
display_name=row["display_name"],
email=row["email"],
created_at=row["created_at"],
updated_at=row["updated_at"],
is_owner=True,
onboarding_complete=True,
)
# Still migrate any new env-var credentials that weren't there before
self._migrate_env_credentials(owner.id)
return owner
now = _now_iso()
owner_id = str(uuid.uuid4())
conn.execute(
"""INSERT INTO users (id, display_name, email, created_at, updated_at,
is_owner, onboarding_complete)
VALUES (?, ?, NULL, ?, ?, 1, 1)""",
(owner_id, "Owner", now, now),
)
# Create "web" external identity for owner
conn.execute(
"""INSERT INTO external_identities (user_id, provider, external_id, created_at)
VALUES (?, 'web', 'owner', ?)""",
(owner_id, now),
)
conn.commit()
logger.info("Created owner user %s", owner_id)
self._migrate_env_credentials(owner_id)
return User(
id=owner_id,
display_name="Owner",
email=None,
created_at=now,
updated_at=now,
is_owner=True,
onboarding_complete=True,
is_new=True,
)
return _write_with_retry(self._conn, _do)
def _migrate_env_credentials(self, owner_id: str) -> None:
"""Migrate static env-var API keys into the owner's credential vault."""
migrations = [
("vikunja", settings.VIKUNJA_API_KEY, settings.VIKUNJA_API_URL),
("karakeep", settings.KARAKEEP_API_KEY, settings.KARAKEEP_API_URL),
]
for service, api_key, _url in migrations:
if not api_key:
continue
existing = self.get_credential(owner_id, service)
if existing:
continue
self.store_credential(owner_id, service, api_key)
self.log_provisioning(owner_id, service, "env_migrated", '{"source": "env_var"}')
logger.info("Migrated %s env-var credential for owner %s", service, owner_id)
# ── Module-level singleton ───────────────────────────────────────
_store: UserStore | None = None
def get_store() -> UserStore:
"""Return the module-level UserStore singleton, creating it if needed."""
global _store # noqa: PLW0603
if _store is None:
_store = UserStore()
return _store
def init_store(db_path: str | Path | None = None) -> UserStore:
"""Initialise the UserStore singleton explicitly (e.g. in FastAPI lifespan)."""
global _store # noqa: PLW0603
_store = UserStore(db_path)
return _store

616
ux.py Normal file
View file

@ -0,0 +1,616 @@
"""User-facing messaging helpers for Telegram and web chat surfaces."""
from __future__ import annotations
import html
import json
import re
from typing import Any, Literal, cast
from copilot.generated.session_events import Data, SessionEvent, SessionEventType, ToolRequest
Surface = Literal["telegram", "web"]
def extract_final_text(events: list[SessionEvent]) -> str:
"""Walk collected events and return the final assistant message text."""
# Prefer the last ASSISTANT_MESSAGE event
for event in reversed(events):
if event.type == SessionEventType.ASSISTANT_MESSAGE and event.data and event.data.content:
return event.data.content.strip()
# Fallback: concatenate deltas
parts: list[str] = []
for event in events:
if event.type == SessionEventType.ASSISTANT_MESSAGE_DELTA and event.data and event.data.delta_content:
parts.append(event.data.delta_content)
text: str = "".join(parts).strip()
return text
def working_message(*, surface: Surface) -> str:
if surface == "telegram":
return "Thinking ..."
return "Working on it"
def busy_message(*, surface: Surface) -> str:
if surface == "telegram":
return "Still working on the previous message in this chat. Wait for that reply, or send /new to reset."
return "Still working on the previous message. Wait for that reply before sending another one."
def format_session_error(*, surface: Surface, error: Exception | str | None = None) -> str:
parts: list[str] = ["Run failed with an internal exception."]
detail: str = _format_error_detail(error)
if detail:
parts.append(f"Exception: {detail}")
if _looks_image_unsupported(detail):
parts.append(
"This model does not support image inputs. "
"Switch to a vision model (e.g. gpt-4o, claude-sonnet, gemini-2.5-pro) or resend without the image."
)
elif _looks_rate_limited(detail):
parts.append(
"The provider is rate-limiting requests (HTTP 429). The SDK already retried several times before giving up."
)
elif _looks_retryable(detail):
parts.append("All automatic retries were exhausted.")
parts.append(_retry_guidance(surface))
return "\n\n".join(parts)
def extract_intent_from_tool(event: SessionEvent) -> str | None:
"""If the event is a report_intent tool call, return the intent text."""
if event.type != SessionEventType.TOOL_EXECUTION_START:
return None
tool_name: str = _event_tool_name(event)
if tool_name != "report_intent":
return None
args = event.data and event.data.arguments
if not args:
return None
if isinstance(args, str):
try:
args = json.loads(args)
except Exception:
return None
if isinstance(args, dict):
args_dict = cast(dict[str, Any], args)
intent = args_dict.get("intent", "")
if isinstance(intent, str) and intent.strip():
return intent.strip()
return None
def extract_tool_intent_summary(event: SessionEvent) -> str | None:
"""Extract intent_summary from tool_requests on any event.
The Copilot SDK can attach ``tool_requests`` to events (e.g. before tool
execution starts). Each tool request may carry an ``intent_summary``
describing *why* the agent wants to call that tool.
"""
data: Data | None = getattr(event, "data", None)
if data is None:
return None
tool_requests: list[ToolRequest] = getattr(data, "tool_requests") or []
if not tool_requests:
return None
try:
summary: str = "\n".join(
intent_summary
for request in tool_requests
if (intent_summary := getattr(request, "intent_summary", "").strip())
)
return summary or None
except (IndexError, TypeError, KeyError):
return None
return None
def stream_status_updates(event: Any, *, include_reasoning_status: bool = True) -> list[str]:
"""Return ordered, deduplicated user-facing status updates for a Copilot SDK event."""
event_type: SessionEventType = event.type
data: Data | None = getattr(event, "data", None)
updates: list[str] = []
seen: set[str] = set()
noise_texts: set[str] = {"tool done", "Thinking"}
ugly_texts: dict[str, str] = {
"Running view": "Viewing file(s)",
}
def add(text: Any, *, prefix: str | None = None, limit: int = 220) -> None:
if text in (None, ""):
return
elif isinstance(text, str) and text.strip().lower() in noise_texts:
return
elif isinstance(text, str):
text = ugly_texts.get(text.strip(), text)
normalized: str = _normalize_status_text(text, prefix=prefix, limit=limit)
if not normalized:
return
dedupe_key = normalized.casefold()
if dedupe_key in seen:
return
seen.add(dedupe_key)
updates.append(normalized)
add(getattr(data, "progress_message", None), limit=240)
add(extract_tool_intent_summary(event), limit=240)
if event_type == SessionEventType.TOOL_EXECUTION_START:
tool_name: str = _event_tool_name(event)
intent: str | None = extract_intent_from_tool(event)
if tool_name == "report_intent":
pass
elif not intent and not tool_name:
pass
else:
kwargs: dict = {"text": intent, "prefix": tool_name} if intent else {"text": f"Running {tool_name}"}
add(limit=160, **kwargs)
if event_type == SessionEventType.TOOL_EXECUTION_COMPLETE:
tool_name: str = _event_tool_name(event)
if tool_name != "report_intent":
add(f"{tool_name} done", limit=160)
if event_type == SessionEventType.SUBAGENT_SELECTED:
add(f"Routed to {_event_agent_name(event)}", limit=180)
if event_type == SessionEventType.SUBAGENT_STARTED:
add(f"{_event_agent_name(event)} working", limit=180)
if event_type == SessionEventType.SESSION_COMPACTION_START:
add("Compacting context", limit=120)
if event_type == SessionEventType.SESSION_COMPACTION_COMPLETE:
add("Context compacted", limit=120)
# if event_type == SessionEventType.ASSISTANT_TURN_START:
# add("Thinking", limit=80)
if event_type == SessionEventType.ASSISTANT_INTENT:
add(getattr(data, "intent", None), limit=240)
if include_reasoning_status and event_type in {
SessionEventType.ASSISTANT_REASONING,
SessionEventType.ASSISTANT_REASONING_DELTA,
}:
reasoning = (data and data.reasoning_text) or ""
if reasoning.strip():
first_line = reasoning.strip().splitlines()[0].strip()
if first_line.lower().startswith(("intent:", "intent ")):
add(first_line, limit=240)
add(getattr(data, "message", None), limit=240)
add(getattr(data, "title", None), prefix="Title", limit=200)
add(getattr(data, "summary", None), prefix="Summary", limit=240)
add(getattr(data, "summary_content", None), prefix="Context summary", limit=240)
add(getattr(data, "warning_type", None), prefix="Warning type", limit=160)
for warning in _iter_status_values(getattr(data, "warnings", None)):
add(warning, prefix="Warning", limit=240)
add(getattr(data, "error_reason", None), prefix="Error", limit=240)
for error in _iter_status_values(getattr(data, "errors", None)):
add(error, prefix="Error", limit=240)
add(getattr(data, "reason", None), prefix="Stop reason", limit=200)
add(_format_server_status(getattr(data, "status", None)), prefix="Server", limit=160)
add(getattr(data, "phase", None), prefix="Phase", limit=120)
add(getattr(data, "mcp_tool_name", None), prefix="MCP tool", limit=180)
add(_format_code_changes_status(getattr(data, "code_changes", None)), limit=200)
# add(_format_cache_status(data), limit=180)
# The SDK's `duration` is only a subtotal for the current API round-trip.
# Total turn runtime is tracked by the caller and surfaced as a live
# elapsed clock while the overall request is still running.
total_premium_requests = getattr(data, "total_premium_requests", None)
if total_premium_requests not in (None, ""):
add(f"Premium requests: {_format_metric_number(total_premium_requests)}", limit=140)
add(getattr(data, "branch", None), prefix="Branch", limit=160)
add(getattr(data, "cwd", None), prefix="CWD", limit=220)
add(getattr(data, "git_root", None), prefix="Git root", limit=220)
head_commit = getattr(data, "head_commit", None)
if head_commit:
add(f"Head: {str(head_commit).strip()[:12]}", limit=80)
if getattr(data, "reasoning_opaque", None):
add("Encrypted reasoning attached", limit=120)
if model := getattr(data, "model", None):
add(model, prefix="\n🤖", limit=160)
return updates
def stream_status_text(event: Any) -> str:
"""Return a single concatenated status string for compatibility call sites."""
return "\n".join(stream_status_updates(event))
def stream_trace_event(event: Any) -> dict[str, str] | None:
"""Extract structured trace entries for tool activity and subagent routing."""
event_type = event.type
if event_type == SessionEventType.TOOL_EXECUTION_START:
tool_name = _event_tool_name(event)
if tool_name == "report_intent":
return None # suppress report_intent from trace
tool_call_id = (event.data and event.data.tool_call_id) or ""
detail = _stringify_trace_detail(event.data and event.data.arguments)
return {
"kind": "trace",
"category": "tool_call",
"key": f"tool:{tool_call_id or tool_name}",
"tool_name": tool_name,
"title": f"Tool call: {tool_name}",
"summary": f"Called {tool_name}",
"text": f"Called {tool_name}",
"detail": detail or "No arguments exposed.",
}
if event_type == SessionEventType.TOOL_EXECUTION_COMPLETE:
tool_name = _event_tool_name(event)
if tool_name == "report_intent":
return None # suppress report_intent from trace
tool_call_id = (event.data and event.data.tool_call_id) or ""
output_detail = _stringify_trace_detail(event.data and event.data.output)
return {
"kind": "trace",
"category": "tool_call",
"key": f"tool:{tool_call_id or tool_name}",
"tool_name": tool_name,
"title": f"Tool call: {tool_name}",
"summary": f"{tool_name} done",
"text": f"{tool_name} done",
"output_detail": output_detail or "Tool finished with no readable output.",
}
if event_type == SessionEventType.SUBAGENT_SELECTED:
agent_name = _event_agent_name(event)
return {
"kind": "trace",
"category": "subagent",
"key": f"agent:{agent_name}",
"title": f"Subagent: {agent_name}",
"summary": f"Routed to {agent_name}",
"text": f"Routed to {agent_name}",
"detail": f"The run is now executing inside the {agent_name} subagent.",
}
return None
def stream_reasoning_event(event: Any) -> tuple[str, str, bool] | None:
"""Extract reasoning text from a Copilot SDK event when available."""
if event.type == SessionEventType.ASSISTANT_REASONING_DELTA:
reasoning_id = (event.data and event.data.reasoning_id) or "reasoning"
text = (event.data and event.data.reasoning_text) or ""
if text.strip():
return reasoning_id, text, False
return None
if event.type == SessionEventType.ASSISTANT_REASONING:
reasoning_id = (event.data and event.data.reasoning_id) or "reasoning"
text = (event.data and event.data.reasoning_text) or ""
if text.strip():
return reasoning_id, text, True
return None
return None
def format_tool_counts(tool_counts: dict[str, int], *, current_status: str = "") -> str:
"""Build a compact one-line summary of tool call counts."""
if not tool_counts:
return current_status or ""
parts: list[str] = []
for name, count in sorted(tool_counts.items(), key=lambda kv: -kv[1]):
if count <= 0:
continue
parts.append(f"{count} {name}")
lines: list[str] = []
if current_status:
lines.append(current_status)
if parts:
lines.append(f"🔧 {' · '.join(parts)}")
return "\n".join(lines)
def format_elapsed_status(elapsed_seconds: float) -> str:
"""Render a human-friendly turn runtime for live status displays."""
total_seconds = max(0, int(elapsed_seconds))
hours, remainder = divmod(total_seconds, 3600)
minutes, seconds = divmod(remainder, 60)
if hours:
return f"Elapsed: {hours}h {minutes:02d}m {seconds:02d}s"
if minutes:
return f"Elapsed: {minutes}m {seconds:02d}s"
return f"Elapsed: {seconds}s"
def append_elapsed_status(text: Any, *, elapsed_seconds: float) -> str:
"""Append the current turn runtime to a status line without duplicating it."""
lines = [line for line in str(text or "").splitlines() if not line.strip().lower().startswith("elapsed:")]
base = "\n".join(lines).strip()
elapsed = format_elapsed_status(elapsed_seconds)
if not base:
return elapsed
return f"{base}\n{elapsed}"
async def safe_delete_message(message: Any) -> None:
if message is None:
return
try:
await message.delete()
except Exception:
return
def markdown_to_telegram_html(text: str) -> str:
"""Convert common Markdown to Telegram-compatible HTML.
Handles fenced code blocks, inline code, bold, italic, strikethrough,
and links. Falls back gracefully anything it can't convert is
HTML-escaped and sent as plain text.
"""
# Split into fenced code blocks vs everything else
parts: list[str] = []
# Match ```lang\n...\n``` (with optional language tag)
code_block_re = re.compile(r"```(\w*)\n(.*?)```", re.DOTALL)
last = 0
for m in code_block_re.finditer(text):
# Process non-code text before this block
if m.start() > last:
parts.append(_md_inline_to_html(text[last : m.start()]))
lang = m.group(1)
code = html.escape(m.group(2).rstrip("\n"))
if lang:
parts.append(f'<pre><code class="language-{html.escape(lang)}">{code}</code></pre>')
else:
parts.append(f"<pre>{code}</pre>")
last = m.end()
# Remaining text after last code block
if last < len(text):
parts.append(_md_inline_to_html(text[last:]))
return "".join(parts)
def _md_inline_to_html(text: str) -> str:
"""Convert inline Markdown (outside code blocks) to Telegram HTML."""
# First, protect inline code spans so their contents aren't modified
inline_code_re = re.compile(r"`([^`]+)`")
placeholder = "\x00CODE\x00"
codes: list[str] = []
def _save_code(m: re.Match) -> str:
codes.append(html.escape(m.group(1)))
return f"{placeholder}{len(codes) - 1}{placeholder}"
text = inline_code_re.sub(_save_code, text)
# Escape HTML entities in the remaining text
text = html.escape(text)
# Bold: **text** or __text__
text = re.sub(r"\*\*(.+?)\*\*", r"<b>\1</b>", text)
text = re.sub(r"__(.+?)__", r"<b>\1</b>", text)
# Italic: *text* or _text_ (but not inside words like foo_bar)
text = re.sub(r"(?<!\w)\*([^*]+?)\*(?!\w)", r"<i>\1</i>", text)
text = re.sub(r"(?<!\w)_([^_]+?)_(?!\w)", r"<i>\1</i>", text)
# Strikethrough: ~~text~~
text = re.sub(r"~~(.+?)~~", r"<s>\1</s>", text)
# Links: [text](url)
text = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", r'<a href="\2">\1</a>', text)
# Restore inline code spans
for i, code_html in enumerate(codes):
text = text.replace(f"{placeholder}{i}{placeholder}", f"<code>{code_html}</code>")
return text
# ── Private helpers ──────────────────────────────────────────────────
def _event_tool_name(event: Any) -> str:
if event.data:
return event.data.tool_name or event.data.name or "tool"
return "tool"
def _event_agent_name(event: Any) -> str:
if event.data:
return event.data.agent_name or event.data.agent_display_name or "specialist"
return "specialist"
def _retry_guidance(surface: Surface) -> str:
if surface == "telegram":
return "Reply with a narrower follow-up, switch model/provider, or send /new to reset the session."
return "Retry with a narrower follow-up, switch model/provider, or start a fresh chat."
def _looks_retryable(detail: str) -> bool:
"""Check if the error detail matches a transient failure pattern."""
if not detail:
return False
if _looks_rate_limited(detail):
return False
lower = detail.lower()
return any(
p in lower
for p in (
"failed to get response",
"operation was aborted",
"timed out",
"timeout",
"502",
"503",
"504",
"service unavailable",
"overloaded",
)
)
def _looks_rate_limited(detail: str) -> bool:
"""Check if the error is specifically a 429 / rate-limit."""
if not detail:
return False
lower = detail.lower()
return any(p in lower for p in ("429", "rate limit", "rate_limit", "too many requests"))
def _looks_image_unsupported(detail: str) -> bool:
"""Check if the error indicates the model does not accept image inputs."""
if not detail:
return False
lower = detail.lower()
return any(
p in lower
for p in ("0 image(s) may be provided", "does not support image", "image input is not supported", "images are not supported")
)
def _format_error_detail(error: Exception | str | None) -> str:
if error is None:
return ""
if isinstance(error, str):
return error.strip()
name = type(error).__name__
message = str(error).strip()
if not message or message == name:
return name
return f"{name}: {message}"
def _normalize_status_text(text: Any, *, prefix: str | None = None, limit: int = 220) -> str:
if text in (None, ""):
return ""
rendered = " ".join(str(text).split())
if not rendered:
return ""
if prefix:
rendered = f"{prefix}: {rendered}"
if limit > 0 and len(rendered) > limit:
return f"{rendered[: limit - 3].rstrip()}..."
return rendered
def _format_metric_number(value: Any) -> str:
try:
number = float(value)
except (TypeError, ValueError):
return str(value).strip()
if number.is_integer():
return f"{int(number):,}"
return f"{number:,.2f}".rstrip("0").rstrip(".")
def _format_code_changes_status(code_changes: Any) -> str:
if not code_changes:
return ""
files_modified = getattr(code_changes, "files_modified", None)
if files_modified is None and isinstance(code_changes, dict):
files_modified = code_changes.get("files_modified")
lines_added = getattr(code_changes, "lines_added", None)
if lines_added is None and isinstance(code_changes, dict):
lines_added = code_changes.get("lines_added")
lines_removed = getattr(code_changes, "lines_removed", None)
if lines_removed is None and isinstance(code_changes, dict):
lines_removed = code_changes.get("lines_removed")
parts: list[str] = []
if isinstance(files_modified, (list, tuple, set)):
parts.append(f"{len(files_modified)} files")
elif files_modified:
parts.append("1 file")
if lines_added not in (None, "") or lines_removed not in (None, ""):
parts.append(f"+{_format_metric_number(lines_added or 0)}/-{_format_metric_number(lines_removed or 0)} lines")
if not parts:
return "Code changes recorded"
return f"Code changes: {', '.join(parts)}"
def _format_cache_status(data: Data | None) -> str:
if data is None:
return ""
cache_read_tokens = getattr(data, "cache_read_tokens", None)
cache_write_tokens = getattr(data, "cache_write_tokens", None)
if cache_read_tokens in (None, "") and cache_write_tokens in (None, ""):
return ""
parts: list[str] = []
if cache_read_tokens not in (None, ""):
parts.append(f"read {_format_metric_number(cache_read_tokens)}")
if cache_write_tokens not in (None, ""):
parts.append(f"wrote {_format_metric_number(cache_write_tokens)}")
return f"Prompt cache: {', '.join(parts)}"
def _format_server_status(status: Any) -> str:
if status in (None, ""):
return ""
if isinstance(status, str):
return status.strip()
for attr in ("value", "status", "name"):
value = getattr(status, attr, None)
if isinstance(value, str) and value.strip():
return value.strip()
return str(status).strip()
def _iter_status_values(value: Any) -> list[Any]:
if value in (None, ""):
return []
if isinstance(value, (list, tuple, set)):
return list(value)
return [value]
def _stringify_trace_detail(value: Any, *, limit: int = 1800) -> str:
if value in (None, ""):
return ""
rendered = ""
if isinstance(value, str):
candidate = value.strip()
if candidate:
if candidate[:1] in {"{", "["}:
try:
rendered = json.dumps(json.loads(candidate), indent=2, ensure_ascii=False)
except Exception:
rendered = candidate
else:
rendered = candidate
elif isinstance(value, (dict, list, tuple)):
rendered = json.dumps(value, indent=2, ensure_ascii=False)
else:
rendered = str(value).strip()
if len(rendered) <= limit:
return rendered
return f"{rendered[: limit - 3].rstrip()}..."

767
web_fallback_store.py Normal file
View file

@ -0,0 +1,767 @@
"""Persistent thread store shared by the web UI and Telegram bot sessions."""
from __future__ import annotations
import asyncio
import json
import uuid
from collections import defaultdict
from decimal import Decimal
from pathlib import Path
from threading import RLock
from typing import Any
from local_media_store import LocalMediaStore
DEFAULT_TITLE = "New chat"
DEFAULT_PROVIDER = "openai"
DEFAULT_MODEL = "gpt-5.4"
DEFAULT_SOURCE = "web"
TITLE_SOURCE_DEFAULT = "default"
TITLE_SOURCE_AUTO = "auto"
TITLE_SOURCE_MANUAL = "manual"
TITLE_SOURCE_MAGIC = "magic"
KNOWN_TITLE_SOURCES = {
TITLE_SOURCE_DEFAULT,
TITLE_SOURCE_AUTO,
TITLE_SOURCE_MANUAL,
TITLE_SOURCE_MAGIC,
}
class WebFallbackStore:
def __init__(self, data_dir: str = "/data", media_store: LocalMediaStore | None = None):
self._threads: dict[str, dict[str, Any]] = {}
self._session_index: dict[str, str] = {}
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._state_lock = RLock()
self._file_path = Path(data_dir) / "web_threads.json"
self._media_store = media_store or LocalMediaStore(data_dir)
self._last_tg_update_id: int = 0
self._topic_system_prompts: dict[str, dict[str, Any]] = {}
self._global_learnings: list[dict[str, Any]] = []
self._load()
self._normalize_threads()
def list_threads(self) -> list[dict[str, Any]]:
threads = sorted(self._threads.values(), key=lambda thread: thread["updated_at"], reverse=True)
return [self._serialize_thread(thread, include_messages=False) for thread in threads]
def search_threads(self, query: str) -> list[dict[str, Any]]:
"""Search threads by keyword across title and message content."""
needle = query.strip().lower()
if not needle:
return self.list_threads()
matches: list[dict[str, Any]] = []
for thread in self._threads.values():
if needle in thread.get("title", "").lower():
matches.append(thread)
continue
for msg in thread.get("messages", []):
if needle in str(msg.get("content", "")).lower():
matches.append(thread)
break
matches.sort(key=lambda t: t["updated_at"], reverse=True)
return [self._serialize_thread(t, include_messages=False) for t in matches]
def create_thread(
self,
title: str,
model: str,
provider: str = DEFAULT_PROVIDER,
*,
source: str = DEFAULT_SOURCE,
session_label: str | None = None,
) -> dict[str, Any]:
with self._state_lock:
thread = self._build_thread(
title=title, model=model, provider=provider, source=source, session_label=session_label
)
self._threads[thread["id"]] = thread
self._persist_locked()
return self._serialize_thread(thread)
def get_thread(self, thread_id: str, *, include_messages: bool = True) -> dict[str, Any]:
thread = self._require_thread(thread_id)
return self._serialize_thread(thread, include_messages=include_messages)
def get_session_thread(self, session_key: str, *, include_messages: bool = True) -> dict[str, Any] | None:
with self._state_lock:
thread = self._get_session_thread_record(session_key)
if thread is None:
return None
return self._serialize_thread(thread, include_messages=include_messages)
def get_or_create_session_thread(
self,
*,
session_key: str,
title: str,
model: str,
provider: str = DEFAULT_PROVIDER,
source: str = "telegram",
session_label: str | None = None,
) -> dict[str, Any]:
with self._state_lock:
thread = self._get_session_thread_record(session_key)
changed = False
if thread is None:
thread = self._build_thread(
title=title,
model=model,
provider=provider,
source=source,
session_label=session_label,
)
self._threads[thread["id"]] = thread
self._session_index[session_key] = thread["id"]
changed = True
elif session_label and thread.get("session_label") != session_label:
thread["session_label"] = session_label
changed = True
if changed:
self._persist_locked()
return self._serialize_thread(thread)
def start_new_session_thread(
self,
*,
session_key: str,
title: str,
model: str,
provider: str = DEFAULT_PROVIDER,
source: str = "telegram",
session_label: str | None = None,
) -> dict[str, Any]:
with self._state_lock:
thread = self._build_thread(
title=title, model=model, provider=provider, source=source, session_label=session_label
)
self._threads[thread["id"]] = thread
self._session_index[session_key] = thread["id"]
self._persist_locked()
return self._serialize_thread(thread)
def update_thread(
self,
thread_id: str,
*,
title: str | None = None,
title_source: str | None = None,
model: str | None = None,
provider: str | None = None,
) -> dict[str, Any]:
with self._state_lock:
thread = self._require_thread(thread_id)
if title is not None:
normalized_title = str(title).strip() or DEFAULT_TITLE
thread["title"] = normalized_title
thread["title_source"] = self._normalize_title_source(title_source) or TITLE_SOURCE_MANUAL
if provider is not None:
normalized_provider = str(provider).strip().lower()
if not normalized_provider:
raise ValueError("Provider cannot be empty")
thread["provider"] = normalized_provider
if model is not None:
normalized_model = str(model).strip()
if not normalized_model:
raise ValueError("Model cannot be empty")
thread["model"] = normalized_model
thread["updated_at"] = self._timestamp()
self._persist_locked()
return self._serialize_thread(thread)
def delete_thread(self, thread_id: str) -> None:
with self._state_lock:
if thread_id not in self._threads:
raise KeyError(thread_id)
del self._threads[thread_id]
stale_session_keys = [
session_key
for session_key, mapped_thread_id in self._session_index.items()
if mapped_thread_id == thread_id
]
for session_key in stale_session_keys:
self._session_index.pop(session_key, None)
self._locks.pop(thread_id, None)
self._persist_locked()
def add_message(
self,
thread_id: str,
role: str,
content: str,
*,
parts: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
with self._state_lock:
thread = self._require_thread(thread_id)
now = self._timestamp()
normalized_parts = self._normalize_message_parts(parts)
message = {
"id": uuid.uuid4().hex,
"role": role,
"content": str(content or ""),
"created_at": now,
}
if normalized_parts:
message["parts"] = normalized_parts
thread["messages"].append(message)
thread["updated_at"] = now
title_source = self._message_text(message)
if (
role == "user"
and thread["title"] == DEFAULT_TITLE
and title_source
and not title_source.startswith("/")
):
thread["title"] = self._derive_title(title_source)
thread["title_source"] = TITLE_SOURCE_AUTO
self._persist_locked()
return self._serialize_message(message)
def remove_message(self, thread_id: str, message_id: str) -> None:
with self._state_lock:
thread = self._require_thread(thread_id)
original_count = len(thread["messages"])
thread["messages"] = [message for message in thread["messages"] if message["id"] != message_id]
if len(thread["messages"]) == original_count:
return
if thread["messages"]:
thread["updated_at"] = thread["messages"][-1]["created_at"]
else:
thread["updated_at"] = thread["created_at"]
if thread["title"] != DEFAULT_TITLE:
thread["title"] = DEFAULT_TITLE
thread["title_source"] = TITLE_SOURCE_DEFAULT
self._persist_locked()
def build_history(self, thread_id: str, *, limit: int | None = None) -> list[dict[str, str]]:
thread = self._require_thread(thread_id)
messages = thread["messages"][-limit:] if limit is not None else thread["messages"]
return [
{
"role": str(message.get("role") or "assistant"),
"content": self._message_text(message)
if str(message.get("role") or "").strip().lower() == "user"
else str(message.get("content") or ""),
}
for message in messages
]
def build_agent_history(self, thread_id: str, *, limit: int | None = None) -> list[dict[str, Any]]:
thread = self._require_thread(thread_id)
messages = thread["messages"][-limit:] if limit is not None else thread["messages"]
history: list[dict[str, Any]] = []
for message in messages:
role = str(message.get("role") or "assistant").strip().lower() or "assistant"
parts = self._normalize_message_parts(message.get("parts"))
if role == "user" and parts:
content_parts: list[dict[str, Any]] = []
for part in parts:
if part["type"] == "input_text":
content_parts.append({"type": "input_text", "text": part["text"]})
continue
try:
content_parts.append(
{
"type": "input_image",
"image_url": self._media_store.build_data_url(part["file_id"]),
"detail": part.get("detail") or "auto",
}
)
except KeyError:
missing_name = str(part.get("name") or "image")
content_parts.append(
{
"type": "input_text",
"text": f"A previously attached image named {missing_name} is no longer available.",
}
)
if content_parts:
history.append({"role": role, "content": content_parts})
continue
content = self._message_text(message) if role == "user" else str(message.get("content") or "")
if content:
history.append({"role": role, "content": content})
return history
def save_upload(self, *, name: str, mime_type: str, data: bytes, file_id: str | None = None) -> dict[str, Any]:
metadata = self._media_store.save_bytes(data, name=name, mime_type=mime_type, file_id=file_id)
return self._serialize_upload(metadata)
def get_upload(self, file_id: str) -> dict[str, Any]:
return self._serialize_upload(self._media_store.get_meta(file_id))
def read_upload(self, file_id: str) -> bytes:
return self._media_store.read_bytes(file_id)
def delete_upload(self, file_id: str) -> None:
self._media_store.delete(file_id)
def lock(self, thread_id: str) -> asyncio.Lock:
self._require_thread(thread_id)
return self._locks[thread_id]
def get_topic_system_prompt(self, session_key: str) -> dict[str, Any] | None:
with self._state_lock:
record = self._topic_system_prompts.get(str(session_key))
if not isinstance(record, dict):
return None
prompt = str(record.get("prompt") or "").strip()
if not prompt:
return None
return {
"session_key": str(session_key),
"prompt": prompt,
"updated_at": str(record.get("updated_at") or ""),
}
def set_topic_system_prompt(self, session_key: str, prompt: str) -> dict[str, Any]:
normalized_session_key = str(session_key).strip()
if not normalized_session_key:
raise ValueError("Session key cannot be empty")
normalized_prompt = str(prompt or "").strip()
if not normalized_prompt:
raise ValueError("System prompt cannot be empty")
with self._state_lock:
record = {
"prompt": normalized_prompt,
"updated_at": self._timestamp(),
}
self._topic_system_prompts[normalized_session_key] = record
self._persist_locked()
return {
"session_key": normalized_session_key,
"prompt": record["prompt"],
"updated_at": record["updated_at"],
}
def clear_topic_system_prompt(self, session_key: str) -> bool:
normalized_session_key = str(session_key).strip()
if not normalized_session_key:
return False
with self._state_lock:
removed = self._topic_system_prompts.pop(normalized_session_key, None)
if removed is None:
return False
self._persist_locked()
return True
def add_usage(self, thread_id: str, usage: dict[str, Any] | None) -> dict[str, Any]:
normalized_usage = self._normalize_usage(usage)
with self._state_lock:
thread = self._require_thread(thread_id)
thread["usage"] = self._merge_usage(thread.get("usage"), normalized_usage)
thread["updated_at"] = self._timestamp()
self._persist_locked()
return self._serialize_usage(thread.get("usage"))
def get_total_usage(self, thread_id: str) -> dict[str, Any]:
"""Return the cumulative usage for a thread."""
with self._state_lock:
thread = self._require_thread(thread_id)
return self._serialize_usage(thread.get("usage"))
# ── Project learnings (per-thread) ────────────────────────────────
MAX_PROJECT_LEARNINGS = 30
def get_project_learnings(self, thread_id: str) -> list[dict[str, Any]]:
"""Return project learnings for a thread."""
with self._state_lock:
thread = self._require_thread(thread_id)
return list(thread.get("project_learnings") or [])
def add_project_learning(self, thread_id: str, fact: str, *, category: str = "general") -> None:
"""Add a project-scoped learning to a thread."""
fact = str(fact or "").strip()
if not fact:
return
with self._state_lock:
thread = self._require_thread(thread_id)
learnings: list[dict[str, Any]] = list(thread.get("project_learnings") or [])
# Skip near-duplicates (same fact text ignoring case/whitespace)
normalized = fact.lower().strip()
for existing in learnings:
if existing.get("fact", "").lower().strip() == normalized:
return
learnings.append({"fact": fact, "category": str(category or "general").strip(), "updated_at": self._timestamp()})
# Cap size — keep most recent
if len(learnings) > self.MAX_PROJECT_LEARNINGS:
learnings = learnings[-self.MAX_PROJECT_LEARNINGS:]
thread["project_learnings"] = learnings
thread["updated_at"] = self._timestamp()
self._persist_locked()
# ── Global learnings (user preferences, cross-project) ────────────
MAX_GLOBAL_LEARNINGS = 40
def get_global_learnings(self) -> list[dict[str, Any]]:
"""Return all global (cross-project) learnings."""
with self._state_lock:
return list(self._global_learnings)
def add_global_learning(self, fact: str, *, category: str = "general") -> None:
"""Add a global learning (user preference, personality, contacts)."""
fact = str(fact or "").strip()
if not fact:
return
with self._state_lock:
normalized = fact.lower().strip()
for existing in self._global_learnings:
if existing.get("fact", "").lower().strip() == normalized:
return
self._global_learnings.append(
{"fact": fact, "category": str(category or "general").strip(), "updated_at": self._timestamp()}
)
if len(self._global_learnings) > self.MAX_GLOBAL_LEARNINGS:
self._global_learnings = self._global_learnings[-self.MAX_GLOBAL_LEARNINGS:]
self._persist_locked()
def _build_thread(
self,
*,
title: str,
model: str,
provider: str,
source: str,
session_label: str | None,
) -> dict[str, Any]:
now = self._timestamp()
normalized_session_label = str(session_label).strip() if session_label else ""
thread = {
"id": uuid.uuid4().hex,
"title": str(title).strip() or DEFAULT_TITLE,
"title_source": TITLE_SOURCE_DEFAULT
if (str(title).strip() or DEFAULT_TITLE) == DEFAULT_TITLE
else TITLE_SOURCE_MANUAL,
"provider": str(provider).strip().lower() or DEFAULT_PROVIDER,
"model": str(model).strip() or DEFAULT_MODEL,
"source": str(source).strip().lower() or DEFAULT_SOURCE,
"created_at": now,
"updated_at": now,
"messages": [],
"usage": self._serialize_usage(None),
}
if normalized_session_label:
thread["session_label"] = normalized_session_label
return thread
def _normalize_usage(self, usage: dict[str, Any] | None) -> dict[str, Any]:
if not isinstance(usage, dict):
return {}
token_fields = (
"prompt_tokens",
"completion_tokens",
"total_tokens",
"reasoning_tokens",
"cached_tokens",
"web_search_requests",
"request_count",
)
money_fields = ("cost_usd",)
text_fields = ("pricing_source",)
normalized: dict[str, Any] = {}
for field in token_fields:
value = usage.get(field)
if value in (None, ""):
continue
try:
normalized[field] = int(value)
except (TypeError, ValueError):
continue
for field in money_fields:
value = usage.get(field)
if value in (None, ""):
continue
try:
normalized[field] = format(Decimal(str(value)), "f")
except Exception:
continue
for field in text_fields:
value = usage.get(field)
if value not in (None, ""):
normalized[field] = str(value)
breakdown = usage.get("cost_breakdown")
if isinstance(breakdown, dict):
clean_breakdown: dict[str, str] = {}
for key, value in breakdown.items():
if value in (None, ""):
continue
try:
clean_breakdown[str(key)] = format(Decimal(str(value)), "f")
except Exception:
continue
if clean_breakdown:
normalized["cost_breakdown"] = clean_breakdown
return normalized
def _merge_usage(self, existing: Any, incoming: dict[str, Any]) -> dict[str, Any]:
current = self._normalize_usage(existing if isinstance(existing, dict) else {})
if not incoming:
return self._serialize_usage(current)
for field in (
"prompt_tokens",
"completion_tokens",
"total_tokens",
"reasoning_tokens",
"cached_tokens",
"web_search_requests",
"request_count",
):
current[field] = int(current.get(field, 0) or 0) + int(incoming.get(field, 0) or 0)
current_cost = Decimal(str(current.get("cost_usd", "0") or "0"))
incoming_cost = Decimal(str(incoming.get("cost_usd", "0") or "0"))
current["cost_usd"] = format(current_cost + incoming_cost, "f")
existing_breakdown = current.get("cost_breakdown") if isinstance(current.get("cost_breakdown"), dict) else {}
incoming_breakdown = incoming.get("cost_breakdown") if isinstance(incoming.get("cost_breakdown"), dict) else {}
merged_breakdown: dict[str, str] = {}
for key in set(existing_breakdown) | set(incoming_breakdown):
total = Decimal(str(existing_breakdown.get(key, "0") or "0")) + Decimal(
str(incoming_breakdown.get(key, "0") or "0")
)
if total:
merged_breakdown[key] = format(total, "f")
if merged_breakdown:
current["cost_breakdown"] = merged_breakdown
else:
current.pop("cost_breakdown", None)
pricing_source = incoming.get("pricing_source") or current.get("pricing_source")
if pricing_source:
current["pricing_source"] = str(pricing_source)
return self._serialize_usage(current)
def _serialize_usage(self, usage: Any) -> dict[str, Any]:
normalized = self._normalize_usage(usage if isinstance(usage, dict) else {})
normalized.setdefault("prompt_tokens", 0)
normalized.setdefault("completion_tokens", 0)
normalized.setdefault("total_tokens", normalized["prompt_tokens"] + normalized["completion_tokens"])
normalized.setdefault("reasoning_tokens", 0)
normalized.setdefault("cached_tokens", 0)
normalized.setdefault("web_search_requests", 0)
normalized.setdefault("request_count", 0)
normalized.setdefault("cost_usd", "0")
normalized.setdefault("cost_breakdown", {})
return normalized
def _serialize_thread(self, thread: dict[str, Any], *, include_messages: bool = True) -> dict[str, Any]:
data = {
"id": thread["id"],
"title": thread["title"],
"title_source": self._normalize_title_source(thread.get("title_source")) or TITLE_SOURCE_DEFAULT,
"provider": thread.get("provider") or DEFAULT_PROVIDER,
"model": thread["model"],
"source": thread.get("source") or DEFAULT_SOURCE,
"created_at": thread["created_at"],
"updated_at": thread["updated_at"],
"message_count": len(thread["messages"]),
"usage": self._serialize_usage(thread.get("usage")),
}
if thread.get("session_label"):
data["session_label"] = thread["session_label"]
if thread.get("project_learnings"):
data["project_learnings"] = list(thread["project_learnings"])
if include_messages:
data["messages"] = [self._serialize_message(message) for message in thread["messages"]]
return data
def _require_thread(self, thread_id: str) -> dict[str, Any]:
if thread_id not in self._threads:
raise KeyError(thread_id)
return self._threads[thread_id]
def _get_session_thread_record(self, session_key: str) -> dict[str, Any] | None:
thread_id = self._session_index.get(session_key)
if not thread_id:
return None
thread = self._threads.get(thread_id)
if thread is None:
self._session_index.pop(session_key, None)
self._persist_locked()
return None
return thread
def is_tg_update_seen(self, update_id: int) -> bool:
"""Return True if this Telegram update_id has already been processed."""
return update_id <= self._last_tg_update_id
def mark_tg_update(self, update_id: int) -> None:
"""Record a Telegram update_id as processed."""
with self._state_lock:
if update_id > self._last_tg_update_id:
self._last_tg_update_id = update_id
self._persist_locked()
def _persist_locked(self) -> None:
self._file_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"threads": self._threads,
"session_index": self._session_index,
"last_tg_update_id": self._last_tg_update_id,
"topic_system_prompts": self._topic_system_prompts,
"global_learnings": self._global_learnings,
}
self._file_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
def _load(self) -> None:
if not self._file_path.exists():
return
try:
payload = json.loads(self._file_path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
return
self._threads = payload.get("threads") or {}
self._session_index = payload.get("session_index") or {}
self._last_tg_update_id = int(payload.get("last_tg_update_id") or 0)
raw_prompts = payload.get("topic_system_prompts") or {}
self._topic_system_prompts = raw_prompts if isinstance(raw_prompts, dict) else {}
raw_learnings = payload.get("global_learnings") or []
self._global_learnings = raw_learnings if isinstance(raw_learnings, list) else []
def _normalize_threads(self) -> None:
changed = False
cleaned_topic_prompts: dict[str, dict[str, Any]] = {}
for session_key, raw_record in self._topic_system_prompts.items():
if not isinstance(raw_record, dict):
changed = True
continue
normalized_session_key = str(session_key).strip()
normalized_prompt = str(raw_record.get("prompt") or "").strip()
if not normalized_session_key or not normalized_prompt:
changed = True
continue
normalized_record = {
"prompt": normalized_prompt,
"updated_at": str(raw_record.get("updated_at") or self._timestamp()),
}
cleaned_topic_prompts[normalized_session_key] = normalized_record
if normalized_session_key != session_key or raw_record != normalized_record:
changed = True
if cleaned_topic_prompts != self._topic_system_prompts:
self._topic_system_prompts = cleaned_topic_prompts
changed = True
for thread in self._threads.values():
if not isinstance(thread.get("messages"), list):
thread["messages"] = []
changed = True
thread["title"] = str(thread.get("title") or DEFAULT_TITLE).strip() or DEFAULT_TITLE
normalized_title_source = self._normalize_title_source(thread.get("title_source"))
expected_title_source = normalized_title_source
if expected_title_source is None:
expected_title_source = (
TITLE_SOURCE_DEFAULT if thread["title"] == DEFAULT_TITLE else TITLE_SOURCE_MANUAL
)
if thread.get("title_source") != expected_title_source:
thread["title_source"] = expected_title_source
changed = True
provider = str(thread.get("provider") or DEFAULT_PROVIDER).strip().lower() or DEFAULT_PROVIDER
if thread.get("provider") != provider:
thread["provider"] = provider
changed = True
source = str(thread.get("source") or DEFAULT_SOURCE).strip().lower() or DEFAULT_SOURCE
if thread.get("source") != source:
thread["source"] = source
changed = True
usage = self._serialize_usage(thread.get("usage"))
if thread.get("usage") != usage:
thread["usage"] = usage
changed = True
if changed:
self._persist_locked()
def _serialize_message(self, message: dict[str, Any]) -> dict[str, Any]:
serialized = {
"id": message["id"],
"role": message["role"],
"content": message["content"],
"created_at": message["created_at"],
}
normalized_parts = self._normalize_message_parts(message.get("parts"))
if normalized_parts:
serialized["parts"] = normalized_parts
return serialized
def _serialize_upload(self, metadata: dict[str, Any]) -> dict[str, Any]:
return {
"id": str(metadata.get("id") or ""),
"name": str(metadata.get("name") or "upload"),
"mime_type": str(metadata.get("mime_type") or "application/octet-stream"),
"size": int(metadata.get("size") or 0),
"preview_url": f"/uploads/{metadata.get('id')}"
if str(metadata.get("mime_type") or "").startswith("image/")
else None,
}
def _normalize_message_parts(self, parts: Any) -> list[dict[str, Any]]:
if not isinstance(parts, list):
return []
normalized_parts: list[dict[str, Any]] = []
for part in parts:
if not isinstance(part, dict):
continue
part_type = str(part.get("type") or "").strip()
if part_type == "input_text":
text = str(part.get("text") or "")
if text:
normalized_parts.append({"type": "input_text", "text": text})
continue
if part_type == "input_image":
file_id = str(part.get("file_id") or "").strip()
if not file_id:
continue
normalized_part = {"type": "input_image", "file_id": file_id}
if part.get("name"):
normalized_part["name"] = str(part.get("name"))
if part.get("mime_type"):
normalized_part["mime_type"] = str(part.get("mime_type"))
if part.get("detail"):
normalized_part["detail"] = str(part.get("detail"))
normalized_parts.append(normalized_part)
return normalized_parts
def _message_text(self, message: dict[str, Any]) -> str:
parts = self._normalize_message_parts(message.get("parts"))
if parts:
texts = [str(part.get("text") or "") for part in parts if part.get("type") == "input_text"]
text = "\n".join(part for part in texts if part).strip()
if text:
return text
return str(message.get("content") or "").strip()
def _derive_title(self, content: str) -> str:
single_line = " ".join(str(content).split())
if len(single_line) <= 48:
return single_line
return single_line[:45].rstrip() + "..."
def _normalize_title_source(self, raw_value: Any) -> str | None:
value = str(raw_value or "").strip().lower()
if value in KNOWN_TITLE_SOURCES:
return value
return None
def _timestamp(self) -> str:
from datetime import datetime, timezone
return datetime.now(timezone.utc).isoformat()