betterbot/llm_costs.py
Andre K e68c84424f
Some checks failed
Deploy BetterBot / deploy (push) Failing after 3s
Deploy BetterBot / notify (push) Successful in 3s
feat: fork from CodeAnywhere framework
Replace standalone Telegram bot with full CodeAnywhere framework fork.
BetterBot shares all framework code and customizes only:
- instance.py: BetterBot identity, system prompt, feature flags
- tools/site_editing/: list_files, read_file, write_file with auto git push
- .env: model defaults and site directory paths
- compose/: Docker setup with betterlifesg + memoraiz mounts
- deploy script: RackNerd with Infisical secrets
2026-04-19 08:01:27 +08:00

450 lines
16 KiB
Python

from __future__ import annotations
import logging
import time
from decimal import ROUND_HALF_UP, Decimal
from typing import Any
from config import settings
logger = logging.getLogger(__name__)
OPENROUTER_PRICING_SOURCE = "https://openrouter.ai/api/v1/models"
VERCEL_GENERATION_URL = "https://ai-gateway.vercel.sh/v1/generation"
ZERO = Decimal("0")
USD_DISPLAY_QUANT = Decimal("0.000001")
_PRICING_CACHE_TTL = 3600 # 1 hour
# In-memory cache: model_id → (pricing_dict, fetched_at)
_pricing_cache: dict[str, tuple[dict[str, Decimal], float]] = {}
# Maps bare model prefixes to their OpenRouter vendor slug so we can look up
# pricing even when the model was routed directly (e.g. openai provider).
_VENDOR_PREFIXES: tuple[tuple[str, str], ...] = (
("gpt-", "openai"),
("chatgpt-", "openai"),
("o1", "openai"),
("o3", "openai"),
("o4", "openai"),
("computer-use-", "openai"),
("claude-", "anthropic"),
("gemini-", "google"),
("deepseek-", "deepseek"),
)
def _to_openrouter_model_id(model: str, provider: str) -> str:
"""Normalise a model string into the ``vendor/model`` format OpenRouter uses."""
model = str(model or "").strip()
if "/" in model:
return model
lowered = model.lower()
for prefix, vendor in _VENDOR_PREFIXES:
if lowered.startswith(prefix):
return f"{vendor}/{model}"
return model
def _to_decimal(value: Any) -> Decimal:
try:
if value in (None, ""):
return ZERO
return Decimal(str(value))
except Exception:
return ZERO
def _to_int(value: Any) -> int:
try:
if value in (None, ""):
return 0
return int(value)
except Exception:
return 0
def fetch_vercel_generation_cost(api_call_ids: list[str]) -> Decimal | None:
"""Look up exact cost from Vercel AI Gateway for generation IDs.
Returns total cost in USD across all generations, or None if lookup fails.
"""
if not settings.VERCEL_API_KEY or not api_call_ids:
return None
try:
import httpx
except Exception:
return None
total_cost = ZERO
found_any = False
with httpx.Client(timeout=10.0) as client:
for gen_id in api_call_ids:
try:
response = client.get(
VERCEL_GENERATION_URL,
params={"id": gen_id},
headers={"Authorization": f"Bearer {settings.VERCEL_API_KEY}"},
)
response.raise_for_status()
data = response.json()
cost = data.get("total_cost")
if cost is not None:
total_cost += _to_decimal(cost)
found_any = True
else:
logger.debug("Vercel generation %s returned no total_cost: %s", gen_id, data)
except Exception:
logger.warning("Failed to fetch Vercel generation cost for %s", gen_id)
continue
return total_cost if found_any else None
def find_openrouter_pricing(model_id: str) -> dict[str, Decimal]:
try:
import httpx
except Exception:
return {}
model_id = str(model_id or "").strip()
if not model_id:
return {}
# Check in-memory cache
cached = _pricing_cache.get(model_id)
if cached is not None:
pricing_data, fetched_at = cached
if time.monotonic() - fetched_at < _PRICING_CACHE_TTL:
return pricing_data
headers = {}
if settings.OPENROUTER_API_KEY:
headers["Authorization"] = f"Bearer {settings.OPENROUTER_API_KEY}"
try:
with httpx.Client(timeout=20.0) as client:
response = client.get(OPENROUTER_PRICING_SOURCE, headers=headers)
response.raise_for_status()
payload = response.json()
except Exception:
logger.warning("Failed to fetch OpenRouter pricing for %s", model_id)
return {}
for item in payload.get("data", []):
if str(item.get("id") or "").strip() != model_id:
continue
pricing = item.get("pricing") or {}
result: dict[str, Decimal] = {}
for key in ("prompt", "completion", "input_cache_read", "web_search"):
amount = _to_decimal(pricing.get(key))
if amount:
result[key] = amount
_pricing_cache[model_id] = (result, time.monotonic())
return result
# Cache the miss too so we don't re-fetch for unknown models
_pricing_cache[model_id] = ({}, time.monotonic())
return {}
def _aggregate_sdk_usage(events: list[Any]) -> dict[str, int]:
"""Aggregate token counts from Copilot SDK events.
Prefers session.usage_info (aggregated model_metrics) when available,
otherwise sums individual assistant.usage events.
"""
from copilot.generated.session_events import SessionEventType
# Collect api_call_ids from ASSISTANT_USAGE events (used for Vercel cost lookup)
api_call_ids: list[str] = []
for event in events:
if getattr(event, "type", None) != SessionEventType.ASSISTANT_USAGE:
continue
data = getattr(event, "data", None)
call_id = getattr(data, "api_call_id", None) if data else None
if call_id:
api_call_ids.append(str(call_id))
if api_call_ids:
logger.debug("Collected %d api_call_ids (first: %s)", len(api_call_ids), api_call_ids[0])
else:
logger.debug("No api_call_ids found in %d events", len(events))
# Try session.usage_info first — it carries aggregated model_metrics
for event in reversed(events):
if getattr(event, "type", None) != SessionEventType.SESSION_USAGE_INFO:
continue
data = getattr(event, "data", None)
metrics = getattr(data, "model_metrics", None)
if not metrics:
continue
agg: dict[str, Any] = {
"input_tokens": 0,
"output_tokens": 0,
"cache_read_tokens": 0,
"cache_write_tokens": 0,
"request_count": 0,
"api_call_ids": api_call_ids,
}
for metric in metrics.values():
usage = getattr(metric, "usage", None)
if usage:
agg["input_tokens"] += _to_int(getattr(usage, "input_tokens", 0))
agg["output_tokens"] += _to_int(getattr(usage, "output_tokens", 0))
agg["cache_read_tokens"] += _to_int(getattr(usage, "cache_read_tokens", 0))
agg["cache_write_tokens"] += _to_int(getattr(usage, "cache_write_tokens", 0))
requests = getattr(metric, "requests", None)
if requests:
agg["request_count"] += _to_int(getattr(requests, "count", 0))
return agg
# Fall back to summing individual assistant.usage events
agg = {
"input_tokens": 0,
"output_tokens": 0,
"cache_read_tokens": 0,
"cache_write_tokens": 0,
"request_count": 0,
"api_call_ids": api_call_ids,
}
found = False
for event in events:
if getattr(event, "type", None) != SessionEventType.ASSISTANT_USAGE:
continue
data = getattr(event, "data", None)
if data is None:
continue
found = True
agg["input_tokens"] += _to_int(getattr(data, "input_tokens", 0))
agg["output_tokens"] += _to_int(getattr(data, "output_tokens", 0))
agg["cache_read_tokens"] += _to_int(getattr(data, "cache_read_tokens", 0))
agg["cache_write_tokens"] += _to_int(getattr(data, "cache_write_tokens", 0))
agg["request_count"] += 1
if found:
return agg
return {}
def extract_usage_and_cost(
model: str,
provider: str,
result: Any,
advisor_usage: dict[str, int] | None = None,
) -> dict[str, Any]:
# When given a list of SDK events, aggregate directly from typed events
if isinstance(result, list):
sdk_usage = _aggregate_sdk_usage(result)
else:
sdk_usage = {}
# Fall back to legacy extraction for non-list results
if not sdk_usage:
single = _find_usage_event(result) if isinstance(result, list) else result
usage = _extract_usage_payload(single)
prompt_tokens = _pick_first_int(usage, "input_tokens", "prompt_tokens", "promptTokens")
completion_tokens = _pick_first_int(
usage, "output_tokens", "completion_tokens", "completionTokens", "outputTokens"
)
reasoning_tokens = _pick_nested_int(
usage, [("output_tokens_details", "reasoning_tokens"), ("completion_tokens_details", "reasoning_tokens")]
)
cached_tokens = _pick_nested_int(
usage, [("input_tokens_details", "cached_tokens"), ("prompt_tokens_details", "cached_tokens")]
)
request_count = 1
else:
prompt_tokens = sdk_usage.get("input_tokens", 0)
completion_tokens = sdk_usage.get("output_tokens", 0)
cached_tokens = sdk_usage.get("cache_read_tokens", 0)
reasoning_tokens = 0
request_count = sdk_usage.get("request_count", 1) or 1
api_call_ids: list[str] = sdk_usage.get("api_call_ids", []) if sdk_usage else []
total_tokens = prompt_tokens + completion_tokens
summary: dict[str, Any] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"reasoning_tokens": reasoning_tokens,
"cached_tokens": cached_tokens,
"request_count": request_count,
"cost_usd": "0",
"cost_breakdown": {},
}
# Merge advisor sidecar usage when provided
if advisor_usage:
summary["advisor_prompt_tokens"] = advisor_usage.get("prompt_tokens", 0)
summary["advisor_completion_tokens"] = advisor_usage.get("completion_tokens", 0)
summary["advisor_total_tokens"] = (
summary["advisor_prompt_tokens"] + summary["advisor_completion_tokens"]
)
# Prefer Vercel generation cost lookup (exact cost from gateway)
if provider == "vercel" and api_call_ids:
vercel_cost = fetch_vercel_generation_cost(api_call_ids)
if vercel_cost is not None:
summary.update(
{
"cost_usd": _format_usd(vercel_cost),
"pricing_source": VERCEL_GENERATION_URL,
}
)
return summary
logger.debug("Vercel generation cost lookup returned None for %d IDs", len(api_call_ids))
elif provider == "vercel":
logger.debug("Vercel provider but no api_call_ids available — falling back to OpenRouter pricing")
# Fall back to OpenRouter pricing catalog (token-rate estimation)
pricing = find_openrouter_pricing(_to_openrouter_model_id(model, provider))
if not pricing:
return summary
prompt_rate = pricing.get("prompt", ZERO)
completion_rate = pricing.get("completion", ZERO)
cache_rate = pricing.get("input_cache_read", ZERO)
web_search_rate = pricing.get("web_search", ZERO)
uncached_prompt_tokens = max(prompt_tokens - cached_tokens, 0)
prompt_cost = Decimal(uncached_prompt_tokens) * prompt_rate
completion_cost = Decimal(completion_tokens) * completion_rate
cache_cost = Decimal(cached_tokens) * cache_rate
web_search_requests = _count_web_search_requests(result)
web_search_cost = Decimal(web_search_requests) * web_search_rate
total_cost = prompt_cost + completion_cost + cache_cost + web_search_cost
breakdown = {}
if prompt_cost:
breakdown["prompt"] = _format_usd(prompt_cost)
if completion_cost:
breakdown["completion"] = _format_usd(completion_cost)
if cache_cost:
breakdown["cache_read"] = _format_usd(cache_cost)
if web_search_cost:
breakdown["web_search"] = _format_usd(web_search_cost)
summary.update(
{
"web_search_requests": web_search_requests,
"cost_usd": _format_usd(total_cost),
"cost_breakdown": breakdown,
"pricing_source": OPENROUTER_PRICING_SOURCE,
}
)
return summary
def format_usage_line(usage: dict[str, Any] | None) -> str:
if not isinstance(usage, dict):
return ""
prompt_tokens = _to_int(usage.get("prompt_tokens"))
completion_tokens = _to_int(usage.get("completion_tokens"))
total_tokens = _to_int(usage.get("total_tokens"))
cost_usd = _format_display_usd(usage.get("cost_usd"))
return f"Usage: {total_tokens:,} tok ({prompt_tokens:,} in / {completion_tokens:,} out) · Cost: ${cost_usd}"
def _extract_usage_payload(result: Any) -> dict[str, Any]:
for attr in ("usage",):
value = getattr(result, attr, None)
if value:
return _to_dict(value)
data = getattr(result, "data", None)
if data is not None:
usage = getattr(data, "usage", None)
if usage:
return _to_dict(usage)
if isinstance(result, dict):
usage = result.get("usage")
if usage:
return _to_dict(usage)
return {}
def _to_dict(value: Any) -> dict[str, Any]:
if isinstance(value, dict):
return value
if hasattr(value, "model_dump"):
try:
return value.model_dump()
except Exception:
return {}
if hasattr(value, "dict"):
try:
return value.dict()
except Exception:
return {}
return {k: getattr(value, k) for k in dir(value) if not k.startswith("_") and not callable(getattr(value, k))}
def _pick_first_int(data: dict[str, Any], *keys: str) -> int:
for key in keys:
if key in data:
value = _to_int(data.get(key))
if value:
return value
return 0
def _pick_nested_int(data: dict[str, Any], paths: list[tuple[str, str]]) -> int:
for first, second in paths:
parent = data.get(first)
if isinstance(parent, dict) and second in parent:
value = _to_int(parent.get(second))
if value:
return value
return 0
def _count_web_search_requests(result: Any) -> int:
events = getattr(result, "events", None)
if not events:
return 0
count = 0
for event in events:
tool_name = getattr(getattr(event, "data", None), "tool_name", None) or getattr(
getattr(event, "data", None), "name", None
)
if str(tool_name or "").strip().lower() == "web_search":
count += 1
return count
def _format_usd(value: Decimal) -> str:
return format(value.quantize(Decimal("0.000000000001"), rounding=ROUND_HALF_UP).normalize(), "f")
def _format_display_usd(value: Any) -> str:
amount = _to_decimal(value).quantize(USD_DISPLAY_QUANT, rounding=ROUND_HALF_UP)
return format(amount, "f")
def _find_usage_event(events: list[Any]) -> Any:
for event in reversed(events):
if _extract_usage_payload(event):
return event
return events[-1] if events else None
def format_cost_value(value: Any) -> str:
return _format_display_usd(value)
def summarize_usage(usage: dict[str, Any] | None) -> dict[str, Any] | None:
if not isinstance(usage, dict):
return None
prompt_tokens = _to_int(usage.get("prompt_tokens"))
completion_tokens = _to_int(usage.get("completion_tokens"))
total_tokens = _to_int(usage.get("total_tokens")) or (prompt_tokens + completion_tokens)
request_count = _to_int(usage.get("request_count"))
cost_usd = _to_decimal(usage.get("cost_usd"))
if request_count <= 0 and total_tokens <= 0 and cost_usd <= ZERO:
return None
return {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"reasoning_tokens": _to_int(usage.get("reasoning_tokens")),
"cached_tokens": _to_int(usage.get("cached_tokens")),
"request_count": max(request_count, 0),
"cost_usd": format(cost_usd, "f"),
}