from __future__ import annotations import logging import time from decimal import ROUND_HALF_UP, Decimal from typing import Any from config import settings logger = logging.getLogger(__name__) OPENROUTER_PRICING_SOURCE = "https://openrouter.ai/api/v1/models" VERCEL_GENERATION_URL = "https://ai-gateway.vercel.sh/v1/generation" ZERO = Decimal("0") USD_DISPLAY_QUANT = Decimal("0.000001") _PRICING_CACHE_TTL = 3600 # 1 hour # In-memory cache: model_id โ†’ (pricing_dict, fetched_at) _pricing_cache: dict[str, tuple[dict[str, Decimal], float]] = {} # Maps bare model prefixes to their OpenRouter vendor slug so we can look up # pricing even when the model was routed directly (e.g. openai provider). _VENDOR_PREFIXES: tuple[tuple[str, str], ...] = ( ("gpt-", "openai"), ("chatgpt-", "openai"), ("o1", "openai"), ("o3", "openai"), ("o4", "openai"), ("computer-use-", "openai"), ("claude-", "anthropic"), ("gemini-", "google"), ("deepseek-", "deepseek"), ) def _to_openrouter_model_id(model: str, provider: str) -> str: """Normalise a model string into the ``vendor/model`` format OpenRouter uses.""" model = str(model or "").strip() if "/" in model: return model lowered = model.lower() for prefix, vendor in _VENDOR_PREFIXES: if lowered.startswith(prefix): return f"{vendor}/{model}" return model def _to_decimal(value: Any) -> Decimal: try: if value in (None, ""): return ZERO return Decimal(str(value)) except Exception: return ZERO def _to_int(value: Any) -> int: try: if value in (None, ""): return 0 return int(value) except Exception: return 0 def fetch_vercel_generation_cost(api_call_ids: list[str]) -> Decimal | None: """Look up exact cost from Vercel AI Gateway for generation IDs. Returns total cost in USD across all generations, or None if lookup fails. """ if not settings.VERCEL_API_KEY or not api_call_ids: return None try: import httpx except Exception: return None total_cost = ZERO found_any = False with httpx.Client(timeout=10.0) as client: for gen_id in api_call_ids: try: response = client.get( VERCEL_GENERATION_URL, params={"id": gen_id}, headers={"Authorization": f"Bearer {settings.VERCEL_API_KEY}"}, ) response.raise_for_status() data = response.json() cost = data.get("total_cost") if cost is not None: total_cost += _to_decimal(cost) found_any = True else: logger.debug("Vercel generation %s returned no total_cost: %s", gen_id, data) except Exception: logger.warning("Failed to fetch Vercel generation cost for %s", gen_id) continue return total_cost if found_any else None def find_openrouter_pricing(model_id: str) -> dict[str, Decimal]: try: import httpx except Exception: return {} model_id = str(model_id or "").strip() if not model_id: return {} # Check in-memory cache cached = _pricing_cache.get(model_id) if cached is not None: pricing_data, fetched_at = cached if time.monotonic() - fetched_at < _PRICING_CACHE_TTL: return pricing_data headers = {} if settings.OPENROUTER_API_KEY: headers["Authorization"] = f"Bearer {settings.OPENROUTER_API_KEY}" try: with httpx.Client(timeout=20.0) as client: response = client.get(OPENROUTER_PRICING_SOURCE, headers=headers) response.raise_for_status() payload = response.json() except Exception: logger.warning("Failed to fetch OpenRouter pricing for %s", model_id) return {} for item in payload.get("data", []): if str(item.get("id") or "").strip() != model_id: continue pricing = item.get("pricing") or {} result: dict[str, Decimal] = {} for key in ("prompt", "completion", "input_cache_read", "web_search"): amount = _to_decimal(pricing.get(key)) if amount: result[key] = amount _pricing_cache[model_id] = (result, time.monotonic()) return result # Cache the miss too so we don't re-fetch for unknown models _pricing_cache[model_id] = ({}, time.monotonic()) return {} def _aggregate_sdk_usage(events: list[Any]) -> dict[str, int]: """Aggregate token counts from Copilot SDK events. Prefers session.usage_info (aggregated model_metrics) when available, otherwise sums individual assistant.usage events. """ from copilot.generated.session_events import SessionEventType # Collect api_call_ids from ASSISTANT_USAGE events (used for Vercel cost lookup) api_call_ids: list[str] = [] for event in events: if getattr(event, "type", None) != SessionEventType.ASSISTANT_USAGE: continue data = getattr(event, "data", None) call_id = getattr(data, "api_call_id", None) if data else None if call_id: api_call_ids.append(str(call_id)) if api_call_ids: logger.debug("Collected %d api_call_ids (first: %s)", len(api_call_ids), api_call_ids[0]) else: logger.debug("No api_call_ids found in %d events", len(events)) # Try session.usage_info first โ€” it carries aggregated model_metrics for event in reversed(events): if getattr(event, "type", None) != SessionEventType.SESSION_USAGE_INFO: continue data = getattr(event, "data", None) metrics = getattr(data, "model_metrics", None) if not metrics: continue agg: dict[str, Any] = { "input_tokens": 0, "output_tokens": 0, "cache_read_tokens": 0, "cache_write_tokens": 0, "request_count": 0, "api_call_ids": api_call_ids, } for metric in metrics.values(): usage = getattr(metric, "usage", None) if usage: agg["input_tokens"] += _to_int(getattr(usage, "input_tokens", 0)) agg["output_tokens"] += _to_int(getattr(usage, "output_tokens", 0)) agg["cache_read_tokens"] += _to_int(getattr(usage, "cache_read_tokens", 0)) agg["cache_write_tokens"] += _to_int(getattr(usage, "cache_write_tokens", 0)) requests = getattr(metric, "requests", None) if requests: agg["request_count"] += _to_int(getattr(requests, "count", 0)) return agg # Fall back to summing individual assistant.usage events agg = { "input_tokens": 0, "output_tokens": 0, "cache_read_tokens": 0, "cache_write_tokens": 0, "request_count": 0, "api_call_ids": api_call_ids, } found = False for event in events: if getattr(event, "type", None) != SessionEventType.ASSISTANT_USAGE: continue data = getattr(event, "data", None) if data is None: continue found = True agg["input_tokens"] += _to_int(getattr(data, "input_tokens", 0)) agg["output_tokens"] += _to_int(getattr(data, "output_tokens", 0)) agg["cache_read_tokens"] += _to_int(getattr(data, "cache_read_tokens", 0)) agg["cache_write_tokens"] += _to_int(getattr(data, "cache_write_tokens", 0)) agg["request_count"] += 1 if found: return agg return {} def extract_usage_and_cost( model: str, provider: str, result: Any, advisor_usage: dict[str, int] | None = None, ) -> dict[str, Any]: # When given a list of SDK events, aggregate directly from typed events if isinstance(result, list): sdk_usage = _aggregate_sdk_usage(result) else: sdk_usage = {} # Fall back to legacy extraction for non-list results if not sdk_usage: single = _find_usage_event(result) if isinstance(result, list) else result usage = _extract_usage_payload(single) prompt_tokens = _pick_first_int(usage, "input_tokens", "prompt_tokens", "promptTokens") completion_tokens = _pick_first_int( usage, "output_tokens", "completion_tokens", "completionTokens", "outputTokens" ) reasoning_tokens = _pick_nested_int( usage, [("output_tokens_details", "reasoning_tokens"), ("completion_tokens_details", "reasoning_tokens")] ) cached_tokens = _pick_nested_int( usage, [("input_tokens_details", "cached_tokens"), ("prompt_tokens_details", "cached_tokens")] ) request_count = 1 else: prompt_tokens = sdk_usage.get("input_tokens", 0) completion_tokens = sdk_usage.get("output_tokens", 0) cached_tokens = sdk_usage.get("cache_read_tokens", 0) reasoning_tokens = 0 request_count = sdk_usage.get("request_count", 1) or 1 api_call_ids: list[str] = sdk_usage.get("api_call_ids", []) if sdk_usage else [] total_tokens = prompt_tokens + completion_tokens summary: dict[str, Any] = { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens, "reasoning_tokens": reasoning_tokens, "cached_tokens": cached_tokens, "request_count": request_count, "cost_usd": "0", "cost_breakdown": {}, } # Merge advisor sidecar usage when provided if advisor_usage: summary["advisor_prompt_tokens"] = advisor_usage.get("prompt_tokens", 0) summary["advisor_completion_tokens"] = advisor_usage.get("completion_tokens", 0) summary["advisor_total_tokens"] = ( summary["advisor_prompt_tokens"] + summary["advisor_completion_tokens"] ) # Prefer Vercel generation cost lookup (exact cost from gateway) if provider == "vercel" and api_call_ids: vercel_cost = fetch_vercel_generation_cost(api_call_ids) if vercel_cost is not None: summary.update( { "cost_usd": _format_usd(vercel_cost), "pricing_source": VERCEL_GENERATION_URL, } ) return summary logger.debug("Vercel generation cost lookup returned None for %d IDs", len(api_call_ids)) elif provider == "vercel": logger.debug("Vercel provider but no api_call_ids available โ€” falling back to OpenRouter pricing") # Fall back to OpenRouter pricing catalog (token-rate estimation) pricing = find_openrouter_pricing(_to_openrouter_model_id(model, provider)) if not pricing: return summary prompt_rate = pricing.get("prompt", ZERO) completion_rate = pricing.get("completion", ZERO) cache_rate = pricing.get("input_cache_read", ZERO) web_search_rate = pricing.get("web_search", ZERO) uncached_prompt_tokens = max(prompt_tokens - cached_tokens, 0) prompt_cost = Decimal(uncached_prompt_tokens) * prompt_rate completion_cost = Decimal(completion_tokens) * completion_rate cache_cost = Decimal(cached_tokens) * cache_rate web_search_requests = _count_web_search_requests(result) web_search_cost = Decimal(web_search_requests) * web_search_rate total_cost = prompt_cost + completion_cost + cache_cost + web_search_cost breakdown = {} if prompt_cost: breakdown["prompt"] = _format_usd(prompt_cost) if completion_cost: breakdown["completion"] = _format_usd(completion_cost) if cache_cost: breakdown["cache_read"] = _format_usd(cache_cost) if web_search_cost: breakdown["web_search"] = _format_usd(web_search_cost) summary.update( { "web_search_requests": web_search_requests, "cost_usd": _format_usd(total_cost), "cost_breakdown": breakdown, "pricing_source": OPENROUTER_PRICING_SOURCE, } ) return summary def format_usage_line(usage: dict[str, Any] | None) -> str: if not isinstance(usage, dict): return "" prompt_tokens = _to_int(usage.get("prompt_tokens")) completion_tokens = _to_int(usage.get("completion_tokens")) total_tokens = _to_int(usage.get("total_tokens")) cost_usd = _format_display_usd(usage.get("cost_usd")) return f"Usage: {total_tokens:,} tok ({prompt_tokens:,} in / {completion_tokens:,} out) ยท Cost: ${cost_usd}" def _extract_usage_payload(result: Any) -> dict[str, Any]: for attr in ("usage",): value = getattr(result, attr, None) if value: return _to_dict(value) data = getattr(result, "data", None) if data is not None: usage = getattr(data, "usage", None) if usage: return _to_dict(usage) if isinstance(result, dict): usage = result.get("usage") if usage: return _to_dict(usage) return {} def _to_dict(value: Any) -> dict[str, Any]: if isinstance(value, dict): return value if hasattr(value, "model_dump"): try: return value.model_dump() except Exception: return {} if hasattr(value, "dict"): try: return value.dict() except Exception: return {} return {k: getattr(value, k) for k in dir(value) if not k.startswith("_") and not callable(getattr(value, k))} def _pick_first_int(data: dict[str, Any], *keys: str) -> int: for key in keys: if key in data: value = _to_int(data.get(key)) if value: return value return 0 def _pick_nested_int(data: dict[str, Any], paths: list[tuple[str, str]]) -> int: for first, second in paths: parent = data.get(first) if isinstance(parent, dict) and second in parent: value = _to_int(parent.get(second)) if value: return value return 0 def _count_web_search_requests(result: Any) -> int: events = getattr(result, "events", None) if not events: return 0 count = 0 for event in events: tool_name = getattr(getattr(event, "data", None), "tool_name", None) or getattr( getattr(event, "data", None), "name", None ) if str(tool_name or "").strip().lower() == "web_search": count += 1 return count def _format_usd(value: Decimal) -> str: return format(value.quantize(Decimal("0.000000000001"), rounding=ROUND_HALF_UP).normalize(), "f") def _format_display_usd(value: Any) -> str: amount = _to_decimal(value).quantize(USD_DISPLAY_QUANT, rounding=ROUND_HALF_UP) return format(amount, "f") def _find_usage_event(events: list[Any]) -> Any: for event in reversed(events): if _extract_usage_payload(event): return event return events[-1] if events else None def format_cost_value(value: Any) -> str: return _format_display_usd(value) def summarize_usage(usage: dict[str, Any] | None) -> dict[str, Any] | None: if not isinstance(usage, dict): return None prompt_tokens = _to_int(usage.get("prompt_tokens")) completion_tokens = _to_int(usage.get("completion_tokens")) total_tokens = _to_int(usage.get("total_tokens")) or (prompt_tokens + completion_tokens) request_count = _to_int(usage.get("request_count")) cost_usd = _to_decimal(usage.get("cost_usd")) if request_count <= 0 and total_tokens <= 0 and cost_usd <= ZERO: return None return { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens, "reasoning_tokens": _to_int(usage.get("reasoning_tokens")), "cached_tokens": _to_int(usage.get("cached_tokens")), "request_count": max(request_count, 0), "cost_usd": format(cost_usd, "f"), }