#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import os
import sys
import socket
import threading
import urllib.error
import urllib.parse
import urllib.request
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from typing import Any

NAME = "omemai-gateway"
VERSION = "0.2.6"


class ConfigError(RuntimeError):
    pass


@dataclass
class Config:
    api_base: str
    api_key: str
    host: str
    port: int
    default_project_id: str | None
    auto_checkpoint_messages: int
    auto_checkpoint_minutes: int
    context_load_on_start: bool
    search_on_user_message: bool
    state_path: Path
    timeout: float
    client_source: str

    @classmethod
    def from_env(cls) -> "Config":
        base = os.environ.get("OMEMAI_API_BASE", "https://omemai.com").strip().rstrip("/")
        if not base:
            raise ConfigError("OMEMAI_API_BASE is empty")
        if not base.endswith("/api"):
            base += "/api"
        key = os.environ.get("OMEMAI_API_KEY", "").strip()
        if not key:
            raise ConfigError("OMEMAI_API_KEY is required")
        root = Path(__file__).resolve().parent
        return cls(
            api_base=base,
            api_key=key,
            host=os.environ.get("OMEMAI_GATEWAY_HOST", "127.0.0.1"),
            port=int(os.environ.get("OMEMAI_GATEWAY_PORT", "8777")),
            default_project_id=os.environ.get("OMEMAI_DEFAULT_PROJECT_ID") or None,
            auto_checkpoint_messages=int(os.environ.get("OMEMAI_AUTO_CHECKPOINT_MESSAGES", "10")),
            auto_checkpoint_minutes=int(os.environ.get("OMEMAI_AUTO_CHECKPOINT_MINUTES", "20")),
            context_load_on_start=os.environ.get("OMEMAI_CONTEXT_LOAD_ON_START", "true").lower() in {"1", "true", "yes", "on"},
            search_on_user_message=os.environ.get("OMEMAI_SEARCH_ON_USER_MESSAGE", "true").lower() in {"1", "true", "yes", "on"},
            state_path=Path(os.environ.get("OMEMAI_GATEWAY_STATE", str(root / "state.json"))),
            timeout=float(os.environ.get("OMEMAI_TIMEOUT_SECONDS", "120")),
            client_source=os.environ.get("OMEMAI_CLIENT_SOURCE", "gateway"),
        )

    @property
    def key_prefix(self) -> str:
        if self.api_key.startswith("omem_"):
            return self.api_key[:12] + "..."
        return "[configured]"


class StateStore:
    def __init__(self, path: Path):
        self.path = path
        self.path.parent.mkdir(parents=True, exist_ok=True)

    def load(self) -> dict[str, Any]:
        if not self.path.exists():
            return {}
        try:
            return json.loads(self.path.read_text())
        except Exception:
            return {"last_error": "state_load_failed"}

    def save(self, state: dict[str, Any]) -> None:
        safe = {k: v for k, v in state.items() if k not in {"api_key", "token", "password", "secret"}}
        tmp = self.path.with_name(f".{self.path.name}.{os.getpid()}.{threading.get_ident()}.tmp")
        tmp.write_text(json.dumps(safe, ensure_ascii=False, indent=2))
        os.replace(tmp, self.path)

    def update(self, **changes: Any) -> dict[str, Any]:
        state = self.load()
        state.update(changes)
        state["last_activity_at"] = now_iso()
        self.save(state)
        return state


class Client:
    def __init__(self, cfg: Config):
        self.cfg = cfg

    def request(self, method: str, path: str, payload: dict | None = None) -> Any:
        headers = {
            "Authorization": f"Bearer {self.cfg.api_key}",
            "X-OMemAI-Client": self.cfg.client_source,
            "User-Agent": f"{NAME}/{VERSION}",
        }
        data = None
        if payload is not None:
            headers["Content-Type"] = "application/json"
            data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
        req = urllib.request.Request(self.cfg.api_base + path, data=data, headers=headers, method=method)
        try:
            with urllib.request.urlopen(req, timeout=self.cfg.timeout) as resp:
                text = resp.read().decode("utf-8")
                return json.loads(text) if text else None
        except urllib.error.HTTPError as exc:
            text = exc.read().decode("utf-8", "ignore")[:500]
            try:
                detail = json.loads(text).get("detail", text)
            except Exception:
                detail = text or exc.reason
            raise RuntimeError(f"OMemAI API error {exc.code}: {detail}") from None


def now_iso() -> str:
    return datetime.now(timezone.utc).isoformat()


