"""Persistent thread store shared by the web UI and Telegram bot sessions.""" from __future__ import annotations import asyncio import json import uuid from collections import defaultdict from decimal import Decimal from pathlib import Path from threading import RLock from typing import Any from local_media_store import LocalMediaStore DEFAULT_TITLE = "New chat" DEFAULT_PROVIDER = "openai" DEFAULT_MODEL = "gpt-5.4" DEFAULT_SOURCE = "web" TITLE_SOURCE_DEFAULT = "default" TITLE_SOURCE_AUTO = "auto" TITLE_SOURCE_MANUAL = "manual" TITLE_SOURCE_MAGIC = "magic" KNOWN_TITLE_SOURCES = { TITLE_SOURCE_DEFAULT, TITLE_SOURCE_AUTO, TITLE_SOURCE_MANUAL, TITLE_SOURCE_MAGIC, } class WebFallbackStore: def __init__(self, data_dir: str = "/data", media_store: LocalMediaStore | None = None): self._threads: dict[str, dict[str, Any]] = {} self._session_index: dict[str, str] = {} self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) self._state_lock = RLock() self._file_path = Path(data_dir) / "web_threads.json" self._media_store = media_store or LocalMediaStore(data_dir) self._last_tg_update_id: int = 0 self._topic_system_prompts: dict[str, dict[str, Any]] = {} self._global_learnings: list[dict[str, Any]] = [] self._load() self._normalize_threads() def list_threads(self) -> list[dict[str, Any]]: threads = sorted(self._threads.values(), key=lambda thread: thread["updated_at"], reverse=True) return [self._serialize_thread(thread, include_messages=False) for thread in threads] def search_threads(self, query: str) -> list[dict[str, Any]]: """Search threads by keyword across title and message content.""" needle = query.strip().lower() if not needle: return self.list_threads() matches: list[dict[str, Any]] = [] for thread in self._threads.values(): if needle in thread.get("title", "").lower(): matches.append(thread) continue for msg in thread.get("messages", []): if needle in str(msg.get("content", "")).lower(): matches.append(thread) break matches.sort(key=lambda t: t["updated_at"], reverse=True) return [self._serialize_thread(t, include_messages=False) for t in matches] def create_thread( self, title: str, model: str, provider: str = DEFAULT_PROVIDER, *, source: str = DEFAULT_SOURCE, session_label: str | None = None, ) -> dict[str, Any]: with self._state_lock: thread = self._build_thread( title=title, model=model, provider=provider, source=source, session_label=session_label ) self._threads[thread["id"]] = thread self._persist_locked() return self._serialize_thread(thread) def get_thread(self, thread_id: str, *, include_messages: bool = True) -> dict[str, Any]: thread = self._require_thread(thread_id) return self._serialize_thread(thread, include_messages=include_messages) def get_session_thread(self, session_key: str, *, include_messages: bool = True) -> dict[str, Any] | None: with self._state_lock: thread = self._get_session_thread_record(session_key) if thread is None: return None return self._serialize_thread(thread, include_messages=include_messages) def get_or_create_session_thread( self, *, session_key: str, title: str, model: str, provider: str = DEFAULT_PROVIDER, source: str = "telegram", session_label: str | None = None, ) -> dict[str, Any]: with self._state_lock: thread = self._get_session_thread_record(session_key) changed = False if thread is None: thread = self._build_thread( title=title, model=model, provider=provider, source=source, session_label=session_label, ) self._threads[thread["id"]] = thread self._session_index[session_key] = thread["id"] changed = True elif session_label and thread.get("session_label") != session_label: thread["session_label"] = session_label changed = True if changed: self._persist_locked() return self._serialize_thread(thread) def start_new_session_thread( self, *, session_key: str, title: str, model: str, provider: str = DEFAULT_PROVIDER, source: str = "telegram", session_label: str | None = None, ) -> dict[str, Any]: with self._state_lock: thread = self._build_thread( title=title, model=model, provider=provider, source=source, session_label=session_label ) self._threads[thread["id"]] = thread self._session_index[session_key] = thread["id"] self._persist_locked() return self._serialize_thread(thread) def update_thread( self, thread_id: str, *, title: str | None = None, title_source: str | None = None, model: str | None = None, provider: str | None = None, ) -> dict[str, Any]: with self._state_lock: thread = self._require_thread(thread_id) if title is not None: normalized_title = str(title).strip() or DEFAULT_TITLE thread["title"] = normalized_title thread["title_source"] = self._normalize_title_source(title_source) or TITLE_SOURCE_MANUAL if provider is not None: normalized_provider = str(provider).strip().lower() if not normalized_provider: raise ValueError("Provider cannot be empty") thread["provider"] = normalized_provider if model is not None: normalized_model = str(model).strip() if not normalized_model: raise ValueError("Model cannot be empty") thread["model"] = normalized_model thread["updated_at"] = self._timestamp() self._persist_locked() return self._serialize_thread(thread) def delete_thread(self, thread_id: str) -> None: with self._state_lock: if thread_id not in self._threads: raise KeyError(thread_id) del self._threads[thread_id] stale_session_keys = [ session_key for session_key, mapped_thread_id in self._session_index.items() if mapped_thread_id == thread_id ] for session_key in stale_session_keys: self._session_index.pop(session_key, None) self._locks.pop(thread_id, None) self._persist_locked() def add_message( self, thread_id: str, role: str, content: str, *, parts: list[dict[str, Any]] | None = None, ) -> dict[str, Any]: with self._state_lock: thread = self._require_thread(thread_id) now = self._timestamp() normalized_parts = self._normalize_message_parts(parts) message = { "id": uuid.uuid4().hex, "role": role, "content": str(content or ""), "created_at": now, } if normalized_parts: message["parts"] = normalized_parts thread["messages"].append(message) thread["updated_at"] = now title_source = self._message_text(message) if ( role == "user" and thread["title"] == DEFAULT_TITLE and title_source and not title_source.startswith("/") ): thread["title"] = self._derive_title(title_source) thread["title_source"] = TITLE_SOURCE_AUTO self._persist_locked() return self._serialize_message(message) def remove_message(self, thread_id: str, message_id: str) -> None: with self._state_lock: thread = self._require_thread(thread_id) original_count = len(thread["messages"]) thread["messages"] = [message for message in thread["messages"] if message["id"] != message_id] if len(thread["messages"]) == original_count: return if thread["messages"]: thread["updated_at"] = thread["messages"][-1]["created_at"] else: thread["updated_at"] = thread["created_at"] if thread["title"] != DEFAULT_TITLE: thread["title"] = DEFAULT_TITLE thread["title_source"] = TITLE_SOURCE_DEFAULT self._persist_locked() def build_history(self, thread_id: str, *, limit: int | None = None) -> list[dict[str, str]]: thread = self._require_thread(thread_id) messages = thread["messages"][-limit:] if limit is not None else thread["messages"] return [ { "role": str(message.get("role") or "assistant"), "content": self._message_text(message) if str(message.get("role") or "").strip().lower() == "user" else str(message.get("content") or ""), } for message in messages ] def build_agent_history(self, thread_id: str, *, limit: int | None = None) -> list[dict[str, Any]]: thread = self._require_thread(thread_id) messages = thread["messages"][-limit:] if limit is not None else thread["messages"] history: list[dict[str, Any]] = [] for message in messages: role = str(message.get("role") or "assistant").strip().lower() or "assistant" parts = self._normalize_message_parts(message.get("parts")) if role == "user" and parts: content_parts: list[dict[str, Any]] = [] for part in parts: if part["type"] == "input_text": content_parts.append({"type": "input_text", "text": part["text"]}) continue try: content_parts.append( { "type": "input_image", "image_url": self._media_store.build_data_url(part["file_id"]), "detail": part.get("detail") or "auto", } ) except KeyError: missing_name = str(part.get("name") or "image") content_parts.append( { "type": "input_text", "text": f"A previously attached image named {missing_name} is no longer available.", } ) if content_parts: history.append({"role": role, "content": content_parts}) continue content = self._message_text(message) if role == "user" else str(message.get("content") or "") if content: history.append({"role": role, "content": content}) return history def save_upload(self, *, name: str, mime_type: str, data: bytes, file_id: str | None = None) -> dict[str, Any]: metadata = self._media_store.save_bytes(data, name=name, mime_type=mime_type, file_id=file_id) return self._serialize_upload(metadata) def get_upload(self, file_id: str) -> dict[str, Any]: return self._serialize_upload(self._media_store.get_meta(file_id)) def read_upload(self, file_id: str) -> bytes: return self._media_store.read_bytes(file_id) def delete_upload(self, file_id: str) -> None: self._media_store.delete(file_id) def lock(self, thread_id: str) -> asyncio.Lock: self._require_thread(thread_id) return self._locks[thread_id] def get_topic_system_prompt(self, session_key: str) -> dict[str, Any] | None: with self._state_lock: record = self._topic_system_prompts.get(str(session_key)) if not isinstance(record, dict): return None prompt = str(record.get("prompt") or "").strip() if not prompt: return None return { "session_key": str(session_key), "prompt": prompt, "updated_at": str(record.get("updated_at") or ""), } def set_topic_system_prompt(self, session_key: str, prompt: str) -> dict[str, Any]: normalized_session_key = str(session_key).strip() if not normalized_session_key: raise ValueError("Session key cannot be empty") normalized_prompt = str(prompt or "").strip() if not normalized_prompt: raise ValueError("System prompt cannot be empty") with self._state_lock: record = { "prompt": normalized_prompt, "updated_at": self._timestamp(), } self._topic_system_prompts[normalized_session_key] = record self._persist_locked() return { "session_key": normalized_session_key, "prompt": record["prompt"], "updated_at": record["updated_at"], } def clear_topic_system_prompt(self, session_key: str) -> bool: normalized_session_key = str(session_key).strip() if not normalized_session_key: return False with self._state_lock: removed = self._topic_system_prompts.pop(normalized_session_key, None) if removed is None: return False self._persist_locked() return True def add_usage(self, thread_id: str, usage: dict[str, Any] | None) -> dict[str, Any]: normalized_usage = self._normalize_usage(usage) with self._state_lock: thread = self._require_thread(thread_id) thread["usage"] = self._merge_usage(thread.get("usage"), normalized_usage) thread["updated_at"] = self._timestamp() self._persist_locked() return self._serialize_usage(thread.get("usage")) def get_total_usage(self, thread_id: str) -> dict[str, Any]: """Return the cumulative usage for a thread.""" with self._state_lock: thread = self._require_thread(thread_id) return self._serialize_usage(thread.get("usage")) # ── Project learnings (per-thread) ──────────────────────────────── MAX_PROJECT_LEARNINGS = 30 def get_project_learnings(self, thread_id: str) -> list[dict[str, Any]]: """Return project learnings for a thread.""" with self._state_lock: thread = self._require_thread(thread_id) return list(thread.get("project_learnings") or []) def add_project_learning(self, thread_id: str, fact: str, *, category: str = "general") -> None: """Add a project-scoped learning to a thread.""" fact = str(fact or "").strip() if not fact: return with self._state_lock: thread = self._require_thread(thread_id) learnings: list[dict[str, Any]] = list(thread.get("project_learnings") or []) # Skip near-duplicates (same fact text ignoring case/whitespace) normalized = fact.lower().strip() for existing in learnings: if existing.get("fact", "").lower().strip() == normalized: return learnings.append({"fact": fact, "category": str(category or "general").strip(), "updated_at": self._timestamp()}) # Cap size — keep most recent if len(learnings) > self.MAX_PROJECT_LEARNINGS: learnings = learnings[-self.MAX_PROJECT_LEARNINGS:] thread["project_learnings"] = learnings thread["updated_at"] = self._timestamp() self._persist_locked() # ── Global learnings (user preferences, cross-project) ──────────── MAX_GLOBAL_LEARNINGS = 40 def get_global_learnings(self) -> list[dict[str, Any]]: """Return all global (cross-project) learnings.""" with self._state_lock: return list(self._global_learnings) def add_global_learning(self, fact: str, *, category: str = "general") -> None: """Add a global learning (user preference, personality, contacts).""" fact = str(fact or "").strip() if not fact: return with self._state_lock: normalized = fact.lower().strip() for existing in self._global_learnings: if existing.get("fact", "").lower().strip() == normalized: return self._global_learnings.append( {"fact": fact, "category": str(category or "general").strip(), "updated_at": self._timestamp()} ) if len(self._global_learnings) > self.MAX_GLOBAL_LEARNINGS: self._global_learnings = self._global_learnings[-self.MAX_GLOBAL_LEARNINGS:] self._persist_locked() def _build_thread( self, *, title: str, model: str, provider: str, source: str, session_label: str | None, ) -> dict[str, Any]: now = self._timestamp() normalized_session_label = str(session_label).strip() if session_label else "" thread = { "id": uuid.uuid4().hex, "title": str(title).strip() or DEFAULT_TITLE, "title_source": TITLE_SOURCE_DEFAULT if (str(title).strip() or DEFAULT_TITLE) == DEFAULT_TITLE else TITLE_SOURCE_MANUAL, "provider": str(provider).strip().lower() or DEFAULT_PROVIDER, "model": str(model).strip() or DEFAULT_MODEL, "source": str(source).strip().lower() or DEFAULT_SOURCE, "created_at": now, "updated_at": now, "messages": [], "usage": self._serialize_usage(None), } if normalized_session_label: thread["session_label"] = normalized_session_label return thread def _normalize_usage(self, usage: dict[str, Any] | None) -> dict[str, Any]: if not isinstance(usage, dict): return {} token_fields = ( "prompt_tokens", "completion_tokens", "total_tokens", "reasoning_tokens", "cached_tokens", "web_search_requests", "request_count", ) money_fields = ("cost_usd",) text_fields = ("pricing_source",) normalized: dict[str, Any] = {} for field in token_fields: value = usage.get(field) if value in (None, ""): continue try: normalized[field] = int(value) except (TypeError, ValueError): continue for field in money_fields: value = usage.get(field) if value in (None, ""): continue try: normalized[field] = format(Decimal(str(value)), "f") except Exception: continue for field in text_fields: value = usage.get(field) if value not in (None, ""): normalized[field] = str(value) breakdown = usage.get("cost_breakdown") if isinstance(breakdown, dict): clean_breakdown: dict[str, str] = {} for key, value in breakdown.items(): if value in (None, ""): continue try: clean_breakdown[str(key)] = format(Decimal(str(value)), "f") except Exception: continue if clean_breakdown: normalized["cost_breakdown"] = clean_breakdown return normalized def _merge_usage(self, existing: Any, incoming: dict[str, Any]) -> dict[str, Any]: current = self._normalize_usage(existing if isinstance(existing, dict) else {}) if not incoming: return self._serialize_usage(current) for field in ( "prompt_tokens", "completion_tokens", "total_tokens", "reasoning_tokens", "cached_tokens", "web_search_requests", "request_count", ): current[field] = int(current.get(field, 0) or 0) + int(incoming.get(field, 0) or 0) current_cost = Decimal(str(current.get("cost_usd", "0") or "0")) incoming_cost = Decimal(str(incoming.get("cost_usd", "0") or "0")) current["cost_usd"] = format(current_cost + incoming_cost, "f") existing_breakdown = current.get("cost_breakdown") if isinstance(current.get("cost_breakdown"), dict) else {} incoming_breakdown = incoming.get("cost_breakdown") if isinstance(incoming.get("cost_breakdown"), dict) else {} merged_breakdown: dict[str, str] = {} for key in set(existing_breakdown) | set(incoming_breakdown): total = Decimal(str(existing_breakdown.get(key, "0") or "0")) + Decimal( str(incoming_breakdown.get(key, "0") or "0") ) if total: merged_breakdown[key] = format(total, "f") if merged_breakdown: current["cost_breakdown"] = merged_breakdown else: current.pop("cost_breakdown", None) pricing_source = incoming.get("pricing_source") or current.get("pricing_source") if pricing_source: current["pricing_source"] = str(pricing_source) return self._serialize_usage(current) def _serialize_usage(self, usage: Any) -> dict[str, Any]: normalized = self._normalize_usage(usage if isinstance(usage, dict) else {}) normalized.setdefault("prompt_tokens", 0) normalized.setdefault("completion_tokens", 0) normalized.setdefault("total_tokens", normalized["prompt_tokens"] + normalized["completion_tokens"]) normalized.setdefault("reasoning_tokens", 0) normalized.setdefault("cached_tokens", 0) normalized.setdefault("web_search_requests", 0) normalized.setdefault("request_count", 0) normalized.setdefault("cost_usd", "0") normalized.setdefault("cost_breakdown", {}) return normalized def _serialize_thread(self, thread: dict[str, Any], *, include_messages: bool = True) -> dict[str, Any]: data = { "id": thread["id"], "title": thread["title"], "title_source": self._normalize_title_source(thread.get("title_source")) or TITLE_SOURCE_DEFAULT, "provider": thread.get("provider") or DEFAULT_PROVIDER, "model": thread["model"], "source": thread.get("source") or DEFAULT_SOURCE, "created_at": thread["created_at"], "updated_at": thread["updated_at"], "message_count": len(thread["messages"]), "usage": self._serialize_usage(thread.get("usage")), } if thread.get("session_label"): data["session_label"] = thread["session_label"] if thread.get("project_learnings"): data["project_learnings"] = list(thread["project_learnings"]) if include_messages: data["messages"] = [self._serialize_message(message) for message in thread["messages"]] return data def _require_thread(self, thread_id: str) -> dict[str, Any]: if thread_id not in self._threads: raise KeyError(thread_id) return self._threads[thread_id] def _get_session_thread_record(self, session_key: str) -> dict[str, Any] | None: thread_id = self._session_index.get(session_key) if not thread_id: return None thread = self._threads.get(thread_id) if thread is None: self._session_index.pop(session_key, None) self._persist_locked() return None return thread def is_tg_update_seen(self, update_id: int) -> bool: """Return True if this Telegram update_id has already been processed.""" return update_id <= self._last_tg_update_id def mark_tg_update(self, update_id: int) -> None: """Record a Telegram update_id as processed.""" with self._state_lock: if update_id > self._last_tg_update_id: self._last_tg_update_id = update_id self._persist_locked() def _persist_locked(self) -> None: self._file_path.parent.mkdir(parents=True, exist_ok=True) payload = { "threads": self._threads, "session_index": self._session_index, "last_tg_update_id": self._last_tg_update_id, "topic_system_prompts": self._topic_system_prompts, "global_learnings": self._global_learnings, } self._file_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") def _load(self) -> None: if not self._file_path.exists(): return try: payload = json.loads(self._file_path.read_text(encoding="utf-8")) except (json.JSONDecodeError, OSError): return self._threads = payload.get("threads") or {} self._session_index = payload.get("session_index") or {} self._last_tg_update_id = int(payload.get("last_tg_update_id") or 0) raw_prompts = payload.get("topic_system_prompts") or {} self._topic_system_prompts = raw_prompts if isinstance(raw_prompts, dict) else {} raw_learnings = payload.get("global_learnings") or [] self._global_learnings = raw_learnings if isinstance(raw_learnings, list) else [] def _normalize_threads(self) -> None: changed = False cleaned_topic_prompts: dict[str, dict[str, Any]] = {} for session_key, raw_record in self._topic_system_prompts.items(): if not isinstance(raw_record, dict): changed = True continue normalized_session_key = str(session_key).strip() normalized_prompt = str(raw_record.get("prompt") or "").strip() if not normalized_session_key or not normalized_prompt: changed = True continue normalized_record = { "prompt": normalized_prompt, "updated_at": str(raw_record.get("updated_at") or self._timestamp()), } cleaned_topic_prompts[normalized_session_key] = normalized_record if normalized_session_key != session_key or raw_record != normalized_record: changed = True if cleaned_topic_prompts != self._topic_system_prompts: self._topic_system_prompts = cleaned_topic_prompts changed = True for thread in self._threads.values(): if not isinstance(thread.get("messages"), list): thread["messages"] = [] changed = True thread["title"] = str(thread.get("title") or DEFAULT_TITLE).strip() or DEFAULT_TITLE normalized_title_source = self._normalize_title_source(thread.get("title_source")) expected_title_source = normalized_title_source if expected_title_source is None: expected_title_source = ( TITLE_SOURCE_DEFAULT if thread["title"] == DEFAULT_TITLE else TITLE_SOURCE_MANUAL ) if thread.get("title_source") != expected_title_source: thread["title_source"] = expected_title_source changed = True provider = str(thread.get("provider") or DEFAULT_PROVIDER).strip().lower() or DEFAULT_PROVIDER if thread.get("provider") != provider: thread["provider"] = provider changed = True source = str(thread.get("source") or DEFAULT_SOURCE).strip().lower() or DEFAULT_SOURCE if thread.get("source") != source: thread["source"] = source changed = True usage = self._serialize_usage(thread.get("usage")) if thread.get("usage") != usage: thread["usage"] = usage changed = True if changed: self._persist_locked() def _serialize_message(self, message: dict[str, Any]) -> dict[str, Any]: serialized = { "id": message["id"], "role": message["role"], "content": message["content"], "created_at": message["created_at"], } normalized_parts = self._normalize_message_parts(message.get("parts")) if normalized_parts: serialized["parts"] = normalized_parts return serialized def _serialize_upload(self, metadata: dict[str, Any]) -> dict[str, Any]: return { "id": str(metadata.get("id") or ""), "name": str(metadata.get("name") or "upload"), "mime_type": str(metadata.get("mime_type") or "application/octet-stream"), "size": int(metadata.get("size") or 0), "preview_url": f"/uploads/{metadata.get('id')}" if str(metadata.get("mime_type") or "").startswith("image/") else None, } def _normalize_message_parts(self, parts: Any) -> list[dict[str, Any]]: if not isinstance(parts, list): return [] normalized_parts: list[dict[str, Any]] = [] for part in parts: if not isinstance(part, dict): continue part_type = str(part.get("type") or "").strip() if part_type == "input_text": text = str(part.get("text") or "") if text: normalized_parts.append({"type": "input_text", "text": text}) continue if part_type == "input_image": file_id = str(part.get("file_id") or "").strip() if not file_id: continue normalized_part = {"type": "input_image", "file_id": file_id} if part.get("name"): normalized_part["name"] = str(part.get("name")) if part.get("mime_type"): normalized_part["mime_type"] = str(part.get("mime_type")) if part.get("detail"): normalized_part["detail"] = str(part.get("detail")) normalized_parts.append(normalized_part) return normalized_parts def _message_text(self, message: dict[str, Any]) -> str: parts = self._normalize_message_parts(message.get("parts")) if parts: texts = [str(part.get("text") or "") for part in parts if part.get("type") == "input_text"] text = "\n".join(part for part in texts if part).strip() if text: return text return str(message.get("content") or "").strip() def _derive_title(self, content: str) -> str: single_line = " ".join(str(content).split()) if len(single_line) <= 48: return single_line return single_line[:45].rstrip() + "..." def _normalize_title_source(self, raw_value: Any) -> str | None: value = str(raw_value or "").strip().lower() if value in KNOWN_TITLE_SOURCES: return value return None def _timestamp(self) -> str: from datetime import datetime, timezone return datetime.now(timezone.utc).isoformat()