cancel_stream() now pops STREAMS/CANCEL_FLAGS/AGENT_INSTANCES and clears session.active_stream_id immediately after signalling cancel. Fixes sessions permanently stuck at 409 when the agent thread is blocked in a bad tool call. Session cleanup runs outside STREAMS_LOCK to preserve lock ordering. Fixes #653 Co-authored-by: bergeouss <bergeouss@users.noreply.github.com>
171 lines
6.1 KiB
Python
171 lines
6.1 KiB
Python
"""
|
|
Test plan for the #653 fix (eager session lock release in cancel_stream).
|
|
|
|
These tests verify that after cancel_stream() is called:
|
|
1. STREAMS is popped (so the 409 guard passes)
|
|
2. CANCEL_FLAGS is popped
|
|
3. AGENT_INSTANCES is popped
|
|
4. Session active_stream_id is cleared (when agent is available)
|
|
5. Session pending fields are cleared (when agent is available)
|
|
|
|
All tests are isolated and clean up after themselves.
|
|
"""
|
|
|
|
import pytest
|
|
import queue
|
|
import threading
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
|
|
from api.streaming import cancel_stream
|
|
from api.config import AGENT_INSTANCES, STREAMS, STREAMS_LOCK, CANCEL_FLAGS
|
|
|
|
|
|
class TestCancelStreamEagerRelease:
|
|
"""Test suite for #653: eager session lock release on cancel."""
|
|
|
|
def setup_method(self):
|
|
"""Clean up before each test."""
|
|
AGENT_INSTANCES.clear()
|
|
STREAMS.clear()
|
|
CANCEL_FLAGS.clear()
|
|
|
|
def teardown_method(self):
|
|
"""Clean up after each test."""
|
|
AGENT_INSTANCES.clear()
|
|
STREAMS.clear()
|
|
CANCEL_FLAGS.clear()
|
|
|
|
def test_cancel_pops_stream_from_streams_dict(self):
|
|
"""After cancel, stream_id should no longer be in STREAMS."""
|
|
stream_id = "test_eager_pop"
|
|
q = queue.Queue()
|
|
STREAMS[stream_id] = q
|
|
CANCEL_FLAGS[stream_id] = threading.Event()
|
|
|
|
result = cancel_stream(stream_id)
|
|
|
|
assert result is True
|
|
assert stream_id not in STREAMS, \
|
|
"cancel_stream() should eagerly pop from STREAMS to release the session lock"
|
|
|
|
def test_cancel_pops_cancel_flags(self):
|
|
"""After cancel, stream_id should no longer be in CANCEL_FLAGS."""
|
|
stream_id = "test_eager_flags"
|
|
STREAMS[stream_id] = queue.Queue()
|
|
CANCEL_FLAGS[stream_id] = threading.Event()
|
|
|
|
cancel_stream(stream_id)
|
|
|
|
assert stream_id not in CANCEL_FLAGS, \
|
|
"cancel_stream() should eagerly pop from CANCEL_FLAGS"
|
|
|
|
def test_cancel_pops_agent_instances(self):
|
|
"""After cancel, stream_id should no longer be in AGENT_INSTANCES."""
|
|
stream_id = "test_eager_agent"
|
|
mock_agent = Mock()
|
|
mock_agent.interrupt = Mock()
|
|
STREAMS[stream_id] = queue.Queue()
|
|
CANCEL_FLAGS[stream_id] = threading.Event()
|
|
AGENT_INSTANCES[stream_id] = mock_agent
|
|
|
|
cancel_stream(stream_id)
|
|
|
|
assert stream_id not in AGENT_INSTANCES, \
|
|
"cancel_stream() should eagerly pop from AGENT_INSTANCES"
|
|
|
|
def test_cancel_clears_session_active_stream_id(self):
|
|
"""After cancel, session.active_stream_id should be None."""
|
|
stream_id = "test_session_clear"
|
|
session_id = "sess_abc123"
|
|
mock_agent = Mock()
|
|
mock_agent.interrupt = Mock()
|
|
mock_agent.session_id = session_id
|
|
|
|
mock_session = Mock()
|
|
mock_session.active_stream_id = stream_id
|
|
mock_session.pending_user_message = "hello"
|
|
mock_session.pending_attachments = ["file.txt"]
|
|
mock_session.pending_started_at = 1234567890.0
|
|
|
|
STREAMS[stream_id] = queue.Queue()
|
|
CANCEL_FLAGS[stream_id] = threading.Event()
|
|
AGENT_INSTANCES[stream_id] = mock_agent
|
|
|
|
with patch('api.streaming.get_session', return_value=mock_session):
|
|
cancel_stream(stream_id)
|
|
|
|
assert mock_session.active_stream_id is None, \
|
|
"cancel_stream() should clear session.active_stream_id"
|
|
assert mock_session.pending_user_message is None, \
|
|
"cancel_stream() should clear session.pending_user_message"
|
|
assert mock_session.pending_attachments == [], \
|
|
"cancel_stream() should clear session.pending_attachments"
|
|
assert mock_session.pending_started_at is None, \
|
|
"cancel_stream() should clear session.pending_started_at"
|
|
mock_session.save.assert_called_once()
|
|
|
|
def test_cancel_without_agent_still_pops_streams(self):
|
|
"""Cancel should pop STREAMS even when no agent instance exists."""
|
|
stream_id = "test_no_agent"
|
|
STREAMS[stream_id] = queue.Queue()
|
|
CANCEL_FLAGS[stream_id] = threading.Event()
|
|
# No AGENT_INSTANCES entry
|
|
|
|
cancel_stream(stream_id)
|
|
|
|
assert stream_id not in STREAMS, \
|
|
"cancel_stream() should pop STREAMS even without agent instance"
|
|
assert stream_id not in CANCEL_FLAGS
|
|
|
|
def test_cancel_sentinel_still_queued(self):
|
|
"""Cancel sentinel should still be queued before popping STREAMS."""
|
|
stream_id = "test_sentinel"
|
|
q = queue.Queue()
|
|
STREAMS[stream_id] = q
|
|
CANCEL_FLAGS[stream_id] = threading.Event()
|
|
|
|
cancel_stream(stream_id)
|
|
|
|
# The cancel sentinel should have been queued before the pop
|
|
assert not q.empty()
|
|
event_type, data = q.get_nowait()
|
|
assert event_type == 'cancel'
|
|
assert data['message'] == 'Cancelled by user'
|
|
|
|
def test_double_cancel_is_safe(self):
|
|
"""Calling cancel_stream() twice should not raise."""
|
|
stream_id = "test_double"
|
|
mock_agent = Mock()
|
|
mock_agent.interrupt = Mock()
|
|
mock_agent.session_id = "sess_xyz"
|
|
|
|
STREAMS[stream_id] = queue.Queue()
|
|
CANCEL_FLAGS[stream_id] = threading.Event()
|
|
AGENT_INSTANCES[stream_id] = mock_agent
|
|
|
|
# First cancel
|
|
result1 = cancel_stream(stream_id)
|
|
assert result1 is True
|
|
assert stream_id not in STREAMS
|
|
|
|
# Second cancel (stream already popped)
|
|
result2 = cancel_stream(stream_id)
|
|
assert result2 is False
|
|
|
|
def test_cancel_handle_get_session_failure(self):
|
|
"""Cancel should not raise even if get_session fails."""
|
|
stream_id = "test_session_fail"
|
|
mock_agent = Mock()
|
|
mock_agent.interrupt = Mock()
|
|
mock_agent.session_id = "sess_nonexistent"
|
|
|
|
STREAMS[stream_id] = queue.Queue()
|
|
CANCEL_FLAGS[stream_id] = threading.Event()
|
|
AGENT_INSTANCES[stream_id] = mock_agent
|
|
|
|
with patch('api.streaming.get_session', side_effect=KeyError("Session not found")):
|
|
# Should not raise
|
|
result = cancel_stream(stream_id)
|
|
|
|
assert result is True
|
|
assert stream_id not in STREAMS |