def parse_iso(value: str | None) -> datetime | None:
    if not value:
        return None
    try:
        return datetime.fromisoformat(value)
    except Exception:
        return None


def redact(text: str) -> str:
    key = os.environ.get("OMEMAI_API_KEY", "")
    if key:
        text = text.replace(key, "[REDACTED]")
    return text[:1000]


IMPORTANT_PATTERNS = (
    "关键事实", "细粒度事实", "请记住", "记住", "定义", "设定", "设置", "保存", "记录",
    "决策", "决定", "结论", "约定", "原则", "偏好", "规则", "架构", "方案",
    "风险", "阻塞", "blocker", "bug", "TODO", "待办", "下一步", "下次", "继续",
    "更正", "改为", "更新为", "作废", "废弃", "旧值", "不再使用", "replaced", "superseded",
)


def importance_signals(text: str) -> list[str]:
    signals = [p for p in IMPORTANT_PATTERNS if p.lower() in text.lower()]
    if "=" in text or "＝" in text:
        signals.append("assignment")
    # Chinese list-like facts: A、B、C after an explicit cue are usually worth checkpointing.
    if any(cue in text for cue in ("关键事实", "细粒度事实", "请记住", "记住", "定义", "约定")) and "、" in text:
        signals.append("fact_list")
    return list(dict.fromkeys(signals))[:12]


