fix: normalize stale session models after provider switch — v0.50.99 (#751)

## Summary

Rebased-on-behalf of @likawa3b (originally PR #748 — stale base).

Sessions can outlive provider changes. When an old session still points to a model from a previous provider (e.g. `gemini-3.1-pro-preview` after switching the agent to OpenAI Codex), starting a chat hits the wrong backend and fails silently.

This PR adds a lightweight normalization pass:
- `_normalize_provider_id()` maps common prefixes to canonical provider IDs
- `_resolve_compatible_session_model()` checks the session model's provider against `active_provider` and returns the default model if they differ
- `_normalize_session_model_in_place()` is called at GET `/api/session` — corrects and persists stale models once
- Chat start also normalizes via `_resolve_compatible_session_model()` and returns `effective_model` in the response
- `messages.js` applies `effective_model` back to the UI/localStorage/dropdown if set

Closes #748

## Tests

1498 passed (2 pre-existing ordering failures unrelated to this PR; 5 new tests added in `test_provider_mismatch.py`).

**Original author:** @likawa3b
This commit is contained in:
nesquena-hermes
2026-04-19 23:22:26 -07:00
committed by GitHub
parent c68420d9aa
commit 7f16a41a31
4 changed files with 197 additions and 2 deletions

View File

@@ -17,6 +17,13 @@ from urllib.parse import parse_qs
logger = logging.getLogger(__name__)
_PROVIDER_ALIASES = {
"claude": "anthropic",
"gpt": "openai",
"gemini": "google",
"openai-codex": "openai",
}
from api.config import (
STATE_DIR,
SESSION_DIR,
@@ -158,6 +165,71 @@ def _check_csrf(handler) -> bool:
return False
def _normalize_provider_id(value: str | None) -> str:
raw = str(value or "").strip().lower()
if not raw:
return ""
if raw in _PROVIDER_ALIASES:
return _PROVIDER_ALIASES[raw]
for prefix, normalized in (
("openai-codex", "openai"),
("openai", "openai"),
("anthropic", "anthropic"),
("claude", "anthropic"),
("google", "google"),
("gemini", "google"),
("openrouter", "openrouter"),
("custom", "custom"),
):
if raw.startswith(prefix):
return normalized
return raw
def _resolve_compatible_session_model(model_id: str | None) -> tuple[str, bool]:
"""Return (effective_model, was_normalized) for persisted session models.
Sessions can outlive provider changes. When an older session still points at
a different provider namespace (for example `gemini/...` after switching the
agent to OpenAI Codex), reusing that stale model causes chat startup to hit
the wrong backend and fail. Normalize only obvious cross-provider mismatches;
preserve bare model IDs and OpenRouter/custom setups.
"""
catalog = get_available_models()
default_model = str(catalog.get("default_model") or DEFAULT_MODEL or "").strip()
model = str(model_id or "").strip()
if not model:
return default_model, bool(default_model)
active_provider = _normalize_provider_id(catalog.get("active_provider"))
if not active_provider or active_provider in {"custom", "openrouter"}:
return model, False
slash = model.find("/")
if slash < 0:
model_lower = model.lower()
for bare_prefix in ("gpt", "claude", "gemini"):
if model_lower.startswith(bare_prefix):
model_provider = _normalize_provider_id(bare_prefix)
if model_provider and model_provider != active_provider and default_model:
return default_model, True
return model, False
return model, False
model_provider = _normalize_provider_id(model[:slash])
if model_provider and model_provider != active_provider and default_model:
return default_model, True
return model, False
def _normalize_session_model_in_place(session) -> str:
effective_model, changed = _resolve_compatible_session_model(getattr(session, "model", None))
if changed and effective_model and getattr(session, "model", None) != effective_model:
session.model = effective_model
session.save(touch_updated_at=False)
return effective_model
from api.models import (
Session,
get_session,
@@ -481,6 +553,7 @@ def handle_get(handler, parsed) -> bool:
return j(handler, {"error": "session_id is required"}, status=400)
try:
s = get_session(sid)
_normalize_session_model_in_place(s)
raw = s.compact() | {
"messages": s.messages,
"tool_calls": getattr(s, "tool_calls", []),
@@ -2071,7 +2144,8 @@ def _handle_chat_start(handler, body):
workspace = str(resolve_trusted_workspace(body.get("workspace") or s.workspace))
except ValueError as e:
return bad(handler, str(e))
model = body.get("model") or s.model
requested_model = body.get("model") or s.model
model, normalized_model = _resolve_compatible_session_model(requested_model)
# Prevent duplicate runs in the same session while a stream is still active.
# This commonly happens after page refresh/reconnect races and can produce
# duplicated clarify cards for what appears to be a single user request.
@@ -2108,7 +2182,10 @@ def _handle_chat_start(handler, body):
daemon=True,
)
thr.start()
return j(handler, {"stream_id": stream_id, "session_id": s.session_id})
response = {"stream_id": stream_id, "session_id": s.session_id}
if normalized_model:
response["effective_model"] = model
return j(handler, response)
def _handle_chat_sync(handler, body):