""" Test plan for the #653 fix (eager session lock release in cancel_stream). These tests verify that after cancel_stream() is called: 1. STREAMS is popped (so the 409 guard passes) 2. CANCEL_FLAGS is popped 3. AGENT_INSTANCES is popped 4. Session active_stream_id is cleared (when agent is available) 5. Session pending fields are cleared (when agent is available) All tests are isolated and clean up after themselves. """ import pytest import queue import threading from unittest.mock import Mock, patch, MagicMock from api.streaming import cancel_stream from api.config import AGENT_INSTANCES, STREAMS, STREAMS_LOCK, CANCEL_FLAGS class TestCancelStreamEagerRelease: """Test suite for #653: eager session lock release on cancel.""" def setup_method(self): """Clean up before each test.""" AGENT_INSTANCES.clear() STREAMS.clear() CANCEL_FLAGS.clear() def teardown_method(self): """Clean up after each test.""" AGENT_INSTANCES.clear() STREAMS.clear() CANCEL_FLAGS.clear() def test_cancel_pops_stream_from_streams_dict(self): """After cancel, stream_id should no longer be in STREAMS.""" stream_id = "test_eager_pop" q = queue.Queue() STREAMS[stream_id] = q CANCEL_FLAGS[stream_id] = threading.Event() result = cancel_stream(stream_id) assert result is True assert stream_id not in STREAMS, \ "cancel_stream() should eagerly pop from STREAMS to release the session lock" def test_cancel_pops_cancel_flags(self): """After cancel, stream_id should no longer be in CANCEL_FLAGS.""" stream_id = "test_eager_flags" STREAMS[stream_id] = queue.Queue() CANCEL_FLAGS[stream_id] = threading.Event() cancel_stream(stream_id) assert stream_id not in CANCEL_FLAGS, \ "cancel_stream() should eagerly pop from CANCEL_FLAGS" def test_cancel_pops_agent_instances(self): """After cancel, stream_id should no longer be in AGENT_INSTANCES.""" stream_id = "test_eager_agent" mock_agent = Mock() mock_agent.interrupt = Mock() STREAMS[stream_id] = queue.Queue() CANCEL_FLAGS[stream_id] = threading.Event() AGENT_INSTANCES[stream_id] = mock_agent cancel_stream(stream_id) assert stream_id not in AGENT_INSTANCES, \ "cancel_stream() should eagerly pop from AGENT_INSTANCES" def test_cancel_clears_session_active_stream_id(self): """After cancel, session.active_stream_id should be None.""" stream_id = "test_session_clear" session_id = "sess_abc123" mock_agent = Mock() mock_agent.interrupt = Mock() mock_agent.session_id = session_id mock_session = Mock() mock_session.active_stream_id = stream_id mock_session.pending_user_message = "hello" mock_session.pending_attachments = ["file.txt"] mock_session.pending_started_at = 1234567890.0 STREAMS[stream_id] = queue.Queue() CANCEL_FLAGS[stream_id] = threading.Event() AGENT_INSTANCES[stream_id] = mock_agent with patch('api.streaming.get_session', return_value=mock_session): cancel_stream(stream_id) assert mock_session.active_stream_id is None, \ "cancel_stream() should clear session.active_stream_id" assert mock_session.pending_user_message is None, \ "cancel_stream() should clear session.pending_user_message" assert mock_session.pending_attachments == [], \ "cancel_stream() should clear session.pending_attachments" assert mock_session.pending_started_at is None, \ "cancel_stream() should clear session.pending_started_at" mock_session.save.assert_called_once() def test_cancel_without_agent_still_pops_streams(self): """Cancel should pop STREAMS even when no agent instance exists.""" stream_id = "test_no_agent" STREAMS[stream_id] = queue.Queue() CANCEL_FLAGS[stream_id] = threading.Event() # No AGENT_INSTANCES entry cancel_stream(stream_id) assert stream_id not in STREAMS, \ "cancel_stream() should pop STREAMS even without agent instance" assert stream_id not in CANCEL_FLAGS def test_cancel_sentinel_still_queued(self): """Cancel sentinel should still be queued before popping STREAMS.""" stream_id = "test_sentinel" q = queue.Queue() STREAMS[stream_id] = q CANCEL_FLAGS[stream_id] = threading.Event() cancel_stream(stream_id) # The cancel sentinel should have been queued before the pop assert not q.empty() event_type, data = q.get_nowait() assert event_type == 'cancel' assert data['message'] == 'Cancelled by user' def test_double_cancel_is_safe(self): """Calling cancel_stream() twice should not raise.""" stream_id = "test_double" mock_agent = Mock() mock_agent.interrupt = Mock() mock_agent.session_id = "sess_xyz" STREAMS[stream_id] = queue.Queue() CANCEL_FLAGS[stream_id] = threading.Event() AGENT_INSTANCES[stream_id] = mock_agent # First cancel result1 = cancel_stream(stream_id) assert result1 is True assert stream_id not in STREAMS # Second cancel (stream already popped) result2 = cancel_stream(stream_id) assert result2 is False def test_cancel_handle_get_session_failure(self): """Cancel should not raise even if get_session fails.""" stream_id = "test_session_fail" mock_agent = Mock() mock_agent.interrupt = Mock() mock_agent.session_id = "sess_nonexistent" STREAMS[stream_id] = queue.Queue() CANCEL_FLAGS[stream_id] = threading.Event() AGENT_INSTANCES[stream_id] = mock_agent with patch('api.streaming.get_session', side_effect=KeyError("Session not found")): # Should not raise result = cancel_stream(stream_id) assert result is True assert stream_id not in STREAMS