class Gateway:
    def __init__(self):
        self.cfg = Config.from_env()
        self.state = StateStore(self.cfg.state_path)
        self.client = Client(self.cfg)

    def _list_projects(self) -> list[dict[str, Any]]:
        projects = self.client.request("GET", "/agent/projects")
        return projects if isinstance(projects, list) else list(projects.get("projects", [])) if isinstance(projects, dict) else []

    def _project_brief(self, project: dict[str, Any] | None) -> dict[str, Any] | None:
        if not isinstance(project, dict):
            return None
        return {
            "id": project.get("id"),
            "name": project.get("name"),
            "slug": project.get("slug"),
            "status": project.get("status"),
        }

    def _get_project(self, project_id: str | None) -> dict[str, Any] | None:
        if not project_id:
            return None
        return self.client.request("GET", f"/agent/projects/{urllib.parse.quote(str(project_id))}")

    def _project_confirmation(self, project: dict[str, Any] | None, *, action: str = "current") -> dict[str, Any]:
        brief = self._project_brief(project)
        name = brief.get("name") if brief else None
        pid = brief.get("id") if brief else None
        verbs = {
            "created": "已新建并切换到项目",
            "switched": "已切换到项目",
            "started": "当前项目",
            "current": "当前项目",
            "write": "本轮记忆写入",
        }
        prefix = verbs.get(action, "当前项目")
        return {
            "current_project": brief,
            "memory_write_target": brief,
            "human_confirmation": f"{prefix}：{name}（project_id={pid}）。后续内容将写入该项目。" if brief else "未绑定项目。请先指定项目。",
            "recommended_commands": ["切到项目：项目名", "新建项目：项目名", "继续项目：项目名"],
        }

    def _resolve_project_arg(self, args: dict[str, Any], *, allow_current: bool = True) -> dict[str, Any]:
        state = self.state.load()
        explicit_id = args.get("project_id")
        name = str(args.get("project_name") or args.get("project") or "").strip()
        if explicit_id:
            project = self._get_project(str(explicit_id))
            return {"ok": True, "project": project, "source": "project_id"}
        if name:
            projects = self._list_projects()
            norm = name.casefold()
            exact = [p for p in projects if str(p.get("name", "")).casefold() == norm or str(p.get("slug", "")).casefold() == norm]
            if len(exact) == 1:
                return {"ok": True, "project": exact[0], "source": "project_name"}
            partial = [p for p in projects if norm in str(p.get("name", "")).casefold() or norm in str(p.get("slug", "")).casefold()]
            if len(partial) == 1:
                return {"ok": True, "project": partial[0], "source": "project_name_partial"}
            return {
                "ok": False,
                "needs_project_confirmation": True,
                "reason": "ambiguous_project" if partial else "project_not_found",
                "message": "项目名称不唯一，请明确选择 project_id。" if partial else "没有找到该项目。请确认项目名，或使用“新建项目：项目名”。",
                "candidates": [self._project_brief(p) for p in partial[:10]],
                "projects": [self._project_brief(p) for p in projects[:20]],
                "recommended_commands": ["切到项目：项目名", "新建项目：项目名", "继续项目：项目名"],
            }
        if allow_current:
            project_id = state.get("current_project_id") or self.cfg.default_project_id
            if project_id:
                return {"ok": True, "project": self._get_project(str(project_id)), "source": "current_or_default"}
        return {
            "ok": False,
            "needs_project": True,
            "message": "当前未绑定项目。请先指定：切到项目：项目名；或新建项目：项目名。",
            "projects": [self._project_brief(p) for p in self._list_projects()[:20]],
            "recommended_commands": ["切到项目：项目名", "新建项目：项目名", "继续项目：项目名"],
        }

    def _mentioned_other_projects(self, content: str, current_project_id: str | None) -> list[dict[str, Any]]:
        text = content.casefold()
        if not text.strip():
            return []
        out: list[dict[str, Any]] = []
        for p in self._list_projects():
            pid = str(p.get("id") or "")
            if current_project_id and pid == str(current_project_id):
                continue
            names = [str(p.get("name") or ""), str(p.get("slug") or "")]
            if any(n and len(n) >= 2 and n.casefold() in text for n in names):
                out.append(self._project_brief(p) or {})
            if len(out) >= 5:
                break
        return out

    def projects(self, _: dict[str, Any]) -> dict[str, Any]:
        return {"projects": self._list_projects(), "recommended_commands": ["切到项目：项目名", "新建项目：项目名", "继续项目：项目名"]}

    def projects_create(self, args: dict[str, Any]) -> dict[str, Any]:
        project = self.client.request("POST", "/agent/projects", {"name": require(args, "name"), "slug": args.get("slug"), "description": args.get("description")})
        self.state.update(current_project_id=project["id"], current_session_id=None, last_error=None, current_project_name=project.get("name"))
        confirmation = self._project_confirmation(project, action="created")
        return {"project": project, "current_project_id": project["id"], **confirmation}

    def projects_current(self, _: dict[str, Any]) -> dict[str, Any]:
        state = self.state.load()
        resolved = self._resolve_project_arg({}, allow_current=True)
        if not resolved.get("ok"):
            return resolved
        project = resolved["project"]
        return {"project": project, "current_project_id": project.get("id"), "current_session_id": state.get("current_session_id"), **self._project_confirmation(project, action="current")}

    def projects_set_current(self, args: dict[str, Any]) -> dict[str, Any]:
        resolved = self._resolve_project_arg(args, allow_current=False)
        if not resolved.get("ok"):
            return resolved
        project = resolved["project"]
        project_id = str(project["id"])
        self.state.update(current_project_id=project_id, current_project_name=project.get("name"), current_session_id=None, message_count_since_checkpoint=0, last_error=None, last_project_switch_at=now_iso())
        confirmation = self._project_confirmation(project, action="switched")
        return {"project": project, "current_project_id": project_id, **confirmation}

    def projects_rename(self, args: dict[str, Any]) -> dict[str, Any]:
        state = self._require_session(allow_no_session=True)
        project_id = args.get("project_id") or state.get("current_project_id") or self.cfg.default_project_id
        if not project_id:
            raise ValueError("project_id is required when no current/default project is set")
        payload = {k: v for k, v in {"name": args.get("name"), "description": args.get("description"), "status": args.get("status")}.items() if v is not None}
        if not payload:
            raise ValueError("At least one of name, description, or status is required")
        project = self.client.request("PATCH", f"/agent/projects/{urllib.parse.quote(str(project_id))}", payload)
        self.state.update(current_project_id=project["id"], last_error=None)
        return {"project": project, "current_project_id": project["id"]}

    def status(self, _: dict[str, Any]) -> dict[str, Any]:
        state = self.state.load()
        api_ok = False
        health: Any = None
        last_error = state.get("last_error")
        try:
            health = self.client.request("GET", "/agent/health/me")
            api_ok = True
            last_error = None
        except Exception as exc:
            last_error = redact(str(exc))
            self.state.update(last_error=last_error)
        current_project = None
        try:
            current_project = self._get_project(str(state.get("current_project_id") or self.cfg.default_project_id)) if (state.get("current_project_id") or self.cfg.default_project_id) else None
        except Exception:
            current_project = None
        return {
            "gateway": {"name": NAME, "version": VERSION, "state_path": str(self.cfg.state_path), "checkpoint_mode": "queued-local-background", "automation_mode": "smart-default", "project_boundary_mode": "confirm-visible"},
            "automation": {
                "smart_checkpoint": True,
                "important_pending_checkpoint": bool(state.get("important_pending_checkpoint")),
                "last_importance_signals": state.get("last_importance_signals", []),
                "recall_health_status": state.get("last_recall_health_status"),
                "explain_states": True,
            },
            "api_ok": api_ok,
            "current_project_id": (current_project or {}).get("id") or state.get("current_project_id"),
            "current_project": self._project_brief(current_project),
            "memory_write_target": self._project_brief(current_project),
            "project_confirmation": self._project_confirmation(current_project, action="current"),
            "current_session_id": state.get("current_session_id"),
            "last_checkpoint_at": state.get("last_checkpoint_at"),
            "last_checkpoint_status": state.get("last_checkpoint_status"),
            "last_checkpoint_job_id": state.get("last_checkpoint_job_id"),
            "last_checkpoint_requested_at": state.get("last_checkpoint_requested_at"),
            "last_close_status": state.get("last_close_status"),
            "last_close_job_id": state.get("last_close_job_id"),
            "last_close_requested_at": state.get("last_close_requested_at"),
            "last_close_at": state.get("last_close_at"),
            "last_error": last_error,
            "agent_health": health,
        }

    def _context_key_fact_results(self, context: Any, limit: int = 30) -> list[dict[str, Any]]:
        if not isinstance(context, dict):
            return []
        out: list[dict[str, Any]] = []
        seen: set[str] = set()
        buckets = [
            ("current_facts", "recent_current_fact", 12.0, "Recent current fact"),
            ("key_facts", "recent_key_fact", 10.0, "Recent key fact"),
            ("superseded_facts", "recent_superseded_fact", 7.0, "Recent superseded fact"),
        ]
        for sess in context.get("recent_sessions") or []:
            if not isinstance(sess, dict):
                continue
            session_id = str(sess.get("session_id") or "")
            for field, source_type, score, title in buckets:
                for fact in sess.get(field) or []:
                    text = str(fact).strip()
                    if not text or text.lower() in seen:
                        continue
                    seen.add(text.lower())
                    item = {
                        "source_type": source_type,
                        "source": f"session:{session_id}",
                        "source_title": title,
                        "snippet": text,
                        "score": score,
                        "session_id": session_id or None,
                    }
                    if source_type == "recent_current_fact":
                        item["fact_status"] = "current"
                    elif source_type == "recent_superseded_fact":
                        item["fact_status"] = "superseded"
                    out.append(item)
                    if len(out) >= limit:
                        return out
        return out

    def _merge_memory_results(self, primary: list[Any], extra: list[Any], limit: int = 20) -> list[Any]:
        out: list[Any] = []
        seen: set[str] = set()
        for item in [*(primary or []), *(extra or [])]:
            key = json.dumps(item, ensure_ascii=False, sort_keys=True)[:1000] if isinstance(item, dict) else str(item)
            if key in seen:
                continue
            seen.add(key)
            out.append(item)
            if len(out) >= limit:
                break
        return out

    def _recall_brief(self, context: Any, memory_results: list[Any]) -> dict[str, Any]:
        current: list[str] = []
        superseded: list[str] = []
        key_facts: list[str] = []
        next_steps: list[str] = []
        if isinstance(context, dict):
            for sess in context.get("recent_sessions") or []:
                if not isinstance(sess, dict):
                    continue
                current.extend(str(x) for x in (sess.get("current_facts") or []) if str(x).strip())
                superseded.extend(str(x) for x in (sess.get("superseded_facts") or []) if str(x).strip())
                key_facts.extend(str(x) for x in (sess.get("key_facts") or []) if str(x).strip())
                hint = str(sess.get("next_start_hint") or "").strip()
                if hint:
                    next_steps.append(hint)
        def uniq(values: list[str], limit: int) -> list[str]:
            out: list[str] = []
            seen: set[str] = set()
            for v in values:
                if v.lower() in seen:
                    continue
                seen.add(v.lower())
                out.append(v)
                if len(out) >= limit:
                    break
            return out
        return {
            "recall_status": "ready" if (current or key_facts or memory_results) else "empty_or_new_project",
            "current_facts": uniq(current, 12),
            "key_facts": uniq(key_facts, 12),
            "superseded_facts": uniq(superseded, 8),
            "next_steps": uniq(next_steps, 5),
            "project_boundary": "Results are scoped to the current/default project unless scope=all is explicitly used.",
            "instructions": "Prefer current_facts. Treat superseded_facts as history only; use replaced_by/current facts for final answers.",
        }

    def _state_explanation(self, *, session_saved: bool = False, checkpoint: Any = None, searchable: str | None = None, importance: list[str] | None = None) -> dict[str, Any]:
        checkpoint_status = None
        if isinstance(checkpoint, dict):
            checkpoint_status = checkpoint.get("status") or checkpoint.get("last_checkpoint_status")
        return {
            "session_write_status": "saved" if session_saved else "not_written",
            "memory_settlement_status": checkpoint_status or ("pending_checkpoint" if importance else "not_required_yet"),
            "searchable_status": searchable or ("pending_checkpoint" if importance else "current_index_only"),
            "importance_signals": importance or [],
            "human_hint": "重要内容已保存到会话；系统会自动排队 checkpoint。完成后会进入可搜索/可回忆层。" if importance else "内容已保存；未检测到必须立即沉淀的强信号。",
        }

    def _queue_checkpoint_if_needed(self, reason: str, importance: list[str]) -> dict[str, Any] | None:
        if not importance:
            return None
        state = self._require_session()
        running = state.get("last_checkpoint_status") in {"queued", "running"}
        self.state.update(important_pending_checkpoint=True, last_importance_signals=importance, last_important_reason=reason)
        if running:
            return {"status": "already_queued", "reason": reason, "importance_signals": importance}
        result = self.checkpoint({"reason": reason, "importance_signals": importance})
        self.state.update(important_pending_checkpoint=False)
        return result

    def start_task(self, args: dict[str, Any]) -> dict[str, Any]:
        resolved = self._resolve_project_arg(args, allow_current=True)
        task = str(args.get("task") or "")
        if not resolved.get("ok"):
            return {**resolved, "session_write_status": "not_written", "instructions": "先确认项目，再调用 gateway.start_task；不要把第一条重要消息写入不确定项目。"}
        project = resolved["project"]
        project_id = str(project["id"])
        recent = self.client.request("GET", "/agent/sessions/recent?" + urllib.parse.urlencode({"project_id": project_id, "include_closed": "false"}))
        started = self.client.request("POST", "/agent/session/start", {"project_id": project_id, "session_id": args.get("client_session_id"), "resume": True})
        session_id = started["session"]["session_id"]
        context = self.client.request("GET", f"/agent/projects/{urllib.parse.quote(str(project_id))}/context") if self.cfg.context_load_on_start else None
        memory_results = []
        if self.cfg.search_on_user_message and task:
            memory_results = self.client.request("POST", "/agent/memory/search", {"project_id": project_id, "query": task[:200], "scope": "current_project", "limit": 10})
        memory_results = self._merge_memory_results(memory_results, self._context_key_fact_results(context))
        recall_brief = self._recall_brief(context, memory_results)
        confirmation = self._project_confirmation(project, action="started")
        self.state.update(current_project_id=project_id, current_project_name=project.get("name"), current_session_id=session_id, last_started_at=now_iso(), message_count_since_checkpoint=0, last_error=None, last_recall_health_status=recall_brief.get("recall_status"))
        return {
            "session_id": session_id,
            **confirmation,
            "resumed_existing": started.get("resumed_existing"),
            "context": context,
            "recent_sessions": recent,
            "memory_results": memory_results,
            "recall_brief": recall_brief,
            "state_explanation": self._state_explanation(session_saved=True, searchable="ready_for_existing_memory"),
            "instructions": "Use recall_brief.current_facts before answering. Treat recall_brief.superseded_facts as history only. Gateway will auto-checkpoint important facts/decisions/corrections.",
        }

    def user_message(self, args: dict[str, Any]) -> dict[str, Any]:
        state = self._require_session()
        content = require(args, "content")
        project_id = state["current_project_id"]
        session_id = state["current_session_id"]
        current_project = self._get_project(str(project_id))
        mentioned_other_projects = self._mentioned_other_projects(content, str(project_id))
        result = self.client.request("POST", "/agent/session/message", {"project_id": project_id, "session_id": session_id, "role": "user", "content": content, "metadata": args.get("metadata") or {}})
        count = int(state.get("message_count_since_checkpoint") or 0) + 1
        importance = importance_signals(content)
        self.state.update(message_count_since_checkpoint=count, last_importance_signals=importance, important_pending_checkpoint=bool(importance))
        memory = []
        if self.cfg.search_on_user_message:
            memory = self.client.request("POST", "/agent/memory/search", {"project_id": project_id, "query": content[:200], "scope": "current_project", "limit": 10})
            try:
                context = self.client.request("GET", f"/agent/projects/{urllib.parse.quote(str(project_id))}/context")
                memory = self._merge_memory_results(memory, self._context_key_fact_results(context))
            except Exception:
                pass
        should_checkpoint = bool(result.get("should_checkpoint")) or self._should_checkpoint(count) or bool(importance)
        return {
            "ok": True,
            "relevant_memory": memory,
            "should_checkpoint": should_checkpoint,
            "auto_checkpoint_reason": "important_user_message" if importance else None,
            "current_project": self._project_brief(current_project),
            "memory_write_target": self._project_brief(current_project),
            "project_confirmation": self._project_confirmation(current_project, action="write"),
            "project_boundary_warning": {
                "mentioned_other_projects": mentioned_other_projects,
                "message": "检测到你提到了其他项目；本条消息仍写入当前项目。如需切换，请先调用 gateway.projects.set_current。",
                "recommended_commands": ["切到项目：项目名", "新建项目：项目名", "继续项目：项目名"],
            } if mentioned_other_projects else None,
            "state_explanation": self._state_explanation(session_saved=True, importance=importance),
            "warnings": ["检测到其他项目名；未自动切换，避免串写。"] if mentioned_other_projects else [],
        }

    def assistant_message(self, args: dict[str, Any]) -> dict[str, Any]:
        state = self._require_session()
        content = require(args, "content")
        project_id = state["current_project_id"]
        session_id = state["current_session_id"]
        current_project = self._get_project(str(project_id))
        result = self.client.request("POST", "/agent/session/message", {"project_id": project_id, "session_id": session_id, "role": "assistant", "content": content, "metadata": args.get("metadata") or {}})
        count = int(state.get("message_count_since_checkpoint") or 0) + 1
        state_after = self.state.update(message_count_since_checkpoint=count)
        checkpoint_result = None
        pending_importance = list(state_after.get("last_importance_signals") or []) if state_after.get("important_pending_checkpoint") else []
        assistant_importance = importance_signals(content)
        combined_importance = list(dict.fromkeys([*pending_importance, *assistant_importance]))
        if combined_importance:
            checkpoint_result = self._queue_checkpoint_if_needed("important_turn", combined_importance)
        elif bool(result.get("should_checkpoint")) or self._should_checkpoint(count):
            checkpoint_result = self.checkpoint({"reason": "interval"})
        return {
            "ok": True,
            "current_project": self._project_brief(current_project),
            "memory_write_target": self._project_brief(current_project),
            "project_confirmation": self._project_confirmation(current_project, action="write"),
            "checkpoint_result": checkpoint_result,
            "next_session_resume_prompt": checkpoint_result.get("next_session_resume_prompt") if isinstance(checkpoint_result, dict) else None,
            "state_explanation": self._state_explanation(session_saved=True, checkpoint=checkpoint_result, searchable="pending_checkpoint" if checkpoint_result else "current_index_only", importance=combined_importance),
        }

    def _checkpoint_payload(self, state: dict[str, Any], args: dict[str, Any]) -> dict[str, Any]:
        # If extraction is omitted, send a lightweight structured payload instead
        # of forcing the cloud API to call an LLM during checkpoint. This keeps
        # Gateway checkpoints fast and avoids tying up later local RPC calls.
        summary = args.get("summary")
        extraction = args.get("extraction")
        if extraction is None:
            reason = str(args.get("reason", "manual"))
            session_id = str(state["current_session_id"])
            summary = summary or f"Gateway checkpoint saved session {session_id}. Reason: {reason}."
            extraction = {
                "session_summary": summary,
                "status_changes": [summary],
                "next_start_hint": "继续前先读取 OMemAI 项目 context；需要完整上下文时查看会话日志。",
            }
        return {"project_id": state["current_project_id"], "session_id": state["current_session_id"], "summary": summary, "extraction": extraction}

    def _checkpoint_sync(self, args: dict[str, Any]) -> dict[str, Any]:
        state = self._require_session()
        payload = self._checkpoint_payload(state, args)
        self.state.update(last_checkpoint_status="running", last_checkpoint_requested_at=now_iso(), last_checkpoint_reason=args.get("reason", "manual"))
        result = self.client.request("POST", "/agent/session/checkpoint", payload)
        self.state.update(message_count_since_checkpoint=0, last_checkpoint_at=now_iso(), last_checkpoint_status="completed", last_checkpoint_reason=args.get("reason", "manual"), important_pending_checkpoint=False, last_searchable_status="searchable_after_checkpoint")
        result.setdefault("session_write_status", "saved")
        result.setdefault("searchable_status", "searchable_after_checkpoint")
        return result

    def _checkpoint_background(self, args: dict[str, Any]) -> None:
        try:
            self._checkpoint_sync(args)
        except Exception as exc:
            self.state.update(last_checkpoint_status="failed", last_error=redact(str(exc)))

    def checkpoint(self, args: dict[str, Any]) -> dict[str, Any]:
        wait = bool(args.get("wait", False))
        if wait:
            return self._checkpoint_sync(args)
        state = self._require_session()
        job_id = f"checkpoint-{datetime.now(timezone.utc).strftime('%Y%m%d-%H%M%S')}-{state.get('current_session_id')}"
        self.state.update(last_checkpoint_status="queued", last_checkpoint_job_id=job_id, last_checkpoint_requested_at=now_iso(), last_checkpoint_reason=args.get("reason", "manual"))
        thread = threading.Thread(target=self._checkpoint_background, args=(dict(args),), daemon=True)
        thread.start()
        return {
            "status": "queued",
            "job_id": job_id,
            "session_write_status": "saved",
            "searchable_status": "pending_checkpoint",
            "instructions": "Checkpoint accepted. Use gateway.status to confirm last_checkpoint_status=completed before relying on search for the newest messages.",
        }

    def _close_sync(self, args: dict[str, Any]) -> dict[str, Any]:
        state = self._require_session()
        checkpoint = self.checkpoint({"reason": "task_done", "summary": args.get("summary"), "wait": True})
        close = self.client.request("POST", "/agent/session/close", {"project_id": state["current_project_id"], "session_id": state["current_session_id"], "summary": args.get("summary") or "Gateway closed task.", "messages": [], "create_proposal": False, "proposal": {"session_summary": args.get("summary") or "Gateway closed task."}})
        self.state.update(current_session_id=None, message_count_since_checkpoint=0, last_close_status="completed", last_close_at=now_iso())
        return {"checkpoint": checkpoint, "close": close, "final_resume_prompt": checkpoint.get("next_session_resume_prompt")}

    def _close_background(self, args: dict[str, Any]) -> None:
        try:
            self._close_sync(args)
        except Exception as exc:
            self.state.update(last_close_status="failed", last_error=redact(str(exc)))

    def close_task(self, args: dict[str, Any]) -> dict[str, Any]:
        if bool(args.get("wait", False)):
            return self._close_sync(args)
        state = self._require_session()
        job_id = f"close-{datetime.now(timezone.utc).strftime('%Y%m%d-%H%M%S')}-{state.get('current_session_id')}"
        self.state.update(last_close_status="queued", last_close_job_id=job_id, last_close_requested_at=now_iso())
        thread = threading.Thread(target=self._close_background, args=(dict(args),), daemon=True)
        thread.start()
        return {
            "status": "queued",
            "job_id": job_id,
            "session_write_status": "saved",
            "instructions": "Close accepted. Use gateway.status to confirm last_close_status=completed.",
        }

    def recall_health(self, args: dict[str, Any]) -> dict[str, Any]:
        state = self._require_session(allow_no_session=True)
        project_id = args.get("project_id") or state.get("current_project_id") or self.cfg.default_project_id
        if not project_id:
            return {"status": "needs_project", "score": 0, "warnings": ["No current/default project set."]}
        context = self.client.request("GET", f"/agent/projects/{urllib.parse.quote(str(project_id))}/context")
        brief = self._recall_brief(context, [])
        warnings: list[str] = []
        checks = {
            "api_ok": True,
            "has_context": isinstance(context, dict),
            "has_recent_recall": bool(brief.get("current_facts") or brief.get("key_facts") or brief.get("superseded_facts")),
            "checkpoint_not_stuck": state.get("last_checkpoint_status") not in {"running", "queued"} or bool(state.get("last_checkpoint_requested_at")),
            "project_boundary_present": True,
        }
        if state.get("last_checkpoint_status") in {"running", "queued"}:
            warnings.append("Checkpoint is still queued/running; newest facts may not be searchable yet.")
        if not checks["has_recent_recall"]:
            warnings.append("No recent key/current facts found; this may be a new project or recall extraction may need validation.")
        score = sum(20 for ok in checks.values() if ok)
        status = "healthy" if score >= 80 and not warnings else "watch" if score >= 60 else "needs_attention"
        result = {"status": status, "score": score, "checks": checks, "warnings": warnings, "recall_brief": brief}
        self.state.update(last_recall_health_status=status, last_recall_health_at=now_iso())
        return result

    def search(self, args: dict[str, Any]) -> dict[str, Any]:
        state = self._require_session(allow_no_session=True)
        project_id = args.get("project_id") or state.get("current_project_id") or self.cfg.default_project_id
        if not project_id:
            raise ValueError("project_id is required when no current/default project is set")
        project = self._get_project(str(project_id))
        results = self.client.request("POST", "/agent/memory/search", {"project_id": project_id, "query": require(args, "query"), "scope": args.get("scope", "current_project"), "limit": int(args.get("limit", 10))})
        return {"results": results, "current_project": self._project_brief(project), "recall_source": self._project_brief(project), "project_confirmation": self._project_confirmation(project, action="current")}

    def _require_session(self, allow_no_session: bool = False) -> dict[str, Any]:
        state = self.state.load()
        if not allow_no_session and not (state.get("current_project_id") and state.get("current_session_id")):
            raise RuntimeError("No active Gateway session. Call gateway.start_task first.")
        return state

    def _should_checkpoint(self, count: int) -> bool:
        if count >= self.cfg.auto_checkpoint_messages:
            return True
        state = self.state.load()
        last = parse_iso(state.get("last_checkpoint_at"))
        if not last:
            started = parse_iso(state.get("last_started_at"))
            last = started
        if last and datetime.now(timezone.utc) - last >= timedelta(minutes=self.cfg.auto_checkpoint_minutes):
            return True
        return False


