cancel_stream() now pops STREAMS/CANCEL_FLAGS/AGENT_INSTANCES and clears session.active_stream_id immediately after signalling cancel. Fixes sessions permanently stuck at 409 when the agent thread is blocked in a bad tool call. Session cleanup runs outside STREAMS_LOCK to preserve lock ordering. Fixes #653 Co-authored-by: bergeouss <bergeouss@users.noreply.github.com>
123 lines
4.2 KiB
Python
123 lines
4.2 KiB
Python
"""
|
|
Unit tests for cancel/interrupt functionality.
|
|
Tests the integration between cancel_stream() and agent.interrupt().
|
|
"""
|
|
import pytest
|
|
import queue
|
|
import threading
|
|
from unittest.mock import Mock
|
|
|
|
from api.streaming import cancel_stream
|
|
from api.config import AGENT_INSTANCES, STREAMS, CANCEL_FLAGS
|
|
|
|
|
|
class TestCancelInterrupt:
|
|
"""Test suite for cancel/interrupt functionality"""
|
|
|
|
def setup_method(self):
|
|
"""Clean up before each test"""
|
|
AGENT_INSTANCES.clear()
|
|
STREAMS.clear()
|
|
CANCEL_FLAGS.clear()
|
|
|
|
def teardown_method(self):
|
|
"""Clean up after each test"""
|
|
AGENT_INSTANCES.clear()
|
|
STREAMS.clear()
|
|
CANCEL_FLAGS.clear()
|
|
|
|
def test_cancel_calls_agent_interrupt(self):
|
|
"""Verify that cancel_stream() calls agent.interrupt() when agent exists"""
|
|
# Setup
|
|
stream_id = "test_stream_123"
|
|
mock_agent = Mock()
|
|
mock_agent.interrupt = Mock()
|
|
|
|
STREAMS[stream_id] = queue.Queue()
|
|
CANCEL_FLAGS[stream_id] = threading.Event()
|
|
AGENT_INSTANCES[stream_id] = mock_agent
|
|
|
|
# Execute
|
|
result = cancel_stream(stream_id)
|
|
|
|
# Assert
|
|
assert result is True
|
|
mock_agent.interrupt.assert_called_once_with("Cancelled by user")
|
|
# CANCEL_FLAGS is eagerly popped after cancel (#776 fix) so the flag
|
|
# is no longer in the dict — verify the pop happened instead
|
|
assert stream_id not in CANCEL_FLAGS, \
|
|
"cancel_stream() should eagerly pop CANCEL_FLAGS after signalling"
|
|
|
|
def test_cancel_handles_interrupt_exception(self):
|
|
"""Verify that cancel_stream() handles interrupt() exceptions gracefully"""
|
|
stream_id = "test_stream_456"
|
|
mock_agent = Mock()
|
|
mock_agent.interrupt = Mock(side_effect=RuntimeError("Agent error"))
|
|
|
|
STREAMS[stream_id] = queue.Queue()
|
|
CANCEL_FLAGS[stream_id] = threading.Event()
|
|
AGENT_INSTANCES[stream_id] = mock_agent
|
|
|
|
# Should not raise exception
|
|
result = cancel_stream(stream_id)
|
|
|
|
# Assert
|
|
assert result is True
|
|
mock_agent.interrupt.assert_called_once()
|
|
assert stream_id not in CANCEL_FLAGS, \
|
|
"cancel_stream() should eagerly pop CANCEL_FLAGS even on interrupt exception"
|
|
|
|
def test_cancel_before_agent_ready(self):
|
|
"""Test cancel when agent not yet stored in AGENT_INSTANCES (race condition)"""
|
|
stream_id = "test_stream_789"
|
|
|
|
STREAMS[stream_id] = queue.Queue()
|
|
CANCEL_FLAGS[stream_id] = threading.Event()
|
|
# Note: AGENT_INSTANCES[stream_id] not set (simulating race condition)
|
|
|
|
# Should succeed even without agent
|
|
result = cancel_stream(stream_id)
|
|
|
|
# Assert
|
|
assert result is True
|
|
# CANCEL_FLAGS is eagerly popped; the agent thread checks the event
|
|
# object it already has a reference to — pop doesn't clear the event
|
|
assert stream_id not in CANCEL_FLAGS, \
|
|
"cancel_stream() should eagerly pop CANCEL_FLAGS even without an agent"
|
|
# Agent will check this flag (it holds a reference to the event object)
|
|
|
|
def test_cancel_nonexistent_stream(self):
|
|
"""Test cancel for a stream that doesn't exist"""
|
|
result = cancel_stream("nonexistent_stream")
|
|
assert result is False
|
|
|
|
def test_cancel_sets_cancel_event(self):
|
|
"""Verify that cancel_stream() sets the cancel_event flag"""
|
|
stream_id = "test_stream_event"
|
|
|
|
STREAMS[stream_id] = queue.Queue()
|
|
cancel_event = threading.Event()
|
|
CANCEL_FLAGS[stream_id] = cancel_event
|
|
|
|
result = cancel_stream(stream_id)
|
|
|
|
assert result is True
|
|
assert cancel_event.is_set()
|
|
|
|
def test_cancel_puts_sentinel_in_queue(self):
|
|
"""Verify that cancel_stream() puts cancel sentinel in queue"""
|
|
stream_id = "test_stream_queue"
|
|
q = queue.Queue()
|
|
|
|
STREAMS[stream_id] = q
|
|
CANCEL_FLAGS[stream_id] = threading.Event()
|
|
|
|
result = cancel_stream(stream_id)
|
|
|
|
assert result is True
|
|
# Check that cancel message was queued
|
|
assert not q.empty()
|
|
event_type, data = q.get_nowait()
|
|
assert event_type == 'cancel'
|
|
assert data['message'] == 'Cancelled by user'
|