Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions src/agents/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:
- A MaxTurnsExceeded exception if the agent exceeds the max_turns limit.
- A GuardrailTripwireTriggered exception if a guardrail is tripped.
"""
cancelled = False
try:
while True:
self._check_errors()
Expand All @@ -320,7 +321,9 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:
try:
item = await self._event_queue.get()
except asyncio.CancelledError:
break
cancelled = True
self.cancel()
raise

if isinstance(item, QueueCompleteSentinel):
# Await input guardrails if they are still running, so late
Expand All @@ -337,11 +340,16 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:
yield item
self._event_queue.task_done()
finally:
# Ensure main execution completes before cleanup to avoid race conditions
# with session operations
await self._await_task_safely(self._run_impl_task)
# Safely terminate all background tasks after main execution has finished
self._cleanup_tasks()
if cancelled:
# Cancellation should return promptly, so avoid waiting on long-running tasks.
# Tasks have already been cancelled above.
self._cleanup_tasks()
else:
# Ensure main execution completes before cleanup to avoid race conditions
# with session operations
await self._await_task_safely(self._run_impl_task)
# Safely terminate all background tasks after main execution has finished
self._cleanup_tasks()

if self._stored_exception:
raise self._stored_exception
Expand Down
45 changes: 45 additions & 0 deletions tests/test_cancel_streaming.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
import asyncio
import json
import time

import pytest
from openai.types.responses import ResponseCompletedEvent

from agents import Agent, Runner
from agents.stream_events import RawResponsesStreamEvent

from .fake_model import FakeModel
from .test_responses import get_function_tool, get_function_tool_call, get_text_message


class SlowCompleteFakeModel(FakeModel):
"""A FakeModel that delays before emitting the completed event in streaming."""

def __init__(self, delay_seconds: float):
super().__init__()
self._delay_seconds = delay_seconds

async def stream_response(self, *args, **kwargs):
async for ev in super().stream_response(*args, **kwargs):
if isinstance(ev, ResponseCompletedEvent) and self._delay_seconds > 0:
await asyncio.sleep(self._delay_seconds)
yield ev


@pytest.mark.asyncio
async def test_simple_streaming_with_cancel():
model = FakeModel()
Expand Down Expand Up @@ -131,3 +149,30 @@ async def test_cancel_immediate_mode_explicit():
assert result.is_complete
assert result._event_queue.empty()
assert result._cancel_mode == "immediate"


@pytest.mark.asyncio
async def test_stream_events_respects_asyncio_timeout_cancellation():
model = SlowCompleteFakeModel(delay_seconds=0.5)
model.set_next_output([get_text_message("Final response")])
agent = Agent(name="TimeoutTester", model=model)

result = Runner.run_streamed(agent, input="Please tell me 5 jokes.")
event_iter = result.stream_events().__aiter__()

# Consume events until the output item is done so the next event is delayed.
while True:
event = await asyncio.wait_for(event_iter.__anext__(), timeout=1.0)
if (
isinstance(event, RawResponsesStreamEvent)
and event.data.type == "response.output_item.done"
):
break

start = time.perf_counter()
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(event_iter.__anext__(), timeout=0.1)
elapsed = time.perf_counter() - start

assert elapsed < 0.3, "Cancellation should propagate promptly when waiting for events."
result.cancel()