Replace standalone Telegram bot with full CodeAnywhere framework fork. BetterBot shares all framework code and customizes only: - instance.py: BetterBot identity, system prompt, feature flags - tools/site_editing/: list_files, read_file, write_file with auto git push - .env: model defaults and site directory paths - compose/: Docker setup with betterlifesg + memoraiz mounts - deploy script: RackNerd with Infisical secrets
767 lines
31 KiB
Python
767 lines
31 KiB
Python
"""Persistent thread store shared by the web UI and Telegram bot sessions."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import uuid
|
|
from collections import defaultdict
|
|
from decimal import Decimal
|
|
from pathlib import Path
|
|
from threading import RLock
|
|
from typing import Any
|
|
|
|
from local_media_store import LocalMediaStore
|
|
|
|
DEFAULT_TITLE = "New chat"
|
|
DEFAULT_PROVIDER = "openai"
|
|
DEFAULT_MODEL = "gpt-5.4"
|
|
DEFAULT_SOURCE = "web"
|
|
TITLE_SOURCE_DEFAULT = "default"
|
|
TITLE_SOURCE_AUTO = "auto"
|
|
TITLE_SOURCE_MANUAL = "manual"
|
|
TITLE_SOURCE_MAGIC = "magic"
|
|
KNOWN_TITLE_SOURCES = {
|
|
TITLE_SOURCE_DEFAULT,
|
|
TITLE_SOURCE_AUTO,
|
|
TITLE_SOURCE_MANUAL,
|
|
TITLE_SOURCE_MAGIC,
|
|
}
|
|
|
|
|
|
class WebFallbackStore:
|
|
def __init__(self, data_dir: str = "/data", media_store: LocalMediaStore | None = None):
|
|
self._threads: dict[str, dict[str, Any]] = {}
|
|
self._session_index: dict[str, str] = {}
|
|
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
|
self._state_lock = RLock()
|
|
self._file_path = Path(data_dir) / "web_threads.json"
|
|
self._media_store = media_store or LocalMediaStore(data_dir)
|
|
self._last_tg_update_id: int = 0
|
|
self._topic_system_prompts: dict[str, dict[str, Any]] = {}
|
|
self._global_learnings: list[dict[str, Any]] = []
|
|
self._load()
|
|
self._normalize_threads()
|
|
|
|
def list_threads(self) -> list[dict[str, Any]]:
|
|
threads = sorted(self._threads.values(), key=lambda thread: thread["updated_at"], reverse=True)
|
|
return [self._serialize_thread(thread, include_messages=False) for thread in threads]
|
|
|
|
def search_threads(self, query: str) -> list[dict[str, Any]]:
|
|
"""Search threads by keyword across title and message content."""
|
|
needle = query.strip().lower()
|
|
if not needle:
|
|
return self.list_threads()
|
|
matches: list[dict[str, Any]] = []
|
|
for thread in self._threads.values():
|
|
if needle in thread.get("title", "").lower():
|
|
matches.append(thread)
|
|
continue
|
|
for msg in thread.get("messages", []):
|
|
if needle in str(msg.get("content", "")).lower():
|
|
matches.append(thread)
|
|
break
|
|
matches.sort(key=lambda t: t["updated_at"], reverse=True)
|
|
return [self._serialize_thread(t, include_messages=False) for t in matches]
|
|
|
|
def create_thread(
|
|
self,
|
|
title: str,
|
|
model: str,
|
|
provider: str = DEFAULT_PROVIDER,
|
|
*,
|
|
source: str = DEFAULT_SOURCE,
|
|
session_label: str | None = None,
|
|
) -> dict[str, Any]:
|
|
with self._state_lock:
|
|
thread = self._build_thread(
|
|
title=title, model=model, provider=provider, source=source, session_label=session_label
|
|
)
|
|
self._threads[thread["id"]] = thread
|
|
self._persist_locked()
|
|
return self._serialize_thread(thread)
|
|
|
|
def get_thread(self, thread_id: str, *, include_messages: bool = True) -> dict[str, Any]:
|
|
thread = self._require_thread(thread_id)
|
|
return self._serialize_thread(thread, include_messages=include_messages)
|
|
|
|
def get_session_thread(self, session_key: str, *, include_messages: bool = True) -> dict[str, Any] | None:
|
|
with self._state_lock:
|
|
thread = self._get_session_thread_record(session_key)
|
|
if thread is None:
|
|
return None
|
|
return self._serialize_thread(thread, include_messages=include_messages)
|
|
|
|
def get_or_create_session_thread(
|
|
self,
|
|
*,
|
|
session_key: str,
|
|
title: str,
|
|
model: str,
|
|
provider: str = DEFAULT_PROVIDER,
|
|
source: str = "telegram",
|
|
session_label: str | None = None,
|
|
) -> dict[str, Any]:
|
|
with self._state_lock:
|
|
thread = self._get_session_thread_record(session_key)
|
|
changed = False
|
|
if thread is None:
|
|
thread = self._build_thread(
|
|
title=title,
|
|
model=model,
|
|
provider=provider,
|
|
source=source,
|
|
session_label=session_label,
|
|
)
|
|
self._threads[thread["id"]] = thread
|
|
self._session_index[session_key] = thread["id"]
|
|
changed = True
|
|
elif session_label and thread.get("session_label") != session_label:
|
|
thread["session_label"] = session_label
|
|
changed = True
|
|
|
|
if changed:
|
|
self._persist_locked()
|
|
return self._serialize_thread(thread)
|
|
|
|
def start_new_session_thread(
|
|
self,
|
|
*,
|
|
session_key: str,
|
|
title: str,
|
|
model: str,
|
|
provider: str = DEFAULT_PROVIDER,
|
|
source: str = "telegram",
|
|
session_label: str | None = None,
|
|
) -> dict[str, Any]:
|
|
with self._state_lock:
|
|
thread = self._build_thread(
|
|
title=title, model=model, provider=provider, source=source, session_label=session_label
|
|
)
|
|
self._threads[thread["id"]] = thread
|
|
self._session_index[session_key] = thread["id"]
|
|
self._persist_locked()
|
|
return self._serialize_thread(thread)
|
|
|
|
def update_thread(
|
|
self,
|
|
thread_id: str,
|
|
*,
|
|
title: str | None = None,
|
|
title_source: str | None = None,
|
|
model: str | None = None,
|
|
provider: str | None = None,
|
|
) -> dict[str, Any]:
|
|
with self._state_lock:
|
|
thread = self._require_thread(thread_id)
|
|
if title is not None:
|
|
normalized_title = str(title).strip() or DEFAULT_TITLE
|
|
thread["title"] = normalized_title
|
|
thread["title_source"] = self._normalize_title_source(title_source) or TITLE_SOURCE_MANUAL
|
|
if provider is not None:
|
|
normalized_provider = str(provider).strip().lower()
|
|
if not normalized_provider:
|
|
raise ValueError("Provider cannot be empty")
|
|
thread["provider"] = normalized_provider
|
|
if model is not None:
|
|
normalized_model = str(model).strip()
|
|
if not normalized_model:
|
|
raise ValueError("Model cannot be empty")
|
|
thread["model"] = normalized_model
|
|
thread["updated_at"] = self._timestamp()
|
|
self._persist_locked()
|
|
return self._serialize_thread(thread)
|
|
|
|
def delete_thread(self, thread_id: str) -> None:
|
|
with self._state_lock:
|
|
if thread_id not in self._threads:
|
|
raise KeyError(thread_id)
|
|
del self._threads[thread_id]
|
|
stale_session_keys = [
|
|
session_key
|
|
for session_key, mapped_thread_id in self._session_index.items()
|
|
if mapped_thread_id == thread_id
|
|
]
|
|
for session_key in stale_session_keys:
|
|
self._session_index.pop(session_key, None)
|
|
self._locks.pop(thread_id, None)
|
|
self._persist_locked()
|
|
|
|
def add_message(
|
|
self,
|
|
thread_id: str,
|
|
role: str,
|
|
content: str,
|
|
*,
|
|
parts: list[dict[str, Any]] | None = None,
|
|
) -> dict[str, Any]:
|
|
with self._state_lock:
|
|
thread = self._require_thread(thread_id)
|
|
now = self._timestamp()
|
|
normalized_parts = self._normalize_message_parts(parts)
|
|
message = {
|
|
"id": uuid.uuid4().hex,
|
|
"role": role,
|
|
"content": str(content or ""),
|
|
"created_at": now,
|
|
}
|
|
if normalized_parts:
|
|
message["parts"] = normalized_parts
|
|
thread["messages"].append(message)
|
|
thread["updated_at"] = now
|
|
title_source = self._message_text(message)
|
|
if (
|
|
role == "user"
|
|
and thread["title"] == DEFAULT_TITLE
|
|
and title_source
|
|
and not title_source.startswith("/")
|
|
):
|
|
thread["title"] = self._derive_title(title_source)
|
|
thread["title_source"] = TITLE_SOURCE_AUTO
|
|
self._persist_locked()
|
|
return self._serialize_message(message)
|
|
|
|
def remove_message(self, thread_id: str, message_id: str) -> None:
|
|
with self._state_lock:
|
|
thread = self._require_thread(thread_id)
|
|
original_count = len(thread["messages"])
|
|
thread["messages"] = [message for message in thread["messages"] if message["id"] != message_id]
|
|
if len(thread["messages"]) == original_count:
|
|
return
|
|
if thread["messages"]:
|
|
thread["updated_at"] = thread["messages"][-1]["created_at"]
|
|
else:
|
|
thread["updated_at"] = thread["created_at"]
|
|
if thread["title"] != DEFAULT_TITLE:
|
|
thread["title"] = DEFAULT_TITLE
|
|
thread["title_source"] = TITLE_SOURCE_DEFAULT
|
|
self._persist_locked()
|
|
|
|
def build_history(self, thread_id: str, *, limit: int | None = None) -> list[dict[str, str]]:
|
|
thread = self._require_thread(thread_id)
|
|
messages = thread["messages"][-limit:] if limit is not None else thread["messages"]
|
|
return [
|
|
{
|
|
"role": str(message.get("role") or "assistant"),
|
|
"content": self._message_text(message)
|
|
if str(message.get("role") or "").strip().lower() == "user"
|
|
else str(message.get("content") or ""),
|
|
}
|
|
for message in messages
|
|
]
|
|
|
|
def build_agent_history(self, thread_id: str, *, limit: int | None = None) -> list[dict[str, Any]]:
|
|
thread = self._require_thread(thread_id)
|
|
messages = thread["messages"][-limit:] if limit is not None else thread["messages"]
|
|
history: list[dict[str, Any]] = []
|
|
|
|
for message in messages:
|
|
role = str(message.get("role") or "assistant").strip().lower() or "assistant"
|
|
parts = self._normalize_message_parts(message.get("parts"))
|
|
|
|
if role == "user" and parts:
|
|
content_parts: list[dict[str, Any]] = []
|
|
for part in parts:
|
|
if part["type"] == "input_text":
|
|
content_parts.append({"type": "input_text", "text": part["text"]})
|
|
continue
|
|
|
|
try:
|
|
content_parts.append(
|
|
{
|
|
"type": "input_image",
|
|
"image_url": self._media_store.build_data_url(part["file_id"]),
|
|
"detail": part.get("detail") or "auto",
|
|
}
|
|
)
|
|
except KeyError:
|
|
missing_name = str(part.get("name") or "image")
|
|
content_parts.append(
|
|
{
|
|
"type": "input_text",
|
|
"text": f"A previously attached image named {missing_name} is no longer available.",
|
|
}
|
|
)
|
|
|
|
if content_parts:
|
|
history.append({"role": role, "content": content_parts})
|
|
continue
|
|
|
|
content = self._message_text(message) if role == "user" else str(message.get("content") or "")
|
|
if content:
|
|
history.append({"role": role, "content": content})
|
|
|
|
return history
|
|
|
|
def save_upload(self, *, name: str, mime_type: str, data: bytes, file_id: str | None = None) -> dict[str, Any]:
|
|
metadata = self._media_store.save_bytes(data, name=name, mime_type=mime_type, file_id=file_id)
|
|
return self._serialize_upload(metadata)
|
|
|
|
def get_upload(self, file_id: str) -> dict[str, Any]:
|
|
return self._serialize_upload(self._media_store.get_meta(file_id))
|
|
|
|
def read_upload(self, file_id: str) -> bytes:
|
|
return self._media_store.read_bytes(file_id)
|
|
|
|
def delete_upload(self, file_id: str) -> None:
|
|
self._media_store.delete(file_id)
|
|
|
|
def lock(self, thread_id: str) -> asyncio.Lock:
|
|
self._require_thread(thread_id)
|
|
return self._locks[thread_id]
|
|
|
|
def get_topic_system_prompt(self, session_key: str) -> dict[str, Any] | None:
|
|
with self._state_lock:
|
|
record = self._topic_system_prompts.get(str(session_key))
|
|
if not isinstance(record, dict):
|
|
return None
|
|
prompt = str(record.get("prompt") or "").strip()
|
|
if not prompt:
|
|
return None
|
|
return {
|
|
"session_key": str(session_key),
|
|
"prompt": prompt,
|
|
"updated_at": str(record.get("updated_at") or ""),
|
|
}
|
|
|
|
def set_topic_system_prompt(self, session_key: str, prompt: str) -> dict[str, Any]:
|
|
normalized_session_key = str(session_key).strip()
|
|
if not normalized_session_key:
|
|
raise ValueError("Session key cannot be empty")
|
|
normalized_prompt = str(prompt or "").strip()
|
|
if not normalized_prompt:
|
|
raise ValueError("System prompt cannot be empty")
|
|
with self._state_lock:
|
|
record = {
|
|
"prompt": normalized_prompt,
|
|
"updated_at": self._timestamp(),
|
|
}
|
|
self._topic_system_prompts[normalized_session_key] = record
|
|
self._persist_locked()
|
|
return {
|
|
"session_key": normalized_session_key,
|
|
"prompt": record["prompt"],
|
|
"updated_at": record["updated_at"],
|
|
}
|
|
|
|
def clear_topic_system_prompt(self, session_key: str) -> bool:
|
|
normalized_session_key = str(session_key).strip()
|
|
if not normalized_session_key:
|
|
return False
|
|
with self._state_lock:
|
|
removed = self._topic_system_prompts.pop(normalized_session_key, None)
|
|
if removed is None:
|
|
return False
|
|
self._persist_locked()
|
|
return True
|
|
|
|
def add_usage(self, thread_id: str, usage: dict[str, Any] | None) -> dict[str, Any]:
|
|
normalized_usage = self._normalize_usage(usage)
|
|
with self._state_lock:
|
|
thread = self._require_thread(thread_id)
|
|
thread["usage"] = self._merge_usage(thread.get("usage"), normalized_usage)
|
|
thread["updated_at"] = self._timestamp()
|
|
self._persist_locked()
|
|
return self._serialize_usage(thread.get("usage"))
|
|
|
|
def get_total_usage(self, thread_id: str) -> dict[str, Any]:
|
|
"""Return the cumulative usage for a thread."""
|
|
with self._state_lock:
|
|
thread = self._require_thread(thread_id)
|
|
return self._serialize_usage(thread.get("usage"))
|
|
|
|
# ── Project learnings (per-thread) ────────────────────────────────
|
|
|
|
MAX_PROJECT_LEARNINGS = 30
|
|
|
|
def get_project_learnings(self, thread_id: str) -> list[dict[str, Any]]:
|
|
"""Return project learnings for a thread."""
|
|
with self._state_lock:
|
|
thread = self._require_thread(thread_id)
|
|
return list(thread.get("project_learnings") or [])
|
|
|
|
def add_project_learning(self, thread_id: str, fact: str, *, category: str = "general") -> None:
|
|
"""Add a project-scoped learning to a thread."""
|
|
fact = str(fact or "").strip()
|
|
if not fact:
|
|
return
|
|
with self._state_lock:
|
|
thread = self._require_thread(thread_id)
|
|
learnings: list[dict[str, Any]] = list(thread.get("project_learnings") or [])
|
|
# Skip near-duplicates (same fact text ignoring case/whitespace)
|
|
normalized = fact.lower().strip()
|
|
for existing in learnings:
|
|
if existing.get("fact", "").lower().strip() == normalized:
|
|
return
|
|
learnings.append({"fact": fact, "category": str(category or "general").strip(), "updated_at": self._timestamp()})
|
|
# Cap size — keep most recent
|
|
if len(learnings) > self.MAX_PROJECT_LEARNINGS:
|
|
learnings = learnings[-self.MAX_PROJECT_LEARNINGS:]
|
|
thread["project_learnings"] = learnings
|
|
thread["updated_at"] = self._timestamp()
|
|
self._persist_locked()
|
|
|
|
# ── Global learnings (user preferences, cross-project) ────────────
|
|
|
|
MAX_GLOBAL_LEARNINGS = 40
|
|
|
|
def get_global_learnings(self) -> list[dict[str, Any]]:
|
|
"""Return all global (cross-project) learnings."""
|
|
with self._state_lock:
|
|
return list(self._global_learnings)
|
|
|
|
def add_global_learning(self, fact: str, *, category: str = "general") -> None:
|
|
"""Add a global learning (user preference, personality, contacts)."""
|
|
fact = str(fact or "").strip()
|
|
if not fact:
|
|
return
|
|
with self._state_lock:
|
|
normalized = fact.lower().strip()
|
|
for existing in self._global_learnings:
|
|
if existing.get("fact", "").lower().strip() == normalized:
|
|
return
|
|
self._global_learnings.append(
|
|
{"fact": fact, "category": str(category or "general").strip(), "updated_at": self._timestamp()}
|
|
)
|
|
if len(self._global_learnings) > self.MAX_GLOBAL_LEARNINGS:
|
|
self._global_learnings = self._global_learnings[-self.MAX_GLOBAL_LEARNINGS:]
|
|
self._persist_locked()
|
|
|
|
def _build_thread(
|
|
self,
|
|
*,
|
|
title: str,
|
|
model: str,
|
|
provider: str,
|
|
source: str,
|
|
session_label: str | None,
|
|
) -> dict[str, Any]:
|
|
now = self._timestamp()
|
|
normalized_session_label = str(session_label).strip() if session_label else ""
|
|
thread = {
|
|
"id": uuid.uuid4().hex,
|
|
"title": str(title).strip() or DEFAULT_TITLE,
|
|
"title_source": TITLE_SOURCE_DEFAULT
|
|
if (str(title).strip() or DEFAULT_TITLE) == DEFAULT_TITLE
|
|
else TITLE_SOURCE_MANUAL,
|
|
"provider": str(provider).strip().lower() or DEFAULT_PROVIDER,
|
|
"model": str(model).strip() or DEFAULT_MODEL,
|
|
"source": str(source).strip().lower() or DEFAULT_SOURCE,
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
"messages": [],
|
|
"usage": self._serialize_usage(None),
|
|
}
|
|
if normalized_session_label:
|
|
thread["session_label"] = normalized_session_label
|
|
return thread
|
|
|
|
def _normalize_usage(self, usage: dict[str, Any] | None) -> dict[str, Any]:
|
|
if not isinstance(usage, dict):
|
|
return {}
|
|
|
|
token_fields = (
|
|
"prompt_tokens",
|
|
"completion_tokens",
|
|
"total_tokens",
|
|
"reasoning_tokens",
|
|
"cached_tokens",
|
|
"web_search_requests",
|
|
"request_count",
|
|
)
|
|
money_fields = ("cost_usd",)
|
|
text_fields = ("pricing_source",)
|
|
|
|
normalized: dict[str, Any] = {}
|
|
for field in token_fields:
|
|
value = usage.get(field)
|
|
if value in (None, ""):
|
|
continue
|
|
try:
|
|
normalized[field] = int(value)
|
|
except (TypeError, ValueError):
|
|
continue
|
|
|
|
for field in money_fields:
|
|
value = usage.get(field)
|
|
if value in (None, ""):
|
|
continue
|
|
try:
|
|
normalized[field] = format(Decimal(str(value)), "f")
|
|
except Exception:
|
|
continue
|
|
|
|
for field in text_fields:
|
|
value = usage.get(field)
|
|
if value not in (None, ""):
|
|
normalized[field] = str(value)
|
|
|
|
breakdown = usage.get("cost_breakdown")
|
|
if isinstance(breakdown, dict):
|
|
clean_breakdown: dict[str, str] = {}
|
|
for key, value in breakdown.items():
|
|
if value in (None, ""):
|
|
continue
|
|
try:
|
|
clean_breakdown[str(key)] = format(Decimal(str(value)), "f")
|
|
except Exception:
|
|
continue
|
|
if clean_breakdown:
|
|
normalized["cost_breakdown"] = clean_breakdown
|
|
|
|
return normalized
|
|
|
|
def _merge_usage(self, existing: Any, incoming: dict[str, Any]) -> dict[str, Any]:
|
|
current = self._normalize_usage(existing if isinstance(existing, dict) else {})
|
|
if not incoming:
|
|
return self._serialize_usage(current)
|
|
|
|
for field in (
|
|
"prompt_tokens",
|
|
"completion_tokens",
|
|
"total_tokens",
|
|
"reasoning_tokens",
|
|
"cached_tokens",
|
|
"web_search_requests",
|
|
"request_count",
|
|
):
|
|
current[field] = int(current.get(field, 0) or 0) + int(incoming.get(field, 0) or 0)
|
|
|
|
current_cost = Decimal(str(current.get("cost_usd", "0") or "0"))
|
|
incoming_cost = Decimal(str(incoming.get("cost_usd", "0") or "0"))
|
|
current["cost_usd"] = format(current_cost + incoming_cost, "f")
|
|
|
|
existing_breakdown = current.get("cost_breakdown") if isinstance(current.get("cost_breakdown"), dict) else {}
|
|
incoming_breakdown = incoming.get("cost_breakdown") if isinstance(incoming.get("cost_breakdown"), dict) else {}
|
|
merged_breakdown: dict[str, str] = {}
|
|
for key in set(existing_breakdown) | set(incoming_breakdown):
|
|
total = Decimal(str(existing_breakdown.get(key, "0") or "0")) + Decimal(
|
|
str(incoming_breakdown.get(key, "0") or "0")
|
|
)
|
|
if total:
|
|
merged_breakdown[key] = format(total, "f")
|
|
if merged_breakdown:
|
|
current["cost_breakdown"] = merged_breakdown
|
|
else:
|
|
current.pop("cost_breakdown", None)
|
|
|
|
pricing_source = incoming.get("pricing_source") or current.get("pricing_source")
|
|
if pricing_source:
|
|
current["pricing_source"] = str(pricing_source)
|
|
|
|
return self._serialize_usage(current)
|
|
|
|
def _serialize_usage(self, usage: Any) -> dict[str, Any]:
|
|
normalized = self._normalize_usage(usage if isinstance(usage, dict) else {})
|
|
normalized.setdefault("prompt_tokens", 0)
|
|
normalized.setdefault("completion_tokens", 0)
|
|
normalized.setdefault("total_tokens", normalized["prompt_tokens"] + normalized["completion_tokens"])
|
|
normalized.setdefault("reasoning_tokens", 0)
|
|
normalized.setdefault("cached_tokens", 0)
|
|
normalized.setdefault("web_search_requests", 0)
|
|
normalized.setdefault("request_count", 0)
|
|
normalized.setdefault("cost_usd", "0")
|
|
normalized.setdefault("cost_breakdown", {})
|
|
return normalized
|
|
|
|
def _serialize_thread(self, thread: dict[str, Any], *, include_messages: bool = True) -> dict[str, Any]:
|
|
data = {
|
|
"id": thread["id"],
|
|
"title": thread["title"],
|
|
"title_source": self._normalize_title_source(thread.get("title_source")) or TITLE_SOURCE_DEFAULT,
|
|
"provider": thread.get("provider") or DEFAULT_PROVIDER,
|
|
"model": thread["model"],
|
|
"source": thread.get("source") or DEFAULT_SOURCE,
|
|
"created_at": thread["created_at"],
|
|
"updated_at": thread["updated_at"],
|
|
"message_count": len(thread["messages"]),
|
|
"usage": self._serialize_usage(thread.get("usage")),
|
|
}
|
|
if thread.get("session_label"):
|
|
data["session_label"] = thread["session_label"]
|
|
if thread.get("project_learnings"):
|
|
data["project_learnings"] = list(thread["project_learnings"])
|
|
if include_messages:
|
|
data["messages"] = [self._serialize_message(message) for message in thread["messages"]]
|
|
return data
|
|
|
|
def _require_thread(self, thread_id: str) -> dict[str, Any]:
|
|
if thread_id not in self._threads:
|
|
raise KeyError(thread_id)
|
|
return self._threads[thread_id]
|
|
|
|
def _get_session_thread_record(self, session_key: str) -> dict[str, Any] | None:
|
|
thread_id = self._session_index.get(session_key)
|
|
if not thread_id:
|
|
return None
|
|
thread = self._threads.get(thread_id)
|
|
if thread is None:
|
|
self._session_index.pop(session_key, None)
|
|
self._persist_locked()
|
|
return None
|
|
return thread
|
|
|
|
def is_tg_update_seen(self, update_id: int) -> bool:
|
|
"""Return True if this Telegram update_id has already been processed."""
|
|
return update_id <= self._last_tg_update_id
|
|
|
|
def mark_tg_update(self, update_id: int) -> None:
|
|
"""Record a Telegram update_id as processed."""
|
|
with self._state_lock:
|
|
if update_id > self._last_tg_update_id:
|
|
self._last_tg_update_id = update_id
|
|
self._persist_locked()
|
|
|
|
def _persist_locked(self) -> None:
|
|
self._file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
payload = {
|
|
"threads": self._threads,
|
|
"session_index": self._session_index,
|
|
"last_tg_update_id": self._last_tg_update_id,
|
|
"topic_system_prompts": self._topic_system_prompts,
|
|
"global_learnings": self._global_learnings,
|
|
}
|
|
self._file_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
|
|
|
def _load(self) -> None:
|
|
if not self._file_path.exists():
|
|
return
|
|
try:
|
|
payload = json.loads(self._file_path.read_text(encoding="utf-8"))
|
|
except (json.JSONDecodeError, OSError):
|
|
return
|
|
self._threads = payload.get("threads") or {}
|
|
self._session_index = payload.get("session_index") or {}
|
|
self._last_tg_update_id = int(payload.get("last_tg_update_id") or 0)
|
|
raw_prompts = payload.get("topic_system_prompts") or {}
|
|
self._topic_system_prompts = raw_prompts if isinstance(raw_prompts, dict) else {}
|
|
raw_learnings = payload.get("global_learnings") or []
|
|
self._global_learnings = raw_learnings if isinstance(raw_learnings, list) else []
|
|
|
|
def _normalize_threads(self) -> None:
|
|
changed = False
|
|
cleaned_topic_prompts: dict[str, dict[str, Any]] = {}
|
|
for session_key, raw_record in self._topic_system_prompts.items():
|
|
if not isinstance(raw_record, dict):
|
|
changed = True
|
|
continue
|
|
normalized_session_key = str(session_key).strip()
|
|
normalized_prompt = str(raw_record.get("prompt") or "").strip()
|
|
if not normalized_session_key or not normalized_prompt:
|
|
changed = True
|
|
continue
|
|
normalized_record = {
|
|
"prompt": normalized_prompt,
|
|
"updated_at": str(raw_record.get("updated_at") or self._timestamp()),
|
|
}
|
|
cleaned_topic_prompts[normalized_session_key] = normalized_record
|
|
if normalized_session_key != session_key or raw_record != normalized_record:
|
|
changed = True
|
|
if cleaned_topic_prompts != self._topic_system_prompts:
|
|
self._topic_system_prompts = cleaned_topic_prompts
|
|
changed = True
|
|
for thread in self._threads.values():
|
|
if not isinstance(thread.get("messages"), list):
|
|
thread["messages"] = []
|
|
changed = True
|
|
thread["title"] = str(thread.get("title") or DEFAULT_TITLE).strip() or DEFAULT_TITLE
|
|
normalized_title_source = self._normalize_title_source(thread.get("title_source"))
|
|
expected_title_source = normalized_title_source
|
|
if expected_title_source is None:
|
|
expected_title_source = (
|
|
TITLE_SOURCE_DEFAULT if thread["title"] == DEFAULT_TITLE else TITLE_SOURCE_MANUAL
|
|
)
|
|
if thread.get("title_source") != expected_title_source:
|
|
thread["title_source"] = expected_title_source
|
|
changed = True
|
|
provider = str(thread.get("provider") or DEFAULT_PROVIDER).strip().lower() or DEFAULT_PROVIDER
|
|
if thread.get("provider") != provider:
|
|
thread["provider"] = provider
|
|
changed = True
|
|
source = str(thread.get("source") or DEFAULT_SOURCE).strip().lower() or DEFAULT_SOURCE
|
|
if thread.get("source") != source:
|
|
thread["source"] = source
|
|
changed = True
|
|
usage = self._serialize_usage(thread.get("usage"))
|
|
if thread.get("usage") != usage:
|
|
thread["usage"] = usage
|
|
changed = True
|
|
if changed:
|
|
self._persist_locked()
|
|
|
|
def _serialize_message(self, message: dict[str, Any]) -> dict[str, Any]:
|
|
serialized = {
|
|
"id": message["id"],
|
|
"role": message["role"],
|
|
"content": message["content"],
|
|
"created_at": message["created_at"],
|
|
}
|
|
normalized_parts = self._normalize_message_parts(message.get("parts"))
|
|
if normalized_parts:
|
|
serialized["parts"] = normalized_parts
|
|
return serialized
|
|
|
|
def _serialize_upload(self, metadata: dict[str, Any]) -> dict[str, Any]:
|
|
return {
|
|
"id": str(metadata.get("id") or ""),
|
|
"name": str(metadata.get("name") or "upload"),
|
|
"mime_type": str(metadata.get("mime_type") or "application/octet-stream"),
|
|
"size": int(metadata.get("size") or 0),
|
|
"preview_url": f"/uploads/{metadata.get('id')}"
|
|
if str(metadata.get("mime_type") or "").startswith("image/")
|
|
else None,
|
|
}
|
|
|
|
def _normalize_message_parts(self, parts: Any) -> list[dict[str, Any]]:
|
|
if not isinstance(parts, list):
|
|
return []
|
|
|
|
normalized_parts: list[dict[str, Any]] = []
|
|
for part in parts:
|
|
if not isinstance(part, dict):
|
|
continue
|
|
part_type = str(part.get("type") or "").strip()
|
|
if part_type == "input_text":
|
|
text = str(part.get("text") or "")
|
|
if text:
|
|
normalized_parts.append({"type": "input_text", "text": text})
|
|
continue
|
|
if part_type == "input_image":
|
|
file_id = str(part.get("file_id") or "").strip()
|
|
if not file_id:
|
|
continue
|
|
normalized_part = {"type": "input_image", "file_id": file_id}
|
|
if part.get("name"):
|
|
normalized_part["name"] = str(part.get("name"))
|
|
if part.get("mime_type"):
|
|
normalized_part["mime_type"] = str(part.get("mime_type"))
|
|
if part.get("detail"):
|
|
normalized_part["detail"] = str(part.get("detail"))
|
|
normalized_parts.append(normalized_part)
|
|
return normalized_parts
|
|
|
|
def _message_text(self, message: dict[str, Any]) -> str:
|
|
parts = self._normalize_message_parts(message.get("parts"))
|
|
if parts:
|
|
texts = [str(part.get("text") or "") for part in parts if part.get("type") == "input_text"]
|
|
text = "\n".join(part for part in texts if part).strip()
|
|
if text:
|
|
return text
|
|
return str(message.get("content") or "").strip()
|
|
|
|
def _derive_title(self, content: str) -> str:
|
|
single_line = " ".join(str(content).split())
|
|
if len(single_line) <= 48:
|
|
return single_line
|
|
return single_line[:45].rstrip() + "..."
|
|
|
|
def _normalize_title_source(self, raw_value: Any) -> str | None:
|
|
value = str(raw_value or "").strip().lower()
|
|
if value in KNOWN_TITLE_SOURCES:
|
|
return value
|
|
return None
|
|
|
|
def _timestamp(self) -> str:
|
|
from datetime import datetime, timezone
|
|
|
|
return datetime.now(timezone.utc).isoformat()
|