"""Copilot SDK runtime — manages CopilotClient lifecycle and session creation.""" import asyncio import logging import random from typing import Any, AsyncIterator from copilot import CopilotClient, SubprocessConfig from copilot.generated.session_events import SessionEventType from copilot.session import CopilotSession, PermissionHandler, SessionEvent from copilot.tools import Tool from config import settings from instance import skill_directories as _default_skill_directories logger = logging.getLogger(__name__) # ── Retry configuration ───────────────────────────────────────────── MAX_SESSION_RETRIES = 3 RETRY_BASE_DELAY = 3.0 # seconds RETRY_MAX_DELAY = 30.0 # seconds _RETRYABLE_PATTERNS = ( "failed to get response", "operation was aborted", "timed out", "timeout", "502", "503", "504", "service unavailable", "overloaded", ) # Rate-limit errors: the SDK already retries 5 times internally, so our # outer retry loop should NOT retry these (it would just multiply the wait). _RATE_LIMIT_PATTERNS = ("429", "rate limit", "rate_limit", "too many requests") def _is_rate_limited(error_msg: str) -> bool: """Check if an error is a rate-limit (429) from the provider.""" lower = error_msg.lower() return any(p in lower for p in _RATE_LIMIT_PATTERNS) def _is_retryable(error_msg: str) -> bool: """Check if an error message indicates a transient failure worth retrying. Rate-limit errors are excluded — the SDK already burned through 5 retries internally, so adding more at our level just wastes time. """ if _is_rate_limited(error_msg): return False lower = error_msg.lower() return any(p in lower for p in _RETRYABLE_PATTERNS) def _backoff_delay(attempt: int) -> float: """Exponential backoff with jitter.""" delay = min(RETRY_BASE_DELAY * (2**attempt), RETRY_MAX_DELAY) return delay * (0.5 + random.random() * 0.5) TERMINAL_EVENT_TYPES = frozenset( { SessionEventType.SESSION_IDLE, SessionEventType.SESSION_ERROR, SessionEventType.SESSION_SHUTDOWN, } ) # Use PermissionHandler.approve_all from the SDK (single-user trusted environment) def format_prompt_with_history( history: list[dict[str, Any]], new_message: str, ) -> str: """Format conversation history + new message into a single prompt string. The history list typically includes the latest user message as the last item (just added to the store). We split context (all but last) from the actual prompt (last). """ if not history: return new_message # All but last message as context; last is the new user message context_messages = history[:-1] if not context_messages: return new_message lines: list[str] = [] for msg in context_messages: role = msg.get("role", "assistant") content = msg.get("content", "") if isinstance(content, list): text_parts = [p.get("text", "") for p in content if isinstance(p, dict) and p.get("type") == "input_text"] image_count = sum(1 for p in content if isinstance(p, dict) and p.get("type") == "input_image") content = " ".join(t for t in text_parts if t) if image_count: content += f" [+{image_count} image(s)]" if content else f"[{image_count} image(s)]" if not content: continue lines.append(f"[{role}]: {content}") if not lines: return new_message return "\n" + "\n".join(lines) + "\n\n\n" + new_message class CopilotRuntime: """Manages a long-lived CopilotClient and creates per-request sessions.""" def __init__(self) -> None: self._client: CopilotClient | None = None async def start(self) -> None: github_token = settings.GITHUB_TOKEN or "not-needed" self._client = CopilotClient( SubprocessConfig( cwd=settings.REPOS_DIR, github_token=github_token, use_logged_in_user=False, ) ) await self._client.start() logger.info( "Copilot SDK client started (cwd=%s, copilot_auth=%s)", settings.REPOS_DIR, "yes" if settings.GITHUB_TOKEN else "no", ) async def stop(self) -> None: if self._client: try: await self._client.stop() except BaseException: logger.warning("Copilot SDK client stop errors", exc_info=True) self._client = None logger.info("Copilot SDK client stopped") @property def client(self) -> CopilotClient: if self._client is None: raise RuntimeError("CopilotRuntime not started — call start() first") return self._client async def create_session( self, *, model: str, provider_config: dict[str, Any] | None, system_message: str, tools: list[Tool] | None = None, streaming: bool = True, ) -> CopilotSession: kwargs: dict[str, Any] = { "on_permission_request": PermissionHandler.approve_all, "model": model, "streaming": streaming, "working_directory": settings.REPOS_DIR, "system_message": {"mode": "replace", "content": system_message}, "tools": tools or None, "excluded_tools": ["task"], "skill_directories": _default_skill_directories(), } if provider_config is not None: kwargs["provider"] = provider_config return await self.client.create_session(**kwargs) async def stream_session( rt: CopilotRuntime, *, model: str, provider_config: dict[str, Any] | None, system_message: str, prompt: str, tools: list[Tool] | None = None, attachments: list[dict[str, Any]] | None = None, thread_id: str | None = None, ) -> AsyncIterator[SessionEvent]: """Create a session, send a prompt, yield events until idle, then destroy. Retries transparently on transient errors (model timeouts, 5xx, rate limits) with exponential backoff + jitter. """ # Reset advisor counter at the start of each run if thread_id is not None: try: from tools.advisor import _reset_advisor_state _reset_advisor_state(thread_id) except ImportError: pass last_error_msg = "" for attempt in range(MAX_SESSION_RETRIES): session = await rt.create_session( model=model, provider_config=provider_config, system_message=system_message, tools=tools, ) queue: asyncio.Queue[SessionEvent | None] = asyncio.Queue() def on_event(event: SessionEvent) -> None: queue.put_nowait(event) if event.type in TERMINAL_EVENT_TYPES: queue.put_nowait(None) # sentinel unsub = session.on(on_event) hit_retryable_error = False try: await session.send(prompt, attachments=attachments) while True: item = await queue.get() if item is None: break # Intercept SESSION_ERROR for retry if item.type == SessionEventType.SESSION_ERROR: error_msg = (item.data and item.data.message) or "Unknown session error" last_error_msg = error_msg retries_left = MAX_SESSION_RETRIES - attempt - 1 if _is_retryable(error_msg) and retries_left > 0: delay = _backoff_delay(attempt) logger.warning( "Retryable session error (attempt %d/%d, next in %.1fs): %s", attempt + 1, MAX_SESSION_RETRIES, delay, error_msg, ) hit_retryable_error = True break # Not retryable or out of retries — yield the error to caller if retries_left == 0 and _is_retryable(error_msg): logger.error( "All %d session retries exhausted: %s", MAX_SESSION_RETRIES, error_msg, ) yield item return yield item finally: unsub() await session.destroy() if hit_retryable_error: await asyncio.sleep(_backoff_delay(attempt)) continue # Completed normally return # Should not reach here, but safety net logger.error("stream_session fell through retry loop; last error: %s", last_error_msg) async def run_session( rt: CopilotRuntime, *, model: str, provider_config: dict[str, Any] | None, system_message: str, prompt: str, tools: list[Tool] | None = None, attachments: list[dict[str, Any]] | None = None, thread_id: str | None = None, ) -> Any: """Create a session, send a prompt, wait for completion, return raw result. Retries transparently on transient errors with exponential backoff + jitter. """ # Reset advisor counter at the start of each run if thread_id is not None: try: from tools.advisor import _reset_advisor_state _reset_advisor_state(thread_id) except ImportError: pass last_exc: Exception | None = None for attempt in range(MAX_SESSION_RETRIES): session = await rt.create_session( model=model, provider_config=provider_config, system_message=system_message, tools=tools, ) try: result = await session.send_and_wait(prompt, attachments=attachments, timeout=300) if result and result.data and result.data.content: return result # Fallback: check message history messages = await session.get_messages() for msg in reversed(messages): if msg.type == SessionEventType.ASSISTANT_MESSAGE and msg.data and msg.data.content: return msg return result except Exception as exc: last_exc = exc retries_left = MAX_SESSION_RETRIES - attempt - 1 if _is_retryable(str(exc)) and retries_left > 0: delay = _backoff_delay(attempt) logger.warning( "Retryable run_session error (attempt %d/%d, next in %.1fs): %s", attempt + 1, MAX_SESSION_RETRIES, delay, exc, ) await asyncio.sleep(delay) continue raise finally: await session.destroy() # All retries exhausted — raise the last exception if last_exc is not None: raise last_exc # Module-level singleton copilot = CopilotRuntime()