diff --git a/laygo/pipeline.py b/laygo/pipeline.py index ae2e953..df4e1c5 100644 --- a/laygo/pipeline.py +++ b/laygo/pipeline.py @@ -11,6 +11,7 @@ from laygo.helpers import PipelineContext from laygo.helpers import is_context_aware +from laygo.transformers.threaded import ThreadedTransformer from laygo.transformers.transformer import Transformer T = TypeVar("T") @@ -120,7 +121,10 @@ def apply[U]( return self # type: ignore - # ... The rest of the Pipeline class (transform, __iter__, to_list, etc.) remains unchanged ... + def buffer(self, size: int) -> "Pipeline[T]": + self.apply(ThreadedTransformer(max_workers=size)) + return self + def __iter__(self) -> Iterator[T]: """Allows the pipeline to be iterated over.""" yield from self.processed_data diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 150470f..d9443c6 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -171,3 +171,43 @@ def test_chunked_processing_consistency(self): expected = list(range(1, 101)) # [1, 2, 3, ..., 100] assert result == expected + + def test_buffer_with_two_maps(self): + """Test that buffer function works correctly with two sequential map operations.""" + # Create a pipeline with two map operations and buffering + data = list(range(10)) + + # Track execution order to verify buffering behavior + execution_order = [] + + def first_map(x): + execution_order.append(f"first_map({x})") + return x * 2 + + def second_map(x): + execution_order.append(f"second_map({x})") + return x + 1 + + # Apply buffering with 2 workers between two map operations + result = ( + Pipeline(data) + .transform(lambda t: t.map(first_map)) + .buffer(2) # Buffer with 2 workers + .transform(lambda t: t.map(second_map)) + .to_list() + ) + + # Verify the final result is correct + expected = [(x * 2) + 1 for x in range(10)] # [1, 3, 5, 7, 9, 11, 13, 15, 17, 19] + assert result == expected + + # Verify both map operations were called for each element + assert len([call for call in execution_order if "first_map" in call]) == 10 + assert len([call for call in execution_order if "second_map" in call]) == 10 + + # Verify all expected values were processed + first_map_values = [int(call.split("(")[1].split(")")[0]) for call in execution_order if "first_map" in call] + second_map_values = [int(call.split("(")[1].split(")")[0]) for call in execution_order if "second_map" in call] + + assert sorted(first_map_values) == list(range(10)) + assert sorted(second_map_values) == [x * 2 for x in range(10)]