fix(streaming): eagerly release session lock in cancel_stream() (#778)
cancel_stream() now pops STREAMS/CANCEL_FLAGS/AGENT_INSTANCES and clears session.active_stream_id immediately after signalling cancel. Fixes sessions permanently stuck at 409 when the agent thread is blocked in a bad tool call. Session cleanup runs outside STREAMS_LOCK to preserve lock ordering. Fixes #653 Co-authored-by: bergeouss <bergeouss@users.noreply.github.com>
This commit is contained in:
@@ -43,7 +43,10 @@ class TestCancelInterrupt:
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_agent.interrupt.assert_called_once_with("Cancelled by user")
|
||||
assert CANCEL_FLAGS[stream_id].is_set()
|
||||
# CANCEL_FLAGS is eagerly popped after cancel (#776 fix) so the flag
|
||||
# is no longer in the dict — verify the pop happened instead
|
||||
assert stream_id not in CANCEL_FLAGS, \
|
||||
"cancel_stream() should eagerly pop CANCEL_FLAGS after signalling"
|
||||
|
||||
def test_cancel_handles_interrupt_exception(self):
|
||||
"""Verify that cancel_stream() handles interrupt() exceptions gracefully"""
|
||||
@@ -61,7 +64,8 @@ class TestCancelInterrupt:
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_agent.interrupt.assert_called_once()
|
||||
assert CANCEL_FLAGS[stream_id].is_set()
|
||||
assert stream_id not in CANCEL_FLAGS, \
|
||||
"cancel_stream() should eagerly pop CANCEL_FLAGS even on interrupt exception"
|
||||
|
||||
def test_cancel_before_agent_ready(self):
|
||||
"""Test cancel when agent not yet stored in AGENT_INSTANCES (race condition)"""
|
||||
@@ -76,8 +80,11 @@ class TestCancelInterrupt:
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert CANCEL_FLAGS[stream_id].is_set()
|
||||
# Agent will check this flag when it starts
|
||||
# CANCEL_FLAGS is eagerly popped; the agent thread checks the event
|
||||
# object it already has a reference to — pop doesn't clear the event
|
||||
assert stream_id not in CANCEL_FLAGS, \
|
||||
"cancel_stream() should eagerly pop CANCEL_FLAGS even without an agent"
|
||||
# Agent will check this flag (it holds a reference to the event object)
|
||||
|
||||
def test_cancel_nonexistent_stream(self):
|
||||
"""Test cancel for a stream that doesn't exist"""
|
||||
|
||||
171
tests/test_sprint51.py
Normal file
171
tests/test_sprint51.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
Test plan for the #653 fix (eager session lock release in cancel_stream).
|
||||
|
||||
These tests verify that after cancel_stream() is called:
|
||||
1. STREAMS is popped (so the 409 guard passes)
|
||||
2. CANCEL_FLAGS is popped
|
||||
3. AGENT_INSTANCES is popped
|
||||
4. Session active_stream_id is cleared (when agent is available)
|
||||
5. Session pending fields are cleared (when agent is available)
|
||||
|
||||
All tests are isolated and clean up after themselves.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import queue
|
||||
import threading
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from api.streaming import cancel_stream
|
||||
from api.config import AGENT_INSTANCES, STREAMS, STREAMS_LOCK, CANCEL_FLAGS
|
||||
|
||||
|
||||
class TestCancelStreamEagerRelease:
|
||||
"""Test suite for #653: eager session lock release on cancel."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clean up before each test."""
|
||||
AGENT_INSTANCES.clear()
|
||||
STREAMS.clear()
|
||||
CANCEL_FLAGS.clear()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up after each test."""
|
||||
AGENT_INSTANCES.clear()
|
||||
STREAMS.clear()
|
||||
CANCEL_FLAGS.clear()
|
||||
|
||||
def test_cancel_pops_stream_from_streams_dict(self):
|
||||
"""After cancel, stream_id should no longer be in STREAMS."""
|
||||
stream_id = "test_eager_pop"
|
||||
q = queue.Queue()
|
||||
STREAMS[stream_id] = q
|
||||
CANCEL_FLAGS[stream_id] = threading.Event()
|
||||
|
||||
result = cancel_stream(stream_id)
|
||||
|
||||
assert result is True
|
||||
assert stream_id not in STREAMS, \
|
||||
"cancel_stream() should eagerly pop from STREAMS to release the session lock"
|
||||
|
||||
def test_cancel_pops_cancel_flags(self):
|
||||
"""After cancel, stream_id should no longer be in CANCEL_FLAGS."""
|
||||
stream_id = "test_eager_flags"
|
||||
STREAMS[stream_id] = queue.Queue()
|
||||
CANCEL_FLAGS[stream_id] = threading.Event()
|
||||
|
||||
cancel_stream(stream_id)
|
||||
|
||||
assert stream_id not in CANCEL_FLAGS, \
|
||||
"cancel_stream() should eagerly pop from CANCEL_FLAGS"
|
||||
|
||||
def test_cancel_pops_agent_instances(self):
|
||||
"""After cancel, stream_id should no longer be in AGENT_INSTANCES."""
|
||||
stream_id = "test_eager_agent"
|
||||
mock_agent = Mock()
|
||||
mock_agent.interrupt = Mock()
|
||||
STREAMS[stream_id] = queue.Queue()
|
||||
CANCEL_FLAGS[stream_id] = threading.Event()
|
||||
AGENT_INSTANCES[stream_id] = mock_agent
|
||||
|
||||
cancel_stream(stream_id)
|
||||
|
||||
assert stream_id not in AGENT_INSTANCES, \
|
||||
"cancel_stream() should eagerly pop from AGENT_INSTANCES"
|
||||
|
||||
def test_cancel_clears_session_active_stream_id(self):
|
||||
"""After cancel, session.active_stream_id should be None."""
|
||||
stream_id = "test_session_clear"
|
||||
session_id = "sess_abc123"
|
||||
mock_agent = Mock()
|
||||
mock_agent.interrupt = Mock()
|
||||
mock_agent.session_id = session_id
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.active_stream_id = stream_id
|
||||
mock_session.pending_user_message = "hello"
|
||||
mock_session.pending_attachments = ["file.txt"]
|
||||
mock_session.pending_started_at = 1234567890.0
|
||||
|
||||
STREAMS[stream_id] = queue.Queue()
|
||||
CANCEL_FLAGS[stream_id] = threading.Event()
|
||||
AGENT_INSTANCES[stream_id] = mock_agent
|
||||
|
||||
with patch('api.streaming.get_session', return_value=mock_session):
|
||||
cancel_stream(stream_id)
|
||||
|
||||
assert mock_session.active_stream_id is None, \
|
||||
"cancel_stream() should clear session.active_stream_id"
|
||||
assert mock_session.pending_user_message is None, \
|
||||
"cancel_stream() should clear session.pending_user_message"
|
||||
assert mock_session.pending_attachments == [], \
|
||||
"cancel_stream() should clear session.pending_attachments"
|
||||
assert mock_session.pending_started_at is None, \
|
||||
"cancel_stream() should clear session.pending_started_at"
|
||||
mock_session.save.assert_called_once()
|
||||
|
||||
def test_cancel_without_agent_still_pops_streams(self):
|
||||
"""Cancel should pop STREAMS even when no agent instance exists."""
|
||||
stream_id = "test_no_agent"
|
||||
STREAMS[stream_id] = queue.Queue()
|
||||
CANCEL_FLAGS[stream_id] = threading.Event()
|
||||
# No AGENT_INSTANCES entry
|
||||
|
||||
cancel_stream(stream_id)
|
||||
|
||||
assert stream_id not in STREAMS, \
|
||||
"cancel_stream() should pop STREAMS even without agent instance"
|
||||
assert stream_id not in CANCEL_FLAGS
|
||||
|
||||
def test_cancel_sentinel_still_queued(self):
|
||||
"""Cancel sentinel should still be queued before popping STREAMS."""
|
||||
stream_id = "test_sentinel"
|
||||
q = queue.Queue()
|
||||
STREAMS[stream_id] = q
|
||||
CANCEL_FLAGS[stream_id] = threading.Event()
|
||||
|
||||
cancel_stream(stream_id)
|
||||
|
||||
# The cancel sentinel should have been queued before the pop
|
||||
assert not q.empty()
|
||||
event_type, data = q.get_nowait()
|
||||
assert event_type == 'cancel'
|
||||
assert data['message'] == 'Cancelled by user'
|
||||
|
||||
def test_double_cancel_is_safe(self):
|
||||
"""Calling cancel_stream() twice should not raise."""
|
||||
stream_id = "test_double"
|
||||
mock_agent = Mock()
|
||||
mock_agent.interrupt = Mock()
|
||||
mock_agent.session_id = "sess_xyz"
|
||||
|
||||
STREAMS[stream_id] = queue.Queue()
|
||||
CANCEL_FLAGS[stream_id] = threading.Event()
|
||||
AGENT_INSTANCES[stream_id] = mock_agent
|
||||
|
||||
# First cancel
|
||||
result1 = cancel_stream(stream_id)
|
||||
assert result1 is True
|
||||
assert stream_id not in STREAMS
|
||||
|
||||
# Second cancel (stream already popped)
|
||||
result2 = cancel_stream(stream_id)
|
||||
assert result2 is False
|
||||
|
||||
def test_cancel_handle_get_session_failure(self):
|
||||
"""Cancel should not raise even if get_session fails."""
|
||||
stream_id = "test_session_fail"
|
||||
mock_agent = Mock()
|
||||
mock_agent.interrupt = Mock()
|
||||
mock_agent.session_id = "sess_nonexistent"
|
||||
|
||||
STREAMS[stream_id] = queue.Queue()
|
||||
CANCEL_FLAGS[stream_id] = threading.Event()
|
||||
AGENT_INSTANCES[stream_id] = mock_agent
|
||||
|
||||
with patch('api.streaming.get_session', side_effect=KeyError("Session not found")):
|
||||
# Should not raise
|
||||
result = cancel_stream(stream_id)
|
||||
|
||||
assert result is True
|
||||
assert stream_id not in STREAMS
|
||||
Reference in New Issue
Block a user