From 94e4fc69e7150dabced8346f1826f0075238123f Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Wed, 26 Nov 2025 15:23:20 +0900 Subject: [PATCH 1/4] [passes] Introduce LowerCopy and remove copy from serialization Let's introduce new LowerCopy pass instead of serialization of copy with redundant variance. TICO-DCO-1.0-Signed-off-by: Dayoung Lee --- test/modules/op/copy.py | 123 ++++++++++++++++++ tico/passes/lower_copy.py | 95 ++++++++++++++ tico/serialize/operators/op_copy.py | 187 ---------------------------- tico/utils/convert.py | 2 + 4 files changed, 220 insertions(+), 187 deletions(-) create mode 100644 tico/passes/lower_copy.py delete mode 100644 tico/serialize/operators/op_copy.py diff --git a/test/modules/op/copy.py b/test/modules/op/copy.py index f773636c..cee39d8b 100644 --- a/test/modules/op/copy.py +++ b/test/modules/op/copy.py @@ -18,6 +18,11 @@ class SimpleCopy(TestModuleBase): + """ + Test case: Same shape copy (should be folded away by ConvertCopyToReshape pass) + 5x5 -> 5x5 + """ + def __init__(self): super().__init__() @@ -30,6 +35,11 @@ def get_example_inputs(self): class SimpleCopyWithBroadcastTo(TestModuleBase): + """ + Test case: Broadcast from 1x5 to 5x5 + This tests the expand + reshape path in ConvertCopyToReshape pass + """ + def __init__(self): super().__init__() @@ -39,3 +49,116 @@ def forward(self, dst, src): def get_example_inputs(self): return (torch.randn(5, 5), torch.randn(1, 5)), {} + + +class CopyWithScalarBroadcast(TestModuleBase): + """ + Test case: Broadcast from 1x1 to 3x3 (scalar-like broadcast) + """ + + def __init__(self): + super().__init__() + + def forward(self, dst, src): + dst.copy_(src) + return dst + + def get_example_inputs(self): + return (torch.randn(3, 3), torch.randn(1, 1)), {} + + +class CopyWithRowBroadcast(TestModuleBase): + """ + Test case: Broadcast from 1x4 to 3x4 (row broadcast) + """ + + def __init__(self): + super().__init__() + + def forward(self, dst, src): + dst.copy_(src) + return dst + + def get_example_inputs(self): + return (torch.randn(3, 4), torch.randn(1, 4)), {} + + +class CopyWithColumnBroadcast(TestModuleBase): + """ + Test case: Broadcast from 3x1 to 3x4 (column broadcast) + """ + + def __init__(self): + super().__init__() + + def forward(self, dst, src): + dst.copy_(src) + return dst + + def get_example_inputs(self): + return (torch.randn(3, 4), torch.randn(3, 1)), {} + + +class CopyWith3DTensor(TestModuleBase): + """ + Test case: 3D tensor copy with same shape + """ + + def __init__(self): + super().__init__() + + def forward(self, dst, src): + dst.copy_(src) + return dst + + def get_example_inputs(self): + return (torch.randn(2, 3, 4), torch.randn(2, 3, 4)), {} + + +class CopyWith3DBroadcast(TestModuleBase): + """ + Test case: 3D tensor broadcast from 1x3x4 to 2x3x4 + """ + + def __init__(self): + super().__init__() + + def forward(self, dst, src): + dst.copy_(src) + return dst + + def get_example_inputs(self): + return (torch.randn(2, 3, 4), torch.randn(1, 3, 4)), {} + + +class CopyWithMultiDimBroadcast(TestModuleBase): + """ + Test case: Multi-dimensional broadcast from 1x1x4 to 2x3x4 + """ + + def __init__(self): + super().__init__() + + def forward(self, dst, src): + dst.copy_(src) + return dst + + def get_example_inputs(self): + return (torch.randn(2, 3, 4), torch.randn(1, 1, 4)), {} + + +class CopyWith4DTensor(TestModuleBase): + """ + Test case: 4D tensor copy (batch, channel, height, width) + Broadcast from 1x3x1x1 to 2x3x4x4 + """ + + def __init__(self): + super().__init__() + + def forward(self, dst, src): + dst.copy_(src) + return dst + + def get_example_inputs(self): + return (torch.randn(2, 3, 4, 4), torch.randn(1, 3, 1, 1)), {} diff --git a/tico/passes/lower_copy.py b/tico/passes/lower_copy.py new file mode 100644 index 00000000..28f8dd6c --- /dev/null +++ b/tico/passes/lower_copy.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch.fx +import torch +from torch.export import ExportedProgram + +from tico.passes import ops +from tico.serialize.circle_mapping import extract_shape +from tico.utils import logging +from tico.utils.graph import create_node +from tico.utils.passes import PassBase, PassResult +from tico.utils.trace_decorators import trace_graph_diff_on_pass +from tico.utils.validate_args_kwargs import CopyArgs + + +@trace_graph_diff_on_pass +class LowerCopy(PassBase): + """ + This pass lowers `aten.copy.default` to simpler broadcast operations. + + - If src and dst shapes are the same, the copy is redundant and folded away. + - If src and dst shapes differ, it's replaced with expand (broadcast). + + This simplifies serialization by handling copy logic at the pass level. + """ + + def __init__(self): + super().__init__() + + def call(self, exported_program: ExportedProgram) -> PassResult: + logger = logging.getLogger(__name__) + + graph_module = exported_program.graph_module + graph = graph_module.graph + modified = False + + for node in graph.nodes: + if not node.op == "call_function": + continue + + if node.target != torch.ops.aten.copy.default: + continue + + args = CopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type] + dst = args.dst + src = args.src + + dst_shape = list(extract_shape(dst)) + src_shape = list(extract_shape(src)) + + # Case 1: Same shape - copy is redundant, just use src + if dst_shape == src_shape: + logger.debug( + f"{node.name}: Same shape {dst_shape}, replacing with src directly" + ) + node.replace_all_uses_with(src, propagate_meta=False) + modified = True + continue + + # Case 2: Different shapes - need expand + logger.debug( + f"{node.name}: Different shapes src={src_shape} dst={dst_shape}, " + f"inserting expand" + ) + + with graph.inserting_before(node): + expand_node = create_node( + graph, + torch.ops.aten.expand.default, + args=(src, dst_shape), + ) + + node.replace_all_uses_with(expand_node, propagate_meta=True) + modified = True + + graph.eliminate_dead_code() + graph.lint() + graph_module.recompile() + + return PassResult(modified) diff --git a/tico/serialize/operators/op_copy.py b/tico/serialize/operators/op_copy.py deleted file mode 100644 index 309f6363..00000000 --- a/tico/serialize/operators/op_copy.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, List, Optional, TYPE_CHECKING, Union - -if TYPE_CHECKING: - import torch._ops - import torch.fx -import torch -from circle_schema import circle - -from tico.serialize.circle_graph import CircleSubgraph -from tico.serialize.operators.hashable_opcode import OpCode -from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor -from tico.serialize.operators.utils import create_builtin_operator, get_op_index -from tico.utils.errors import NotYetSupportedError -from tico.utils.validate_args_kwargs import CopyArgs - - -@register_node_visitor -class CopyVisitor(NodeVisitor): - """ - NOTE `torch.Tensor.copy_`'s behavior matches with `Reshape` of CIRCLE. - - because `torch.Tensor.copy_` is a in-place operator, so `dst` is converted to `Shape` of CIRCLE. - - after that, `dst` converted to `Shape` is connected to shape of `Reshape`. - - `src` is connected to tensor of `Reshape`. - - if `dst` is not converted to `Shape`. - [dst] [src] - | - [Reshape] - - if `dst` is converted to `Shape`. - [dst] [src] - | | - [Shape] | - \ / - [Reshape] - """ - - target: List[torch._ops.OpOverload] = [torch.ops.aten.copy.default] - - def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): - super().__init__(op_codes, graph) - - def check_to_do_broadcast( - self, - dst: List[int], - dst_sig: Optional[List[int]], - src: List[int], - src_sig: Optional[List[int]], - ) -> bool: - assert dst_sig is None - assert src_sig is None - return dst != src - - def define_broadcast_to_node( - self, - inputs: List[Union[circle.Tensor.TensorT, torch.Tensor]], - outputs: List[circle.Tensor.TensorT], - ) -> circle.Operator.OperatorT: - op_index = get_op_index( - circle.BuiltinOperator.BuiltinOperator.BROADCAST_TO, self._op_codes - ) - operator = create_builtin_operator(self.graph, op_index, inputs, outputs) - operator.builtinOptionsType = ( - circle.BuiltinOptions.BuiltinOptions.BroadcastToOptions - ) - - option = circle.BroadcastToOptions.BroadcastToOptionsT() - operator.builtinOptions = option - return operator - - def define_shape_node( - self, inputs: List[torch.fx.Node], outputs: List[circle.Tensor.TensorT] - ) -> circle.Operator.OperatorT: - op_index = get_op_index( - circle.BuiltinOperator.BuiltinOperator.SHAPE, self._op_codes - ) - operator = create_builtin_operator(self.graph, op_index, inputs, outputs) - operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ShapeOptions - - option = circle.ShapeOptions.ShapeOptionsT() - option.outType = circle.TensorType.TensorType.INT32 - operator.builtinOptions = option - return operator - - def define_node( - self, - node: torch.fx.Node, - ) -> circle.Operator.OperatorT: - if len(node.args) == 3: - raise NotYetSupportedError("'non_blocking' is not supported yet.") - - assert len(node.args) == 2, len(node.args) - - args = CopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type] - dst = args.dst - src = args.src - - # To connect 'dst' to Reshape node in the graph, 'dst' must be converted to Shape op. - dst_tensor: circle.Tensor.TensorT = self.graph.get_tensor(dst) - dst_shape: List[int] = dst_tensor.shape - dst_shape_signature: Optional[List[int]] = dst_tensor.shapeSignature - - if dst_shape_signature is not None: - # TODO: support dynamic shape - raise NotYetSupportedError("Dynamic shape is not supported yet.") - - dst_shape_tensor = torch.as_tensor(dst_shape, dtype=torch.int32) - - dst_shape_shape = [len(dst_shape)] - dst_name: str = dst.name - - shape_output = self.graph.add_tensor_from_scratch( - prefix=f"{dst_name}_shape_output", - shape=dst_shape_shape, - shape_signature=None, - dtype=circle.TensorType.TensorType.INT32, - source_node=node, - ) - - shape_operator = self.define_shape_node([dst], [shape_output]) - self.graph.add_operator(shape_operator) - - src_tensor: circle.Tensor.TensorT = self.graph.get_tensor(src) - src_shape: List[int] = src_tensor.shape - src_shape_signature: Optional[List[int]] = src_tensor.shapeSignature - - if src_shape_signature is not None: - # TODO: support dynamic shape - raise NotYetSupportedError("Dynamic shape is not supported yet.") - - # The src tensor must be broadcastable with the dst tensor. - do_broadcast = self.check_to_do_broadcast( - dst_shape, dst_shape_signature, src_shape, src_shape_signature - ) - if do_broadcast: - # create braodcastTo output tensor - src_name: str = src.name - src_type: int = src_tensor.type - - broadcast_to_output: circle.Tensor.TensorT = ( - self.graph.add_tensor_from_scratch( - prefix=f"{src_name}_broadcast_to_output", - shape=dst_shape, - shape_signature=dst_shape_signature, - dtype=src_type, - source_node=node, - ) - ) - - broadcast_to_operator: circle.Operator.OperatorT = ( - self.define_broadcast_to_node( - [src_tensor, dst_shape_tensor], [broadcast_to_output] - ) - ) - self.graph.add_operator(broadcast_to_operator) - inputs: List = [broadcast_to_output, shape_output] - else: - inputs = [src, shape_output] - - outputs = [node] - op_index = get_op_index( - circle.BuiltinOperator.BuiltinOperator.RESHAPE, self._op_codes - ) - - operator = create_builtin_operator(self.graph, op_index, inputs, outputs) - - # Op-specific option - operator.builtinOptionsType = ( - circle.BuiltinOptions.BuiltinOptions.ReshapeOptions - ) - option = circle.ReshapeOptions.ReshapeOptionsT() - option.newShape = dst_shape - - operator.builtinOptions = option - return operator diff --git a/tico/utils/convert.py b/tico/utils/convert.py index 7ac47f25..2fe12093 100644 --- a/tico/utils/convert.py +++ b/tico/utils/convert.py @@ -47,6 +47,7 @@ from tico.passes.legalize_predefined_layout_operators import ( LegalizePreDefinedLayoutOperators, ) +from tico.passes.lower_copy import LowerCopy from tico.passes.lower_pow2_to_mul import LowerPow2ToMul from tico.passes.lower_to_resize_nearest_neighbor import LowerToResizeNearestNeighbor from tico.passes.lower_to_slice import passes as LowerToSlicePasses @@ -224,6 +225,7 @@ def convert_exported_module_to_circle( FillMetaVal(), ExtractDtypeKwargsPass(), RemoveNop(), + LowerCopy(), ConvertLayoutOpToReshape(), RestoreLinear(), ConvertToReLU6(), From cb596ee44812325716574af3778c054138607a9c Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Thu, 27 Nov 2025 14:18:54 +0900 Subject: [PATCH 2/4] symsize --- test/modules/op/sym_size.py | 79 +++++++++++++ test/pt2_to_circle_test/builder.py | 10 +- test/pt2_to_circle_test/test_pt2_to_circle.py | 2 +- test/utils/infer.py | 42 ++++++- .../convert_sym_size_to_circle_shape.py | 102 +++++++++++++++++ tico/serialize/operators/op_sym_size_int.py | 108 ++++++++++++++++++ tico/serialize/operators/utils.py | 2 +- tico/utils/convert.py | 4 + tico/utils/register_custom_op.py | 27 +++++ 9 files changed, 371 insertions(+), 5 deletions(-) create mode 100644 test/modules/op/sym_size.py create mode 100644 tico/passes/convert_sym_size_to_circle_shape.py create mode 100644 tico/serialize/operators/op_sym_size_int.py diff --git a/test/modules/op/sym_size.py b/test/modules/op/sym_size.py new file mode 100644 index 00000000..2a482d3b --- /dev/null +++ b/test/modules/op/sym_size.py @@ -0,0 +1,79 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.export import Dim + +from test.modules.base import TestModuleBase +from test.utils import tag + + +@tag.use_onert +class SymSizeSimple(TestModuleBase): + """ + Simplest test case for sym_size.int generation. + Just returns the batch size (first dimension). + """ + + def forward(self, x): + # Accessing x.shape[0] on a dynamic dimension creates sym_size.int + return x.shape[0] + + def get_example_inputs(self): + return (torch.randn(2, 4, 4),), {} + + def get_dynamic_shapes(self): + batch = Dim("batch", min=1, max=128) + return {"x": {0: batch}} + + +@tag.use_onert +class SymSizeInReshape(TestModuleBase): + """ + Test case using sym_size.int in a reshape operation. + This is a common pattern in dynamic batch size models. + """ + + def forward(self, x): + batch_size = x.shape[0] + # Use the dynamic batch size in reshape + return x.reshape(batch_size, -1) + + def get_example_inputs(self): + return (torch.randn(2, 4, 4),), {} + + def get_dynamic_shapes(self): + batch = Dim("batch", min=1, max=128) + return {"x": {0: batch}} + + +@tag.use_onert +class SymSizeMultipleDims(TestModuleBase): + """ + Test case using multiple dynamic dimensions. + """ + + def forward(self, x): + h = x.shape[1] + w = x.shape[2] + # Reshape using multiple dynamic dimensions + return x.reshape(-1, h * w) + + def get_example_inputs(self): + return (torch.randn(2, 4, 4),), {} + + def get_dynamic_shapes(self): + batch = Dim("batch", min=1, max=128) + return {"x": {0: batch}} + diff --git a/test/pt2_to_circle_test/builder.py b/test/pt2_to_circle_test/builder.py index 3f793067..df300ac7 100644 --- a/test/pt2_to_circle_test/builder.py +++ b/test/pt2_to_circle_test/builder.py @@ -20,6 +20,8 @@ from pathlib import Path from typing import Optional +import torch + from tico.config.base import CompileConfigBase from tico.utils.signature import ModelInputSpec @@ -204,8 +206,12 @@ def has_symbolic_input(circle_model_path: str) -> bool: forward_kwargs=deepcopy(self.forward_kwargs), runtime="onert", ) - torch_shape = torch_result[0].shape - circle_result[0] = circle_result[0].reshape(torch_shape) + for idx, (tr, cr) in enumerate(zip(torch_result, circle_result)): + if isinstance(tr, torch.Tensor): + circle_result[idx] = circle_result[idx].reshape(tr.shape) + else: + # tr is scalar + torch_result[idx] = torch.tensor([tr], dtype=type(cr)) # TODO Fix else: circle_result = infer_circle( circle_model_path, diff --git a/test/pt2_to_circle_test/test_pt2_to_circle.py b/test/pt2_to_circle_test/test_pt2_to_circle.py index 2ca99219..57b6e36b 100644 --- a/test/pt2_to_circle_test/test_pt2_to_circle.py +++ b/test/pt2_to_circle_test/test_pt2_to_circle.py @@ -235,7 +235,7 @@ def validate_result( else: raise TypeError("Expected result must be a tensor or scalar value.") - # Check both dypte and value mismatch + # Check both dtype and value mismatch torch.testing.assert_close( actual=circle_tensor, expected=expected_tensor, diff --git a/test/utils/infer.py b/test/utils/infer.py index bd835182..2ba7d3a4 100644 --- a/test/utils/infer.py +++ b/test/utils/infer.py @@ -17,7 +17,7 @@ import tico.utils import tico.utils.model from tico.utils.signature import ModelInputSpec - +import torch def infer_with_circle_interpreter( circle_path: str, @@ -51,6 +51,13 @@ def infer_with_circle_interpreter( return circle_result +from tico.serialize.circle_mapping import ( + extract_circle_dtype, + extract_circle_shape, + str_to_circle_dtype, + to_circle_dtype, + to_circle_shape, +) def infer_with_onert( circle_path: str, @@ -83,6 +90,39 @@ def infer_with_onert( inputs = ispec.bind(forward_args, forward_kwargs, check=True) session_float = infer.session(circle_path) + + # Handle dynamic shapes: onert cannot execute models with unspecified dimensions + # Check if any input has dynamic dimensions (indicated by -1) + input_tensorinfos = session_float.get_inputs_tensorinfo() + has_dynamic_shapes = any( + -1 in info.dims for info in input_tensorinfos + ) + + if has_dynamic_shapes: + # Set concrete input shapes based on the actual input data + from onert.native.libnnfw_api_pybind import tensorinfo + + for idx, (info, input_data) in enumerate(zip(input_tensorinfos, inputs)): + if -1 in info.dims: + # Create new tensorinfo with concrete shape from input data + new_info = tensorinfo() + new_info.rank = len(input_data.shape) + new_info.dims = list(input_data.shape) + + assert input_data.dtype in [torch.float32, torch.float] + new_info.dtype = "float32" + + try: + session_float.session.set_input_tensorinfo(idx, new_info) + except Exception as e: + # If setting tensorinfo fails, try to continue anyway + # Some versions of onert might handle this differently + import warnings + warnings.warn( + f"Failed to set input tensorinfo for input {idx}: {e}. " + f"Attempting inference anyway." + ) + output = session_float.infer(inputs) return output diff --git a/tico/passes/convert_sym_size_to_circle_shape.py b/tico/passes/convert_sym_size_to_circle_shape.py new file mode 100644 index 00000000..d8b6ae09 --- /dev/null +++ b/tico/passes/convert_sym_size_to_circle_shape.py @@ -0,0 +1,102 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch.fx +import torch +from torch.export import ExportedProgram + +from tico.utils import logging +from tico.utils.graph import create_node +from tico.utils.passes import PassBase, PassResult +from tico.utils.trace_decorators import trace_graph_diff_on_pass + + +@trace_graph_diff_on_pass +class ConvertSymSizeToCircleShape(PassBase): + """ + This pass converts torch.ops.aten.sym_size.int operations to circle_custom::shape. + + The circle_custom::shape operator is similar to TensorFlow's shape operator and + allows preserving dynamic shape information in the Circle model. This is essential + for models with dynamic batch sizes or other dynamic dimensions. + + Example: + Before: %sym_size_int_1 = call_function[target=torch.ops.aten.sym_size.int](args=(%x, 0)) + After: %shape_0 = call_function[target=torch.ops.circle_custom.shape](args=(%x,)) + %slice_0 = call_function[target=torch.ops.aten.slice.Tensor](args=(%shape_0, 0, 0, 1, 1)) + %squeeze_0 = call_function[target=torch.ops.aten.squeeze.dim](args=(%slice_0, 0)) + """ + + def __init__(self): + super().__init__() + + def call(self, exported_program: ExportedProgram) -> PassResult: + logger = logging.getLogger(__name__) + + graph_module = exported_program.graph_module + graph = graph_module.graph + modified = False + + for node in graph.nodes: + if node.op != "call_function": + continue + + if node.target == torch.ops.aten.sym_size.int: + # sym_size.int has args: (input, dim) + input_tensor = node.args[0] + dim = node.args[1] + + # Create circle_custom::shape node + with graph.inserting_after(node): + shape_node = create_node( + graph, + torch.ops.circle_custom.shape, + args=(input_tensor,), + ) + + # Set metadata for shape_node + if "val" in input_tensor.meta: + input_val = input_tensor.meta["val"] + rank = len(input_val.shape) + # shape output is a 1D tensor of size rank, dtype int32 + # We use a real tensor here as a placeholder for metadata + shape_node.meta["val"] = torch.zeros(rank, dtype=torch.int32) + + # Extract the specific dimension using slice + squeeze + # shape is 1D, so we slice [dim:dim+1] then squeeze dim 0 + with graph.inserting_after(shape_node): + slice_node = create_node( + graph, + torch.ops.aten.slice.Tensor, + args=(shape_node, 0, dim, dim + 1, 1), + ) + # slice output is 1D tensor of size 1 + slice_node.meta["val"] = torch.zeros(1, dtype=torch.int32) + + # Replace all uses + node.replace_all_uses_with(slice_node, propagate_meta=False) + modified = True + + # logger.debug( + # f"Converted {node.name} (sym_size.int) to {shape_node.name} (circle_custom::shape) + {slice_node.name} (slice) + {squeeze_node.name} (squeeze)" + # ) + + graph.eliminate_dead_code() + graph.lint() + graph_module.recompile() + + return PassResult(modified) diff --git a/tico/serialize/operators/op_sym_size_int.py b/tico/serialize/operators/op_sym_size_int.py new file mode 100644 index 00000000..a6e2f0be --- /dev/null +++ b/tico/serialize/operators/op_sym_size_int.py @@ -0,0 +1,108 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, TYPE_CHECKING + +if TYPE_CHECKING: + import torch._ops + import torch.fx +import torch +from circle_schema import circle + +from tico.serialize.circle_graph import CircleSubgraph +from tico.serialize.operators.hashable_opcode import OpCode +from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor +from tico.serialize.operators.utils import create_builtin_operator, get_op_index + +@register_node_visitor +class SymSizeIntVisitor(NodeVisitor): + target: List[torch._ops.OpOverload] = [ + torch.ops.aten.sym_size.int, + ] + + def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): + super().__init__(op_codes, graph) + + def define_node( + self, + node: torch.fx.Node, + ) -> circle.Operator.OperatorT: + # args: (input, dim) + input_node = node.args[0] + dim = node.args[1] + + # 1. Shape op + op_index_shape = get_op_index( + circle.BuiltinOperator.BuiltinOperator.SHAPE, self._op_codes + ) + + # Create a temporary tensor for shape output + # The shape of 'Shape' output is [rank_of_input] + input_tensor = self.graph.get_tensor(input_node) + rank = len(input_tensor.shape) + shape_output_shape = [rank] + + shape_output = self.graph.add_tensor_from_scratch( + prefix=f"{node.name}_shape", + shape=shape_output_shape, + shape_signature=None, + dtype=circle.TensorType.TensorType.INT32, + source_node=node, + ) + + shape_op = create_builtin_operator( + self.graph, op_index_shape, [input_node], [shape_output] + ) + shape_op.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ShapeOptions + shape_op.builtinOptions = circle.ShapeOptions.ShapeOptionsT() + shape_op.builtinOptions.outType = circle.TensorType.TensorType.INT32 + + self.graph.add_operator(shape_op) + + # Handle negative dim + if dim < 0: + dim += rank + + # 2. StridedSlice to extract the dimension + # Input: shape_output + # Output: node (scalar) + + op_index_slice = get_op_index( + circle.BuiltinOperator.BuiltinOperator.STRIDED_SLICE, self._op_codes + ) + + # Create const tensors for begin, end, strides + dim_i32 = torch.tensor([dim], dtype=torch.int32) + begin_tensor = self.graph.add_const_tensor(dim_i32) + end_tensor = self.graph.add_const_tensor(dim_i32 + 1) + strides_tensor = self.graph.add_const_tensor(torch.tensor([1], dtype=torch.int32)) + + inputs = [shape_output, begin_tensor, end_tensor, strides_tensor] + outputs = [node] + + slice_op = create_builtin_operator( + self.graph, op_index_slice, inputs, outputs + ) + + slice_op.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.StridedSliceOptions + option = circle.StridedSliceOptions.StridedSliceOptionsT() + option.beginMask = 0 + option.endMask = 0 + option.ellipsisMask = 0 + option.newAxisMask = 0 + option.shrinkAxisMask = 1 # Shrink the 0-th axis to make it scalar + + slice_op.builtinOptions = option + + return slice_op diff --git a/tico/serialize/operators/utils.py b/tico/serialize/operators/utils.py index 462001d6..f58c8a06 100644 --- a/tico/serialize/operators/utils.py +++ b/tico/serialize/operators/utils.py @@ -47,7 +47,7 @@ def create_builtin_operator( ) -> circle.Operator.OperatorT: operator = circle.Operator.OperatorT() operator.opcodeIndex = op_index - operator.inputs = [graph.get_tid(input) for input in inputs] + operator.inputs = [graph.get_tid(input_) for input_ in inputs] operator.outputs = [graph.get_tid(output) for output in outputs] return operator diff --git a/tico/utils/convert.py b/tico/utils/convert.py index 2fe12093..499c2f10 100644 --- a/tico/utils/convert.py +++ b/tico/utils/convert.py @@ -27,6 +27,7 @@ from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d from tico.passes.convert_expand_to_slice_cat import ConvertExpandToSliceCat from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape +from tico.passes.convert_sym_size_to_circle_shape import ConvertSymSizeToCircleShape from tico.passes.convert_matmul_to_linear import ConvertMatmulToLinear from tico.passes.convert_repeat_to_expand_copy import ConvertRepeatToExpandCopy from tico.passes.convert_to_relu6 import ConvertToReLU6 @@ -226,6 +227,7 @@ def convert_exported_module_to_circle( ExtractDtypeKwargsPass(), RemoveNop(), LowerCopy(), + ConvertSymSizeToCircleShape(), ConvertLayoutOpToReshape(), RestoreLinear(), ConvertToReLU6(), @@ -294,6 +296,8 @@ def convert_exported_module_to_circle( check_unsupported_target(exported_program) check_training_ops(exported_program) + + exported_program.graph.print_tabular() circle_program = build_circle(exported_program, config) return circle_program diff --git a/tico/utils/register_custom_op.py b/tico/utils/register_custom_op.py index 095cb6a5..ea06ee5f 100644 --- a/tico/utils/register_custom_op.py +++ b/tico/utils/register_custom_op.py @@ -727,6 +727,32 @@ def _( return hidden_states.new_empty(hidden_states.size()) +def CircleShape(): + """ + Custom operator to extract the shape of a tensor. + This is similar to TensorFlow's shape operator and is used to preserve + dynamic shape information in the Circle model. + + Args: + input_: Input tensor + + Returns: + A 1D tensor containing the shape of the input tensor + """ + @custom_op("circle_custom::shape", mutates_args=()) + def shape(input_: torch.Tensor) -> torch.Tensor: + # Return the shape of the input tensor as a 1D tensor + shape_val = list(input_.size()) + return torch.tensor(shape_val, dtype=torch.int32) + + @register_fake("circle_custom::shape") + def _(input_: torch.Tensor) -> torch.Tensor: + # Return a 1D tensor with symbolic shape + # The actual value will be determined at runtime + rank = len(input_.size()) + return torch.empty([rank], dtype=torch.int32) + + # Add custom ops to the torch namespace def RegisterOps(): CircleResizeNearestNeighbor() @@ -740,3 +766,4 @@ def RegisterOps(): CircleInstanceNorm() CircleQuantizeMX() CircleRMSNorm() + CircleShape() From c9c9a06c8dfe4f4b12a02f8f21aca512558fa428 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Thu, 27 Nov 2025 15:37:14 +0900 Subject: [PATCH 3/4] merge --- test/pt2_to_circle_test/builder.py | 2 +- tico/serialize/circle_graph.py | 2 +- tico/serialize/circle_serializer.py | 4 +++- tico/serialize/operators/op_slice.py | 18 +++++++++--------- tico/utils/validate_args_kwargs.py | 12 ++++++++---- 5 files changed, 22 insertions(+), 16 deletions(-) diff --git a/test/pt2_to_circle_test/builder.py b/test/pt2_to_circle_test/builder.py index df300ac7..f6360807 100644 --- a/test/pt2_to_circle_test/builder.py +++ b/test/pt2_to_circle_test/builder.py @@ -211,7 +211,7 @@ def has_symbolic_input(circle_model_path: str) -> bool: circle_result[idx] = circle_result[idx].reshape(tr.shape) else: # tr is scalar - torch_result[idx] = torch.tensor([tr], dtype=type(cr)) # TODO Fix + torch_result[idx] = torch.tensor([tr], dtype=torch.int32) # TODO Fix properly else: circle_result = infer_circle( circle_model_path, diff --git a/tico/serialize/circle_graph.py b/tico/serialize/circle_graph.py index aacba4f2..796a04b8 100644 --- a/tico/serialize/circle_graph.py +++ b/tico/serialize/circle_graph.py @@ -324,4 +324,4 @@ def get_tid( return self.name_to_tid[node_name] # Unreachable - raise RuntimeError("fx Node was not converted to tensor.") + raise RuntimeError(f"fx Node was not converted to tensor: {node}") diff --git a/tico/serialize/circle_serializer.py b/tico/serialize/circle_serializer.py index 9ac21cc1..aad1e99c 100644 --- a/tico/serialize/circle_serializer.py +++ b/tico/serialize/circle_serializer.py @@ -149,7 +149,9 @@ def _export_tensors(graph: CircleSubgraph, ep: ExportedProgram) -> None: if node.target in multiple_output_ops: continue node_val = node.meta["val"] - if node_val.layout != torch.strided: + if isinstance(node_val, torch.SymInt): + pass + elif node_val.layout != torch.strided: raise RuntimeError( f"Only support dense tensors (node layout: {node_val.layout})" ) diff --git a/tico/serialize/operators/op_slice.py b/tico/serialize/operators/op_slice.py index 7566bfec..3c99b3b5 100644 --- a/tico/serialize/operators/op_slice.py +++ b/tico/serialize/operators/op_slice.py @@ -51,14 +51,14 @@ def define_node( circle.BuiltinOperator.BuiltinOperator.STRIDED_SLICE, self._op_codes ) - args = SliceArgs(*node.args, **node.kwargs) # type: ignore[arg-type] - input = args.input - dim = args.dim - start = args.start - end = args.end - step = args.step - - input_tensor: circle.Tensor.TensorT = self.graph.get_tensor(input) + slice_args = SliceArgs(*node.args, **node.kwargs) # type: ignore[arg-type] + input_ = slice_args.input + dim = slice_args.dim + start = slice_args.start + end = slice_args.end + step = slice_args.step + + input_tensor: circle.Tensor.TensorT = self.graph.get_tensor(input_) input_shape: List[int] = input_tensor.shape if start is None: @@ -140,7 +140,7 @@ def define_node( stride_shape[dim] = step stride_shape_tensor = torch.as_tensor(stride_shape, dtype=torch.int32) - inputs = [input, begin_shape_tensor, end_shape_tensor, stride_shape_tensor] + inputs = [input_, begin_shape_tensor, end_shape_tensor, stride_shape_tensor] outputs = [node] operator = create_builtin_operator(self.graph, op_index, inputs, outputs) diff --git a/tico/utils/validate_args_kwargs.py b/tico/utils/validate_args_kwargs.py index 8a5feb21..f84550f3 100644 --- a/tico/utils/validate_args_kwargs.py +++ b/tico/utils/validate_args_kwargs.py @@ -930,7 +930,7 @@ class RepeatArgs: """ input: torch.fx.Node - repeats: List[int] + repeats: List[Union[int, torch.SymInt, torch.fx.Node]] @enforce_type @@ -938,10 +938,14 @@ class RepeatArgs: class ReshapeArgs: """ reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) + + Note: After PrepareReshapeDynamicShape pass, shape can be either: + - A list of int/SymInt/Node (original or static) + - A single Node (dynamic shape tensor prepared by the pass) """ input: torch.fx.Node - shape: List[int] + shape: Union[List[Union[int, torch.SymInt, torch.fx.Node]], torch.fx.Node] @enforce_type @@ -1077,7 +1081,7 @@ class SplitWithSizesArgs: """ input: torch.fx.Node - split_sizes: List[int] + split_sizes: List[Union[int, torch.SymInt, torch.fx.Node]] dim: int = 0 @@ -1218,7 +1222,7 @@ class ViewArgs: """ input: torch.fx.Node - size: List[int] + size: List[Union[int, torch.SymInt, torch.fx.Node]] @enforce_type From fea6de0fb838ff9bc7bc427437a74fc0261d471d Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Thu, 27 Nov 2025 15:37:20 +0900 Subject: [PATCH 4/4] temp --- test/modules/op/reshape.py | 24 ++ test/modules/op/sym_size.py | 1 + tico/interpreter/infer.py | 30 ++- tico/passes/convert_layout_op_to_reshape.py | 81 ++++-- tico/passes/prepare_reshape_dynamic_shape.py | 149 +++++++++++ tico/passes/remove_redundant_reshape.py | 32 ++- tico/serialize/operators/op_circle_shape.py | 66 +++++ tico/serialize/operators/op_repeat.py | 246 ++++++++++++++---- tico/serialize/operators/op_reshape.py | 53 +++- .../operators/op_split_with_sizes.py | 88 ++++++- tico/serialize/operators/op_view.py | 98 ++++++- 11 files changed, 760 insertions(+), 108 deletions(-) create mode 100644 tico/passes/prepare_reshape_dynamic_shape.py create mode 100644 tico/serialize/operators/op_circle_shape.py diff --git a/test/modules/op/reshape.py b/test/modules/op/reshape.py index d3b6b4c1..0afda54a 100644 --- a/test/modules/op/reshape.py +++ b/test/modules/op/reshape.py @@ -13,9 +13,12 @@ # limitations under the License. import torch +from torch.export import Dim from test.modules.base import TestModuleBase +from test.utils import tag + # Note. tests that call `aten.reshape` or `torch.reshape` are exporeted to aten graph that has `aten.view` instead of `aten.reshape`. @@ -65,3 +68,24 @@ def forward(self, x): def get_example_inputs(self): return (torch.randn(2, 4, 5),), {} + + +@tag.use_onert +class ReshapeDynamicShape(TestModuleBase): + def __init__(self): + super().__init__() + + def forward(self, x): + # Reshape to (batch, -1) where batch is dynamic + return x.reshape(x.shape[0], -1) + + def get_example_inputs(self): + return (torch.randn(4, 4, 4),), {} + + def get_dynamic_shapes(self): + batch = Dim("batch", min=2, max=128) + dynamic_shapes = { + "x": {0: batch}, + } + return dynamic_shapes + diff --git a/test/modules/op/sym_size.py b/test/modules/op/sym_size.py index 2a482d3b..27462a3e 100644 --- a/test/modules/op/sym_size.py +++ b/test/modules/op/sym_size.py @@ -39,6 +39,7 @@ def get_dynamic_shapes(self): @tag.use_onert +@tag.skip(reason="Not yet supported") class SymSizeInReshape(TestModuleBase): """ Test case using sym_size.int in a reshape operation. diff --git a/tico/interpreter/infer.py b/tico/interpreter/infer.py index 792699b7..d388920a 100644 --- a/tico/interpreter/infer.py +++ b/tico/interpreter/infer.py @@ -69,6 +69,7 @@ def infer(circle_binary: bytes, *args: Any, **kwargs: Any) -> Any: graph.Tensors(graph.Inputs(o)) for o in range(graph.InputsLength()) ] model_input_shapes_np = [t.ShapeAsNumpy() for t in model_input_tensors] + model_input_shape_sigs_np = [t.ShapeSignatureAsNumpy() for t in model_input_tensors] model_input_types_cm = [t.Type() for t in model_input_tensors] # Check if given inputs' dtype and shape from users match the inputs' from model binary. @@ -77,11 +78,30 @@ def infer(circle_binary: bytes, *args: Any, **kwargs: Any) -> Any: f"Mismatch input length: input({len(user_inputs)}) != circle model({len(model_input_shapes_np)})" ) for input_idx, user_input in enumerate(user_inputs): - # Shape - if list(user_input.shape) != list(model_input_shapes_np[input_idx]): - raise RuntimeError( - f"Mismatch input {input_idx} shape : input({user_input.shape}) != circle model({model_input_shapes_np[input_idx]})" - ) + # Shape - check against shape_signature if available (for dynamic shapes) + model_shape = model_input_shapes_np[input_idx] + model_shape_sig = model_input_shape_sigs_np[input_idx] + user_shape = list(user_input.shape) + + # If shape_signature exists, validate against it (supports dynamic dimensions) + if model_shape_sig is not None and len(model_shape_sig) > 0: + if len(user_shape) != len(model_shape_sig): + raise RuntimeError( + f"Mismatch input {input_idx} rank: input({len(user_shape)}) != circle model({len(model_shape_sig)})" + ) + for dim_idx, (user_dim, sig_dim) in enumerate(zip(user_shape, model_shape_sig)): + # -1 in shape_signature means dynamic dimension, accept any value + if sig_dim != -1 and user_dim != sig_dim: + raise RuntimeError( + f"Mismatch input {input_idx} shape at dimension {dim_idx}: input({user_dim}) != circle model({sig_dim})" + ) + else: + # No shape_signature, validate against static shape + if user_shape != list(model_shape): + raise RuntimeError( + f"Mismatch input {input_idx} shape : input({user_input.shape}) != circle model({model_shape})" + ) + # Data type user_input_type_cm = to_circle_dtype(user_input.dtype) if user_input_type_cm != model_input_types_cm[input_idx]: diff --git a/tico/passes/convert_layout_op_to_reshape.py b/tico/passes/convert_layout_op_to_reshape.py index 443f37a4..ced3739f 100644 --- a/tico/passes/convert_layout_op_to_reshape.py +++ b/tico/passes/convert_layout_op_to_reshape.py @@ -20,7 +20,6 @@ from torch.export import ExportedProgram from tico.passes import ops -from tico.serialize.circle_mapping import extract_shape from tico.utils import logging from tico.utils.graph import create_node from tico.utils.passes import PassBase, PassResult @@ -45,38 +44,78 @@ def call(self, exported_program: ExportedProgram) -> PassResult: graph = graph_module.graph modified = False - def convert(node, input): - out_shape = list(extract_shape(node)) - - with graph.inserting_after(node): - reshape_node = create_node( - graph, - torch.ops.aten.reshape.default, - args=(input, out_shape), - ) - node.replace_all_uses_with(reshape_node, propagate_meta=True) - - logger.debug(f"{node.name} is replaced with {reshape_node.name}") - for node in graph.nodes: if not node.op == "call_function": continue + reshape_node = None + if node.target in ops.aten.view: view_args = ViewArgs(*node.args, **node.kwargs) - convert(node, view_args.input) + # Preserve the original size argument which may contain dynamic shapes + # (e.g., sym_size.int nodes, SymInt values, or -1 for inferred dimensions) + with graph.inserting_after(node): + reshape_node = create_node( + graph, + torch.ops.aten.reshape.default, + args=(view_args.input, view_args.size), + ) modified = True - continue + elif node.target in ops.aten.unsqueeze: unsqueeze_args = UnSqueezeArgs(*node.args, **node.kwargs) - convert(node, unsqueeze_args.input) - modified = True - continue + # For unsqueeze, we need to construct the output shape dynamically + # to preserve symbolic dimensions from the input + input_node = unsqueeze_args.input + dim = unsqueeze_args.dim + + # Get input shape - may contain symbolic dimensions + input_meta = input_node.meta.get("val") + if input_meta is not None: + input_shape = list(input_meta.shape) + # Build output shape by inserting 1 at the specified dimension + # Preserve any symbolic dimensions (SymInt) from input + output_shape = input_shape[:dim] + [1] + input_shape[dim:] + + with graph.inserting_after(node): + reshape_node = create_node( + graph, + torch.ops.aten.reshape.default, + args=(input_node, output_shape), + ) + modified = True + elif node.target in ops.aten.squeeze: squeeze_args = SqueezeArgs(*node.args, **node.kwargs) - convert(node, squeeze_args.input) + # For squeeze, we need to construct the output shape dynamically + # to preserve symbolic dimensions from the input + input_node = squeeze_args.input + dims = squeeze_args.dims + + # Get input shape - may contain symbolic dimensions + input_meta = input_node.meta.get("val") + assert input_meta is not None + + input_shape = list(input_meta.shape) + # Remove specific dimension if it's size 1 + for dim in dims: + assert input_shape[dim] == 1 + output_shape = [] + for dim in range(len(input_shape)): + if dim not in dims: + output_shape.append(input_shape[dim]) + + with graph.inserting_after(node): + reshape_node = create_node( + graph, + torch.ops.aten.reshape.default, + args=(input_node, output_shape), + ) modified = True - continue + + if reshape_node is not None: + node.replace_all_uses_with(reshape_node, propagate_meta=True) + logger.debug(f"{node.name} is replaced with {reshape_node.name}") graph.eliminate_dead_code() graph.lint() diff --git a/tico/passes/prepare_reshape_dynamic_shape.py b/tico/passes/prepare_reshape_dynamic_shape.py new file mode 100644 index 00000000..c519f360 --- /dev/null +++ b/tico/passes/prepare_reshape_dynamic_shape.py @@ -0,0 +1,149 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch.fx +import torch +from torch.export import ExportedProgram + +from tico.utils import logging +from tico.utils.graph import create_node +from tico.utils.passes import PassBase, PassResult +from tico.utils.trace_decorators import trace_graph_diff_on_pass +from tico.utils.validate_args_kwargs import ReshapeArgs + + +@trace_graph_diff_on_pass +class PrepareReshapeDynamicShape(PassBase): + """ + This pass prepares dynamic shape arguments for reshape operations. + + For reshape operations with dynamic shapes (containing fx.Node or SymInt), + this pass converts the shape list into a single 1D tensor by: + 1. Converting scalar Node elements to 1D tensors via slice + 2. Converting int/SymInt elements to constant 1D tensors + 3. Concatenating all elements into a single shape tensor + + This simplifies the serialization logic by ensuring reshape always receives + either a constant shape list or a single shape tensor node. + + Example: + Before: %reshape = call_function[target=torch.ops.aten.reshape.default]( + args=(%x, [%slice_tensor, -1])) + After: %const_neg1 = call_function[target=torch.ops.aten.tensor.default](args=([-1],)) + %cat = call_function[target=torch.ops.aten.cat.default]( + args=([%slice_tensor, %const_neg1], 0)) + %reshape = call_function[target=torch.ops.aten.reshape.default]( + args=(%x, %cat)) + """ + + def __init__(self): + super().__init__() + + def call(self, exported_program: ExportedProgram) -> PassResult: + logger = logging.getLogger(__name__) + + graph_module = exported_program.graph_module + graph = graph_module.graph + modified = False + + for node in graph.nodes: + if node.op != "call_function": + continue + + if node.target != torch.ops.aten.reshape.default: + continue + + # Get the shape argument - need to access it directly since + # ReshapeArgs expects a list, but we might have a Node + if len(node.args) < 2: + continue + + size = node.args[1] + + # If size is already a single Node (already prepared), skip + if isinstance(size, torch.fx.Node): + continue + + # Check if this is a dynamic reshape + is_dynamic = any(isinstance(s, (torch.SymInt, torch.fx.Node)) for s in size) + + if not is_dynamic: + continue + + args = ReshapeArgs(*node.args, **node.kwargs) # type: ignore[arg-type] + input_node = args.input + + # Build list of 1D tensor nodes for each dimension + shape_elements = [] + + with graph.inserting_before(node): + for s in size: + if isinstance(s, torch.fx.Node): + # Node is already a tensor, but might be scalar + # We need to ensure it's 1D [1] shape + # Check if it's already 1D with shape [1] + s_meta = s.meta.get("val") + if s_meta is not None and len(s_meta.shape) == 1 and s_meta.shape[0] == 1: + # Already 1D, use as-is + shape_elements.append(s) + else: + # Need to reshape to [1] + reshape_node = create_node( + graph, + torch.ops.aten.reshape.default, + args=(s, [1]), + ) + reshape_node.meta["val"] = torch.zeros(1, dtype=torch.int32) + shape_elements.append(reshape_node) + + elif isinstance(s, (int, torch.SymInt)): + # Create a constant 1D tensor using full + val = int(s) + const_node = create_node( + graph, + torch.ops.aten.full.default, + args=([1], val), + kwargs={"dtype": torch.int32}, + ) + const_node.meta["val"] = torch.tensor([val], dtype=torch.int32) + shape_elements.append(const_node) + else: + raise RuntimeError(f"Unsupported size element: {s} {type(s)}") + + # Concatenate all shape elements + cat_node = create_node( + graph, + torch.ops.aten.cat.default, + args=(shape_elements, 0), + ) + # Set metadata for cat output + cat_node.meta["val"] = torch.zeros(len(size), dtype=torch.int32) + + # Replace the reshape args with the concatenated tensor + node.args = (input_node, cat_node) + modified = True + + logger.debug( + f"Prepared dynamic shape for {node.name}: concatenated {len(shape_elements)} elements" + ) + + if modified: + graph.eliminate_dead_code() + graph.lint() + graph_module.recompile() + + return PassResult(modified) diff --git a/tico/passes/remove_redundant_reshape.py b/tico/passes/remove_redundant_reshape.py index 87f4d83d..15285a5b 100644 --- a/tico/passes/remove_redundant_reshape.py +++ b/tico/passes/remove_redundant_reshape.py @@ -76,6 +76,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult: if len(reshape1.users) != 1: continue reshape1_args = ReshapeArgs(*reshape1.args, **reshape1.kwargs) # type: ignore[arg-type] + # Skip dynamic reshapes (shape is a single Node) + if isinstance(reshape1_args.shape, torch.fx.Node): + continue reshape1_input = reshape1_args.input # `(AxBxC) - aten.reshape` - (1xAxBxC) if [1] + list(extract_shape(reshape1_input)) != list( @@ -158,6 +161,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult: if len(reshape1.users) != 1: continue reshape1_args = ReshapeArgs(*reshape1.args, **reshape1.kwargs) # type: ignore[arg-type] + # Skip dynamic reshapes (shape is a single Node) + if isinstance(reshape1_args.shape, torch.fx.Node): + continue reshape1_input = reshape1_args.input # `(AxBxC) - aten.reshape` - (1xAxBxC) if [1] + list(extract_shape(reshape1_input)) != list( @@ -239,6 +245,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult: if not is_target_node(reshape_1, ops.aten.reshape): continue reshape_1_args = ReshapeArgs(*reshape_1.args, **reshape_1.kwargs) # type: ignore[arg-type] + # Skip dynamic reshapes (shape is a single Node) + if isinstance(reshape_1_args.shape, torch.fx.Node): + continue softmax = reshape_1_args.input # softmax @@ -275,6 +284,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult: if reshape_2.target not in ops.aten.reshape: continue reshape_2_args = ReshapeArgs(*reshape_2.args, **reshape_2.kwargs) # type: ignore[arg-type] + # Skip dynamic reshapes (shape is a single Node) + if isinstance(reshape_2_args.shape, torch.fx.Node): + continue reshape_2_input = reshape_2_args.input assert isinstance(reshape_2_input, torch.fx.Node), type(reshape_2_input) # reshape_3 @@ -283,6 +295,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult: if reshape_3.target not in ops.aten.reshape: continue reshape_3_args = ReshapeArgs(*reshape_3.args, **reshape_3.kwargs) # type: ignore[arg-type] + # Skip dynamic reshapes (shape is a single Node) + if isinstance(reshape_3_args.shape, torch.fx.Node): + continue reshape_3_input = reshape_3_args.input assert isinstance(reshape_3_input, torch.fx.Node), type(reshape_3_input) @@ -354,9 +369,12 @@ def call(self, exported_program: ExportedProgram) -> PassResult: reshape1_args = ReshapeArgs(*reshape1.args, **reshape1.kwargs) # type: ignore[arg-type] reshape1_input, size = reshape1_args.input, reshape1_args.shape assert isinstance(reshape1_input, torch.fx.Node), type(reshape1_input) + # Skip if dynamic shape (single Node or list with SymInt/Node) + if isinstance(size, torch.fx.Node): + continue assert isinstance(size, list), type(size) - for s in size: - assert isinstance(s, int), type(s) + if any(isinstance(s, (torch.SymInt, torch.fx.Node)) for s in size): + continue if not len(reshape1.users) == 1: continue @@ -369,9 +387,12 @@ def call(self, exported_program: ExportedProgram) -> PassResult: reshape2_args = ReshapeArgs(*reshape2.args, **reshape2.kwargs) # type: ignore[arg-type] reshape2_input, reshape2_size = reshape2_args.input, reshape2_args.shape assert isinstance(reshape2_input, torch.fx.Node), type(reshape2_input) + # Skip if dynamic shape (single Node or list with SymInt/Node) + if isinstance(reshape2_size, torch.fx.Node): + continue assert isinstance(reshape2_size, list), type(reshape2_size) - for s in reshape2_size: - assert isinstance(s, int), type(s) + if any(isinstance(s, (torch.SymInt, torch.fx.Node)) for s in reshape2_size): + continue with graph.inserting_before(reshape1): fused_reshape = create_node( @@ -418,6 +439,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult: args = ReshapeArgs(*node.args, **node.kwargs) # type: ignore[arg-type] output_shape = args.shape + # Skip dynamic reshapes (shape is a single Node) + if isinstance(output_shape, torch.fx.Node): + continue input_shape = list(extract_shape(args.input)) if output_shape != input_shape: diff --git a/tico/serialize/operators/op_circle_shape.py b/tico/serialize/operators/op_circle_shape.py new file mode 100644 index 00000000..bb9648b7 --- /dev/null +++ b/tico/serialize/operators/op_circle_shape.py @@ -0,0 +1,66 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, TYPE_CHECKING + +if TYPE_CHECKING: + import torch._ops + import torch.fx +import torch +from circle_schema import circle + +from tico.serialize.circle_graph import CircleSubgraph +from tico.serialize.operators.hashable_opcode import OpCode +from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor +from tico.serialize.operators.utils import create_builtin_operator, get_op_index + + +@register_node_visitor +class CircleShapeVisitor(NodeVisitor): + """ + Visitor for circle_custom::shape operator. + + This operator extracts the shape of a tensor. + It's implemented using Circle's SHAPE operator. + """ + + target: List[torch._ops.OpOverload] = [ + torch.ops.circle_custom.shape, + ] + + def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): + super().__init__(op_codes, graph) + + def define_node( + self, + node: torch.fx.Node, + ) -> circle.Operator.OperatorT: + # Args: (input) + input_node = node.args[0] + + # SHAPE operator to get the full shape of input + op_index = get_op_index( + circle.BuiltinOperator.BuiltinOperator.SHAPE, self._op_codes + ) + + inputs = [input_node] + outputs = [node] + operator = create_builtin_operator( + self.graph, op_index, inputs, outputs + ) + operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ShapeOptions + operator.builtinOptions = circle.ShapeOptions.ShapeOptionsT() + operator.builtinOptions.outType = circle.TensorType.TensorType.INT32 + + return operator diff --git a/tico/serialize/operators/op_repeat.py b/tico/serialize/operators/op_repeat.py index bfc2a1a4..7d0d4345 100644 --- a/tico/serialize/operators/op_repeat.py +++ b/tico/serialize/operators/op_repeat.py @@ -25,6 +25,7 @@ extract_circle_dtype, extract_shape, to_circle_shape, + circle_legalize_dtype_to, ) from tico.serialize.operators.hashable_opcode import OpCode from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor @@ -48,62 +49,201 @@ def define_node( input = args.input repeats = args.repeats - for r in repeats: - if r == 0: - # TODO: Support r == 0 case - raise NotYetSupportedError("Only support positive repeat value") - elif r < 0: - raise InvalidArgumentError("Only support positive repeat value") - - tensor_shape = extract_shape(input) - assert len(tensor_shape) <= len(repeats) - if len(tensor_shape) != len(repeats): - # TODO Support len(tensor_shape) < len(repeats) - raise NotYetSupportedError( - "Length of both input tensor and repeats vector should be same." - ) - repeat_dim_cnt = len(repeats) - repeats.count(1) - tensor_dtype = extract_circle_dtype(input) - op_index = get_op_index( - circle.BuiltinOperator.BuiltinOperator.CONCATENATION, self._op_codes - ) - concat_input: torch.fx.Node | circle.Tensor.TensorT = input - concat_output: torch.fx.node.Node | circle.Tensor.TensorT = node - for idx, r in enumerate(repeats): - # concat along idx dimension - if r > 1: - # Except last created concat, a tensor should be created. - if repeat_dim_cnt > 1: - repeated_shape: List[int | torch.SymInt] = list(tensor_shape) - repeated_shape[idx] = repeated_shape[idx] * r + # Check ranks + input_tensor = self.graph.get_tensor(input) + input_rank = len(input_tensor.shape) + repeats_len = len(repeats) + + if input_rank > repeats_len: + raise RuntimeError(f"Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor") + + # If input rank < repeats length, we need to reshape input to match rank + tile_input = input + if input_rank < repeats_len: + # We need to prepend 1s to input shape + # If input is static, we can just compute new shape + # If input is dynamic, we need to construct shape tensor + + # Check if input is dynamic + input_is_dynamic = input_tensor.shapeSignature is not None + + if not input_is_dynamic: + new_shape = [1] * (repeats_len - input_rank) + input_tensor.shape + # Create Reshape op + reshape_op_idx = get_op_index(circle.BuiltinOperator.BuiltinOperator.RESHAPE, self._op_codes) + reshaped_input = self.graph.add_tensor_from_scratch( + prefix=f"{input.name}_reshaped_for_tile", + shape=new_shape, + shape_signature=None, + dtype=input_tensor.type, + source_node=input + ) + + # Create shape tensor for Reshape (required by Circle?) + # Or just use newShape option. + # Using newShape option is enough for static. + # But let's provide shape tensor for consistency if possible, or just option. + # Existing op_reshape uses shape tensor if available. + # Here we can just use option for static. + # Wait, op_reshape always provides shape tensor input. + # Let's provide it. + shape_tensor = self.graph.add_const_tensor(circle_legalize_dtype_to(new_shape, dtype=torch.int32)) + + reshape_op = create_builtin_operator( + self.graph, reshape_op_idx, [input, shape_tensor], [reshaped_input] + ) + reshape_op.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ReshapeOptions + reshape_op.builtinOptions = circle.ReshapeOptions.ReshapeOptionsT() + reshape_op.builtinOptions.newShape = new_shape + self.graph.add_operator(reshape_op) + tile_input = reshaped_input + else: + # Dynamic input. Construct shape tensor: [1, 1...] + Shape(input) + # 1. Shape op + shape_op_idx = get_op_index(circle.BuiltinOperator.BuiltinOperator.SHAPE, self._op_codes) + input_shape_tensor = self.graph.add_tensor_from_scratch( + prefix=f"{input.name}_shape", + shape=[input_rank], + shape_signature=None, + dtype=circle.TensorType.TensorType.INT32, + source_node=input + ) + shape_op = create_builtin_operator( + self.graph, shape_op_idx, [input], [input_shape_tensor] + ) + shape_op.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ShapeOptions + shape_op.builtinOptions = circle.ShapeOptions.ShapeOptionsT() + shape_op.builtinOptions.outType = circle.TensorType.TensorType.INT32 + self.graph.add_operator(shape_op) + + # 2. Const tensor for prefix 1s + prefix_len = repeats_len - input_rank + prefix_shape = [1] * prefix_len + prefix_tensor = self.graph.add_const_tensor(circle_legalize_dtype_to(prefix_shape, dtype=torch.int32)) + + # 3. Concat + concat_op_idx = get_op_index(circle.BuiltinOperator.BuiltinOperator.CONCATENATION, self._op_codes) + new_shape_tensor = self.graph.add_tensor_from_scratch( + prefix=f"{input.name}_new_shape", + shape=[repeats_len], + shape_signature=None, + dtype=circle.TensorType.TensorType.INT32, + source_node=input + ) + concat_op = create_builtin_operator( + self.graph, concat_op_idx, [prefix_tensor, input_shape_tensor], [new_shape_tensor] + ) + concat_op.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ConcatenationOptions + concat_op.builtinOptions = circle.ConcatenationOptions.ConcatenationOptionsT() + concat_op.builtinOptions.axis = 0 + self.graph.add_operator(concat_op) + + # 4. Reshape + reshape_op_idx = get_op_index(circle.BuiltinOperator.BuiltinOperator.RESHAPE, self._op_codes) + # Output shape is dynamic (rank is repeats_len) + # We can compute shape signature: [1]*prefix + input_sig + new_sig = [1]*prefix_len + (input_tensor.shapeSignature if input_tensor.shapeSignature else input_tensor.shape) + # Wait, input_tensor.shape might contain 1 for dynamic. + # input_tensor.shapeSignature contains -1. + + reshaped_input = self.graph.add_tensor_from_scratch( + prefix=f"{input.name}_reshaped_for_tile", + shape=[1]*prefix_len + input_tensor.shape, + shape_signature=new_sig, + dtype=input_tensor.type, + source_node=input + ) + + reshape_op = create_builtin_operator( + self.graph, reshape_op_idx, [input, new_shape_tensor], [reshaped_input] + ) + reshape_op.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ReshapeOptions + reshape_op.builtinOptions = circle.ReshapeOptions.ReshapeOptionsT() + reshape_op.builtinOptions.newShape = [-1] * repeats_len # Dummy for dynamic + self.graph.add_operator(reshape_op) + tile_input = reshaped_input - repeated_cshape, repeated_cshape_signature = to_circle_shape( - repeated_shape + # Construct multiples tensor + multiples_tensors = [] + for r in repeats: + if isinstance(r, torch.fx.Node): + r_tensor = self.graph.get_tensor(r) + # Cast to INT32 if needed + if r_tensor.type == circle.TensorType.TensorType.INT64: + cast_op_idx = get_op_index(circle.BuiltinOperator.BuiltinOperator.CAST, self._op_codes) + r_i32 = self.graph.add_tensor_from_scratch( + prefix=f"{r.name}_cast_i32", + shape=list(r_tensor.shape), + shape_signature=list(r_tensor.shapeSignature) if r_tensor.shapeSignature else None, + dtype=circle.TensorType.TensorType.INT32, + source_node=r ) - concat_output = self.graph.add_tensor_from_scratch( - prefix=f"{node.name}_concat_{idx}", - shape=repeated_cshape, - shape_signature=repeated_cshape_signature, - dtype=tensor_dtype, - source_node=node, + cast_op = create_builtin_operator( + self.graph, cast_op_idx, [r], [r_i32] ) - inputs = [concat_input] * r - if repeat_dim_cnt == 1: - outputs: List[torch.fx.node.Node | circle.Tensor.TensorT] = [node] - else: - outputs = [concat_output] - operator = create_builtin_operator( - self.graph, op_index, inputs, outputs + cast_op.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.CastOptions + cast_op.builtinOptions = circle.CastOptions.CastOptionsT() + cast_op.builtinOptions.inDataType = circle.TensorType.TensorType.INT64 + cast_op.builtinOptions.outDataType = circle.TensorType.TensorType.INT32 + self.graph.add_operator(cast_op) + r_tensor = r_i32 + r = r_i32 + + # Reshape to [1] + reshape_op_idx = get_op_index(circle.BuiltinOperator.BuiltinOperator.RESHAPE, self._op_codes) + reshaped_r = self.graph.add_tensor_from_scratch( + prefix=f"{r_tensor.name}_reshaped_1d", + shape=[1], + shape_signature=[1], + dtype=circle.TensorType.TensorType.INT32, + source_node=None ) - operator.builtinOptionsType = ( - circle.BuiltinOptions.BuiltinOptions.ConcatenationOptions + shape_1 = self.graph.add_const_tensor([1]) + + reshape_op = create_builtin_operator( + self.graph, reshape_op_idx, [r_tensor], [reshaped_r] ) - option = circle.ConcatenationOptions.ConcatenationOptionsT() - option.axis = idx - operator.builtinOptions = option - if repeat_dim_cnt > 1: - self.graph.add_operator(operator) - concat_input = concat_output - repeat_dim_cnt -= 1 - + reshape_op.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ReshapeOptions + reshape_op.builtinOptions = circle.ReshapeOptions.ReshapeOptionsT() + reshape_op.builtinOptions.newShape = [1] + self.graph.add_operator(reshape_op) + + multiples_tensors.append(reshaped_r) + + elif isinstance(r, (int, torch.SymInt)): + val = int(r) + t_i32_val = circle_legalize_dtype_to([val], dtype=torch.int32) + t = self.graph.add_const_tensor(t_i32_val) + multiples_tensors.append(t) + else: + raise RuntimeError(f"Unsupported repeat element: {r} {type(r)}") + + # Concatenate multiples + concat_op_idx = get_op_index(circle.BuiltinOperator.BuiltinOperator.CONCATENATION, self._op_codes) + multiples_tensor = self.graph.add_tensor_from_scratch( + prefix=f"{node.name}_multiples", + shape=[repeats_len], + shape_signature=None, # multiples is always static shape [rank] + dtype=circle.TensorType.TensorType.INT32, + source_node=node + ) + concat_op = create_builtin_operator( + self.graph, concat_op_idx, multiples_tensors, [multiples_tensor] + ) + concat_op.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ConcatenationOptions + concat_op.builtinOptions = circle.ConcatenationOptions.ConcatenationOptionsT() + concat_op.builtinOptions.axis = 0 + self.graph.add_operator(concat_op) + + # Tile op + op_index = get_op_index( + circle.BuiltinOperator.BuiltinOperator.TILE, self._op_codes + ) + inputs = [tile_input, multiples_tensor] + outputs = [node] + + operator = create_builtin_operator(self.graph, op_index, inputs, outputs) + operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.TileOptions + operator.builtinOptions = circle.TileOptions.TileOptionsT() + return operator diff --git a/tico/serialize/operators/op_reshape.py b/tico/serialize/operators/op_reshape.py index ae81c582..b6c4d55d 100644 --- a/tico/serialize/operators/op_reshape.py +++ b/tico/serialize/operators/op_reshape.py @@ -47,16 +47,55 @@ def define_node( self._op_codes, ) args = ReshapeArgs(*node.args, **node.kwargs) # type: ignore[arg-type] - input = args.input + input_node = args.input size = args.shape - if isinstance(size, int): - raise NotYetSupportedError("scalar size conversion is not supported yet.") + # After PrepareReshapeDynamicShape pass, size is either: + # 1. A constant list/tuple (static shape) + # 2. A single fx.Node (dynamic shape tensor prepared by the pass) + + if isinstance(size, torch.fx.Node): + # Dynamic shape: size is a 1D tensor node + # Cast to INT32 if needed + size_tensor = self.graph.get_tensor(size) + + if size_tensor.type == circle.TensorType.TensorType.INT64: + cast_op_idx = get_op_index( + circle.BuiltinOperator.BuiltinOperator.CAST, self._op_codes + ) + size_i32 = self.graph.add_tensor_from_scratch( + prefix=f"{size.name}_i32", + shape=list(size_tensor.shape), + shape_signature=list(size_tensor.shapeSignature) if size_tensor.shapeSignature else None, + dtype=circle.TensorType.TensorType.INT32, + source_node=size + ) + cast_op = create_builtin_operator( + self.graph, cast_op_idx, [size], [size_i32] + ) + cast_op.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.CastOptions + cast_op.builtinOptions = circle.CastOptions.CastOptionsT() + cast_op.builtinOptions.inDataType = circle.TensorType.TensorType.INT64 + cast_op.builtinOptions.outDataType = circle.TensorType.TensorType.INT32 + self.graph.add_operator(cast_op) + size_node = size_i32 + else: + size_node = size + + inputs = [input_node, size_node] + # size_tensor.shape[0] gives us the rank of the output tensor + new_shape = [-1] * size_tensor.shape[0] # Placeholder for dynamic shape + else: + # Static shape: size is a constant list/tuple + if isinstance(size, int): + raise NotYetSupportedError("scalar size conversion is not supported yet.") - assert is_const(size), type(size) + assert is_const(size), type(size) + + size_i32 = circle_legalize_dtype_to(size, dtype=torch.int32) + inputs = [input_node, size_i32] + new_shape = size_i32.tolist() - size_i32 = circle_legalize_dtype_to(size, dtype=torch.int32) - inputs = [input, size_i32] outputs = [node] operator = create_builtin_operator(self.graph, op_index, inputs, outputs) @@ -66,7 +105,7 @@ def define_node( circle.BuiltinOptions.BuiltinOptions.ReshapeOptions ) option = circle.ReshapeOptions.ReshapeOptionsT() - option.newShape = size_i32.tolist() + option.newShape = new_shape operator.builtinOptions = option diff --git a/tico/serialize/operators/op_split_with_sizes.py b/tico/serialize/operators/op_split_with_sizes.py index eff3aaf0..cd450e79 100644 --- a/tico/serialize/operators/op_split_with_sizes.py +++ b/tico/serialize/operators/op_split_with_sizes.py @@ -22,7 +22,7 @@ from torch._subclasses.fake_tensor import FakeTensor from tico.serialize.circle_graph import CircleSubgraph -from tico.serialize.circle_mapping import circle_legalize_dtype_to, to_circle_dtype +from tico.serialize.circle_mapping import circle_legalize_dtype_to, to_circle_dtype, to_circle_shape from tico.serialize.operators.hashable_opcode import OpCode from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor from tico.serialize.operators.utils import create_builtin_operator, get_op_index @@ -50,12 +50,80 @@ def define_node( split_sizes = args.split_sizes axis = args.dim - split_sizes_i32 = [ - circle_legalize_dtype_to(split_size, dtype=torch.int32) - for split_size in split_sizes - ] + # Construct split_sizes_tensor + split_sizes_tensors = [] + for s in split_sizes: + if isinstance(s, torch.fx.Node): + s_tensor = self.graph.get_tensor(s) + # Cast to INT32 if needed + if s_tensor.type == circle.TensorType.TensorType.INT64: + cast_op_idx = get_op_index(circle.BuiltinOperator.BuiltinOperator.CAST, self._op_codes) + s_i32 = self.graph.add_tensor_from_scratch( + prefix=f"{s.name}_cast_i32", + shape=list(s_tensor.shape), + shape_signature=list(s_tensor.shapeSignature) if s_tensor.shapeSignature else None, + dtype=circle.TensorType.TensorType.INT32, + source_node=s + ) + cast_op = create_builtin_operator( + self.graph, cast_op_idx, [s], [s_i32] + ) + cast_op.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.CastOptions + cast_op.builtinOptions = circle.CastOptions.CastOptionsT() + cast_op.builtinOptions.inDataType = circle.TensorType.TensorType.INT64 + cast_op.builtinOptions.outDataType = circle.TensorType.TensorType.INT32 + self.graph.add_operator(cast_op) + s_tensor = s_i32 + s = s_i32 + + # Reshape to [1] + reshape_op_idx = get_op_index(circle.BuiltinOperator.BuiltinOperator.RESHAPE, self._op_codes) + reshaped_s = self.graph.add_tensor_from_scratch( + prefix=f"{s_tensor.name}_reshaped_1d", + shape=[1], + shape_signature=[1], + dtype=circle.TensorType.TensorType.INT32, + source_node=None + ) + shape_1 = self.graph.add_const_tensor([1]) + + reshape_op = create_builtin_operator( + self.graph, reshape_op_idx, [s_tensor], [reshaped_s] + ) + reshape_op.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ReshapeOptions + reshape_op.builtinOptions = circle.ReshapeOptions.ReshapeOptionsT() + reshape_op.builtinOptions.newShape = [1] + self.graph.add_operator(reshape_op) + + split_sizes_tensors.append(reshaped_s) + + elif isinstance(s, (int, torch.SymInt)): + val = int(s) + t_i32_val = circle_legalize_dtype_to([val], dtype=torch.int32) + t = self.graph.add_const_tensor(t_i32_val) + split_sizes_tensors.append(t) + else: + raise RuntimeError(f"Unsupported split_size element: {s} {type(s)}") + + # Concatenate split_sizes + concat_op_idx = get_op_index(circle.BuiltinOperator.BuiltinOperator.CONCATENATION, self._op_codes) + split_sizes_tensor = self.graph.add_tensor_from_scratch( + prefix=f"{node.name}_split_sizes", + shape=[len(split_sizes)], + shape_signature=None, + dtype=circle.TensorType.TensorType.INT32, + source_node=node + ) + concat_op = create_builtin_operator( + self.graph, concat_op_idx, split_sizes_tensors, [split_sizes_tensor] + ) + concat_op.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ConcatenationOptions + concat_op.builtinOptions = circle.ConcatenationOptions.ConcatenationOptionsT() + concat_op.builtinOptions.axis = 0 + self.graph.add_operator(concat_op) + axis_i32 = circle_legalize_dtype_to(axis, dtype=torch.int32) - inputs = [input, split_sizes_i32, axis_i32] + inputs = [input, split_sizes_tensor, axis_i32] """ `split_with_sizes` has multiple output tensors along with `getitem`. @@ -83,15 +151,13 @@ def define_node( assert isinstance(fake_tensor, FakeTensor) shape = list(fake_tensor.size()) - if any(isinstance(s, torch.SymInt) for s in shape): - # TODO: support dynamic shape - raise NotImplementedError("Dynamic shape is not supported yet.") + c_shape, c_sig = to_circle_shape(shape) dtype = to_circle_dtype(fake_tensor.dtype) tensor = self.graph.add_tensor_from_scratch( prefix=f"{node.name}_unused_{idx}", - shape=shape, - shape_signature=None, # TODO: support dynamic shape + shape=c_shape, + shape_signature=c_sig, dtype=dtype, source_node=node, ) diff --git a/tico/serialize/operators/op_view.py b/tico/serialize/operators/op_view.py index 8572c25d..fda237fd 100644 --- a/tico/serialize/operators/op_view.py +++ b/tico/serialize/operators/op_view.py @@ -51,14 +51,95 @@ def define_node( input = args.input size = args.size - assert is_const(size), type(size) + is_dynamic = any(isinstance(s, (torch.SymInt, torch.fx.Node)) for s in size) - if isinstance(size, int): - raise Exception("scalar size conversion is not supported yet.") + if not is_dynamic: + assert is_const(size), type(size) + + if isinstance(size, int): + raise Exception("scalar size conversion is not supported yet.") + + size_i32 = circle_legalize_dtype_to(size, dtype=torch.int32) + inputs = [input, size_i32] + else: + shape_tensors = [] + for s in size: + if isinstance(s, torch.fx.Node): + s_tensor = self.graph.get_tensor(s) + + # Cast to INT32 if needed + if s_tensor.type == circle.TensorType.TensorType.INT64: + cast_op_idx = get_op_index(circle.BuiltinOperator.BuiltinOperator.CAST, self._op_codes) + s_i32 = self.graph.add_tensor_from_scratch( + prefix=f"{s.name}_cast_i32", + shape=list(s_tensor.shape), + shape_signature=list(s_tensor.shapeSignature) if s_tensor.shapeSignature else None, + dtype=circle.TensorType.TensorType.INT32, + source_node=s + ) + cast_op = create_builtin_operator( + self.graph, cast_op_idx, [s], [s_i32] + ) + cast_op.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.CastOptions + cast_op.builtinOptions = circle.CastOptions.CastOptionsT() + cast_op.builtinOptions.inDataType = circle.TensorType.TensorType.INT64 + cast_op.builtinOptions.outDataType = circle.TensorType.TensorType.INT32 + self.graph.add_operator(cast_op) + s_tensor = s_i32 + s = s_i32 + + # Reshape to [1] + reshape_op_idx = get_op_index(circle.BuiltinOperator.BuiltinOperator.RESHAPE, self._op_codes) + reshaped_s = self.graph.add_tensor_from_scratch( + prefix=f"{s_tensor.name}_reshaped_1d", + shape=[1], + shape_signature=[1], + dtype=circle.TensorType.TensorType.INT32, + source_node=None + ) + shape_1_data = circle_legalize_dtype_to([1], dtype=torch.int32) + shape_1 = self.graph.add_const_tensor(shape_1_data) + + reshape_op = create_builtin_operator( + self.graph, reshape_op_idx, [s_tensor, shape_1], [reshaped_s] + ) + reshape_op.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ReshapeOptions + reshape_op.builtinOptions = circle.ReshapeOptions.ReshapeOptionsT() + reshape_op.builtinOptions.newShape = [1] + self.graph.add_operator(reshape_op) + + shape_tensors.append(reshaped_s) + + elif isinstance(s, (int, torch.SymInt)): + val = int(s) + t_i32_val = circle_legalize_dtype_to([val], dtype=torch.int32) + t = self.graph.add_const_tensor(t_i32_val) + shape_tensors.append(t) + else: + raise RuntimeError(f"Unsupported size element: {s} {type(s)}") + + # Concatenate + concat_op_idx = get_op_index(circle.BuiltinOperator.BuiltinOperator.CONCATENATION, self._op_codes) + + shape_tensor_shape = [len(size)] + shape_tensor = self.graph.add_tensor_from_scratch( + prefix=f"{node.name}_shape_tensor", + shape=shape_tensor_shape, + shape_signature=shape_tensor_shape, + dtype=circle.TensorType.TensorType.INT32, + source_node=node + ) + + concat_op = create_builtin_operator( + self.graph, concat_op_idx, shape_tensors, [shape_tensor] + ) + concat_op.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ConcatenationOptions + concat_op.builtinOptions = circle.ConcatenationOptions.ConcatenationOptionsT() + concat_op.builtinOptions.axis = 0 + self.graph.add_operator(concat_op) + + inputs = [input, shape_tensor] - # TODO: support dynamic shape - size_i32 = circle_legalize_dtype_to(size, dtype=torch.int32) - inputs = [input, size_i32] outputs = [node] operator = create_builtin_operator(self.graph, op_index, inputs, outputs) @@ -68,7 +149,10 @@ def define_node( circle.BuiltinOptions.BuiltinOptions.ReshapeOptions ) option = circle.ReshapeOptions.ReshapeOptionsT() - option.newShape = size_i32.tolist() + if not is_dynamic: + option.newShape = size_i32.tolist() + else: + option.newShape = [-1] * len(size) operator.builtinOptions = option