perf: TTL cache for model list + incremental session index (#780)
Fixes AWS IMDS timeout on model dropdown. Incremental index writes. Co-authored-by: starship-s <starship-s@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,11 @@
|
|||||||
# Hermes Web UI -- Changelog
|
# Hermes Web UI -- Changelog
|
||||||
|
|
||||||
|
## [v0.50.121] — 2026-04-20
|
||||||
|
|
||||||
|
### Performance
|
||||||
|
- **Model list no longer re-scans on every session load** — `get_available_models()` now caches its result for 60 seconds (configurable via `_AVAILABLE_MODELS_CACHE_TTL`). Config file changes (mtime) invalidate the cache immediately. This eliminates the ~4s AWS IMDS timeout that blocked the model dropdown on every page load for users on EC2 without an IAM role. Thread-safe via a dedicated lock; callers receive a `copy.deepcopy()` so mutations don't pollute the cache. (credit: @starship-s)
|
||||||
|
- **Session saves no longer trigger a full O(n) index rebuild** — `_write_session_index()` now does an incremental read-patch-write of the existing index JSON when called from `Session.save()`, rather than re-scanning every session file on disk. Falls back to a full rebuild when the index is missing or corrupt. Atomic write via `.tmp` + `os.replace()`. At 100+ sessions this is a meaningful speedup. (credit: @starship-s)
|
||||||
|
|
||||||
## [v0.50.120] — 2026-04-20
|
## [v0.50.120] — 2026-04-20
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ Discovery order for all paths:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -802,6 +803,26 @@ def set_hermes_default_model(model_id: str) -> dict:
|
|||||||
return get_available_models()
|
return get_available_models()
|
||||||
|
|
||||||
|
|
||||||
|
# ── TTL cache for get_available_models() ─────────────────────────────────────
|
||||||
|
_available_models_cache: dict | None = None
|
||||||
|
_available_models_cache_ts: float = 0.0
|
||||||
|
_AVAILABLE_MODELS_CACHE_TTL: float = 60.0 # seconds — refresh at most once per minute
|
||||||
|
_available_models_cache_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def invalidate_models_cache():
|
||||||
|
"""Force the TTL cache for get_available_models() to be cleared.
|
||||||
|
|
||||||
|
Call this after modifying config.cfg in-memory (e.g. in tests) so
|
||||||
|
the next call to get_available_models() picks up the changes rather
|
||||||
|
than returning a stale cached result.
|
||||||
|
"""
|
||||||
|
global _available_models_cache, _available_models_cache_ts
|
||||||
|
with _available_models_cache_lock:
|
||||||
|
_available_models_cache = None
|
||||||
|
_available_models_cache_ts = 0.0
|
||||||
|
|
||||||
|
|
||||||
def get_available_models() -> dict:
|
def get_available_models() -> dict:
|
||||||
"""
|
"""
|
||||||
Return available models grouped by provider.
|
Return available models grouped by provider.
|
||||||
@@ -821,12 +842,24 @@ def get_available_models() -> dict:
|
|||||||
# Reload config from disk if config.yaml has changed since last load.
|
# Reload config from disk if config.yaml has changed since last load.
|
||||||
# This ensures CLI model changes are picked up on page refresh without
|
# This ensures CLI model changes are picked up on page refresh without
|
||||||
# a server restart, while avoiding clearing in-memory mocks during tests. (#585)
|
# a server restart, while avoiding clearing in-memory mocks during tests. (#585)
|
||||||
|
# Must run BEFORE the TTL check so config edits within the 60s window are visible.
|
||||||
|
global _available_models_cache, _available_models_cache_ts
|
||||||
|
with _available_models_cache_lock:
|
||||||
try:
|
try:
|
||||||
_current_mtime = Path(_get_config_path()).stat().st_mtime
|
_current_mtime = Path(_get_config_path()).stat().st_mtime
|
||||||
except OSError:
|
except OSError:
|
||||||
_current_mtime = 0.0
|
_current_mtime = 0.0
|
||||||
|
# Note: env-var changes (e.g. API key rotation) are not detected by mtime;
|
||||||
|
# cache will be stale for up to TTL seconds in that case.
|
||||||
if _current_mtime != _cfg_mtime:
|
if _current_mtime != _cfg_mtime:
|
||||||
reload_config()
|
reload_config()
|
||||||
|
# Config changed — force cache invalidation
|
||||||
|
_available_models_cache = None
|
||||||
|
_available_models_cache_ts = 0.0
|
||||||
|
# Serve from TTL cache if fresh.
|
||||||
|
now = time.monotonic()
|
||||||
|
if _available_models_cache is not None and (now - _available_models_cache_ts) < _AVAILABLE_MODELS_CACHE_TTL:
|
||||||
|
return copy.deepcopy(_available_models_cache)
|
||||||
active_provider = None
|
active_provider = None
|
||||||
default_model = get_effective_default_model(cfg)
|
default_model = get_effective_default_model(cfg)
|
||||||
groups = []
|
groups = []
|
||||||
@@ -1277,11 +1310,16 @@ def get_available_models() -> dict:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
result = {
|
||||||
"active_provider": active_provider,
|
"active_provider": active_provider,
|
||||||
"default_model": default_model,
|
"default_model": default_model,
|
||||||
"groups": groups,
|
"groups": groups,
|
||||||
}
|
}
|
||||||
|
# Cache the result for TTL seconds
|
||||||
|
with _available_models_cache_lock:
|
||||||
|
_available_models_cache = result
|
||||||
|
_available_models_cache_ts = time.monotonic()
|
||||||
|
return copy.deepcopy(result)
|
||||||
|
|
||||||
|
|
||||||
# ── Static file path ─────────────────────────────────────────────────────────
|
# ── Static file path ─────────────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ Hermes Web UI -- Session model and in-memory session store.
|
|||||||
import collections
|
import collections
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -19,8 +20,16 @@ from api.workspace import get_last_workspace
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _write_session_index():
|
def _write_session_index(updates=None):
|
||||||
"""Rebuild the session index file for O(1) future reads."""
|
"""Update the session index file.
|
||||||
|
|
||||||
|
When *updates* is provided (a list of Session objects whose compact
|
||||||
|
entries should be refreshed), this does a targeted in-place update of
|
||||||
|
the existing index — O(1) for single-session changes. When *updates*
|
||||||
|
is None, a full rebuild is performed (used on startup / first call).
|
||||||
|
"""
|
||||||
|
# Lazy full-rebuild path — used when index doesn't exist yet.
|
||||||
|
if updates is None or not SESSION_INDEX_FILE.exists():
|
||||||
entries = []
|
entries = []
|
||||||
for p in SESSION_DIR.glob('*.json'):
|
for p in SESSION_DIR.glob('*.json'):
|
||||||
if p.name.startswith('_'): continue
|
if p.name.startswith('_'): continue
|
||||||
@@ -34,7 +43,40 @@ def _write_session_index():
|
|||||||
if not any(e['session_id'] == s.session_id for e in entries):
|
if not any(e['session_id'] == s.session_id for e in entries):
|
||||||
entries.append(s.compact())
|
entries.append(s.compact())
|
||||||
entries.sort(key=lambda s: s['updated_at'], reverse=True)
|
entries.sort(key=lambda s: s['updated_at'], reverse=True)
|
||||||
SESSION_INDEX_FILE.write_text(json.dumps(entries, ensure_ascii=False, indent=2), encoding='utf-8')
|
_tmp = SESSION_INDEX_FILE.with_suffix('.tmp')
|
||||||
|
_tmp.write_text(json.dumps(entries, ensure_ascii=False, indent=2), encoding='utf-8')
|
||||||
|
os.replace(_tmp, SESSION_INDEX_FILE)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Fast path: patch existing index with updated sessions.
|
||||||
|
# This avoids loading every session file on every single save().
|
||||||
|
# LOCK covers the entire read-patch-write to prevent concurrent save() calls
|
||||||
|
# from both reading the same baseline and one losing its update.
|
||||||
|
_fallback = False
|
||||||
|
try:
|
||||||
|
with LOCK:
|
||||||
|
existing = json.loads(SESSION_INDEX_FILE.read_text(encoding='utf-8'))
|
||||||
|
# Build lookup of updated entries
|
||||||
|
updated_map = {s.session_id: s.compact() for s in updates}
|
||||||
|
existing_ids = {e.get('session_id') for e in existing}
|
||||||
|
# Add any updated entries not yet in the index
|
||||||
|
for sid, entry in updated_map.items():
|
||||||
|
if sid not in existing_ids:
|
||||||
|
existing.append(entry)
|
||||||
|
# Replace matching entries in-place
|
||||||
|
for i, e in enumerate(existing):
|
||||||
|
sid = e.get('session_id')
|
||||||
|
if sid in updated_map:
|
||||||
|
existing[i] = updated_map[sid]
|
||||||
|
existing.sort(key=lambda s: s.get('updated_at', 0), reverse=True)
|
||||||
|
_tmp = SESSION_INDEX_FILE.with_suffix('.tmp')
|
||||||
|
_tmp.write_text(json.dumps(existing, ensure_ascii=False, indent=2), encoding='utf-8')
|
||||||
|
os.replace(_tmp, SESSION_INDEX_FILE)
|
||||||
|
except Exception:
|
||||||
|
_fallback = True
|
||||||
|
if _fallback:
|
||||||
|
# Corrupt or missing index — fall back to full rebuild (called outside LOCK to avoid deadlock)
|
||||||
|
_write_session_index(updates=None)
|
||||||
|
|
||||||
|
|
||||||
class Session:
|
class Session:
|
||||||
@@ -86,7 +128,7 @@ class Session:
|
|||||||
json.dumps(self.__dict__, ensure_ascii=False, indent=2),
|
json.dumps(self.__dict__, ensure_ascii=False, indent=2),
|
||||||
encoding='utf-8',
|
encoding='utf-8',
|
||||||
)
|
)
|
||||||
_write_session_index()
|
_write_session_index(updates=[self])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, sid):
|
def load(cls, sid):
|
||||||
|
|||||||
@@ -342,6 +342,33 @@ def base_url():
|
|||||||
return TEST_BASE
|
return TEST_BASE
|
||||||
|
|
||||||
|
|
||||||
|
# ── Per-test model cache invalidation ────────────────────────────────────────
|
||||||
|
# The TTL cache for get_available_models() persists across tests within the
|
||||||
|
# same process. Tests that modify cfg in-memory won't trigger the mtime path,
|
||||||
|
# so the cache must be explicitly invalidated after each test that exercises
|
||||||
|
# provider/model detection.
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _invalidate_models_cache_after_test():
|
||||||
|
"""Force the TTL cache to be cleared before and after every test.
|
||||||
|
|
||||||
|
This prevents state bleed where a test that calls get_available_models()
|
||||||
|
populates the cache with a particular config, and the next test sees stale
|
||||||
|
results even though it has mutated _cfg_cache in-memory.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from api.config import invalidate_models_cache
|
||||||
|
invalidate_models_cache()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
yield
|
||||||
|
try:
|
||||||
|
from api.config import invalidate_models_cache
|
||||||
|
invalidate_models_cache()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# ── Per-test session cleanup ──────────────────────────────────────────────────
|
# ── Per-test session cleanup ──────────────────────────────────────────────────
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
|
|||||||
350
tests/test_session_index.py
Normal file
350
tests/test_session_index.py
Normal file
@@ -0,0 +1,350 @@
|
|||||||
|
"""
|
||||||
|
Tests for the incremental session index in api/models.py.
|
||||||
|
|
||||||
|
Validates:
|
||||||
|
- Incremental patch correctness (existing entries preserved, updated)
|
||||||
|
- New session appended to existing index
|
||||||
|
- First call (no index file) triggers full rebuild
|
||||||
|
- Corrupt index triggers fallback to full rebuild
|
||||||
|
- Concurrent saves don't lose data
|
||||||
|
- Atomic write leaves no .tmp file behind
|
||||||
|
- Deadlock guard on fallback path
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import api.models as models
|
||||||
|
from api.models import Session, _write_session_index
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _isolate_session_dir(tmp_path, monkeypatch):
|
||||||
|
"""Redirect SESSION_DIR and SESSION_INDEX_FILE to a temp directory
|
||||||
|
so tests don't touch the real session store.
|
||||||
|
"""
|
||||||
|
session_dir = tmp_path / "sessions"
|
||||||
|
session_dir.mkdir()
|
||||||
|
index_file = session_dir / "_index.json"
|
||||||
|
|
||||||
|
monkeypatch.setattr(models, "SESSION_DIR", session_dir)
|
||||||
|
monkeypatch.setattr(models, "SESSION_INDEX_FILE", index_file)
|
||||||
|
# Also patch the module-level references that Session uses
|
||||||
|
monkeypatch.setattr(models.Session, "__module__", models.__name__)
|
||||||
|
|
||||||
|
# Clear the in-memory SESSIONS cache to avoid bleed
|
||||||
|
models.SESSIONS.clear()
|
||||||
|
|
||||||
|
yield session_dir, index_file
|
||||||
|
|
||||||
|
models.SESSIONS.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_session(session_id, title="Untitled", updated_at=None):
|
||||||
|
"""Helper to create a Session with a known ID and title."""
|
||||||
|
s = Session(session_id=session_id, title=title, messages=[{"role": "user", "content": "hi"}])
|
||||||
|
if updated_at is not None:
|
||||||
|
s.updated_at = updated_at
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def _write_index_file(index_file, entries):
|
||||||
|
"""Write entries list to the index file atomically."""
|
||||||
|
tmp = index_file.with_suffix(".tmp")
|
||||||
|
tmp.write_text(json.dumps(entries, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
os.replace(str(tmp), str(index_file))
|
||||||
|
|
||||||
|
|
||||||
|
def _read_index(index_file):
|
||||||
|
"""Read and parse the session index file."""
|
||||||
|
return json.loads(index_file.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
# ── 6. test_incremental_patch_correctness ─────────────────────────────────
|
||||||
|
|
||||||
|
def test_incremental_patch_correctness():
|
||||||
|
"""Pre-write an index with 3 sessions (A, B, C). Create an updated
|
||||||
|
Session for B with a new title. Call _write_session_index(updates=[B]).
|
||||||
|
Verify A and C are unchanged, B has the new title, sort order preserved.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# We need to get the fixture values — but since it's autouse, the monkeypatch
|
||||||
|
# has already been applied. Access the patched values directly.
|
||||||
|
session_dir = models.SESSION_DIR
|
||||||
|
index_file = models.SESSION_INDEX_FILE
|
||||||
|
|
||||||
|
# Create 3 sessions with different timestamps
|
||||||
|
sA = _make_session("sess_a", "Alpha", updated_at=100.0)
|
||||||
|
sB = _make_session("sess_b", "Bravo", updated_at=200.0)
|
||||||
|
sC = _make_session("sess_c", "Charlie", updated_at=300.0)
|
||||||
|
|
||||||
|
# Write session files to disk (so full rebuild can find them)
|
||||||
|
for s in (sA, sB, sC):
|
||||||
|
s.path.write_text(json.dumps(s.__dict__, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
|
||||||
|
# Build initial index
|
||||||
|
_write_session_index(updates=None)
|
||||||
|
index = _read_index(index_file)
|
||||||
|
assert len(index) == 3
|
||||||
|
|
||||||
|
# Now update B with a new title
|
||||||
|
sB_updated = _make_session("sess_b", "Bravo Updated", updated_at=250.0)
|
||||||
|
sB_updated.path.write_text(
|
||||||
|
json.dumps(sB_updated.__dict__, ensure_ascii=False, indent=2), encoding="utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Incremental update
|
||||||
|
_write_session_index(updates=[sB_updated])
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
index = _read_index(index_file)
|
||||||
|
index_map = {e["session_id"]: e for e in index}
|
||||||
|
|
||||||
|
assert index_map["sess_a"]["title"] == "Alpha", "A should be unchanged"
|
||||||
|
assert index_map["sess_c"]["title"] == "Charlie", "C should be unchanged"
|
||||||
|
assert index_map["sess_b"]["title"] == "Bravo Updated", "B should have new title"
|
||||||
|
|
||||||
|
# Sort order: Charlie (300) > Bravo Updated (250) > Alpha (100)
|
||||||
|
assert index[0]["session_id"] == "sess_c"
|
||||||
|
assert index[1]["session_id"] == "sess_b"
|
||||||
|
assert index[2]["session_id"] == "sess_a"
|
||||||
|
|
||||||
|
|
||||||
|
# ── 7. test_new_session_appended_to_index ─────────────────────────────────
|
||||||
|
|
||||||
|
def test_new_session_appended_to_index():
|
||||||
|
"""Pre-write index with sessions A, B. Call _write_session_index(updates=[C])
|
||||||
|
where C is not in the existing index. Verify C appears in the index.
|
||||||
|
"""
|
||||||
|
session_dir = models.SESSION_DIR
|
||||||
|
index_file = models.SESSION_INDEX_FILE
|
||||||
|
|
||||||
|
sA = _make_session("sess_a", "Alpha", updated_at=100.0)
|
||||||
|
sB = _make_session("sess_b", "Bravo", updated_at=200.0)
|
||||||
|
|
||||||
|
for s in (sA, sB):
|
||||||
|
s.path.write_text(json.dumps(s.__dict__, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
|
||||||
|
_write_session_index(updates=None)
|
||||||
|
|
||||||
|
# Create a new session C not in the index
|
||||||
|
sC = _make_session("sess_c", "Charlie", updated_at=300.0)
|
||||||
|
sC.path.write_text(json.dumps(sC.__dict__, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
|
||||||
|
_write_session_index(updates=[sC])
|
||||||
|
|
||||||
|
index = _read_index(index_file)
|
||||||
|
ids = {e["session_id"] for e in index}
|
||||||
|
assert "sess_c" in ids, "New session C should appear in the index"
|
||||||
|
assert "sess_a" in ids
|
||||||
|
assert "sess_b" in ids
|
||||||
|
|
||||||
|
|
||||||
|
# ── 8. test_first_call_full_rebuild ──────────────────────────────────────
|
||||||
|
|
||||||
|
def test_first_call_full_rebuild():
|
||||||
|
"""When no index file exists, calling _write_session_index(updates=[session])
|
||||||
|
should fall back to full rebuild and create the index.
|
||||||
|
"""
|
||||||
|
session_dir = models.SESSION_DIR
|
||||||
|
index_file = models.SESSION_INDEX_FILE
|
||||||
|
|
||||||
|
# No index file yet
|
||||||
|
assert not index_file.exists()
|
||||||
|
|
||||||
|
sA = _make_session("sess_a", "Alpha", updated_at=100.0)
|
||||||
|
sA.path.write_text(json.dumps(sA.__dict__, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
|
||||||
|
# Call with updates — should trigger full rebuild since index doesn't exist
|
||||||
|
_write_session_index(updates=[sA])
|
||||||
|
|
||||||
|
# Index should now exist
|
||||||
|
assert index_file.exists(), "Index file should be created"
|
||||||
|
|
||||||
|
index = _read_index(index_file)
|
||||||
|
ids = {e["session_id"] for e in index}
|
||||||
|
assert "sess_a" in ids, "Session A should appear in the rebuilt index"
|
||||||
|
|
||||||
|
|
||||||
|
# ── 9. test_corrupt_index_fallback ────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_corrupt_index_fallback():
|
||||||
|
"""Write garbage/invalid JSON to SESSION_INDEX_FILE. Call
|
||||||
|
_write_session_index(updates=[session]). Verify it falls back to
|
||||||
|
full rebuild and the result is valid JSON with correct entries.
|
||||||
|
"""
|
||||||
|
session_dir = models.SESSION_DIR
|
||||||
|
index_file = models.SESSION_INDEX_FILE
|
||||||
|
|
||||||
|
# Write corrupt data
|
||||||
|
index_file.write_text("THIS IS NOT JSON {{{", encoding="utf-8")
|
||||||
|
|
||||||
|
sA = _make_session("sess_a", "Alpha", updated_at=100.0)
|
||||||
|
sA.path.write_text(json.dumps(sA.__dict__, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
|
||||||
|
# Should not raise; should fall back to full rebuild
|
||||||
|
_write_session_index(updates=[sA])
|
||||||
|
|
||||||
|
# Index should now be valid JSON
|
||||||
|
assert index_file.exists()
|
||||||
|
index = _read_index(index_file)
|
||||||
|
assert isinstance(index, list), "Index should be a list"
|
||||||
|
|
||||||
|
ids = {e["session_id"] for e in index}
|
||||||
|
assert "sess_a" in ids, "Session A should appear after fallback rebuild"
|
||||||
|
|
||||||
|
|
||||||
|
# ── 10. test_concurrent_saves_dont_lose_data ────────────────────────────
|
||||||
|
|
||||||
|
def test_concurrent_saves_dont_lose_data():
|
||||||
|
"""Create 2 threads, each calling Session.save() on different sessions
|
||||||
|
with a pre-existing index. Use a threading.Event barrier to force them
|
||||||
|
to run concurrently. Assert both updates are present in the final index.
|
||||||
|
"""
|
||||||
|
session_dir = models.SESSION_DIR
|
||||||
|
index_file = models.SESSION_INDEX_FILE
|
||||||
|
|
||||||
|
sA = _make_session("sess_a", "Alpha", updated_at=100.0)
|
||||||
|
sB = _make_session("sess_b", "Bravo", updated_at=200.0)
|
||||||
|
|
||||||
|
for s in (sA, sB):
|
||||||
|
s.path.write_text(json.dumps(s.__dict__, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
|
||||||
|
# Build initial index
|
||||||
|
_write_session_index(updates=None)
|
||||||
|
|
||||||
|
# Now update both sessions concurrently
|
||||||
|
barrier = threading.Event()
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
def _update_session(session, new_title, new_updated_at):
|
||||||
|
try:
|
||||||
|
barrier.wait(timeout=5)
|
||||||
|
session.title = new_title
|
||||||
|
session.updated_at = new_updated_at
|
||||||
|
session.save()
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
sA.title = "Alpha V2"
|
||||||
|
sA.updated_at = 150.0
|
||||||
|
sB.title = "Bravo V2"
|
||||||
|
sB.updated_at = 250.0
|
||||||
|
|
||||||
|
t1 = threading.Thread(target=_update_session, args=(sA, "Alpha V2", 150.0))
|
||||||
|
t2 = threading.Thread(target=_update_session, args=(sB, "Bravo V2", 250.0))
|
||||||
|
|
||||||
|
t1.start()
|
||||||
|
t2.start()
|
||||||
|
|
||||||
|
# Release both threads simultaneously
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
t1.join(timeout=10)
|
||||||
|
t2.join(timeout=10)
|
||||||
|
|
||||||
|
assert not errors, f"Errors during concurrent saves: {errors}"
|
||||||
|
|
||||||
|
# Verify both updates are in the final index
|
||||||
|
index = _read_index(index_file)
|
||||||
|
index_map = {e["session_id"]: e for e in index}
|
||||||
|
|
||||||
|
assert "sess_a" in index_map, "Session A should be in index"
|
||||||
|
assert "sess_b" in index_map, "Session B should be in index"
|
||||||
|
assert index_map["sess_a"]["title"] == "Alpha V2", "Session A title should be updated"
|
||||||
|
assert index_map["sess_b"]["title"] == "Bravo V2", "Session B title should be updated"
|
||||||
|
|
||||||
|
|
||||||
|
# ── 11. test_atomic_write_no_tmp_remains ─────────────────────────────────
|
||||||
|
|
||||||
|
def test_atomic_write_no_tmp_remains():
|
||||||
|
"""After _write_session_index completes, no .tmp file should remain
|
||||||
|
in SESSION_DIR.
|
||||||
|
"""
|
||||||
|
session_dir = models.SESSION_DIR
|
||||||
|
index_file = models.SESSION_INDEX_FILE
|
||||||
|
|
||||||
|
sA = _make_session("sess_a", "Alpha", updated_at=100.0)
|
||||||
|
sA.path.write_text(json.dumps(sA.__dict__, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
|
||||||
|
_write_session_index(updates=[sA])
|
||||||
|
|
||||||
|
# Check for any .tmp files in SESSION_DIR
|
||||||
|
tmp_files = list(session_dir.glob("*.tmp"))
|
||||||
|
assert len(tmp_files) == 0, f"Unexpected .tmp files remain: {tmp_files}"
|
||||||
|
|
||||||
|
# Also test incremental path
|
||||||
|
sA.title = "Alpha V2"
|
||||||
|
sA.updated_at = 200.0
|
||||||
|
_write_session_index(updates=[sA])
|
||||||
|
|
||||||
|
tmp_files = list(session_dir.glob("*.tmp"))
|
||||||
|
assert len(tmp_files) == 0, f"Unexpected .tmp files after incremental write: {tmp_files}"
|
||||||
|
|
||||||
|
|
||||||
|
# ── 12. test_deadlock_guard_on_fallback ──────────────────────────────────
|
||||||
|
|
||||||
|
def test_deadlock_guard_on_fallback():
|
||||||
|
"""Mock the index file read to raise an exception, then verify
|
||||||
|
_write_session_index(updates=[session]) completes without hanging.
|
||||||
|
|
||||||
|
This tests that the fallback path (corrupt index -> full rebuild)
|
||||||
|
is called outside the LOCK, so it doesn't deadlock.
|
||||||
|
"""
|
||||||
|
session_dir = models.SESSION_DIR
|
||||||
|
index_file = models.SESSION_INDEX_FILE
|
||||||
|
|
||||||
|
# Create a valid index file so the incremental path is attempted
|
||||||
|
_write_index_file(index_file, [
|
||||||
|
{"session_id": "sess_a", "title": "Alpha", "updated_at": 100.0,
|
||||||
|
"workspace": "/tmp", "model": "test", "message_count": 0,
|
||||||
|
"created_at": 100.0, "pinned": False, "archived": False},
|
||||||
|
])
|
||||||
|
|
||||||
|
sB = _make_session("sess_b", "Bravo", updated_at=200.0)
|
||||||
|
sB.path.write_text(json.dumps(sB.__dict__, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
|
||||||
|
# Make the index file read raise an exception to trigger fallback
|
||||||
|
original_read_text = Path.read_text
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def _broken_read_text(self, *args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
# Only break the index file read, not the session file reads
|
||||||
|
if str(self) == str(index_file) and call_count == 0:
|
||||||
|
call_count += 1
|
||||||
|
raise OSError("Simulated corrupt index read")
|
||||||
|
return original_read_text(self, *args, **kwargs)
|
||||||
|
|
||||||
|
with patch.object(Path, "read_text", _broken_read_text):
|
||||||
|
# This should complete without hanging (deadlock guard)
|
||||||
|
# Use a timeout to detect deadlock
|
||||||
|
done = threading.Event()
|
||||||
|
result = [None]
|
||||||
|
exc = [None]
|
||||||
|
|
||||||
|
def _run():
|
||||||
|
try:
|
||||||
|
_write_session_index(updates=[sB])
|
||||||
|
result[0] = "done"
|
||||||
|
except Exception as e:
|
||||||
|
exc[0] = e
|
||||||
|
finally:
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
t = threading.Thread(target=_run)
|
||||||
|
t.start()
|
||||||
|
finished = done.wait(timeout=10)
|
||||||
|
|
||||||
|
assert finished, "_write_session_index hung — likely deadlock in fallback path"
|
||||||
|
assert exc[0] is None, f"Unexpected exception: {exc[0]}"
|
||||||
|
|
||||||
|
# The index should still be valid after fallback
|
||||||
|
index = _read_index(index_file)
|
||||||
|
assert isinstance(index, list)
|
||||||
226
tests/test_ttl_cache.py
Normal file
226
tests/test_ttl_cache.py
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
"""
|
||||||
|
Tests for the TTL cache in api/config.py — get_available_models().
|
||||||
|
|
||||||
|
Validates:
|
||||||
|
- Cache hit within TTL window
|
||||||
|
- TTL expiry triggers re-scan
|
||||||
|
- Config mtime change invalidates cache before TTL check
|
||||||
|
- copy.deepcopy() isolation (mutating returned dict doesn't pollute cache)
|
||||||
|
- invalidate_models_cache() direct invalidation
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import api.config as config
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_cache():
|
||||||
|
"""Reset TTL cache globals to a clean state."""
|
||||||
|
config._available_models_cache = None
|
||||||
|
config._available_models_cache_ts = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── 1. test_cache_hit_within_ttl ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_cache_hit_within_ttl():
|
||||||
|
"""Call get_available_models() twice within the TTL window.
|
||||||
|
The second call should return cached data without re-scanning providers.
|
||||||
|
We verify this by patching reload_config (called when cache is cold)
|
||||||
|
and asserting it is only invoked once.
|
||||||
|
"""
|
||||||
|
_reset_cache()
|
||||||
|
original_reload = config.reload_config
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def _counting_reload():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return original_reload()
|
||||||
|
|
||||||
|
with patch.object(config, "reload_config", wraps=original_reload, side_effect=_counting_reload):
|
||||||
|
saved_mtime = config._cfg_mtime
|
||||||
|
try:
|
||||||
|
# Force mtime mismatch so the first call triggers reload_config + cache fill
|
||||||
|
config._cfg_mtime = 0.0
|
||||||
|
result1 = config.get_available_models()
|
||||||
|
first_call_count = call_count
|
||||||
|
|
||||||
|
# Sync _cfg_mtime to the actual file so the second call doesn't
|
||||||
|
# re-trigger reload_config via mtime mismatch — we want it to hit the TTL cache.
|
||||||
|
try:
|
||||||
|
config._cfg_mtime = config.Path(config._get_config_path()).stat().st_mtime
|
||||||
|
except OSError:
|
||||||
|
config._cfg_mtime = 0.0
|
||||||
|
|
||||||
|
result2 = config.get_available_models()
|
||||||
|
|
||||||
|
# Both results should have the same structure
|
||||||
|
assert "groups" in result1
|
||||||
|
assert "groups" in result2
|
||||||
|
|
||||||
|
# reload_config should not have been called again for the second invocation
|
||||||
|
# (the TTL cache served it)
|
||||||
|
assert call_count == first_call_count, (
|
||||||
|
f"Expected no extra reload_config calls, but got "
|
||||||
|
f"{call_count - first_call_count} extra"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
config._cfg_mtime = saved_mtime
|
||||||
|
_reset_cache()
|
||||||
|
|
||||||
|
|
||||||
|
# ── 2. test_ttl_expiry ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_ttl_expiry():
|
||||||
|
"""Populate the cache, then advance time.monotonic() past 60s.
|
||||||
|
The next call should re-scan (not serve from cache).
|
||||||
|
"""
|
||||||
|
_reset_cache()
|
||||||
|
|
||||||
|
# Ensure _cfg_mtime matches file so mtime check doesn't invalidate
|
||||||
|
try:
|
||||||
|
config._cfg_mtime = config.Path(config._get_config_path()).stat().st_mtime
|
||||||
|
except OSError:
|
||||||
|
config._cfg_mtime = 0.0
|
||||||
|
|
||||||
|
# First call populates cache
|
||||||
|
result1 = config.get_available_models()
|
||||||
|
assert config._available_models_cache is not None, "Cache should be populated"
|
||||||
|
|
||||||
|
# Record the cache timestamp
|
||||||
|
cache_ts = config._available_models_cache_ts
|
||||||
|
|
||||||
|
# Advance time.monotonic() by more than the TTL
|
||||||
|
original_monotonic = time.monotonic
|
||||||
|
offset = config._AVAILABLE_MODELS_CACHE_TTL + 10.0 # 70s past the real monotonic
|
||||||
|
|
||||||
|
with patch.object(time, "monotonic", side_effect=lambda: original_monotonic() + offset):
|
||||||
|
result2 = config.get_available_models()
|
||||||
|
|
||||||
|
# The cache should have been refreshed — the timestamp must be newer
|
||||||
|
assert config._available_models_cache_ts > cache_ts, (
|
||||||
|
"Cache should have been refreshed after TTL expiry"
|
||||||
|
)
|
||||||
|
|
||||||
|
_reset_cache()
|
||||||
|
|
||||||
|
|
||||||
|
# ── 3. test_mtime_invalidation ───────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_mtime_invalidation():
|
||||||
|
"""Populate the cache, then change _cfg_mtime to simulate a config file
|
||||||
|
change on disk. The next call should invalidate the cache and re-scan.
|
||||||
|
"""
|
||||||
|
_reset_cache()
|
||||||
|
|
||||||
|
# Ensure _cfg_mtime matches file so first call doesn't re-scan due to mtime
|
||||||
|
try:
|
||||||
|
real_mtime = config.Path(config._get_config_path()).stat().st_mtime
|
||||||
|
except OSError:
|
||||||
|
real_mtime = 0.0
|
||||||
|
config._cfg_mtime = real_mtime
|
||||||
|
|
||||||
|
# First call populates cache
|
||||||
|
result1 = config.get_available_models()
|
||||||
|
assert config._available_models_cache is not None
|
||||||
|
|
||||||
|
# Simulate config.yaml changed on disk by setting _cfg_mtime to 0
|
||||||
|
# (which won't match the actual file mtime)
|
||||||
|
config._cfg_mtime = 0.0
|
||||||
|
|
||||||
|
# The next call should detect mtime mismatch, reload, and invalidate cache
|
||||||
|
old_cache = config._available_models_cache
|
||||||
|
old_ts = config._available_models_cache_ts
|
||||||
|
|
||||||
|
result2 = config.get_available_models()
|
||||||
|
|
||||||
|
# Cache must have been refreshed — timestamp advanced since we reset it
|
||||||
|
# to 0.0 on invalidation.
|
||||||
|
assert config._available_models_cache_ts > 0.0, (
|
||||||
|
"Cache timestamp should be updated after invalidation + rebuild"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Restore
|
||||||
|
config._cfg_mtime = real_mtime
|
||||||
|
_reset_cache()
|
||||||
|
|
||||||
|
|
||||||
|
# ── 4. test_deepcopy_isolation ────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_deepcopy_isolation():
|
||||||
|
"""Mutating the returned dict from get_available_models() must not
|
||||||
|
affect the cache or subsequent return values.
|
||||||
|
"""
|
||||||
|
_reset_cache()
|
||||||
|
|
||||||
|
# Ensure _cfg_mtime matches file so mtime check doesn't invalidate
|
||||||
|
try:
|
||||||
|
config._cfg_mtime = config.Path(config._get_config_path()).stat().st_mtime
|
||||||
|
except OSError:
|
||||||
|
config._cfg_mtime = 0.0
|
||||||
|
|
||||||
|
# First call populates cache
|
||||||
|
result1 = config.get_available_models()
|
||||||
|
|
||||||
|
# Mutate the returned dict
|
||||||
|
if result1["groups"]:
|
||||||
|
result1["groups"][0]["models"].clear()
|
||||||
|
result1["groups"].append({"provider": "FAKE", "models": [{"id": "fake-model"}]})
|
||||||
|
result1["active_provider"] = "HACKED"
|
||||||
|
|
||||||
|
# Second call should return an unmutated copy
|
||||||
|
result2 = config.get_available_models()
|
||||||
|
|
||||||
|
# The mutated keys must not appear in the second result
|
||||||
|
assert result2["active_provider"] != "HACKED", "Mutation leaked into cache"
|
||||||
|
assert not any(
|
||||||
|
g.get("provider") == "FAKE" for g in result2["groups"]
|
||||||
|
), "Fake provider leaked into cache"
|
||||||
|
|
||||||
|
# If there were groups originally, the first group's models should not be empty
|
||||||
|
# (unless it genuinely had no models, which is unlikely)
|
||||||
|
if result1["groups"] and result2["groups"]:
|
||||||
|
# result1["groups"][0]["models"] was cleared, but result2 should be intact
|
||||||
|
assert len(result2["groups"][0].get("models", [])) > 0, (
|
||||||
|
"Mutation of result1 cleared models in result2 — deepcopy failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
_reset_cache()
|
||||||
|
|
||||||
|
|
||||||
|
# ── 5. test_invalidate_models_cache_direct ───────────────────────────────
|
||||||
|
|
||||||
|
def test_invalidate_models_cache_direct():
|
||||||
|
"""Call invalidate_models_cache() after populating the cache.
|
||||||
|
_AVAILABLE_MODELS_CACHE should be None and the next call should re-scan.
|
||||||
|
"""
|
||||||
|
_reset_cache()
|
||||||
|
|
||||||
|
# Ensure _cfg_mtime matches file so mtime check doesn't invalidate
|
||||||
|
try:
|
||||||
|
config._cfg_mtime = config.Path(config._get_config_path()).stat().st_mtime
|
||||||
|
except OSError:
|
||||||
|
config._cfg_mtime = 0.0
|
||||||
|
|
||||||
|
# First call populates cache
|
||||||
|
result1 = config.get_available_models()
|
||||||
|
assert config._available_models_cache is not None, "Cache should be populated"
|
||||||
|
first_ts = config._available_models_cache_ts
|
||||||
|
|
||||||
|
# Directly invalidate
|
||||||
|
config.invalidate_models_cache()
|
||||||
|
|
||||||
|
# Cache must be cleared
|
||||||
|
assert config._available_models_cache is None, (
|
||||||
|
"invalidate_models_cache() should set _AVAILABLE_MODELS_CACHE to None"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Next call should re-scan and produce a fresh cache
|
||||||
|
result2 = config.get_available_models()
|
||||||
|
assert config._available_models_cache is not None, "Cache should be re-populated"
|
||||||
|
assert config._available_models_cache_ts >= first_ts, (
|
||||||
|
"Cache timestamp should be updated after re-scan"
|
||||||
|
)
|
||||||
|
|
||||||
|
_reset_cache()
|
||||||
Reference in New Issue
Block a user