def require(args: dict[str, Any], key: str) -> str:
    value = args.get(key)
    if value is None or value == "":
        raise ValueError(f"Missing required argument: {key}")
    return str(value)


METHODS = {
    "gateway.status": "status",
    "gateway.projects": "projects",
    "gateway.projects.create": "projects_create",
    "gateway.projects.rename": "projects_rename",
    "gateway.projects.current": "projects_current",
    "gateway.projects.set_current": "projects_set_current",
    "gateway.start_task": "start_task",
    "gateway.user_message": "user_message",
    "gateway.assistant_message": "assistant_message",
    "gateway.checkpoint": "checkpoint",
    "gateway.close_task": "close_task",
    "gateway.recall_health": "recall_health",
    "gateway.search": "search",
}


def handle_rpc(message: dict[str, Any]) -> dict[str, Any]:
    mid = message.get("id")
    method = message.get("method")
    try:
        if method not in METHODS:
            return {"jsonrpc": "2.0", "id": mid, "error": {"code": -32601, "message": f"Method not found: {method}"}}
        gateway = Gateway()
        func = getattr(gateway, METHODS[method])
        result = func(message.get("params") or {})
        return {"jsonrpc": "2.0", "id": mid, "result": result}
    except Exception as exc:
        return {"jsonrpc": "2.0", "id": mid, "error": {"code": -32000, "message": redact(str(exc))}}


class Handler(BaseHTTPRequestHandler):
    def do_POST(self) -> None:
        try:
            length = int(self.headers.get("content-length", "0"))
            raw = self.rfile.read(length).decode("utf-8")
            try:
                response = handle_rpc(json.loads(raw))
                body = json.dumps(response, ensure_ascii=False).encode("utf-8")
                self.send_response(200)
            except Exception as exc:
                body = json.dumps({"jsonrpc": "2.0", "id": None, "error": {"code": -32700, "message": redact(str(exc))}}, ensure_ascii=False).encode("utf-8")
                self.send_response(400)
            self.send_header("content-type", "application/json; charset=utf-8")
            self.send_header("content-length", str(len(body)))
            self.end_headers()
            self.wfile.write(body)
        except (BrokenPipeError, ConnectionResetError, socket.timeout):
            return

    def log_message(self, fmt: str, *args: Any) -> None:
        try:
            sys.stderr.write("gateway-http " + fmt % args + "\n")
        except (BrokenPipeError, ConnectionResetError):
            return


class QuietThreadingHTTPServer(ThreadingHTTPServer):
    def handle_error(self, request: Any, client_address: Any) -> None:
        exc_type, exc, _ = sys.exc_info()
        if exc_type in {BrokenPipeError, ConnectionResetError, socket.timeout}:
            return
        super().handle_error(request, client_address)


def run_http(host: str, port: int) -> None:
    try:
        cfg = Config.from_env()
        print(f"{NAME} listening on {host}:{port}; api_base={cfg.api_base}; key_prefix={cfg.key_prefix}; state={cfg.state_path}", file=sys.stderr)
    except ConfigError as exc:
        print(f"{NAME} config error: {exc}", file=sys.stderr)
    QuietThreadingHTTPServer((host, port), Handler).serve_forever()


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", default=os.environ.get("OMEMAI_GATEWAY_HOST", "127.0.0.1"))
    parser.add_argument("--port", type=int, default=int(os.environ.get("OMEMAI_GATEWAY_PORT", "8777")))
    args = parser.parse_args()
    run_http(args.host, args.port)


if __name__ == "__main__":
    main()
