From a8b176f12b4a8fb35e34653516484388b098db33 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Wed, 10 Sep 2025 21:15:23 +0900 Subject: [PATCH 1/9] PoC: Convert cond --- call.py | 70 ++++++ if.py | 76 +++++++ signature.py | 61 +++++ test.py | 8 + test/modules/op/cond.py | 80 +++++++ .../controlflow/passes/map_subgraph.py | 88 ++++++++ .../quantization/passes/fold_quant_ops.py | 3 +- .../insert_quantize_on_dtype_mismatch.py | 4 +- .../passes/propagate_qparam_backward.py | 4 +- .../passes/propagate_qparam_forward.py | 4 +- .../quantization/passes/quantize_bias.py | 4 +- .../passes/remove_weight_dequant_op.py | 4 +- tico/interpreter/infer.py | 2 +- tico/passes/cast_aten_where_arg_type.py | 3 +- tico/passes/cast_clamp_mixed_type_args.py | 9 +- tico/passes/cast_mixed_type_args.py | 4 +- tico/passes/const_prop_pass.py | 15 +- tico/passes/convert_conv1d_to_conv2d.py | 4 +- tico/passes/convert_layout_op_to_reshape.py | 4 +- tico/passes/convert_repeat_to_expand_copy.py | 4 +- tico/passes/convert_to_relu6.py | 4 +- tico/passes/decompose_addmm.py | 4 +- tico/passes/decompose_batch_norm.py | 4 +- tico/passes/decompose_fake_quantize.py | 4 +- .../decompose_fake_quantize_tensor_qparams.py | 4 +- tico/passes/decompose_group_norm.py | 4 +- tico/passes/decompose_grouped_conv2d.py | 4 +- tico/passes/decompose_slice_scatter.py | 4 +- tico/passes/extract_dtype_kwargs.py | 3 +- tico/passes/fill_meta_val.py | 4 +- tico/passes/fuse_leading_unsqueeze_reshape.py | 7 +- tico/passes/fuse_redundant_reshape_to_mean.py | 4 +- tico/passes/legalize_causal_mask_value.py | 4 +- .../legalize_predefined_layout_operators.py | 4 +- tico/passes/lower_pow2_to_mul.py | 4 +- .../lower_to_resize_nearest_neighbor.py | 3 +- tico/passes/lower_to_slice.py | 8 +- tico/passes/merge_consecutive_cat.py | 4 +- tico/passes/remove_nop.py | 4 +- tico/passes/remove_redundant_assert_nodes.py | 3 +- tico/passes/remove_redundant_expand.py | 4 +- tico/passes/remove_redundant_permute.py | 4 +- tico/passes/remove_redundant_reshape.py | 24 +- tico/passes/remove_redundant_slice.py | 4 +- tico/passes/remove_redundant_to_copy.py | 4 +- tico/passes/restore_linear.py | 4 +- tico/passes/segment_index_select.py | 4 +- tico/serialize/circle_graph.py | 7 +- tico/serialize/circle_serializer.py | 210 ++++++++++++------ tico/serialize/operators/op_circle_if.py | 66 ++++++ tico/serialize/operators/utils.py | 18 +- tico/utils/convert.py | 188 +++++++++------- tico/utils/passes.py | 8 +- tico/utils/register_custom_op.py | 45 +++- tico/utils/signature.py | 2 +- tico/utils/subgraph.py | 43 ++++ tico/utils/trace_decorators.py | 12 +- tico/utils/validate_args_kwargs.py | 21 ++ 58 files changed, 895 insertions(+), 308 deletions(-) create mode 100644 call.py create mode 100644 if.py create mode 100644 signature.py create mode 100644 test.py create mode 100644 test/modules/op/cond.py create mode 100644 tico/experimental/controlflow/passes/map_subgraph.py create mode 100644 tico/serialize/operators/op_circle_if.py create mode 100644 tico/utils/subgraph.py diff --git a/call.py b/call.py new file mode 100644 index 00000000..d807a56d --- /dev/null +++ b/call.py @@ -0,0 +1,70 @@ +""" Example - circle model import/export """ + +import pycircle + +from pycircle.circleir.model import Model +from pycircle.circleir.subgraph import Subgraph +from pycircle.circleir.tensor import Tensor +from pycircle.circleir.operators import CircleAdd, CircleCall +from pycircle.util.alias import TensorType + + +### subgraph 0 +### input0, input1 -> call0 (subgraph 1) -> tensor0 +### tensor0, weights0 -> add0 -> tensor1 +graph0 = Subgraph() +graph0.name = "graph0" +graph0.inputs = [ + Tensor("sub1_input0", [1, 3], TensorType.FLOAT32), + Tensor("sub1_input1", [1, 3], TensorType.FLOAT32), +] + +call0 = CircleCall() +call0.inputs = [graph0.inputs[0], graph0.inputs[1]] +call0.subgraph = 1 +call0.outputs(0).attribute("Call0", [1, 3], TensorType.FLOAT32) + + +add1 = CircleAdd() +weights0 = Tensor("weights0", [1, 3], TensorType.FLOAT32, [100., 100., 100.]) +add1.inputs = [call0.outputs(0), weights0] +add1.outputs(0).attribute("add0", [1, 3], TensorType.FLOAT32) + +graph0.outputs = [add1.outputs(0)] + +### subgraph 1 +### input0, input1 -> ADD -> output +graph1 = Subgraph() +graph1.name = "graph1" +graph1.inputs = [ + Tensor("input0", [1, 3], TensorType.FLOAT32), + Tensor("input1", [1, 3], TensorType.FLOAT32, [-100., -100., -100.]) +] +sub_add = CircleAdd() +sub_add.inputs = [graph1.inputs[0], graph1.inputs[1]] +sub_add.outputs(0).attribute("SubAdd", [1, 3], TensorType.FLOAT32) +graph1.outputs = [sub_add.outputs(0)] + +### model +circle_model = Model() +circle_model.subgraphs = [graph0, graph1] +circle_model.signature_defs = { + "graph0": { + "subgraph_index": 0 + }, + "graph1": { + "subgraph_index": 1 + }, +} + +pycircle.export_circle_model(circle_model, "call.circle") + +import torch +try: + from onert import infer +except ImportError: + raise RuntimeError("The 'onert' package is required to run this function.") + +session_float = infer.session("call.circle") +output = session_float.infer((torch.randn(1,3),torch.randn(1,3),), measure=True) +print(output) \ No newline at end of file diff --git a/if.py b/if.py new file mode 100644 index 00000000..cf240484 --- /dev/null +++ b/if.py @@ -0,0 +1,76 @@ +""" Example - circle model import/export """ + +import pycircle +from pycircle.circleir.model import Model +from pycircle.circleir.subgraph import Subgraph +from pycircle.circleir.tensor import Tensor +from pycircle.circleir.operators import CircleAdd, CircleIf +from pycircle.util.alias import TensorType + +# 입력 텐서 및 상수 텐서 정의 +input_tensor0 = Tensor("input0", [1, 3], TensorType.FLOAT32) +input_tensor1 = Tensor("input1", [1, 3], TensorType.FLOAT32) +weight_add_100 = Tensor("constant0", [1, 3], TensorType.FLOAT32, [100, 100, 100]) +weight_sub_100 = Tensor("constant1", [1, 3], TensorType.FLOAT32, [-100, -100, -100]) + +### then_subgraph ### +then_subgraph = Subgraph() +then_subgraph.inputs = [Tensor("input0", [1, 3], TensorType.FLOAT32), weight_add_100] + +add_op_then = CircleAdd() +add_op_then.inputs = [then_subgraph.inputs[0], then_subgraph.inputs[1]] +add_op_then.outputs(0).attribute("add_output_then", [1, 3], TensorType.FLOAT32) +then_subgraph.outputs = [add_op_then.outputs(0)] + +### else_subgraph ### +else_subgraph = Subgraph() +else_subgraph.inputs = [Tensor("input0", [1, 3], TensorType.FLOAT32), weight_sub_100] + +add_op_else = CircleAdd() +add_op_else.inputs = [else_subgraph.inputs[0], Tensor("input0", [1, 3], TensorType.FLOAT32)] +add_op_else.outputs(0).attribute("add_output_else", [1, 3], TensorType.FLOAT32) +else_subgraph.outputs = [add_op_else.outputs(0)] + +### root_subgraph with CircleIf ### +root_subgraph = Subgraph() +root_subgraph.name = "root_subgraph" +condition_tensor = Tensor("condition", [1], TensorType.BOOL) +root_subgraph.inputs = [condition_tensor, input_tensor0, input_tensor1] + +circle_if_op = CircleIf(1, 2) +circle_if_op.inputs = [condition_tensor, input_tensor0, input_tensor1] +circle_if_op.outputs(0).attribute("output_tensor", [1, 3], TensorType.FLOAT32) +circle_if_op.then_subgraph_index = 1 +circle_if_op.else_subgraph_index = 2 +root_subgraph.outputs = [circle_if_op.outputs(0)] + +# 모델 구성 +circle_model = Model() +circle_model.description = "pycircle example : signature_def" +circle_model.subgraphs = [root_subgraph, then_subgraph, else_subgraph] +circle_model.signature_defs = { + "root_graph": {"subgraph_index": 0}, + "then_graph": {"subgraph_index": 1}, + "else_graph": {"subgraph_index": 2}, +} + +# 모델 export +pycircle.export_circle_model(circle_model, "signature_def.circle") + +# onert를 통한 추론 (Inference) +import torch +try: + from onert import infer +except ImportError: + raise RuntimeError("The 'onert' package is required to run this function.") + +session = infer.session("signature_def.circle") +output = session.infer( + ( + torch.tensor([True]), # condition tensor + torch.randn(1, 3), # input tensor 0 + torch.tensor([[100., 100., 100.]]),# weights tensor + ), + measure=True +) +print(output) diff --git a/signature.py b/signature.py new file mode 100644 index 00000000..d4ad0f89 --- /dev/null +++ b/signature.py @@ -0,0 +1,61 @@ +""" Example - circle model import/export """ + +import pycircle + +from pycircle.circleir.model import Model +from pycircle.circleir.subgraph import Subgraph +from pycircle.circleir.tensor import Tensor +from pycircle.circleir.operators import CircleAdd +from pycircle.util.alias import TensorType + +subgraph1 = Subgraph() +subgraph1.name = "subgraph1" +subgraph1.inputs = [ + Tensor("subgraph1_input", [1, 3], TensorType.FLOAT32), +] + +weights1 = Tensor("constant1", [1, 3], TensorType.FLOAT32, [0.1, 0.2, 0.3]) + +add1 = CircleAdd() +add1.inputs = [subgraph1.inputs[0], weights1] +add1.outputs(0).attribute("Add1", [1, 3], TensorType.FLOAT32) + +subgraph1.outputs = [add1.outputs(0)] + +subgraph2 = Subgraph() +subgraph2.name = "subgraph2" +subgraph2.inputs = [ + Tensor("subgraph2_input1", [1, 3], TensorType.FLOAT32), + Tensor("subgraph2_input2", [1, 3], TensorType.FLOAT32), +] + +add2 = CircleAdd() +add2.inputs = [subgraph2.inputs[0], subgraph2.inputs[1]] +add2.outputs(0).attribute("Add2", [1, 3], TensorType.FLOAT32) + +subgraph2.outputs = [add2.outputs(0)] + +circle_model = Model() +circle_model.description = "pycircle example : signature_def" +# circle_model.subgraphs = [subgraph2, subgraph1] +circle_model.subgraphs = [subgraph1, subgraph2] +circle_model.signature_defs = { + "add_constant": { + "subgraph_index": 0 + }, + "add_two_inputs": { + "subgraph_index": 1 + }, +} + +pycircle.export_circle_model(circle_model, "signature_def_original.circle") + +import torch +try: + from onert import infer +except ImportError: + raise RuntimeError("The 'onert' package is required to run this function.") + +session_float = infer.session("signature_def_original.circle") +output = session_float.infer((torch.randn(1,3),)) +breakpoint() \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 00000000..6f0f4e4b --- /dev/null +++ b/test.py @@ -0,0 +1,8 @@ + +def test(): + if 1 is 1: + pass + + print("HI") + +test() \ No newline at end of file diff --git a/test/modules/op/cond.py b/test/modules/op/cond.py new file mode 100644 index 00000000..8aaff89b --- /dev/null +++ b/test/modules/op/cond.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from test.modules.base import TestModuleBase + +class SimpleCond1(TestModuleBase): + class Sin(torch.nn.Module): + def forward(self, x): + return torch.sin(x) + 1 + + class Cos(torch.nn.Module): + def forward(self, x): + return torch.cos(x) - 1 + + def __init__(self): + super().__init__() + self.sin = self.Sin() + self.cos = self.Cos() + + def forward(self, x, y): + return torch.cond(x.sum() + y.sum() > 0, + lambda x_: self.sin(x_), + lambda x_: self.cos(x_), + operands=(x,)) + def get_example_inputs(self): + return (torch.randn(3, 3), torch.randn(3, 3)), {} + + + +class SimpleCond2(TestModuleBase): + class Sin(torch.nn.Module): + def forward(self, x, y): + return torch.sin(x) + 1 + + class Cos(torch.nn.Module): + def forward(self, x, y): + return torch.cos(x) - 1 + + def __init__(self): + super().__init__() + self.sin = self.Sin() + self.cos = self.Cos() + + def forward(self, x, y): + return torch.cond(x.sum() + y.sum() > 0, + lambda x, y: self.sin(x, y), + lambda x, y: self.cos(x, y), + operands=(x,y)) + def get_example_inputs(self): + return (torch.randn(3, 3), torch.randn(3, 3)), {} + + +if __name__ == "__main__": + model = SimpleCond2() + x = torch.randn(3, 3) + y = torch.randn(3, 3) + + # export (그래프 생성) + exported_model = torch.export.export(model, (x, y)) + + # export된 모델 호출 테스트 + output = exported_model.module()(x, y) + exported_model.graph.print_tabular() + print(exported_model.graph_signature.user_inputs) + print(exported_model.graph_signature.user_outputs) + print(output) + breakpoint() \ No newline at end of file diff --git a/tico/experimental/controlflow/passes/map_subgraph.py b/tico/experimental/controlflow/passes/map_subgraph.py new file mode 100644 index 00000000..c86d9d0a --- /dev/null +++ b/tico/experimental/controlflow/passes/map_subgraph.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch.fx + +import torch +from torch.export import ExportedProgram + +from tico.serialize.quant_param import QPARAM_KEY, QuantParam +from tico.utils import logging +from tico.utils.passes import PassBase, PassResult +from tico.utils.trace_decorators import trace_graph_diff_on_pass +from tico.utils.utils import get_quant_dtype +from tico.utils.validate_args_kwargs import CondArgs +from tico.utils.graph import create_node +from tico.utils.subgraph import get_gm_map +import operator + +@trace_graph_diff_on_pass +class MapSubgraph(PassBase): + """ + """ + + def __init__(self): + super().__init__() + + def call(self, exported_program: ExportedProgram, _) -> PassResult: + logger = logging.getLogger(__name__) + + graph_module = exported_program.graph_module + graph: torch.fx.Graph = graph_module.graph + for node in graph.nodes: + if node.op != "call_function": + continue + if ( + node.target + != torch.ops.higher_order.cond + ): + continue + + cond_args = CondArgs(*node.args, **node.kwargs) + + true_graph_idx = None + false_graph_idx = None + for gm_info in get_gm_map(exported_program): + if gm_info["name"] == cond_args.true_graph.name: + true_graph_idx = gm_info["index"] + continue + if gm_info["name"] == cond_args.false_graph.name: + false_graph_idx = gm_info["index"] + continue + assert true_graph_idx is not None + assert false_graph_idx is not None + + with graph.inserting_before(node): + circle_if = create_node( + graph, + torch.ops.circle_custom.if_, + args=(cond_args.condition, true_graph_idx, false_graph_idx, cond_args.cond_args), + kwargs={}, + origin=node, + ) + + # FIX ME UNLESS torch.ops.higher_order.cond generates this pattern + assert len(node.users) == 1 + getitem_node = list(node.users.items())[0][0] + assert getitem_node.target == operator.getitem + getitem_node.replace_all_uses_with(circle_if) + graph.eliminate_dead_code() + graph.lint() + graph_module.recompile() + + # Run only once. + return PassResult(False) diff --git a/tico/experimental/quantization/passes/fold_quant_ops.py b/tico/experimental/quantization/passes/fold_quant_ops.py index 48afa7d0..5a4fe9f7 100644 --- a/tico/experimental/quantization/passes/fold_quant_ops.py +++ b/tico/experimental/quantization/passes/fold_quant_ops.py @@ -72,10 +72,9 @@ class FoldQuantOps(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph for dq in graph.nodes: if dq.op != "call_function": diff --git a/tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py b/tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py index 2a442987..81e23e87 100644 --- a/tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +++ b/tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py @@ -437,10 +437,8 @@ class InsertQuantizeOnDtypeMismatch(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph for node in graph.nodes: diff --git a/tico/experimental/quantization/passes/propagate_qparam_backward.py b/tico/experimental/quantization/passes/propagate_qparam_backward.py index 1d7f94b9..b77cea82 100644 --- a/tico/experimental/quantization/passes/propagate_qparam_backward.py +++ b/tico/experimental/quantization/passes/propagate_qparam_backward.py @@ -45,10 +45,8 @@ class PropagateQParamBackward(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node): diff --git a/tico/experimental/quantization/passes/propagate_qparam_forward.py b/tico/experimental/quantization/passes/propagate_qparam_forward.py index eb3c8bee..56d835ae 100644 --- a/tico/experimental/quantization/passes/propagate_qparam_forward.py +++ b/tico/experimental/quantization/passes/propagate_qparam_forward.py @@ -50,7 +50,7 @@ class PropagateQParamForward(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node): @@ -66,8 +66,6 @@ def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node): dst.meta[QPARAM_KEY] = copy.deepcopy(src.meta[QPARAM_KEY]) logger.debug(f"{src.name}'s quantparam is propagated to {dst.name}.") - - graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph for node in graph.nodes: if node.op != "call_function": diff --git a/tico/experimental/quantization/passes/quantize_bias.py b/tico/experimental/quantization/passes/quantize_bias.py index 1b826093..40cb9104 100644 --- a/tico/experimental/quantization/passes/quantize_bias.py +++ b/tico/experimental/quantization/passes/quantize_bias.py @@ -41,10 +41,8 @@ class QuantizeBias(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph for node in graph.nodes: if node.op != "call_function": diff --git a/tico/experimental/quantization/passes/remove_weight_dequant_op.py b/tico/experimental/quantization/passes/remove_weight_dequant_op.py index 35fecc2b..b17ed0fa 100644 --- a/tico/experimental/quantization/passes/remove_weight_dequant_op.py +++ b/tico/experimental/quantization/passes/remove_weight_dequant_op.py @@ -96,10 +96,8 @@ class RemoveWeightDequantOp(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph for dq in graph.nodes: if not dq.op == "call_function": diff --git a/tico/interpreter/infer.py b/tico/interpreter/infer.py index 792699b7..1816395d 100644 --- a/tico/interpreter/infer.py +++ b/tico/interpreter/infer.py @@ -63,7 +63,7 @@ def infer(circle_binary: bytes, *args: Any, **kwargs: Any) -> Any: # Get input spec from circle binary. model = circle.Model.Model.GetRootAsModel(circle_binary, 0) - assert model.SubgraphsLength() == 1 + # assert model.SubgraphsLength() == 1 graph = model.Subgraphs(0) model_input_tensors = [ graph.Tensors(graph.Inputs(o)) for o in range(graph.InputsLength()) diff --git a/tico/passes/cast_aten_where_arg_type.py b/tico/passes/cast_aten_where_arg_type.py index f1d1e362..386defc2 100644 --- a/tico/passes/cast_aten_where_arg_type.py +++ b/tico/passes/cast_aten_where_arg_type.py @@ -109,9 +109,8 @@ class CastATenWhereArgType(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - graph_module = exported_program.graph_module graph = graph_module.graph modified = False diff --git a/tico/passes/cast_clamp_mixed_type_args.py b/tico/passes/cast_clamp_mixed_type_args.py index 80288534..651a626d 100644 --- a/tico/passes/cast_clamp_mixed_type_args.py +++ b/tico/passes/cast_clamp_mixed_type_args.py @@ -92,11 +92,10 @@ class CastClampMixedTypeArgs(PassBase): def __init__(self): super().__init__() - def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> bool: + def convert(self, exported_program: ExportedProgram, node: torch.fx.Node, graph_module) -> bool: logger = logging.getLogger(__name__) modified = False - graph_module = exported_program.graph_module graph = graph_module.graph # clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor @@ -150,17 +149,15 @@ def _convert_arg(arg, arg_name: str): return modified - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: target_op = ops.aten.clamp - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: if not is_target_node(node, target_op): continue - modified |= self.convert(exported_program, node) + modified |= self.convert(exported_program, node, graph_module) graph.eliminate_dead_code() graph.lint() diff --git a/tico/passes/cast_mixed_type_args.py b/tico/passes/cast_mixed_type_args.py index 1168db62..010430a1 100644 --- a/tico/passes/cast_mixed_type_args.py +++ b/tico/passes/cast_mixed_type_args.py @@ -92,10 +92,8 @@ def __init__(self, preserve_ep_invariant=True): self.preserve_ep_invariant = preserve_ep_invariant # TODO Folding float and int values before this pass - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/const_prop_pass.py b/tico/passes/const_prop_pass.py index 6d3c96ac..9d697bb0 100644 --- a/tico/passes/const_prop_pass.py +++ b/tico/passes/const_prop_pass.py @@ -49,12 +49,12 @@ def get_constant_placeholder_to_tensor_dict( exported_program: ExportedProgram, + graph_module, ) -> OrderedDict[torch.fx.Node, torch.Tensor]: """ Returns a dictionary of constant placeholder node to constant tensor. """ const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict() - graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph for node in graph.nodes: if node.op != "placeholder": @@ -114,13 +114,13 @@ def get_data( def propagate_constants( exported_program: ExportedProgram, + graph_module ) -> OrderedDict[torch.fx.Node, torch.Tensor]: """ Propagates constants and returns a dictionary of node to constant tensors of the graph. """ - const_node_to_tensor = get_constant_placeholder_to_tensor_dict(exported_program) + const_node_to_tensor = get_constant_placeholder_to_tensor_dict(exported_program, graph_module) - graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph for node in graph.nodes: if node.op != "call_function": @@ -177,6 +177,7 @@ def erase_constant_node( def create_constant_placeholder( const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor], exported_program: ExportedProgram, + graph_module ) -> List[torch.fx.Node]: """ This function creates constant placeholder nodes according to the given constant nodes (`const_node_to_tensor`) and replace it with the original node. @@ -265,19 +266,17 @@ class ConstPropPass(PassBase): def __init__(self) -> None: super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph # [1], [2] const_node_to_tensor: OrderedDict[ torch.fx.Node, torch.Tensor - ] = propagate_constants(exported_program) + ] = propagate_constants(exported_program, graph_module) # [3] placeholders = create_constant_placeholder( - const_node_to_tensor, exported_program + const_node_to_tensor, exported_program, graph_module ) # [4] new_name_to_spec = create_input_specs(placeholders) diff --git a/tico/passes/convert_conv1d_to_conv2d.py b/tico/passes/convert_conv1d_to_conv2d.py index d7f511f1..1e9cc9aa 100644 --- a/tico/passes/convert_conv1d_to_conv2d.py +++ b/tico/passes/convert_conv1d_to_conv2d.py @@ -141,10 +141,8 @@ def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> boo modified = True return modified - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: target_conv_op = [torch.ops.aten.conv1d.default, torch.ops.aten.conv1d.padding] - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/convert_layout_op_to_reshape.py b/tico/passes/convert_layout_op_to_reshape.py index 443f37a4..bdaf1968 100644 --- a/tico/passes/convert_layout_op_to_reshape.py +++ b/tico/passes/convert_layout_op_to_reshape.py @@ -38,10 +38,8 @@ class ConvertLayoutOpToReshape(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False diff --git a/tico/passes/convert_repeat_to_expand_copy.py b/tico/passes/convert_repeat_to_expand_copy.py index 82d6f93f..a8183748 100644 --- a/tico/passes/convert_repeat_to_expand_copy.py +++ b/tico/passes/convert_repeat_to_expand_copy.py @@ -38,10 +38,8 @@ class ConvertRepeatToExpandCopy(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/convert_to_relu6.py b/tico/passes/convert_to_relu6.py index 76d2a576..fec1bfa5 100644 --- a/tico/passes/convert_to_relu6.py +++ b/tico/passes/convert_to_relu6.py @@ -155,10 +155,8 @@ def __init__(self): ConvertDoubleClampsToReLU6(), ] - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/decompose_addmm.py b/tico/passes/decompose_addmm.py index 67793fc7..d6d8fe43 100644 --- a/tico/passes/decompose_addmm.py +++ b/tico/passes/decompose_addmm.py @@ -57,8 +57,8 @@ class DecomposeAddmm(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: - gm = exported_program.graph_module + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: + gm = graph_module graph: torch.fx.Graph = gm.graph modified = False diff --git a/tico/passes/decompose_batch_norm.py b/tico/passes/decompose_batch_norm.py index eccae901..a6586040 100644 --- a/tico/passes/decompose_batch_norm.py +++ b/tico/passes/decompose_batch_norm.py @@ -74,10 +74,10 @@ class DecomposeBatchNorm(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - gm = exported_program.graph_module + gm = graph_module graph: torch.fx.Graph = gm.graph modified = False diff --git a/tico/passes/decompose_fake_quantize.py b/tico/passes/decompose_fake_quantize.py index e26dda3d..b3deb590 100644 --- a/tico/passes/decompose_fake_quantize.py +++ b/tico/passes/decompose_fake_quantize.py @@ -63,10 +63,10 @@ def forward(self, x): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: modified = False - gm = exported_program.graph_module + gm = graph_module g = gm.graph qd = torch.ops.quantized_decomposed # type: ignore[return] for node in gm.graph.nodes: diff --git a/tico/passes/decompose_fake_quantize_tensor_qparams.py b/tico/passes/decompose_fake_quantize_tensor_qparams.py index c263c91f..7d0987bd 100644 --- a/tico/passes/decompose_fake_quantize_tensor_qparams.py +++ b/tico/passes/decompose_fake_quantize_tensor_qparams.py @@ -195,10 +195,10 @@ def forward(self, x: "f32[4]"): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: modified = False - gm = exported_program.graph_module + gm = graph_module g = gm.graph qd = torch.ops.quantized_decomposed # type: ignore[return] for node in gm.graph.nodes: diff --git a/tico/passes/decompose_group_norm.py b/tico/passes/decompose_group_norm.py index d184644f..2ba73acc 100644 --- a/tico/passes/decompose_group_norm.py +++ b/tico/passes/decompose_group_norm.py @@ -124,8 +124,8 @@ def _insert_norm(self, graph, tensor, eps, origin): graph, torch.ops.aten.mul.Tensor, (deviation, inverse_std), origin=origin ) - def call(self, exported_program: ExportedProgram) -> PassResult: - gm = exported_program.graph_module + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: + gm = graph_module graph: torch.fx.Graph = gm.graph modified = False diff --git a/tico/passes/decompose_grouped_conv2d.py b/tico/passes/decompose_grouped_conv2d.py index a2082eb8..41215e54 100644 --- a/tico/passes/decompose_grouped_conv2d.py +++ b/tico/passes/decompose_grouped_conv2d.py @@ -81,10 +81,10 @@ class DecomposeGroupedConv2d(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - gm = exported_program.graph_module + gm = graph_module graph: torch.fx.Graph = gm.graph modified = False diff --git a/tico/passes/decompose_slice_scatter.py b/tico/passes/decompose_slice_scatter.py index e39242be..0e908c5b 100644 --- a/tico/passes/decompose_slice_scatter.py +++ b/tico/passes/decompose_slice_scatter.py @@ -75,10 +75,8 @@ class DecomposeSliceScatter(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph modified = False diff --git a/tico/passes/extract_dtype_kwargs.py b/tico/passes/extract_dtype_kwargs.py index 19357e89..3cab0f3b 100644 --- a/tico/passes/extract_dtype_kwargs.py +++ b/tico/passes/extract_dtype_kwargs.py @@ -103,8 +103,7 @@ def __init__(self): self.target_ops = dict() self.target_ops[torch.ops.aten.full_like.default] = _extract_to_output - def call(self, exported_program: ExportedProgram) -> PassResult: - graph_module = exported_program.graph_module + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: graph: torch.fx.Graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/fill_meta_val.py b/tico/passes/fill_meta_val.py index b35631b8..27db1629 100644 --- a/tico/passes/fill_meta_val.py +++ b/tico/passes/fill_meta_val.py @@ -29,10 +29,8 @@ class FillMetaVal(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False # To make sure graph is topologically sorted diff --git a/tico/passes/fuse_leading_unsqueeze_reshape.py b/tico/passes/fuse_leading_unsqueeze_reshape.py index 532f5f90..03678747 100644 --- a/tico/passes/fuse_leading_unsqueeze_reshape.py +++ b/tico/passes/fuse_leading_unsqueeze_reshape.py @@ -49,11 +49,10 @@ class FuseLeadingUnsqueezeReshape(PassBase): x - aten.reshape([1]*k + s1) - aten.permute(list(range(k)) + [d+k for d in p]) """ - def call(self, ep: ExportedProgram) -> PassResult: + def call(self, ep: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - gm = ep.graph_module - graph = gm.graph + graph = graph_module.graph modified = False for reshape_back in graph.nodes: if not is_target_node(reshape_back, ops.aten.reshape): @@ -107,6 +106,6 @@ def call(self, ep: ExportedProgram) -> PassResult: if modified: graph.eliminate_dead_code() graph.lint() - gm.recompile() + graph_module.recompile() return PassResult(modified) diff --git a/tico/passes/fuse_redundant_reshape_to_mean.py b/tico/passes/fuse_redundant_reshape_to_mean.py index 9d9a88e9..1dfbb03a 100644 --- a/tico/passes/fuse_redundant_reshape_to_mean.py +++ b/tico/passes/fuse_redundant_reshape_to_mean.py @@ -39,10 +39,8 @@ class FuseRedundantReshapeToMean(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/legalize_causal_mask_value.py b/tico/passes/legalize_causal_mask_value.py index 01a99222..b3738238 100644 --- a/tico/passes/legalize_causal_mask_value.py +++ b/tico/passes/legalize_causal_mask_value.py @@ -43,14 +43,12 @@ def __init__(self, enabled: bool = False): super().__init__() self.enabled = enabled - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: if not self.enabled: return PassResult(False) new_mask = -120 # Make it configurable logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False diff --git a/tico/passes/legalize_predefined_layout_operators.py b/tico/passes/legalize_predefined_layout_operators.py index 5e9bd690..250797cf 100644 --- a/tico/passes/legalize_predefined_layout_operators.py +++ b/tico/passes/legalize_predefined_layout_operators.py @@ -434,7 +434,7 @@ def legalize_avg_pool2d(self, exported_program, node) -> bool: modified = True return modified - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: target_to_legalize_func = { torch.ops.aten.conv2d.default: self.legalize_conv2d, torch.ops.aten.conv2d.padding: self.legalize_conv2d, @@ -443,8 +443,6 @@ def call(self, exported_program: ExportedProgram) -> PassResult: torch.ops.aten.avg_pool2d.default: self.legalize_avg_pool2d, torch.ops.aten.instance_norm.default: self.legalize_instance_norm, } - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/lower_pow2_to_mul.py b/tico/passes/lower_pow2_to_mul.py index 70074dcc..7227a563 100644 --- a/tico/passes/lower_pow2_to_mul.py +++ b/tico/passes/lower_pow2_to_mul.py @@ -38,10 +38,8 @@ class LowerPow2ToMul(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/lower_to_resize_nearest_neighbor.py b/tico/passes/lower_to_resize_nearest_neighbor.py index ae0d6c1a..039a518a 100644 --- a/tico/passes/lower_to_resize_nearest_neighbor.py +++ b/tico/passes/lower_to_resize_nearest_neighbor.py @@ -197,11 +197,10 @@ def close_enough(x, y, epsilon=1e-10): node.replace_all_uses_with(nhwc_to_nchw, propagate_meta=True) return resize_nearest_neighbor - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) modified = False - graph_module = exported_program.graph_module graph = graph_module.graph for node in graph.nodes: if not is_target_node( diff --git a/tico/passes/lower_to_slice.py b/tico/passes/lower_to_slice.py index 20bb4afc..240fed4a 100644 --- a/tico/passes/lower_to_slice.py +++ b/tico/passes/lower_to_slice.py @@ -77,10 +77,8 @@ class LowerSelectCopyToSlice(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: @@ -157,10 +155,8 @@ class LowerIndexSelectToSlice(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/merge_consecutive_cat.py b/tico/passes/merge_consecutive_cat.py index 110f1419..bb825b64 100644 --- a/tico/passes/merge_consecutive_cat.py +++ b/tico/passes/merge_consecutive_cat.py @@ -31,10 +31,8 @@ class MergeConsecutiveCat(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for cat in graph.nodes: diff --git a/tico/passes/remove_nop.py b/tico/passes/remove_nop.py index 57d686f3..da434ee0 100644 --- a/tico/passes/remove_nop.py +++ b/tico/passes/remove_nop.py @@ -45,10 +45,8 @@ class RemoveNop(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/remove_redundant_assert_nodes.py b/tico/passes/remove_redundant_assert_nodes.py index a12abdc9..f0a96971 100644 --- a/tico/passes/remove_redundant_assert_nodes.py +++ b/tico/passes/remove_redundant_assert_nodes.py @@ -37,8 +37,7 @@ class RemoveRedundantAssertionNodes(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: - graph_module = exported_program.graph_module + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/remove_redundant_expand.py b/tico/passes/remove_redundant_expand.py index 2402af5e..c862fa3d 100644 --- a/tico/passes/remove_redundant_expand.py +++ b/tico/passes/remove_redundant_expand.py @@ -32,10 +32,8 @@ class RemoveRedundantExpand(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/remove_redundant_permute.py b/tico/passes/remove_redundant_permute.py index ab806bec..a7b73505 100644 --- a/tico/passes/remove_redundant_permute.py +++ b/tico/passes/remove_redundant_permute.py @@ -60,7 +60,7 @@ class RemoveRedundantPermutePattern1(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: """ [BEFORE] (AxBxC) - aten.permute_1 - aten.permute_2 - (OUT_SHAPE) @@ -72,8 +72,6 @@ def call(self, exported_program: ExportedProgram) -> PassResult: """ logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for permute2 in graph.nodes: diff --git a/tico/passes/remove_redundant_reshape.py b/tico/passes/remove_redundant_reshape.py index 87f4d83d..76501f68 100644 --- a/tico/passes/remove_redundant_reshape.py +++ b/tico/passes/remove_redundant_reshape.py @@ -55,7 +55,7 @@ class RemoveRedundantReshapePattern1(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: """ [BEFORE] `(AxBxC) - aten.reshape` - (1xAxBxC) - `aten.permute` - (1xAxCxB) - `aten.mul` - (1xAxCxB) - `aten.reshape - (AxCxB)` @@ -63,8 +63,6 @@ def call(self, exported_program: ExportedProgram) -> PassResult: `(AxBxC) - `aten.permute` - (AxCxB) - `aten.mul` - (AxCxB)` """ logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for reshape1 in graph.nodes: @@ -139,7 +137,7 @@ class RemoveRedundantReshapePattern2(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: """ [BEFORE] `(AxBxC) - aten.reshape` - (1xAxBxC) - `aten.permute` - (Bx1xAxC) - `aten.reshape - (Bx(A*C))` @@ -147,8 +145,6 @@ def call(self, exported_program: ExportedProgram) -> PassResult: `(AxBxC) - `aten.permute` - (BxAxC) - `aten.reshape` - (Bx(A*C))` """ logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for reshape1 in graph.nodes: @@ -218,7 +214,7 @@ class RemoveRedundantReshapePattern3(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: """ [BEFORE] (AxBxC) - aten.reshape - (1xAxBxC) - aten.add - (1xAxBxC) - aten.softmax - (1xAxBxC) - aten.reshape - (AxBxC) @@ -230,8 +226,6 @@ def call(self, exported_program: ExportedProgram) -> PassResult: (AxBxC) / (add) (softmax) """ logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for reshape_1 in graph.nodes: @@ -332,18 +326,16 @@ class RemoveRedundantReshapePattern4(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: """ NOTE: Below graph is just an example. This pattern matches not only for the 3D tensors. What this pattern aims to remove is that the consecutive `aten.reshape` ops. [BEFORE] - (AxBxC) - aten.reshape - (AxB'xC') - aten.reshape - (A'xB''xC') + (AxBxC) - aten.reshape - (AxB'xC') - aten.reshape - (A'xB''xC) [AFTER] - (AxBxC) - aten.reshape - (A'xB''xC') + (AxBxC) - aten.reshape - (A'xB''xC) """ logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for reshape1 in graph.nodes: @@ -399,7 +391,7 @@ class RemoveRedundantReshapePattern5(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: """ [BEFORE] (AxBxC) - aten.reshape - (AxBxC) @@ -407,8 +399,6 @@ def call(self, exported_program: ExportedProgram) -> PassResult: (AxBxC) """ logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False diff --git a/tico/passes/remove_redundant_slice.py b/tico/passes/remove_redundant_slice.py index a71f4c85..d8e14c38 100644 --- a/tico/passes/remove_redundant_slice.py +++ b/tico/passes/remove_redundant_slice.py @@ -32,10 +32,8 @@ class RemoveRedundantSlice(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/remove_redundant_to_copy.py b/tico/passes/remove_redundant_to_copy.py index 375ffb57..9784256f 100644 --- a/tico/passes/remove_redundant_to_copy.py +++ b/tico/passes/remove_redundant_to_copy.py @@ -35,10 +35,8 @@ class RemoveRedundantToCopy(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/restore_linear.py b/tico/passes/restore_linear.py index f9ed5351..fe9ba4d9 100644 --- a/tico/passes/restore_linear.py +++ b/tico/passes/restore_linear.py @@ -48,10 +48,8 @@ class RestoreLinear(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/segment_index_select.py b/tico/passes/segment_index_select.py index 31411043..212fa6f0 100644 --- a/tico/passes/segment_index_select.py +++ b/tico/passes/segment_index_select.py @@ -72,10 +72,8 @@ class SegmentIndexSelectConst(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/serialize/circle_graph.py b/tico/serialize/circle_graph.py index aacba4f2..42b47c72 100644 --- a/tico/serialize/circle_graph.py +++ b/tico/serialize/circle_graph.py @@ -69,7 +69,7 @@ class CircleModel(circle.Model.ModelT): def __init__(self): super().__init__() self.subgraphs: List[circle.SubGraph.SubGraphT] = [] - self.buffers: List[circle.Buffer.BufferT] = [] + self.buffers: List[circle.Buffer.BufferT] = [circle.Buffer.BufferT()] # Add empty buffer at the front def add_subgraph(self, graph: circle.SubGraph.SubGraphT) -> None: self.subgraphs.append(graph) @@ -80,7 +80,6 @@ def add_buffer(self, buffer: circle.Buffer.BufferT) -> int: buf_id = len(self.buffers) - 1 # last index return buf_id - @final class CircleSubgraph(circle.SubGraph.SubGraphT): def __init__(self, model: CircleModel): @@ -97,6 +96,7 @@ def __init__(self, model: CircleModel): # human-readable tensor names after serialization. self.name_to_node: Dict[str, torch.fx.Node] = {} self.counter: defaultdict = defaultdict(int) + # Generate a unique name with prefix. # Naming rule @@ -129,7 +129,7 @@ def add_input(self, input_name: str) -> None: def add_output(self, output: Any) -> None: if isinstance(output, str): - assert output in self.name_to_tid + assert output in self.name_to_tid, f"{output} is not in {self.name_to_tid}" output_name = output elif isinstance(output, int | float): # output is built-in type. @@ -324,4 +324,5 @@ def get_tid( return self.name_to_tid[node_name] # Unreachable + breakpoint() raise RuntimeError("fx Node was not converted to tensor.") diff --git a/tico/serialize/circle_serializer.py b/tico/serialize/circle_serializer.py index 9ac21cc1..947abad4 100644 --- a/tico/serialize/circle_serializer.py +++ b/tico/serialize/circle_serializer.py @@ -35,7 +35,6 @@ torch.ops.aten.max.dim, ] - def _initialize_model() -> tuple[CircleModel, CircleSubgraph]: """Initialize a new Circle model and subgraph. @@ -47,6 +46,7 @@ def _initialize_model() -> tuple[CircleModel, CircleSubgraph]: graph = CircleSubgraph(model) return model, graph +from tico.utils.subgraph import get_gm_map def build_circle( ep: ExportedProgram, config: CompileConfigBase = get_default_config() @@ -61,61 +61,77 @@ def build_circle( """ logger = logging.getLogger(__name__) builder = flatbuffers.Builder() - model, graph = _initialize_model() - - # Export tensors - _export_tensors(graph, ep) + model = CircleModel() + + op_codes: Dict[OpCode, int] = {} + + for gm_info in get_gm_map(ep): + if gm_info["name"]: #non-root subgraph + graph_module = getattr(ep.graph_module, gm_info["name"]) + else: + graph_module = ep.graph_module + ep_graph = graph_module.graph + + graph = CircleSubgraph(model) + # Export tensors + if gm_info["name"]: #non-root subgraph + _export_tensors_for_subgraph(graph, ep_graph, ep) + else: + _export_tensors(graph, ep_graph, ep) + if gm_info["index"] == 0: # Root graph + # Register inputs + logger.debug("---------------Register inputs--------------") + for in_spec in ep.graph_signature.input_specs: + if in_spec.kind != InputKind.USER_INPUT: + continue + if isinstance(in_spec.arg, ConstantArgument): + # ConstantArgument is ignored when option is given + if config.get("remove_constant_input"): + continue + # NoneType ConstantArgument is ignored. + if in_spec.arg.value == None: + continue + arg_name = in_spec.arg.name + graph.add_input(arg_name) + logger.debug(f"Registered input: {arg_name}") + + # Register outputs + logger.debug("---------------Register outputs--------------") + for user_output in ep.graph_signature.user_outputs: + if user_output == None: + logger.debug("Ignore 'None' output") + continue + + graph.add_output(user_output) + logger.debug(f"Registered output: {user_output}") + + # Export operators + logger.debug("---------------Export operators--------------") + visitors = get_node_visitors(op_codes, graph) + ep_graph.print_tabular() + breakpoint() + for node in ep_graph.nodes: + if node.op != "call_function": + continue - # Register inputs - logger.debug("---------------Register inputs--------------") - for in_spec in ep.graph_signature.input_specs: - if in_spec.kind != InputKind.USER_INPUT: - continue - if isinstance(in_spec.arg, ConstantArgument): - # ConstantArgument is ignored when option is given - if config.get("remove_constant_input"): + opcode = node.target + if opcode == operator.getitem: continue - # NoneType ConstantArgument is ignored. - if in_spec.arg.value == None: + if opcode == torch.ops.higher_order.cond: + # TODO process continue - arg_name = in_spec.arg.name - graph.add_input(arg_name) - logger.debug(f"Registered input: {arg_name}") - - # Register outputs - logger.debug("---------------Register outputs--------------") - for user_output in ep.graph_signature.user_outputs: - if user_output == None: - logger.debug("Ignore 'None' output") - continue - - graph.add_output(user_output) - logger.debug(f"Registered output: {user_output}") - - # Export operators - logger.debug("---------------Export operators--------------") - op_codes: Dict[OpCode, int] = {} - visitors = get_node_visitors(op_codes, graph) - for node in ep.graph.nodes: - if node.op != "call_function": - continue - - opcode = node.target - if opcode == operator.getitem: - continue - if opcode not in visitors: - raise RuntimeError(f"{opcode} is not yet supported") - circle_op = visitors[opcode].define_node(node) - - if circle_op: - graph.add_operator(circle_op) - logger.debug(f"call_function: {node.name} ({opcode}) Op exported.") + if opcode not in visitors: + raise RuntimeError(f"{opcode} is not yet supported") + circle_op = visitors[opcode].define_node(node) - finalise_tensor_names(graph) - validate_tensor_shapes(graph) + if circle_op: + graph.add_operator(circle_op) + logger.debug(f"call_function: {node.name} ({opcode}) Op exported.") - # Register subgraph - model.subgraphs.append(graph) + finalise_tensor_names(graph) + validate_tensor_shapes(graph) + + model.subgraphs.append(graph) # Encode operator codes model.operatorCodes = [ @@ -133,7 +149,7 @@ def build_circle( return bytes(buf) -def _export_tensors(graph: CircleSubgraph, ep: ExportedProgram) -> None: +def _export_tensors(graph: CircleSubgraph, ep_graph, ep: ExportedProgram) -> None: """Export all tensors from the exported program to the circle graph. Args: @@ -144,11 +160,13 @@ def _export_tensors(graph: CircleSubgraph, ep: ExportedProgram) -> None: logger.debug("---------------Export tensors--------------") buf_name_to_data = {name: buf for name, buf in ep.named_buffers()} - for node in ep.graph.nodes: + for node in ep_graph.nodes: if node.op == "call_function": if node.target in multiple_output_ops: continue node_val = node.meta["val"] + if node.name == 'cond': + continue if node_val.layout != torch.strided: raise RuntimeError( f"Only support dense tensors (node layout: {node_val.layout})" @@ -157,7 +175,7 @@ def _export_tensors(graph: CircleSubgraph, ep: ExportedProgram) -> None: logger.debug(f"call_function: {node.name} tensor exported.") elif node.op == "placeholder": - _handle_placeholder_node(graph, node, ep, buf_name_to_data) + _handle_placeholder_node(graph, node, ep_graph, ep, buf_name_to_data) elif node.op == "get_attr": _handle_get_attr_node(graph, node) @@ -178,9 +196,57 @@ def _export_tensors(graph: CircleSubgraph, ep: ExportedProgram) -> None: raise AssertionError(f"Unknown fx.Node op {node.op}") +def _export_tensors_for_subgraph(graph: CircleSubgraph, ep_graph, ep: ExportedProgram) -> None: + """Export all tensors from the exported program to the circle graph. + + Args: + graph: The CircleSubgraph to add tensors to + ep: The exported PyTorch program + """ + logger = logging.getLogger(__name__) + logger.debug("---------------Export tensors--------------") + buf_name_to_data = {name: buf for name, buf in ep.named_buffers()} #model-wise context + + for node in ep_graph.nodes: + if node.op == "call_function": + if node.target in multiple_output_ops: + continue + node_val = node.meta["val"] + if node.name == 'cond': + continue + if node_val.layout != torch.strided: + raise RuntimeError( + f"Only support dense tensors (node layout: {node_val.layout})" + ) + graph.add_tensor_from_node(node) + logger.debug(f"call_function: {node.name} tensor exported.") + + elif node.op == "placeholder": + _handle_placeholder_node(graph, node, ep_graph, ep, buf_name_to_data) + graph.add_input(node.name) # This is added for subgraph + + elif node.op == "get_attr": + _handle_get_attr_node(graph, node) + elif node.op == "output": + for output in node.args[0]: + if isinstance(output, torch.fx.Node): + assert graph.has_tensor(output.name) + graph.add_output(output.name) # This is added for subgraph + continue + + elif node.op == "call_method": + raise AssertionError("Not yet implemented") + + elif node.op == "call_module": + raise AssertionError("Not yet implemented") + + else: + raise AssertionError(f"Unknown fx.Node op {node.op}") + def _handle_placeholder_node( graph: CircleSubgraph, node: torch.fx.Node, + ep_graph, ep: ExportedProgram, buf_name_to_data: dict, ) -> None: @@ -326,20 +392,22 @@ def _handle_get_attr_node( node: The get_attr node to process """ assert isinstance(node.target, str) - attr_tensor = getattr(node.graph.owning_module, node.target) - - if not isinstance(attr_tensor, torch.Tensor): - raise ValueError(f"Attribute {node.target} is not a tensor") - - attr_shape, attr_shape_signature = to_circle_shape(attr_tensor.shape) - - graph.add_tensor_from_scratch( - prefix=node.name, - shape=attr_shape, - shape_signature=attr_shape_signature, - dtype=to_circle_dtype(attr_tensor.dtype), - source_node=node, - ) - - logger = logging.getLogger(__name__) - logger.debug(f"Exported attribute tensor: {node.name}") + attr = getattr(node.graph.owning_module, node.target) + + if isinstance(attr, torch.fx.graph_module.GraphModule): + pass + elif isinstance(attr, torch.Tensor): + attr_shape, attr_shape_signature = to_circle_shape(attr.shape) + + graph.add_tensor_from_scratch( + prefix=node.name, + shape=attr_shape, + shape_signature=attr_shape_signature, + dtype=to_circle_dtype(attr.dtype), + source_node=node, + ) + + logger = logging.getLogger(__name__) + logger.debug(f"Exported attribute tensor: {node.name}") + else: + raise ValueError(f"Unsupported get_attr target type {type(attr)}") diff --git a/tico/serialize/operators/op_circle_if.py b/tico/serialize/operators/op_circle_if.py new file mode 100644 index 00000000..4093d3d5 --- /dev/null +++ b/tico/serialize/operators/op_circle_if.py @@ -0,0 +1,66 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, TYPE_CHECKING + +if TYPE_CHECKING: + import torch._ops + import torch.fx +import torch +from circle_schema import circle + +from tico.serialize.circle_graph import CircleSubgraph +from tico.serialize.operators.hashable_opcode import OpCode +from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor +from tico.serialize.operators.utils import create_builtin_operator, get_op_index +from tico.utils.validate_args_kwargs import CircleIfArgs +from tico.utils.errors import NotYetSupportedError + + +@register_node_visitor +class CircleIfVisitor(NodeVisitor): + target: List[torch._ops.OpOverload] = [torch.ops.circle_custom.if_] + + def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): + super().__init__(op_codes, graph) + + def define_node( + self, + node: torch.fx.Node, + ) -> circle.Operator.OperatorT: + op_index = get_op_index( + circle.BuiltinOperator.BuiltinOperator.IF, self._op_codes + ) + if_args = CircleIfArgs(*node.args, **node.kwargs) + + pred = if_args.pred + then_idx = if_args.then_graph_idx + else_idx = if_args.else_graph_idx + arguments = if_args.if_args + + if len(arguments) > 1: + raise NotYetSupportedError("Not supported multiple input case yet. Only one input is allowed.") + + arguments = arguments[0] + + inputs = [pred, arguments] + outputs = [node] + + operator = create_builtin_operator(self.graph, op_index, inputs, outputs) + operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.IfOptions + operator.builtinOptions = circle.IfOptions.IfOptionsT() + operator.builtinOptions.thenSubgraphIndex = then_idx + operator.builtinOptions.elseSubgraphIndex = else_idx + + return operator diff --git a/tico/serialize/operators/utils.py b/tico/serialize/operators/utils.py index 462001d6..3b4fa8c1 100644 --- a/tico/serialize/operators/utils.py +++ b/tico/serialize/operators/utils.py @@ -40,6 +40,7 @@ def get_op_index(opcode: int, opcode_map: Dict[OpCode, int]) -> int: op_index = opcode_map[op_code] return op_index +import torch # TODO Move this to CircleSubGraph def create_builtin_operator( @@ -47,8 +48,21 @@ def create_builtin_operator( ) -> circle.Operator.OperatorT: operator = circle.Operator.OperatorT() operator.opcodeIndex = op_index - operator.inputs = [graph.get_tid(input) for input in inputs] - operator.outputs = [graph.get_tid(output) for output in outputs] + + operator.inputs = [] + for inp in inputs: + if isinstance(inp, torch.fx.immutable_collections.immutable_list): + operator.inputs.append(tuple(graph.get_tid(inp_item) for inp_item in inp)) # TODO: extend to multiple tuple processing + print(f"input: {inp}") + else: + operator.inputs.append(graph.get_tid(inp)) + operator.outputs = [] + for outp in outputs: + if isinstance(outp, torch.fx.immutable_collections.immutable_list): + print(f"output: {outp}") + operator.outputs.append(tuple(graph.get_tid(outp_item) for outp_item in outp)) # TODO: extend to multiple tuple processing + else: + operator.outputs.append(graph.get_tid(outp)) return operator diff --git a/tico/utils/convert.py b/tico/utils/convert.py index d13551e3..4105a5bf 100644 --- a/tico/utils/convert.py +++ b/tico/utils/convert.py @@ -34,6 +34,9 @@ from tico.experimental.quantization.passes.remove_weight_dequant_op import ( RemoveWeightDequantOp, ) +from tico.experimental.controlflow.passes.map_subgraph import ( + MapSubgraph, +) from tico.passes.cast_aten_where_arg_type import CastATenWhereArgType from tico.passes.cast_clamp_mixed_type_args import CastClampMixedTypeArgs from tico.passes.cast_mixed_type_args import CastMixedTypeArgs @@ -84,6 +87,7 @@ trace_graph_diff_on_func, ) from tico.utils.utils import has_quantization_ops, SuppressWarning +from tico.utils.subgraph import get_gm_map @trace_const_diff_on_func @@ -154,6 +158,7 @@ def check_unsupported_target(exported_program: ExportedProgram): supported_target = list(get_support_targets()) # Ignore `getitem` since it is no-op for multiple outputs. supported_target.append(operator.getitem) + supported_target.append(torch.ops.higher_order.cond) unsupported = [] for n in exported_program.graph.nodes: if n.op != "call_function": @@ -193,22 +198,27 @@ def convert_exported_module_to_circle( config = get_default_config() assert isinstance(config, CompileConfigBase) - - logger = logging.getLogger(__name__) - logger.debug("Input ExportedProgram (must be core aten)") - logger.debug(exported_program) - - # PRE-EDGE PASSES - # - # Here are the passes that run before to_edge() conversion. - # Let's decompose nodes that are not Aten Canonical, which can't be converted to the edge IR. - decompose_quantize_op = PassManager( - passes=[ - DecomposeFakeQuantize(), - DecomposeFakeQuantizeTensorQParams(), - ] - ) - decompose_quantize_op.run(exported_program) + + for gm_info in get_gm_map(exported_program): + if gm_info["name"]: #non-root subgraph + graph_module = getattr(exported_program.graph_module, gm_info["name"]) + else: + graph_module = exported_program.graph_module + logger = logging.getLogger(__name__) + logger.debug("Input ExportedProgram (must be core aten)") + logger.debug(exported_program) + + # PRE-EDGE PASSES + # + # Here are the passes that run before to_edge() conversion. + # Let's decompose nodes that are not Aten Canonical, which can't be converted to the edge IR. + decompose_quantize_op = PassManager( + passes=[ + DecomposeFakeQuantize(), + DecomposeFakeQuantizeTensorQParams(), + ] + ) + decompose_quantize_op.run(exported_program, graph_module) # This pass should be run before 'RestoreLinear' and after 'decompose_quantize_op'. # TODO run pass regardless of the orders. @@ -221,77 +231,91 @@ def convert_exported_module_to_circle( # UserWarning: At pre-dispatch tracing, we assume that any custom op marked with # CompositeImplicitAutograd and have functional schema are safe to not decompose. exported_program = traced_run_decompositions(exported_program) + + for gm_info in get_gm_map(exported_program): + if gm_info["name"]: #non-root subgraph + graph_module = getattr(exported_program.graph_module, gm_info["name"]) + else: + graph_module = exported_program.graph_module + graph = graph_module.graph + + reinterpret_pass = PassManager( + passes=[ + MapSubgraph(), + ] + ) + reinterpret_pass.run(exported_program, graph_module) - # TODO Distinguish legalize and optimize - circle_legalize = PassManager( - passes=[ - FillMetaVal(), - ExtractDtypeKwargsPass(), - RemoveNop(), - ConvertLayoutOpToReshape(), - RestoreLinear(), - ConvertToReLU6(), - DecomposeAddmm(), - DecomposeSliceScatter(), - DecomposeGroupNorm(), - DecomposeBatchNorm(), - DecomposeGroupedConv2d(), - CastATenWhereArgType(), - ConvertRepeatToExpandCopy(), - *RemoveRedundantPermutePasses(), - RemoveRedundantAssertionNodes(), - RemoveRedundantExpand(), - RemoveRedundantSlice(), - FuseRedundantReshapeToMean(), - *RemoveRedundantViewPasses(), - RemoveRedundantToCopy(), - MergeConsecutiveCat(), - CastMixedTypeArgs(preserve_ep_invariant=True), - ConstPropPass(), - SegmentIndexSelectConst(), - LegalizeCausalMaskValue(enabled=config.get("legalize_causal_mask_value")), - ConvertMatmulToLinear( - enable_lhs_const=config.get("convert_lhs_const_mm_to_fc"), - enable_rhs_const=config.get("convert_rhs_const_mm_to_fc"), - enable_single_batch_lhs_const_bmm=config.get( - "convert_single_batch_lhs_const_bmm_to_fc" + # TODO Distinguish legalize and optimize + circle_legalize = PassManager( + passes=[ + FillMetaVal(), + ExtractDtypeKwargsPass(), + RemoveNop(), + ConvertLayoutOpToReshape(), + RestoreLinear(), + ConvertToReLU6(), + DecomposeAddmm(), + DecomposeSliceScatter(), + DecomposeGroupNorm(), + DecomposeBatchNorm(), + DecomposeGroupedConv2d(), + CastATenWhereArgType(), + ConvertRepeatToExpandCopy(), + *RemoveRedundantPermutePasses(), + RemoveRedundantAssertionNodes(), + RemoveRedundantExpand(), + RemoveRedundantSlice(), + FuseRedundantReshapeToMean(), + *RemoveRedundantViewPasses(), + RemoveRedundantToCopy(), + MergeConsecutiveCat(), + CastMixedTypeArgs(preserve_ep_invariant=True), + ConstPropPass(), + SegmentIndexSelectConst(), + LegalizeCausalMaskValue(enabled=config.get("legalize_causal_mask_value")), + ConvertMatmulToLinear( + enable_lhs_const=config.get("convert_lhs_const_mm_to_fc"), + enable_rhs_const=config.get("convert_rhs_const_mm_to_fc"), + enable_single_batch_lhs_const_bmm=config.get( + "convert_single_batch_lhs_const_bmm_to_fc" + ), ), - ), - LowerToResizeNearestNeighbor(), - LegalizePreDefinedLayoutOperators(), - LowerPow2ToMul(), - ConvertConv1dToConv2d(), - *LowerToSlicePasses(), - FuseLeadingUnsqueezeReshape(), - CastClampMixedTypeArgs(), - ] - ) - circle_legalize.run(exported_program) - - # After this stage, ExportedProgram invariant is broken, i.e., - # graph can have a constant torch.tensor not lifted to a placeholder - circle_legalize = PassManager( - passes=[ - FillMetaVal(), - CastMixedTypeArgs(preserve_ep_invariant=False), - ] - ) - circle_legalize.run(exported_program) - - # TODO Give an option to enable quantiztion to user - enable_quantization = has_quantization_ops(exported_program.graph) - if enable_quantization: - quantize_graph = PassManager( + LowerToResizeNearestNeighbor(), + LegalizePreDefinedLayoutOperators(), + LowerPow2ToMul(), + ConvertConv1dToConv2d(), + *LowerToSlicePasses(), + FuseLeadingUnsqueezeReshape(), + CastClampMixedTypeArgs(), + ] + ) + circle_legalize.run(exported_program, graph_module) + + # After this stage, ExportedProgram invariant is broken, i.e., + # graph can have a constant torch.tensor not lifted to a placeholder + circle_legalize = PassManager( passes=[ - FoldQuantOps(), - RemoveWeightDequantOp(), - PropagateQParamForward(), - PropagateQParamBackward(), - QuantizeBias(), - InsertQuantizeOnDtypeMismatch(), + FillMetaVal(), + CastMixedTypeArgs(preserve_ep_invariant=False), ] ) - quantize_graph.run(exported_program) + circle_legalize.run(exported_program, graph_module) + + # TODO Give an option to enable quantiztion to user + enable_quantization = has_quantization_ops(graph) + if enable_quantization: + quantize_graph = PassManager( + passes=[ + FoldQuantOps(), + RemoveWeightDequantOp(), + PropagateQParamForward(), + PropagateQParamBackward(), + QuantizeBias(), + InsertQuantizeOnDtypeMismatch(), + ] + ) + quantize_graph.run(exported_program) check_unsupported_target(exported_program) check_training_ops(exported_program) diff --git a/tico/utils/passes.py b/tico/utils/passes.py index 0a59d9bf..cee92c75 100644 --- a/tico/utils/passes.py +++ b/tico/utils/passes.py @@ -31,7 +31,7 @@ class PassBase(ABC): """ @abstractmethod - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, gm) -> PassResult: pass @@ -51,7 +51,7 @@ def __init__( self.passes: List[PassBase] = passes self.strategy: PassStrategy = strategy - def run(self, exported_program: ExportedProgram): + def run(self, exported_program: ExportedProgram, graph_module): MAXIMUM_STEP_COUNT = 1000 step = 0 while True: @@ -59,10 +59,10 @@ def run(self, exported_program: ExportedProgram): for _pass in self.passes: # Automatically update the signatures of the input and output. # https://github.com/pytorch/executorch/issues/4013#issuecomment-2187161844 - with exported_program.graph_module._set_replace_hook( + with graph_module._set_replace_hook( exported_program.graph_signature.get_replace_hook() ): - result = _pass.call(exported_program) + result = _pass.call(exported_program, graph_module) modified = modified or result.modified if modified and self.strategy == PassStrategy.RESTART: break diff --git a/tico/utils/register_custom_op.py b/tico/utils/register_custom_op.py index 095cb6a5..82005986 100644 --- a/tico/utils/register_custom_op.py +++ b/tico/utils/register_custom_op.py @@ -12,13 +12,55 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List, Optional, Union import torch from torch._subclasses.fake_tensor import FakeTensor from torch.library import custom_op, register_fake from tico.utils.mx.mx_ops import _quantize_mx +from tico.utils.subgraph import get_gm_map + +# Note that an operator assumes input tensor has NHWC format. +def CircleIf(): + @custom_op("circle_custom::if_", mutates_args=()) + def if_(pred: torch.Tensor, true_graph_idx: int, false_graph_idx: int, if_args: List[torch.Tensor]) -> torch.Tensor: + true_graph = None + false_graph = None + for gm_info in get_gm_map(): + if gm_info["index"] == true_graph_idx: + true_graph = gm_info["gm"] + continue + if gm_info["index"] == false_graph_idx: + false_graph = gm_info["gm"] + continue + + if pred: + result = true_graph(*if_args) + assert len(result) == 1 # TODO: Support tuple of result + return result[0] + else: + result = false_graph(*if_args) + assert len(result) == 1 # TODO: Support tuple of result + return result[0] + + @register_fake("circle_custom::if_") + def _(pred: torch.Tensor, true_graph_idx: int, false_graph_idx: int, if_args: List): + true_graph = None + false_graph = None + for gm_info in get_gm_map(): + if gm_info["index"] == true_graph_idx: + true_graph = gm_info["gm"] + continue + if gm_info["index"] == false_graph_idx: + false_graph = gm_info["gm"] + continue + + result = true_graph(*if_args) + assert len(result) == 1 # TODO: Support tuple of result + + return result[0] + # Note that an operator assumes input tensor has NHWC format. def CircleResizeNearestNeighbor(): @@ -740,3 +782,4 @@ def RegisterOps(): CircleInstanceNorm() CircleQuantizeMX() CircleRMSNorm() + CircleIf() diff --git a/tico/utils/signature.py b/tico/utils/signature.py index acabaf89..a69c0796 100644 --- a/tico/utils/signature.py +++ b/tico/utils/signature.py @@ -117,8 +117,8 @@ def load(circle_path: str) -> bytes: def __init__(self, circle_binary): model = circle.Model.Model.GetRootAsModel(circle_binary, 0) - assert model.SubgraphsLength() == 1, "Only one subgraph is supported" + # Assumption; Circle model's user IO signature is defined in the first subgraph graph = model.Subgraphs(0) tensors = [graph.Tensors(graph.Inputs(o)) for o in range(graph.InputsLength())] diff --git a/tico/utils/subgraph.py b/tico/utils/subgraph.py new file mode 100644 index 00000000..3c9aa8ea --- /dev/null +++ b/tico/utils/subgraph.py @@ -0,0 +1,43 @@ +import torch +from torch.export import ExportedProgram +from typing import Optional +import functools + +_gm_map = None +def get_gm_map(ep: Optional[ExportedProgram] = None): + """ + Returns [{"index":0, "name": "true_graph_0", "getter": lambda ep: ep.graph_module}, ...}] + """ + # Build _gm_map only once while compiler running + global _gm_map + if _gm_map is None: + assert ep is not None + _gm_map = _build_gm_map(ep) + return _gm_map + +def _build_gm_map(ep: ExportedProgram): + ret = [] + + # root GraphModule 추가 + ret.append({ + "index": len(ret), + "name": "", + "gm": ep.graph_module, + }) + + # Inspect non-root subgraphs + for node in ep.graph.nodes: + if node.op == "get_attr": + attr = getattr(node.graph.owning_module, node.target) + + # TODO: Enable recursion (n-depth) + if isinstance(attr, torch.fx.graph_module.GraphModule): + assert hasattr(node, 'name') + assert getattr(node, 'name') != ret[0]["name"] + graph_name = getattr(node, 'name') + ret.append({ + "index": len(ret), + "name": graph_name, + "gm": attr, + }) + return ret diff --git a/tico/utils/trace_decorators.py b/tico/utils/trace_decorators.py index c47c180a..62dab48b 100644 --- a/tico/utils/trace_decorators.py +++ b/tico/utils/trace_decorators.py @@ -28,13 +28,11 @@ def trace_const_diff_on_pass(cls): def _call_traced(fn): @wraps(fn) - def wrapped(*args): - _, exported_program = args + def wrapped(self, exported_program, graph_module): assert isinstance(exported_program, ExportedProgram) - graph_module = exported_program.graph_module assert isinstance(graph_module, torch.fx.GraphModule), type(graph_module) capture_const(exported_program) - ret = fn(*args) + ret = fn(self, exported_program, graph_module) log_const(exported_program, title=str(cls.__name__), recapture=False) return ret @@ -54,13 +52,11 @@ def trace_graph_diff_on_pass(cls): def _call_traced(fn): @wraps(fn) - def wrapped(*args): - _, exported_program = args + def wrapped(self, exported_program, graph_module): assert isinstance(exported_program, ExportedProgram) - graph_module = exported_program.graph_module assert isinstance(graph_module, torch.fx.GraphModule), type(graph_module) capture(graph_module.graph) - ret = fn(*args) + ret = fn(self, exported_program, graph_module) log(graph_module.graph, title=str(cls.__name__), recapture=False) return ret diff --git a/tico/utils/validate_args_kwargs.py b/tico/utils/validate_args_kwargs.py index 8a5feb21..b5de2337 100644 --- a/tico/utils/validate_args_kwargs.py +++ b/tico/utils/validate_args_kwargs.py @@ -205,6 +205,27 @@ class CloneArgs: input: torch.fx.Node memory_format: Optional[torch.memory_format] = None +@enforce_type +@dataclass +class CircleIfArgs: + """ + """ + pred: torch.fx.Node + then_graph_idx: int + else_graph_idx: int + if_args: torch.fx.immutable_collections.immutable_list + +@enforce_type +@dataclass +class CondArgs: + """ + # This is not aten operator but `torch.ops.higher_order_op.cond` + """ + condition: torch.fx.Node + true_graph: torch.fx.Node + false_graph: torch.fx.Node + cond_args: torch.fx.immutable_collections.immutable_list + @enforce_type @dataclass From 329b6acf13f04b1f57573e75b4c282b202d5a76b Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Thu, 11 Sep 2025 11:05:28 +0900 Subject: [PATCH 2/9] Pass: test SimpleCond1 --- test/modules/op/cond.py | 53 ++++++++++++++++++++++------- tico/serialize/circle_serializer.py | 4 --- 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/test/modules/op/cond.py b/test/modules/op/cond.py index 8aaff89b..b543567f 100644 --- a/test/modules/op/cond.py +++ b/test/modules/op/cond.py @@ -15,7 +15,9 @@ import torch from test.modules.base import TestModuleBase +from test.utils.tag import use_onert +@use_onert class SimpleCond1(TestModuleBase): class Sin(torch.nn.Module): def forward(self, x): @@ -40,6 +42,7 @@ def get_example_inputs(self): +@use_onert class SimpleCond2(TestModuleBase): class Sin(torch.nn.Module): def forward(self, x, y): @@ -61,20 +64,44 @@ def forward(self, x, y): operands=(x,y)) def get_example_inputs(self): return (torch.randn(3, 3), torch.randn(3, 3)), {} + + +@use_onert +class SimpleCond3(TestModuleBase): + class Sin(torch.nn.Module): + def forward(self, x, y): + return torch.sin(x) + torch.sin(y) + + class Cos(torch.nn.Module): + def forward(self, x, y): + return torch.cos(x) - torch.cos(y) + + def __init__(self): + super().__init__() + self.sin = self.Sin() + self.cos = self.Cos() + + def forward(self, x, y): + return torch.cond(x.sum() + y.sum() > 0, + lambda x, y: self.sin(x, y), + lambda x, y: self.cos(x, y), + operands=(x,y)) + def get_example_inputs(self): + return (torch.randn(3, 3), torch.randn(3, 3)), {} -if __name__ == "__main__": - model = SimpleCond2() - x = torch.randn(3, 3) - y = torch.randn(3, 3) +# if __name__ == "__main__": +# model = SimpleCond2() +# x = torch.randn(3, 3) +# y = torch.randn(3, 3) - # export (그래프 생성) - exported_model = torch.export.export(model, (x, y)) +# # export (그래프 생성) +# exported_model = torch.export.export(model, (x, y)) - # export된 모델 호출 테스트 - output = exported_model.module()(x, y) - exported_model.graph.print_tabular() - print(exported_model.graph_signature.user_inputs) - print(exported_model.graph_signature.user_outputs) - print(output) - breakpoint() \ No newline at end of file +# # export된 모델 호출 테스트 +# output = exported_model.module()(x, y) +# exported_model.graph.print_tabular() +# print(exported_model.graph_signature.user_inputs) +# print(exported_model.graph_signature.user_outputs) +# print(output) +# breakpoint() \ No newline at end of file diff --git a/tico/serialize/circle_serializer.py b/tico/serialize/circle_serializer.py index 947abad4..5bcb1c68 100644 --- a/tico/serialize/circle_serializer.py +++ b/tico/serialize/circle_serializer.py @@ -109,7 +109,6 @@ def build_circle( logger.debug("---------------Export operators--------------") visitors = get_node_visitors(op_codes, graph) ep_graph.print_tabular() - breakpoint() for node in ep_graph.nodes: if node.op != "call_function": continue @@ -117,9 +116,6 @@ def build_circle( opcode = node.target if opcode == operator.getitem: continue - if opcode == torch.ops.higher_order.cond: - # TODO process - continue if opcode not in visitors: raise RuntimeError(f"{opcode} is not yet supported") circle_op = visitors[opcode].define_node(node) From f7c0c8bff91234baed4659c23fd03301d0943135 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Thu, 11 Sep 2025 11:11:28 +0900 Subject: [PATCH 3/9] Pass: test SimpleCond3 --- tico/serialize/operators/op_circle_if.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tico/serialize/operators/op_circle_if.py b/tico/serialize/operators/op_circle_if.py index 4093d3d5..d4c49765 100644 --- a/tico/serialize/operators/op_circle_if.py +++ b/tico/serialize/operators/op_circle_if.py @@ -49,12 +49,7 @@ def define_node( else_idx = if_args.else_graph_idx arguments = if_args.if_args - if len(arguments) > 1: - raise NotYetSupportedError("Not supported multiple input case yet. Only one input is allowed.") - - arguments = arguments[0] - - inputs = [pred, arguments] + inputs = [pred, *arguments] outputs = [node] operator = create_builtin_operator(self.graph, op_index, inputs, outputs) From 0dea901a1390574022bc2165c782409a381deb65 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Wed, 24 Sep 2025 19:04:13 +0900 Subject: [PATCH 4/9] merge --- tico/passes/convert_matmul_to_linear.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/tico/passes/convert_matmul_to_linear.py b/tico/passes/convert_matmul_to_linear.py index 2db06fa4..d0337eaa 100644 --- a/tico/passes/convert_matmul_to_linear.py +++ b/tico/passes/convert_matmul_to_linear.py @@ -44,8 +44,7 @@ class MatmulToLinearConverter(Converter): def __init__(self): super().__init__() - def convert(self, exported_program, node) -> torch.fx.Node: - graph_module = exported_program.graph_module + def convert(self, exported_program, graph_module, node) -> torch.fx.Node: graph = graph_module.graph mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type] @@ -73,7 +72,7 @@ class RhsConstMatmulToLinearConverter(MatmulToLinearConverter): def __init__(self): super().__init__() - def match(self, exported_program, node) -> bool: + def match(self, exported_program, graph_module, node) -> bool: if not node.target == torch.ops.aten.mm.default: return False @@ -99,7 +98,7 @@ class LhsConstMatmulToLinearConverter(MatmulToLinearConverter): def __init__(self): super().__init__() - def match(self, exported_program, node) -> bool: + def match(self, exported_program, graph_module, node) -> bool: if not node.target == torch.ops.aten.mm.default: return False @@ -114,8 +113,8 @@ def match(self, exported_program, node) -> bool: return True return False - def convert(self, exported_program, node) -> torch.fx.Node: - return super().convert(exported_program, node) + def convert(self, exported_program, graph_module, node) -> torch.fx.Node: + return super().convert(exported_program, graph_module, node) class SingleBatchLhsConstBmmToLinearConverter(Converter): @@ -154,7 +153,7 @@ class SingleBatchLhsConstBmmToLinearConverter(Converter): def __init__(self): super().__init__() - def match(self, exported_program, node) -> bool: + def match(self, exported_program, graph_module, node) -> bool: if not node.target == torch.ops.aten.bmm.default: return False @@ -185,7 +184,7 @@ def match(self, exported_program, node) -> bool: return True - def convert(self, exported_program, node) -> torch.fx.Node: + def convert(self, exported_program, graph_module, node) -> torch.fx.Node: graph_module = exported_program.graph_module graph = graph_module.graph @@ -284,10 +283,9 @@ def __init__( if enable_single_batch_lhs_const_bmm: self.converters.append(SingleBatchLhsConstBmmToLinearConverter()) - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: @@ -295,10 +293,10 @@ def call(self, exported_program: ExportedProgram) -> PassResult: continue for converter in self.converters: - if not converter.match(exported_program, node): + if not converter.match(exported_program, graph_module, node): continue - new_node = converter.convert(exported_program, node) + new_node = converter.convert(exported_program, graph_module, node) modified = True logger.debug( f"{node.name} is replaced with {new_node.name} operator (permute + linear)" From 3ed898a59dbc87a4f46a66364633390124e062be Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Fri, 26 Sep 2025 19:34:30 +0900 Subject: [PATCH 5/9] Pass: op.cond all test --- .../controlflow/passes/map_subgraph.py | 52 ++++++++++----- tico/serialize/circle_graph.py | 1 - tico/serialize/circle_serializer.py | 27 ++++---- tico/serialize/operators/op_circle_if.py | 25 +++++--- tico/utils/convert.py | 16 ++--- tico/utils/register_custom_op.py | 27 +------- tico/utils/subgraph.py | 63 +++++++++---------- tico/utils/utils.py | 9 ++- tico/utils/validate_args_kwargs.py | 4 +- 9 files changed, 112 insertions(+), 112 deletions(-) diff --git a/tico/experimental/controlflow/passes/map_subgraph.py b/tico/experimental/controlflow/passes/map_subgraph.py index c86d9d0a..c74d037d 100644 --- a/tico/experimental/controlflow/passes/map_subgraph.py +++ b/tico/experimental/controlflow/passes/map_subgraph.py @@ -20,15 +20,14 @@ import torch from torch.export import ExportedProgram -from tico.serialize.quant_param import QPARAM_KEY, QuantParam from tico.utils import logging from tico.utils.passes import PassBase, PassResult from tico.utils.trace_decorators import trace_graph_diff_on_pass -from tico.utils.utils import get_quant_dtype from tico.utils.validate_args_kwargs import CondArgs from tico.utils.graph import create_node -from tico.utils.subgraph import get_gm_map +from tico.utils.subgraph import get_all_graph_modules, freeze_subgraphs import operator +from torch.utils import _pytree as pytree @trace_graph_diff_on_pass class MapSubgraph(PassBase): @@ -53,27 +52,50 @@ def call(self, exported_program: ExportedProgram, _) -> PassResult: continue cond_args = CondArgs(*node.args, **node.kwargs) + true_graph = cond_args.true_graph + false_graph = cond_args.false_graph + graph_args = cond_args.cond_args - true_graph_idx = None - false_graph_idx = None - for gm_info in get_gm_map(exported_program): - if gm_info["name"] == cond_args.true_graph.name: - true_graph_idx = gm_info["index"] - continue - if gm_info["name"] == cond_args.false_graph.name: - false_graph_idx = gm_info["index"] - continue - assert true_graph_idx is not None - assert false_graph_idx is not None + def _set_meta_val(graph_node, graph_module, graph_args): + def _get_meta_val(node): + assert hasattr(node, 'meta'), f"'node' has no attribute named 'meta' (node: {node})" + assert "val" in node.meta, f"val key not in node.meta (node: {node}, meta: {node.meta})" + return node.meta["val"] + + args, kwargs = pytree.tree_map_only( + torch.fx.Node, + _get_meta_val, + (graph_args, {}), + ) + + new_val = graph_module(*args, **kwargs) # type: ignore[operator] + graph_node.meta["val"] = new_val + + for graph_module, name in get_all_graph_modules(exported_program, subgraph_only=True): + if true_graph.name == name: + _set_meta_val(true_graph, graph_module, graph_args) + if false_graph.name == name: + _set_meta_val(false_graph, graph_module, graph_args) + + assert "val" in true_graph.meta, f"{true_graph} has no node.meta['val']" + assert "val" in false_graph.meta, f"{false_graph} has no node.meta['val']" + freeze_subgraphs(exported_program) with graph.inserting_before(node): circle_if = create_node( graph, torch.ops.circle_custom.if_, - args=(cond_args.condition, true_graph_idx, false_graph_idx, cond_args.cond_args), + args=(cond_args.condition, cond_args.true_graph, cond_args.false_graph, cond_args.cond_args), kwargs={}, origin=node, ) + + for t, f in zip(true_graph.meta['val'], false_graph.meta['val']): + assert type(t) == type(f) + assert t.shape == f.shape, f"{t.shape} != {f.shape}" + assert t.dtype == f.dtype, f"{t.dtype} != {f.dtype}" + + circle_if.meta["val"] = true_graph.meta['val'][0] # FIX ME UNLESS torch.ops.higher_order.cond generates this pattern assert len(node.users) == 1 diff --git a/tico/serialize/circle_graph.py b/tico/serialize/circle_graph.py index 42b47c72..a8b53c90 100644 --- a/tico/serialize/circle_graph.py +++ b/tico/serialize/circle_graph.py @@ -324,5 +324,4 @@ def get_tid( return self.name_to_tid[node_name] # Unreachable - breakpoint() raise RuntimeError("fx Node was not converted to tensor.") diff --git a/tico/serialize/circle_serializer.py b/tico/serialize/circle_serializer.py index 5bcb1c68..0f3299d9 100644 --- a/tico/serialize/circle_serializer.py +++ b/tico/serialize/circle_serializer.py @@ -28,11 +28,14 @@ from tico.serialize.operators.node_visitor import get_node_visitors from tico.utils import logging from tico.utils.serialize import finalise_tensor_names, validate_tensor_shapes +from tico.utils.subgraph import get_all_graph_modules multiple_output_ops = [ torch.ops.aten.split_with_sizes.default, torch.ops.aten.max.dim, + # torch.ops.circle_custom.if_.default, + # torch.ops.circle_custom.if_, ] def _initialize_model() -> tuple[CircleModel, CircleSubgraph]: @@ -46,7 +49,6 @@ def _initialize_model() -> tuple[CircleModel, CircleSubgraph]: graph = CircleSubgraph(model) return model, graph -from tico.utils.subgraph import get_gm_map def build_circle( ep: ExportedProgram, config: CompileConfigBase = get_default_config() @@ -64,21 +66,17 @@ def build_circle( model = CircleModel() op_codes: Dict[OpCode, int] = {} - - for gm_info in get_gm_map(ep): - if gm_info["name"]: #non-root subgraph - graph_module = getattr(ep.graph_module, gm_info["name"]) - else: - graph_module = ep.graph_module + for graph_module, name in get_all_graph_modules(ep): ep_graph = graph_module.graph - graph = CircleSubgraph(model) + # Export tensors - if gm_info["name"]: #non-root subgraph - _export_tensors_for_subgraph(graph, ep_graph, ep) - else: + if name == '': # root graph _export_tensors(graph, ep_graph, ep) - if gm_info["index"] == 0: # Root graph + else: + _export_tensors_for_subgraph(graph, ep_graph, ep) + + if name == '': # root graph # Register inputs logger.debug("---------------Register inputs--------------") for in_spec in ep.graph_signature.input_specs: @@ -108,7 +106,6 @@ def build_circle( # Export operators logger.debug("---------------Export operators--------------") visitors = get_node_visitors(op_codes, graph) - ep_graph.print_tabular() for node in ep_graph.nodes: if node.op != "call_function": continue @@ -161,8 +158,6 @@ def _export_tensors(graph: CircleSubgraph, ep_graph, ep: ExportedProgram) -> Non if node.target in multiple_output_ops: continue node_val = node.meta["val"] - if node.name == 'cond': - continue if node_val.layout != torch.strided: raise RuntimeError( f"Only support dense tensors (node layout: {node_val.layout})" @@ -179,7 +174,7 @@ def _export_tensors(graph: CircleSubgraph, ep_graph, ep: ExportedProgram) -> Non elif node.op == "output": for output in node.args[0]: if isinstance(output, torch.fx.Node): - assert graph.has_tensor(output.name) + assert graph.has_tensor(output.name), f"{output}" continue elif node.op == "call_method": diff --git a/tico/serialize/operators/op_circle_if.py b/tico/serialize/operators/op_circle_if.py index d4c49765..cb2aaed7 100644 --- a/tico/serialize/operators/op_circle_if.py +++ b/tico/serialize/operators/op_circle_if.py @@ -26,11 +26,11 @@ from tico.serialize.operators.utils import create_builtin_operator, get_op_index from tico.utils.validate_args_kwargs import CircleIfArgs from tico.utils.errors import NotYetSupportedError - +from tico.utils.subgraph import get_frozen_subgraphs @register_node_visitor class CircleIfVisitor(NodeVisitor): - target: List[torch._ops.OpOverload] = [torch.ops.circle_custom.if_] + target: List[torch._ops.OpOverload] = [torch.ops.circle_custom.if_, torch.ops.circle_custom.if_.default] def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): super().__init__(op_codes, graph) @@ -45,17 +45,28 @@ def define_node( if_args = CircleIfArgs(*node.args, **node.kwargs) pred = if_args.pred - then_idx = if_args.then_graph_idx - else_idx = if_args.else_graph_idx + then_graph = if_args.then_graph + else_graph = if_args.else_graph arguments = if_args.if_args + then_graph_idx = None + else_graph_idx = None + for frozen_subgraph in get_frozen_subgraphs(): + if frozen_subgraph.name == then_graph.name: + then_graph_idx = frozen_subgraph.idx + if frozen_subgraph.name == else_graph.name: + else_graph_idx = frozen_subgraph.idx + assert then_graph_idx is not None + assert else_graph_idx is not None + + inputs = [pred, *arguments] outputs = [node] - + # outputs = [i for i in node.users.keys()] operator = create_builtin_operator(self.graph, op_index, inputs, outputs) operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.IfOptions operator.builtinOptions = circle.IfOptions.IfOptionsT() - operator.builtinOptions.thenSubgraphIndex = then_idx - operator.builtinOptions.elseSubgraphIndex = else_idx + operator.builtinOptions.thenSubgraphIndex = then_graph_idx + operator.builtinOptions.elseSubgraphIndex = else_graph_idx return operator diff --git a/tico/utils/convert.py b/tico/utils/convert.py index 4105a5bf..57882c54 100644 --- a/tico/utils/convert.py +++ b/tico/utils/convert.py @@ -87,7 +87,7 @@ trace_graph_diff_on_func, ) from tico.utils.utils import has_quantization_ops, SuppressWarning -from tico.utils.subgraph import get_gm_map +from tico.utils.subgraph import get_all_graph_modules @trace_const_diff_on_func @@ -199,11 +199,7 @@ def convert_exported_module_to_circle( assert isinstance(config, CompileConfigBase) - for gm_info in get_gm_map(exported_program): - if gm_info["name"]: #non-root subgraph - graph_module = getattr(exported_program.graph_module, gm_info["name"]) - else: - graph_module = exported_program.graph_module + for graph_module, _ in get_all_graph_modules(exported_program): logger = logging.getLogger(__name__) logger.debug("Input ExportedProgram (must be core aten)") logger.debug(exported_program) @@ -232,13 +228,9 @@ def convert_exported_module_to_circle( # CompositeImplicitAutograd and have functional schema are safe to not decompose. exported_program = traced_run_decompositions(exported_program) - for gm_info in get_gm_map(exported_program): - if gm_info["name"]: #non-root subgraph - graph_module = getattr(exported_program.graph_module, gm_info["name"]) - else: - graph_module = exported_program.graph_module + for graph_module, _ in get_all_graph_modules(exported_program): graph = graph_module.graph - + reinterpret_pass = PassManager( passes=[ MapSubgraph(), diff --git a/tico/utils/register_custom_op.py b/tico/utils/register_custom_op.py index 82005986..d1bf0839 100644 --- a/tico/utils/register_custom_op.py +++ b/tico/utils/register_custom_op.py @@ -19,22 +19,10 @@ from torch.library import custom_op, register_fake from tico.utils.mx.mx_ops import _quantize_mx -from tico.utils.subgraph import get_gm_map -# Note that an operator assumes input tensor has NHWC format. def CircleIf(): @custom_op("circle_custom::if_", mutates_args=()) - def if_(pred: torch.Tensor, true_graph_idx: int, false_graph_idx: int, if_args: List[torch.Tensor]) -> torch.Tensor: - true_graph = None - false_graph = None - for gm_info in get_gm_map(): - if gm_info["index"] == true_graph_idx: - true_graph = gm_info["gm"] - continue - if gm_info["index"] == false_graph_idx: - false_graph = gm_info["gm"] - continue - + def if_(pred: torch.Tensor, true_graph: torch.Tensor, false_graph: torch.Tensor, if_args: List[torch.Tensor]) -> torch.Tensor: if pred: result = true_graph(*if_args) assert len(result) == 1 # TODO: Support tuple of result @@ -45,23 +33,12 @@ def if_(pred: torch.Tensor, true_graph_idx: int, false_graph_idx: int, if_args: return result[0] @register_fake("circle_custom::if_") - def _(pred: torch.Tensor, true_graph_idx: int, false_graph_idx: int, if_args: List): - true_graph = None - false_graph = None - for gm_info in get_gm_map(): - if gm_info["index"] == true_graph_idx: - true_graph = gm_info["gm"] - continue - if gm_info["index"] == false_graph_idx: - false_graph = gm_info["gm"] - continue - + def _(pred: torch.Tensor, true_graph: torch.Tensor, false_graph: torch.Tensor, if_args: List[torch.Tensor]): result = true_graph(*if_args) assert len(result) == 1 # TODO: Support tuple of result return result[0] - # Note that an operator assumes input tensor has NHWC format. def CircleResizeNearestNeighbor(): @custom_op("circle_custom::resize_nearest_neighbor", mutates_args=()) diff --git a/tico/utils/subgraph.py b/tico/utils/subgraph.py index 3c9aa8ea..80d8dcca 100644 --- a/tico/utils/subgraph.py +++ b/tico/utils/subgraph.py @@ -1,43 +1,42 @@ import torch from torch.export import ExportedProgram -from typing import Optional -import functools +from copy import deepcopy +from typing import Iterator, List, Iterator +from dataclasses import dataclass +@dataclass +class FrozenSubgraph: + idx: int + name: str # model-wise, unique name + frozen_graph_module: torch.fx.GraphModule # copied subgraph -_gm_map = None -def get_gm_map(ep: Optional[ExportedProgram] = None): +_frozen_subgraphs: List[FrozenSubgraph] = [] + +def freeze_subgraphs(ep: ExportedProgram): """ - Returns [{"index":0, "name": "true_graph_0", "getter": lambda ep: ep.graph_module}, ...}] + Freeze subgraphs to provide shape inference logic of FakeTensor. """ - # Build _gm_map only once while compiler running - global _gm_map - if _gm_map is None: - assert ep is not None - _gm_map = _build_gm_map(ep) - return _gm_map + for idx, (graph_module, name) in enumerate(get_all_graph_modules(ep, subgraph_only=True), start = 1): + global _frozen_subgraphs + _frozen_subgraphs += [FrozenSubgraph(idx = idx, name = name, frozen_graph_module = deepcopy(graph_module))] -def _build_gm_map(ep: ExportedProgram): - ret = [] - - # root GraphModule 추가 - ret.append({ - "index": len(ret), - "name": "", - "gm": ep.graph_module, - }) +def get_frozen_subgraphs() -> List[FrozenSubgraph]: + global _frozen_subgraphs + return _frozen_subgraphs + + +def get_all_graph_modules(ep: ExportedProgram, subgraph_only: bool = False) -> Iterator[tuple[torch.fx.GraphModule, str]]: + """ + Get all graph modules and its name + """ + if not subgraph_only: + yield ep.graph_module, "" # root has no name - # Inspect non-root subgraphs + # yield subgraphs for node in ep.graph.nodes: if node.op == "get_attr": - attr = getattr(node.graph.owning_module, node.target) + graph_module = getattr(node.graph.owning_module, node.target) # TODO: Enable recursion (n-depth) - if isinstance(attr, torch.fx.graph_module.GraphModule): - assert hasattr(node, 'name') - assert getattr(node, 'name') != ret[0]["name"] - graph_name = getattr(node, 'name') - ret.append({ - "index": len(ret), - "name": graph_name, - "gm": attr, - }) - return ret + if isinstance(graph_module, torch.fx.graph_module.GraphModule): + assert hasattr(graph_module, 'meta') + yield graph_module, getattr(node, 'name') diff --git a/tico/utils/utils.py b/tico/utils/utils.py index 3848e9b2..1a134fb3 100644 --- a/tico/utils/utils.py +++ b/tico/utils/utils.py @@ -184,18 +184,23 @@ def set_new_meta_val(node: torch.fx.node.Node): - After updating node's args or kwargs """ assert isinstance(node, torch.fx.node.Node) + + def _get_meta_val(node): + assert hasattr(node, 'meta'), f"'node' has no attribute named 'meta' (node: {node})" + assert "val" in node.meta, f"val key not in node.meta (node: {node}, meta: {node.meta})" + return node.meta["val"] # `node.target()` needs only `Tensor` for its arguments. # Therefore, let's retrieve `FakeTensor` if it is `torch.fx.Node`. args, kwargs = pytree.tree_map_only( torch.fx.Node, - lambda n: n.meta["val"], + _get_meta_val, (node.args, node.kwargs), ) + new_val = node.target(*args, **kwargs) # type: ignore[operator] node.meta["val"] = new_val - def unset_meta_val(node: torch.fx.node.Node): """ Unset node.meta["val"]. diff --git a/tico/utils/validate_args_kwargs.py b/tico/utils/validate_args_kwargs.py index b5de2337..14d97cb1 100644 --- a/tico/utils/validate_args_kwargs.py +++ b/tico/utils/validate_args_kwargs.py @@ -211,8 +211,8 @@ class CircleIfArgs: """ """ pred: torch.fx.Node - then_graph_idx: int - else_graph_idx: int + then_graph: torch.fx.Node + else_graph: torch.fx.Node if_args: torch.fx.immutable_collections.immutable_list @enforce_type From 8764327a1fb999e41abf0698a3ac87a1d33384a2 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Thu, 2 Oct 2025 13:31:28 +0900 Subject: [PATCH 6/9] Refactor and format --- test/modules/op/cond.py | 43 ++++++---- .../passes/{map_subgraph.py => lower_cond.py} | 82 +++++++++++++------ .../passes/propagate_qparam_forward.py | 1 + tico/passes/cast_clamp_mixed_type_args.py | 4 +- tico/passes/const_prop_pass.py | 9 +- tico/serialize/circle_graph.py | 6 +- tico/serialize/circle_serializer.py | 30 ++++--- tico/serialize/operators/op_circle_if.py | 22 +++-- tico/serialize/operators/utils.py | 11 ++- tico/utils/convert.py | 16 ++-- tico/utils/register_custom_op.py | 24 ++++-- tico/utils/subgraph.py | 53 ++++++------ tico/utils/utils.py | 13 ++- tico/utils/validate_args_kwargs.py | 4 + 14 files changed, 204 insertions(+), 114 deletions(-) rename tico/experimental/controlflow/passes/{map_subgraph.py => lower_cond.py} (55%) diff --git a/test/modules/op/cond.py b/test/modules/op/cond.py index b543567f..6856ee55 100644 --- a/test/modules/op/cond.py +++ b/test/modules/op/cond.py @@ -17,6 +17,7 @@ from test.modules.base import TestModuleBase from test.utils.tag import use_onert + @use_onert class SimpleCond1(TestModuleBase): class Sin(torch.nn.Module): @@ -33,14 +34,16 @@ def __init__(self): self.cos = self.Cos() def forward(self, x, y): - return torch.cond(x.sum() + y.sum() > 0, - lambda x_: self.sin(x_), - lambda x_: self.cos(x_), - operands=(x,)) + return torch.cond( + x.sum() + y.sum() > 0, + lambda x_: self.sin(x_), + lambda x_: self.cos(x_), + operands=(x,), + ) + def get_example_inputs(self): return (torch.randn(3, 3), torch.randn(3, 3)), {} - - + @use_onert class SimpleCond2(TestModuleBase): @@ -58,14 +61,17 @@ def __init__(self): self.cos = self.Cos() def forward(self, x, y): - return torch.cond(x.sum() + y.sum() > 0, - lambda x, y: self.sin(x, y), - lambda x, y: self.cos(x, y), - operands=(x,y)) + return torch.cond( + x.sum() + y.sum() > 0, + lambda x, y: self.sin(x, y), + lambda x, y: self.cos(x, y), + operands=(x, y), + ) + def get_example_inputs(self): return (torch.randn(3, 3), torch.randn(3, 3)), {} - - + + @use_onert class SimpleCond3(TestModuleBase): class Sin(torch.nn.Module): @@ -82,10 +88,13 @@ def __init__(self): self.cos = self.Cos() def forward(self, x, y): - return torch.cond(x.sum() + y.sum() > 0, - lambda x, y: self.sin(x, y), - lambda x, y: self.cos(x, y), - operands=(x,y)) + return torch.cond( + x.sum() + y.sum() > 0, + lambda x, y: self.sin(x, y), + lambda x, y: self.cos(x, y), + operands=(x, y), + ) + def get_example_inputs(self): return (torch.randn(3, 3), torch.randn(3, 3)), {} @@ -104,4 +113,4 @@ def get_example_inputs(self): # print(exported_model.graph_signature.user_inputs) # print(exported_model.graph_signature.user_outputs) # print(output) -# breakpoint() \ No newline at end of file +# breakpoint() diff --git a/tico/experimental/controlflow/passes/map_subgraph.py b/tico/experimental/controlflow/passes/lower_cond.py similarity index 55% rename from tico/experimental/controlflow/passes/map_subgraph.py rename to tico/experimental/controlflow/passes/lower_cond.py index c74d037d..d873c8a0 100644 --- a/tico/experimental/controlflow/passes/map_subgraph.py +++ b/tico/experimental/controlflow/passes/lower_cond.py @@ -17,49 +17,65 @@ if TYPE_CHECKING: import torch.fx +import operator + import torch -from torch.export import ExportedProgram from tico.utils import logging +from tico.utils.graph import create_node from tico.utils.passes import PassBase, PassResult +from tico.utils.subgraph import get_all_graph_modules, store_subgraph_indices from tico.utils.trace_decorators import trace_graph_diff_on_pass from tico.utils.validate_args_kwargs import CondArgs -from tico.utils.graph import create_node -from tico.utils.subgraph import get_all_graph_modules, freeze_subgraphs -import operator +from torch.export import ExportedProgram from torch.utils import _pytree as pytree + @trace_graph_diff_on_pass -class MapSubgraph(PassBase): +class LowerCond(PassBase): + # Pass that lowers `torch.cond` higher‑order ops into a custom intermediate representation. """ + To support torch.cond, + (1) fill in the meta values, which requires specific subgraph inference. (this process differs from that of filling meta of other tensors) + (2) freeze the subgraph information to ensure no further modifications during compilation or export phases. + (3) translate the frozen subgraph along with meta-information into a custom intermediate representation (IR) """ def __init__(self): + # Initialise the base Pass class. super().__init__() def call(self, exported_program: ExportedProgram, _) -> PassResult: + # Main entry point for the pass. It walks the graph, finds `torch.ops.higher_order.cond` + # nodes, extracts their subgraphs, computes meta‑values, freezes the subgraphs, and + # replaces the original cond node with a custom `circle_custom.if_` node. logger = logging.getLogger(__name__) graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph for node in graph.nodes: + # Iterate over all nodes to locate `torch.ops.higher_order.cond` calls. if node.op != "call_function": continue - if ( - node.target - != torch.ops.higher_order.cond - ): + if node.target != torch.ops.higher_order.cond: + # Skip nodes that are not `torch.ops.higher_order.cond`. continue - + cond_args = CondArgs(*node.args, **node.kwargs) + # Extract the true/false subgraphs and the condition arguments. true_graph = cond_args.true_graph false_graph = cond_args.false_graph graph_args = cond_args.cond_args - + def _set_meta_val(graph_node, graph_module, graph_args): + # Helper to compute and set the `meta["val"]` for a node in a subgraph. def _get_meta_val(node): - assert hasattr(node, 'meta'), f"'node' has no attribute named 'meta' (node: {node})" - assert "val" in node.meta, f"val key not in node.meta (node: {node}, meta: {node.meta})" + assert hasattr( + node, "meta" + ), f"'node' has no attribute named 'meta' (node: {node})" + assert ( + "val" in node.meta + ), f"val key not in node.meta (node: {node}, meta: {node.meta})" return node.meta["val"] args, kwargs = pytree.tree_map_only( @@ -67,44 +83,60 @@ def _get_meta_val(node): _get_meta_val, (graph_args, {}), ) - + new_val = graph_module(*args, **kwargs) # type: ignore[operator] + # Execute the subgraph with concrete arguments to obtain runtime values. graph_node.meta["val"] = new_val - - for graph_module, name in get_all_graph_modules(exported_program, subgraph_only=True): + + for graph_module, name in get_all_graph_modules( + exported_program, subgraph_only=True + ): if true_graph.name == name: _set_meta_val(true_graph, graph_module, graph_args) if false_graph.name == name: _set_meta_val(false_graph, graph_module, graph_args) - + assert "val" in true_graph.meta, f"{true_graph} has no node.meta['val']" assert "val" in false_graph.meta, f"{false_graph} has no node.meta['val']" - - freeze_subgraphs(exported_program) + + store_subgraph_indices(exported_program) + # Freeze subgraphs to prevent further modifications during later compilation stages. with graph.inserting_before(node): + # Create a custom `circle_custom.if_` node that represents the lowered conditional. circle_if = create_node( graph, torch.ops.circle_custom.if_, - args=(cond_args.condition, cond_args.true_graph, cond_args.false_graph, cond_args.cond_args), + args=( + cond_args.condition, + cond_args.true_graph, + cond_args.false_graph, + cond_args.cond_args, + ), kwargs={}, origin=node, ) - - for t, f in zip(true_graph.meta['val'], false_graph.meta['val']): + + for t, f in zip(true_graph.meta["val"], false_graph.meta["val"]): + # Ensure the true and false branches produce compatible tensors. assert type(t) == type(f) assert t.shape == f.shape, f"{t.shape} != {f.shape}" assert t.dtype == f.dtype, f"{t.dtype} != {f.dtype}" - - circle_if.meta["val"] = true_graph.meta['val'][0] - + + circle_if.meta["val"] = true_graph.meta["val"][0] + # FIX ME UNLESS torch.ops.higher_order.cond generates this pattern assert len(node.users) == 1 + # The original cond node should have exactly one user: the getitem extracting the result. getitem_node = list(node.users.items())[0][0] assert getitem_node.target == operator.getitem getitem_node.replace_all_uses_with(circle_if) + graph.eliminate_dead_code() + # Clean up any nodes that are no longer reachable after replacement. graph.lint() + # Verify graph consistency. graph_module.recompile() + # Recompile the graph module to reflect the updated graph. # Run only once. return PassResult(False) diff --git a/tico/experimental/quantization/passes/propagate_qparam_forward.py b/tico/experimental/quantization/passes/propagate_qparam_forward.py index 56d835ae..a900cf0f 100644 --- a/tico/experimental/quantization/passes/propagate_qparam_forward.py +++ b/tico/experimental/quantization/passes/propagate_qparam_forward.py @@ -66,6 +66,7 @@ def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node): dst.meta[QPARAM_KEY] = copy.deepcopy(src.meta[QPARAM_KEY]) logger.debug(f"{src.name}'s quantparam is propagated to {dst.name}.") + graph: torch.fx.Graph = graph_module.graph for node in graph.nodes: if node.op != "call_function": diff --git a/tico/passes/cast_clamp_mixed_type_args.py b/tico/passes/cast_clamp_mixed_type_args.py index 651a626d..1c4ac0f7 100644 --- a/tico/passes/cast_clamp_mixed_type_args.py +++ b/tico/passes/cast_clamp_mixed_type_args.py @@ -92,7 +92,9 @@ class CastClampMixedTypeArgs(PassBase): def __init__(self): super().__init__() - def convert(self, exported_program: ExportedProgram, node: torch.fx.Node, graph_module) -> bool: + def convert( + self, exported_program: ExportedProgram, node: torch.fx.Node, graph_module + ) -> bool: logger = logging.getLogger(__name__) modified = False diff --git a/tico/passes/const_prop_pass.py b/tico/passes/const_prop_pass.py index 9d697bb0..68973c23 100644 --- a/tico/passes/const_prop_pass.py +++ b/tico/passes/const_prop_pass.py @@ -113,13 +113,14 @@ def get_data( def propagate_constants( - exported_program: ExportedProgram, - graph_module + exported_program: ExportedProgram, graph_module ) -> OrderedDict[torch.fx.Node, torch.Tensor]: """ Propagates constants and returns a dictionary of node to constant tensors of the graph. """ - const_node_to_tensor = get_constant_placeholder_to_tensor_dict(exported_program, graph_module) + const_node_to_tensor = get_constant_placeholder_to_tensor_dict( + exported_program, graph_module + ) graph: torch.fx.Graph = graph_module.graph for node in graph.nodes: @@ -177,7 +178,7 @@ def erase_constant_node( def create_constant_placeholder( const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor], exported_program: ExportedProgram, - graph_module + graph_module, ) -> List[torch.fx.Node]: """ This function creates constant placeholder nodes according to the given constant nodes (`const_node_to_tensor`) and replace it with the original node. diff --git a/tico/serialize/circle_graph.py b/tico/serialize/circle_graph.py index a8b53c90..f4a5e18e 100644 --- a/tico/serialize/circle_graph.py +++ b/tico/serialize/circle_graph.py @@ -69,7 +69,9 @@ class CircleModel(circle.Model.ModelT): def __init__(self): super().__init__() self.subgraphs: List[circle.SubGraph.SubGraphT] = [] - self.buffers: List[circle.Buffer.BufferT] = [circle.Buffer.BufferT()] # Add empty buffer at the front + self.buffers: List[circle.Buffer.BufferT] = [ + circle.Buffer.BufferT() + ] # Add empty buffer at the front def add_subgraph(self, graph: circle.SubGraph.SubGraphT) -> None: self.subgraphs.append(graph) @@ -80,6 +82,7 @@ def add_buffer(self, buffer: circle.Buffer.BufferT) -> int: buf_id = len(self.buffers) - 1 # last index return buf_id + @final class CircleSubgraph(circle.SubGraph.SubGraphT): def __init__(self, model: CircleModel): @@ -96,7 +99,6 @@ def __init__(self, model: CircleModel): # human-readable tensor names after serialization. self.name_to_node: Dict[str, torch.fx.Node] = {} self.counter: defaultdict = defaultdict(int) - # Generate a unique name with prefix. # Naming rule diff --git a/tico/serialize/circle_serializer.py b/tico/serialize/circle_serializer.py index 0f3299d9..5f70d363 100644 --- a/tico/serialize/circle_serializer.py +++ b/tico/serialize/circle_serializer.py @@ -38,6 +38,7 @@ # torch.ops.circle_custom.if_, ] + def _initialize_model() -> tuple[CircleModel, CircleSubgraph]: """Initialize a new Circle model and subgraph. @@ -64,19 +65,19 @@ def build_circle( logger = logging.getLogger(__name__) builder = flatbuffers.Builder() model = CircleModel() - + op_codes: Dict[OpCode, int] = {} for graph_module, name in get_all_graph_modules(ep): ep_graph = graph_module.graph graph = CircleSubgraph(model) - + # Export tensors - if name == '': # root graph + if name == "": # root graph _export_tensors(graph, ep_graph, ep) else: _export_tensors_for_subgraph(graph, ep_graph, ep) - if name == '': # root graph + if name == "": # root graph # Register inputs logger.debug("---------------Register inputs--------------") for in_spec in ep.graph_signature.input_specs: @@ -123,7 +124,7 @@ def build_circle( finalise_tensor_names(graph) validate_tensor_shapes(graph) - + model.subgraphs.append(graph) # Encode operator codes @@ -187,7 +188,9 @@ def _export_tensors(graph: CircleSubgraph, ep_graph, ep: ExportedProgram) -> Non raise AssertionError(f"Unknown fx.Node op {node.op}") -def _export_tensors_for_subgraph(graph: CircleSubgraph, ep_graph, ep: ExportedProgram) -> None: +def _export_tensors_for_subgraph( + graph: CircleSubgraph, ep_graph, ep: ExportedProgram +) -> None: """Export all tensors from the exported program to the circle graph. Args: @@ -196,14 +199,16 @@ def _export_tensors_for_subgraph(graph: CircleSubgraph, ep_graph, ep: ExportedPr """ logger = logging.getLogger(__name__) logger.debug("---------------Export tensors--------------") - buf_name_to_data = {name: buf for name, buf in ep.named_buffers()} #model-wise context + buf_name_to_data = { + name: buf for name, buf in ep.named_buffers() + } # model-wise context for node in ep_graph.nodes: if node.op == "call_function": if node.target in multiple_output_ops: continue node_val = node.meta["val"] - if node.name == 'cond': + if node.name == "cond": continue if node_val.layout != torch.strided: raise RuntimeError( @@ -214,15 +219,15 @@ def _export_tensors_for_subgraph(graph: CircleSubgraph, ep_graph, ep: ExportedPr elif node.op == "placeholder": _handle_placeholder_node(graph, node, ep_graph, ep, buf_name_to_data) - graph.add_input(node.name) # This is added for subgraph + graph.add_input(node.name) # This is added for subgraph elif node.op == "get_attr": _handle_get_attr_node(graph, node) elif node.op == "output": - for output in node.args[0]: + for output in node.args[0]: if isinstance(output, torch.fx.Node): assert graph.has_tensor(output.name) - graph.add_output(output.name) # This is added for subgraph + graph.add_output(output.name) # This is added for subgraph continue elif node.op == "call_method": @@ -234,10 +239,11 @@ def _export_tensors_for_subgraph(graph: CircleSubgraph, ep_graph, ep: ExportedPr else: raise AssertionError(f"Unknown fx.Node op {node.op}") + def _handle_placeholder_node( graph: CircleSubgraph, node: torch.fx.Node, - ep_graph, + ep_graph, ep: ExportedProgram, buf_name_to_data: dict, ) -> None: diff --git a/tico/serialize/operators/op_circle_if.py b/tico/serialize/operators/op_circle_if.py index cb2aaed7..9ec3ed5a 100644 --- a/tico/serialize/operators/op_circle_if.py +++ b/tico/serialize/operators/op_circle_if.py @@ -24,13 +24,17 @@ from tico.serialize.operators.hashable_opcode import OpCode from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor from tico.serialize.operators.utils import create_builtin_operator, get_op_index -from tico.utils.validate_args_kwargs import CircleIfArgs from tico.utils.errors import NotYetSupportedError -from tico.utils.subgraph import get_frozen_subgraphs +from tico.utils.subgraph import get_subgraph_indices +from tico.utils.validate_args_kwargs import CircleIfArgs + @register_node_visitor class CircleIfVisitor(NodeVisitor): - target: List[torch._ops.OpOverload] = [torch.ops.circle_custom.if_, torch.ops.circle_custom.if_.default] + target: List[torch._ops.OpOverload] = [ + torch.ops.circle_custom.if_, + torch.ops.circle_custom.if_.default, + ] def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): super().__init__(op_codes, graph) @@ -43,23 +47,23 @@ def define_node( circle.BuiltinOperator.BuiltinOperator.IF, self._op_codes ) if_args = CircleIfArgs(*node.args, **node.kwargs) - + pred = if_args.pred then_graph = if_args.then_graph else_graph = if_args.else_graph arguments = if_args.if_args - + then_graph_idx = None else_graph_idx = None - for frozen_subgraph in get_frozen_subgraphs(): + + for frozen_subgraph in get_subgraph_indices(): if frozen_subgraph.name == then_graph.name: then_graph_idx = frozen_subgraph.idx if frozen_subgraph.name == else_graph.name: else_graph_idx = frozen_subgraph.idx assert then_graph_idx is not None assert else_graph_idx is not None - - + inputs = [pred, *arguments] outputs = [node] # outputs = [i for i in node.users.keys()] @@ -68,5 +72,5 @@ def define_node( operator.builtinOptions = circle.IfOptions.IfOptionsT() operator.builtinOptions.thenSubgraphIndex = then_graph_idx operator.builtinOptions.elseSubgraphIndex = else_graph_idx - + return operator diff --git a/tico/serialize/operators/utils.py b/tico/serialize/operators/utils.py index 3b4fa8c1..1b87f5b3 100644 --- a/tico/serialize/operators/utils.py +++ b/tico/serialize/operators/utils.py @@ -40,6 +40,7 @@ def get_op_index(opcode: int, opcode_map: Dict[OpCode, int]) -> int: op_index = opcode_map[op_code] return op_index + import torch # TODO Move this to CircleSubGraph @@ -48,11 +49,13 @@ def create_builtin_operator( ) -> circle.Operator.OperatorT: operator = circle.Operator.OperatorT() operator.opcodeIndex = op_index - + operator.inputs = [] for inp in inputs: if isinstance(inp, torch.fx.immutable_collections.immutable_list): - operator.inputs.append(tuple(graph.get_tid(inp_item) for inp_item in inp)) # TODO: extend to multiple tuple processing + operator.inputs.append( + tuple(graph.get_tid(inp_item) for inp_item in inp) + ) # TODO: extend to multiple tuple processing print(f"input: {inp}") else: operator.inputs.append(graph.get_tid(inp)) @@ -60,7 +63,9 @@ def create_builtin_operator( for outp in outputs: if isinstance(outp, torch.fx.immutable_collections.immutable_list): print(f"output: {outp}") - operator.outputs.append(tuple(graph.get_tid(outp_item) for outp_item in outp)) # TODO: extend to multiple tuple processing + operator.outputs.append( + tuple(graph.get_tid(outp_item) for outp_item in outp) + ) # TODO: extend to multiple tuple processing else: operator.outputs.append(graph.get_tid(outp)) return operator diff --git a/tico/utils/convert.py b/tico/utils/convert.py index 57882c54..e68012c9 100644 --- a/tico/utils/convert.py +++ b/tico/utils/convert.py @@ -20,6 +20,7 @@ from torch.export import export, ExportedProgram from tico.config import CompileConfigBase, get_default_config +from tico.experimental.controlflow.passes.lower_cond import LowerCond from tico.experimental.quantization.passes.fold_quant_ops import FoldQuantOps from tico.experimental.quantization.passes.insert_quantize_on_dtype_mismatch import ( InsertQuantizeOnDtypeMismatch, @@ -34,9 +35,6 @@ from tico.experimental.quantization.passes.remove_weight_dequant_op import ( RemoveWeightDequantOp, ) -from tico.experimental.controlflow.passes.map_subgraph import ( - MapSubgraph, -) from tico.passes.cast_aten_where_arg_type import CastATenWhereArgType from tico.passes.cast_clamp_mixed_type_args import CastClampMixedTypeArgs from tico.passes.cast_mixed_type_args import CastMixedTypeArgs @@ -82,12 +80,12 @@ from tico.utils.errors import NotYetSupportedError from tico.utils.model import CircleModel from tico.utils.passes import PassManager +from tico.utils.subgraph import get_all_graph_modules from tico.utils.trace_decorators import ( trace_const_diff_on_func, trace_graph_diff_on_func, ) from tico.utils.utils import has_quantization_ops, SuppressWarning -from tico.utils.subgraph import get_all_graph_modules @trace_const_diff_on_func @@ -198,7 +196,7 @@ def convert_exported_module_to_circle( config = get_default_config() assert isinstance(config, CompileConfigBase) - + for graph_module, _ in get_all_graph_modules(exported_program): logger = logging.getLogger(__name__) logger.debug("Input ExportedProgram (must be core aten)") @@ -227,13 +225,13 @@ def convert_exported_module_to_circle( # UserWarning: At pre-dispatch tracing, we assume that any custom op marked with # CompositeImplicitAutograd and have functional schema are safe to not decompose. exported_program = traced_run_decompositions(exported_program) - + for graph_module, _ in get_all_graph_modules(exported_program): graph = graph_module.graph reinterpret_pass = PassManager( passes=[ - MapSubgraph(), + LowerCond(), ] ) reinterpret_pass.run(exported_program, graph_module) @@ -265,7 +263,9 @@ def convert_exported_module_to_circle( CastMixedTypeArgs(preserve_ep_invariant=True), ConstPropPass(), SegmentIndexSelectConst(), - LegalizeCausalMaskValue(enabled=config.get("legalize_causal_mask_value")), + LegalizeCausalMaskValue( + enabled=config.get("legalize_causal_mask_value") + ), ConvertMatmulToLinear( enable_lhs_const=config.get("convert_lhs_const_mm_to_fc"), enable_rhs_const=config.get("convert_rhs_const_mm_to_fc"), diff --git a/tico/utils/register_custom_op.py b/tico/utils/register_custom_op.py index d1bf0839..de404abd 100644 --- a/tico/utils/register_custom_op.py +++ b/tico/utils/register_custom_op.py @@ -20,25 +20,37 @@ from tico.utils.mx.mx_ops import _quantize_mx + def CircleIf(): @custom_op("circle_custom::if_", mutates_args=()) - def if_(pred: torch.Tensor, true_graph: torch.Tensor, false_graph: torch.Tensor, if_args: List[torch.Tensor]) -> torch.Tensor: + def if_( + pred: torch.Tensor, + true_graph: torch.Tensor, + false_graph: torch.Tensor, + if_args: List[torch.Tensor], + ) -> torch.Tensor: if pred: result = true_graph(*if_args) - assert len(result) == 1 # TODO: Support tuple of result + assert len(result) == 1 # TODO: Support tuple of result return result[0] else: result = false_graph(*if_args) - assert len(result) == 1 # TODO: Support tuple of result + assert len(result) == 1 # TODO: Support tuple of result return result[0] @register_fake("circle_custom::if_") - def _(pred: torch.Tensor, true_graph: torch.Tensor, false_graph: torch.Tensor, if_args: List[torch.Tensor]): + def _( + pred: torch.Tensor, + true_graph: torch.Tensor, + false_graph: torch.Tensor, + if_args: List[torch.Tensor], + ): result = true_graph(*if_args) - assert len(result) == 1 # TODO: Support tuple of result - + assert len(result) == 1 # TODO: Support tuple of result + return result[0] + # Note that an operator assumes input tensor has NHWC format. def CircleResizeNearestNeighbor(): @custom_op("circle_custom::resize_nearest_neighbor", mutates_args=()) diff --git a/tico/utils/subgraph.py b/tico/utils/subgraph.py index 80d8dcca..cd42d9ba 100644 --- a/tico/utils/subgraph.py +++ b/tico/utils/subgraph.py @@ -1,42 +1,49 @@ -import torch -from torch.export import ExportedProgram from copy import deepcopy -from typing import Iterator, List, Iterator from dataclasses import dataclass +from typing import Iterator, Iterator, List + +import torch +from torch.export import ExportedProgram + + @dataclass -class FrozenSubgraph: +class SubgraphIdx: idx: int - name: str # model-wise, unique name - frozen_graph_module: torch.fx.GraphModule # copied subgraph + name: str # model-wise, unique name -_frozen_subgraphs: List[FrozenSubgraph] = [] -def freeze_subgraphs(ep: ExportedProgram): - """ - Freeze subgraphs to provide shape inference logic of FakeTensor. - """ - for idx, (graph_module, name) in enumerate(get_all_graph_modules(ep, subgraph_only=True), start = 1): - global _frozen_subgraphs - _frozen_subgraphs += [FrozenSubgraph(idx = idx, name = name, frozen_graph_module = deepcopy(graph_module))] +_subgraph_indices: List[SubgraphIdx] = [] + + +def store_subgraph_indices(ep: ExportedProgram): + global _subgraph_indices + + for idx, (_, name) in enumerate( + get_all_graph_modules(ep, subgraph_only=True), start=1 + ): + _subgraph_indices += [SubgraphIdx(idx=idx, name=name)] -def get_frozen_subgraphs() -> List[FrozenSubgraph]: - global _frozen_subgraphs - return _frozen_subgraphs +def get_subgraph_indices() -> List[SubgraphIdx]: + global _subgraph_indices + return _subgraph_indices -def get_all_graph_modules(ep: ExportedProgram, subgraph_only: bool = False) -> Iterator[tuple[torch.fx.GraphModule, str]]: + +def get_all_graph_modules( + ep: ExportedProgram, subgraph_only: bool = False +) -> Iterator[tuple[torch.fx.GraphModule, str]]: """ Get all graph modules and its name """ if not subgraph_only: - yield ep.graph_module, "" # root has no name - + yield ep.graph_module, "" # root has no name + # yield subgraphs for node in ep.graph.nodes: if node.op == "get_attr": graph_module = getattr(node.graph.owning_module, node.target) - + # TODO: Enable recursion (n-depth) if isinstance(graph_module, torch.fx.graph_module.GraphModule): - assert hasattr(graph_module, 'meta') - yield graph_module, getattr(node, 'name') + assert hasattr(graph_module, "meta") + yield graph_module, getattr(node, "name") diff --git a/tico/utils/utils.py b/tico/utils/utils.py index 1a134fb3..da88d9c4 100644 --- a/tico/utils/utils.py +++ b/tico/utils/utils.py @@ -184,10 +184,14 @@ def set_new_meta_val(node: torch.fx.node.Node): - After updating node's args or kwargs """ assert isinstance(node, torch.fx.node.Node) - + def _get_meta_val(node): - assert hasattr(node, 'meta'), f"'node' has no attribute named 'meta' (node: {node})" - assert "val" in node.meta, f"val key not in node.meta (node: {node}, meta: {node.meta})" + assert hasattr( + node, "meta" + ), f"'node' has no attribute named 'meta' (node: {node})" + assert ( + "val" in node.meta + ), f"val key not in node.meta (node: {node}, meta: {node.meta})" return node.meta["val"] # `node.target()` needs only `Tensor` for its arguments. @@ -197,10 +201,11 @@ def _get_meta_val(node): _get_meta_val, (node.args, node.kwargs), ) - + new_val = node.target(*args, **kwargs) # type: ignore[operator] node.meta["val"] = new_val + def unset_meta_val(node: torch.fx.node.Node): """ Unset node.meta["val"]. diff --git a/tico/utils/validate_args_kwargs.py b/tico/utils/validate_args_kwargs.py index 14d97cb1..13a65a50 100644 --- a/tico/utils/validate_args_kwargs.py +++ b/tico/utils/validate_args_kwargs.py @@ -205,22 +205,26 @@ class CloneArgs: input: torch.fx.Node memory_format: Optional[torch.memory_format] = None + @enforce_type @dataclass class CircleIfArgs: """ + CondArgs """ pred: torch.fx.Node then_graph: torch.fx.Node else_graph: torch.fx.Node if_args: torch.fx.immutable_collections.immutable_list + @enforce_type @dataclass class CondArgs: """ # This is not aten operator but `torch.ops.higher_order_op.cond` """ + condition: torch.fx.Node true_graph: torch.fx.Node false_graph: torch.fx.Node From 0defa5e80ad9fa7d6abb15381443a27486706507 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Thu, 2 Oct 2025 13:52:56 +0900 Subject: [PATCH 7/9] Refactor --- .../controlflow/passes/lower_cond.py | 40 ++++++++++--------- tico/serialize/operators/op_circle_if.py | 17 +------- tico/utils/register_custom_op.py | 18 +++++---- tico/utils/subgraph.py | 27 +------------ tico/utils/validate_args_kwargs.py | 7 +++- 5 files changed, 41 insertions(+), 68 deletions(-) diff --git a/tico/experimental/controlflow/passes/lower_cond.py b/tico/experimental/controlflow/passes/lower_cond.py index d873c8a0..09feba4a 100644 --- a/tico/experimental/controlflow/passes/lower_cond.py +++ b/tico/experimental/controlflow/passes/lower_cond.py @@ -24,7 +24,7 @@ from tico.utils import logging from tico.utils.graph import create_node from tico.utils.passes import PassBase, PassResult -from tico.utils.subgraph import get_all_graph_modules, store_subgraph_indices +from tico.utils.subgraph import get_all_graph_modules from tico.utils.trace_decorators import trace_graph_diff_on_pass from tico.utils.validate_args_kwargs import CondArgs from torch.export import ExportedProgram @@ -33,42 +33,37 @@ @trace_graph_diff_on_pass class LowerCond(PassBase): - # Pass that lowers `torch.cond` higher‑order ops into a custom intermediate representation. """ - To support torch.cond, + To support torch.cond, with Circle If, translate into a custom IR. + Note that the custom IR must include the information of both graph node and graph index. + `graph node` is required to carry the graph until serialization step alive. + `graph index` is required to create the corresponding circle ir, because circle ir requires graph numbering. + (1) fill in the meta values, which requires specific subgraph inference. (this process differs from that of filling meta of other tensors) - (2) freeze the subgraph information to ensure no further modifications during compilation or export phases. - (3) translate the frozen subgraph along with meta-information into a custom intermediate representation (IR) + (2) get the subgraph index + (3) translate the information into a custom intermediate representation (IR) """ def __init__(self): - # Initialise the base Pass class. super().__init__() def call(self, exported_program: ExportedProgram, _) -> PassResult: - # Main entry point for the pass. It walks the graph, finds `torch.ops.higher_order.cond` - # nodes, extracts their subgraphs, computes meta‑values, freezes the subgraphs, and - # replaces the original cond node with a custom `circle_custom.if_` node. logger = logging.getLogger(__name__) graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph for node in graph.nodes: - # Iterate over all nodes to locate `torch.ops.higher_order.cond` calls. if node.op != "call_function": continue if node.target != torch.ops.higher_order.cond: - # Skip nodes that are not `torch.ops.higher_order.cond`. continue cond_args = CondArgs(*node.args, **node.kwargs) - # Extract the true/false subgraphs and the condition arguments. true_graph = cond_args.true_graph false_graph = cond_args.false_graph graph_args = cond_args.cond_args def _set_meta_val(graph_node, graph_module, graph_args): - # Helper to compute and set the `meta["val"]` for a node in a subgraph. def _get_meta_val(node): assert hasattr( node, "meta" @@ -85,24 +80,29 @@ def _get_meta_val(node): ) new_val = graph_module(*args, **kwargs) # type: ignore[operator] - # Execute the subgraph with concrete arguments to obtain runtime values. graph_node.meta["val"] = new_val - for graph_module, name in get_all_graph_modules( - exported_program, subgraph_only=True + # [1] Fill in the meta values + # [2] Get the subgraph indices + true_graph_idx = -1 + false_graph_idx = -1 + for idx, (graph_module, name) in enumerate( + get_all_graph_modules(exported_program, subgraph_only=True), start=1 ): if true_graph.name == name: _set_meta_val(true_graph, graph_module, graph_args) + true_graph_idx = idx if false_graph.name == name: _set_meta_val(false_graph, graph_module, graph_args) + false_graph_idx = idx assert "val" in true_graph.meta, f"{true_graph} has no node.meta['val']" assert "val" in false_graph.meta, f"{false_graph} has no node.meta['val']" + assert true_graph_idx != -1 + assert false_graph_idx != -1 - store_subgraph_indices(exported_program) - # Freeze subgraphs to prevent further modifications during later compilation stages. + # [3] Create the translated IR (circle_custom.if_) with graph.inserting_before(node): - # Create a custom `circle_custom.if_` node that represents the lowered conditional. circle_if = create_node( graph, torch.ops.circle_custom.if_, @@ -110,6 +110,8 @@ def _get_meta_val(node): cond_args.condition, cond_args.true_graph, cond_args.false_graph, + true_graph_idx, + false_graph_idx, cond_args.cond_args, ), kwargs={}, diff --git a/tico/serialize/operators/op_circle_if.py b/tico/serialize/operators/op_circle_if.py index 9ec3ed5a..f92eedb8 100644 --- a/tico/serialize/operators/op_circle_if.py +++ b/tico/serialize/operators/op_circle_if.py @@ -24,8 +24,6 @@ from tico.serialize.operators.hashable_opcode import OpCode from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor from tico.serialize.operators.utils import create_builtin_operator, get_op_index -from tico.utils.errors import NotYetSupportedError -from tico.utils.subgraph import get_subgraph_indices from tico.utils.validate_args_kwargs import CircleIfArgs @@ -49,20 +47,9 @@ def define_node( if_args = CircleIfArgs(*node.args, **node.kwargs) pred = if_args.pred - then_graph = if_args.then_graph - else_graph = if_args.else_graph arguments = if_args.if_args - - then_graph_idx = None - else_graph_idx = None - - for frozen_subgraph in get_subgraph_indices(): - if frozen_subgraph.name == then_graph.name: - then_graph_idx = frozen_subgraph.idx - if frozen_subgraph.name == else_graph.name: - else_graph_idx = frozen_subgraph.idx - assert then_graph_idx is not None - assert else_graph_idx is not None + then_graph_idx = if_args.then_graph_idx + else_graph_idx = if_args.else_graph_idx inputs = [pred, *arguments] outputs = [node] diff --git a/tico/utils/register_custom_op.py b/tico/utils/register_custom_op.py index de404abd..023de45d 100644 --- a/tico/utils/register_custom_op.py +++ b/tico/utils/register_custom_op.py @@ -25,27 +25,31 @@ def CircleIf(): @custom_op("circle_custom::if_", mutates_args=()) def if_( pred: torch.Tensor, - true_graph: torch.Tensor, - false_graph: torch.Tensor, + then_graph: torch.Tensor, + else_graph: torch.Tensor, + then_graph_idx: int, + else_graph_idx: int, if_args: List[torch.Tensor], ) -> torch.Tensor: if pred: - result = true_graph(*if_args) + result = then_graph(*if_args) assert len(result) == 1 # TODO: Support tuple of result return result[0] else: - result = false_graph(*if_args) + result = else_graph(*if_args) assert len(result) == 1 # TODO: Support tuple of result return result[0] @register_fake("circle_custom::if_") def _( pred: torch.Tensor, - true_graph: torch.Tensor, - false_graph: torch.Tensor, + then_graph: torch.Tensor, + else_graph: torch.Tensor, + then_graph_idx: int, + else_graph_idx: int, if_args: List[torch.Tensor], ): - result = true_graph(*if_args) + result = then_graph(*if_args) assert len(result) == 1 # TODO: Support tuple of result return result[0] diff --git a/tico/utils/subgraph.py b/tico/utils/subgraph.py index cd42d9ba..ae33066c 100644 --- a/tico/utils/subgraph.py +++ b/tico/utils/subgraph.py @@ -1,34 +1,9 @@ -from copy import deepcopy -from dataclasses import dataclass -from typing import Iterator, Iterator, List +from typing import Iterator import torch from torch.export import ExportedProgram -@dataclass -class SubgraphIdx: - idx: int - name: str # model-wise, unique name - - -_subgraph_indices: List[SubgraphIdx] = [] - - -def store_subgraph_indices(ep: ExportedProgram): - global _subgraph_indices - - for idx, (_, name) in enumerate( - get_all_graph_modules(ep, subgraph_only=True), start=1 - ): - _subgraph_indices += [SubgraphIdx(idx=idx, name=name)] - - -def get_subgraph_indices() -> List[SubgraphIdx]: - global _subgraph_indices - return _subgraph_indices - - def get_all_graph_modules( ep: ExportedProgram, subgraph_only: bool = False ) -> Iterator[tuple[torch.fx.GraphModule, str]]: diff --git a/tico/utils/validate_args_kwargs.py b/tico/utils/validate_args_kwargs.py index 13a65a50..646b1246 100644 --- a/tico/utils/validate_args_kwargs.py +++ b/tico/utils/validate_args_kwargs.py @@ -210,11 +210,16 @@ class CloneArgs: @dataclass class CircleIfArgs: """ - CondArgs + Why carry both `graph` and `graph_idx`? + [1] `graph` is required to get the proper meta value while processing FakeTensor for torch nodes. Plus, carrying this information until `serialize` leaves the graph until that process, otherwise the graph will be cleaned-up by dead code elimination of graph_module. + [2] `graph_idx` is required to map to circle ir. """ + pred: torch.fx.Node then_graph: torch.fx.Node else_graph: torch.fx.Node + then_graph_idx: int + else_graph_idx: int if_args: torch.fx.immutable_collections.immutable_list From a65c438d4a645de00e92bdbc56b7018d8a255185 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Thu, 2 Oct 2025 18:03:53 +0900 Subject: [PATCH 8/9] temp --- call.py | 25 +++-- example_liquid.py | 37 +++++++ if.py | 16 +-- shortconv.py | 132 +++++++++++++++++++++++ signature.py | 15 ++- test.py | 8 +- test/modules/net/Lfm2ShortConv.py | 171 ++++++++++++++++++++++++++++++ 7 files changed, 374 insertions(+), 30 deletions(-) create mode 100644 example_liquid.py create mode 100644 shortconv.py create mode 100644 test/modules/net/Lfm2ShortConv.py diff --git a/call.py b/call.py index d807a56d..940502b8 100644 --- a/call.py +++ b/call.py @@ -3,9 +3,9 @@ import pycircle from pycircle.circleir.model import Model +from pycircle.circleir.operators import CircleAdd, CircleCall from pycircle.circleir.subgraph import Subgraph from pycircle.circleir.tensor import Tensor -from pycircle.circleir.operators import CircleAdd, CircleCall from pycircle.util.alias import TensorType @@ -26,7 +26,7 @@ add1 = CircleAdd() -weights0 = Tensor("weights0", [1, 3], TensorType.FLOAT32, [100., 100., 100.]) +weights0 = Tensor("weights0", [1, 3], TensorType.FLOAT32, [100.0, 100.0, 100.0]) add1.inputs = [call0.outputs(0), weights0] add1.outputs(0).attribute("add0", [1, 3], TensorType.FLOAT32) @@ -38,7 +38,7 @@ graph1.name = "graph1" graph1.inputs = [ Tensor("input0", [1, 3], TensorType.FLOAT32), - Tensor("input1", [1, 3], TensorType.FLOAT32, [-100., -100., -100.]) + Tensor("input1", [1, 3], TensorType.FLOAT32, [-100.0, -100.0, -100.0]), ] sub_add = CircleAdd() sub_add.inputs = [graph1.inputs[0], graph1.inputs[1]] @@ -49,22 +49,25 @@ circle_model = Model() circle_model.subgraphs = [graph0, graph1] circle_model.signature_defs = { - "graph0": { - "subgraph_index": 0 - }, - "graph1": { - "subgraph_index": 1 - }, + "graph0": {"subgraph_index": 0}, + "graph1": {"subgraph_index": 1}, } pycircle.export_circle_model(circle_model, "call.circle") import torch + try: from onert import infer except ImportError: raise RuntimeError("The 'onert' package is required to run this function.") session_float = infer.session("call.circle") -output = session_float.infer((torch.randn(1,3),torch.randn(1,3),), measure=True) -print(output) \ No newline at end of file +output = session_float.infer( + ( + torch.randn(1, 3), + torch.randn(1, 3), + ), + measure=True, +) +print(output) diff --git a/example_liquid.py b/example_liquid.py new file mode 100644 index 00000000..fb9dd987 --- /dev/null +++ b/example_liquid.py @@ -0,0 +1,37 @@ +from transformers import AutoProcessor, AutoModelForImageTextToText +from transformers.image_utils import load_image +from transformers.integrations.executorch import sdpa_mask_without_vmap +from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS +# Load model and processor +model_id = "LiquidAI/LFM2-VL-450M" +model = AutoModelForImageTextToText.from_pretrained( + model_id, + device_map="auto", + torch_dtype="bfloat16", + trust_remote_code=True +) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +# Load image and create conversation +url = "https://www.ilankelman.org/stopsigns/australia.jpg" +image = load_image(url) +conversation = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "What is in this image?"}, + ], + }, +] + +# Generate Answer +inputs = processor.apply_chat_template( + conversation, + add_generation_prompt=True, + return_tensors="pt", + return_dict=True, + tokenize=True, +).to(model.device) +outputs = model.generate(**inputs, max_new_tokens=64) +processor.batch_decode(outputs, skip_special_tokens=True)[0] \ No newline at end of file diff --git a/if.py b/if.py index cf240484..ce923760 100644 --- a/if.py +++ b/if.py @@ -2,9 +2,9 @@ import pycircle from pycircle.circleir.model import Model +from pycircle.circleir.operators import CircleAdd, CircleIf from pycircle.circleir.subgraph import Subgraph from pycircle.circleir.tensor import Tensor -from pycircle.circleir.operators import CircleAdd, CircleIf from pycircle.util.alias import TensorType # 입력 텐서 및 상수 텐서 정의 @@ -27,7 +27,10 @@ else_subgraph.inputs = [Tensor("input0", [1, 3], TensorType.FLOAT32), weight_sub_100] add_op_else = CircleAdd() -add_op_else.inputs = [else_subgraph.inputs[0], Tensor("input0", [1, 3], TensorType.FLOAT32)] +add_op_else.inputs = [ + else_subgraph.inputs[0], + Tensor("input0", [1, 3], TensorType.FLOAT32), +] add_op_else.outputs(0).attribute("add_output_else", [1, 3], TensorType.FLOAT32) else_subgraph.outputs = [add_op_else.outputs(0)] @@ -59,6 +62,7 @@ # onert를 통한 추론 (Inference) import torch + try: from onert import infer except ImportError: @@ -67,10 +71,10 @@ session = infer.session("signature_def.circle") output = session.infer( ( - torch.tensor([True]), # condition tensor - torch.randn(1, 3), # input tensor 0 - torch.tensor([[100., 100., 100.]]),# weights tensor + torch.tensor([True]), # condition tensor + torch.randn(1, 3), # input tensor 0 + torch.tensor([[100.0, 100.0, 100.0]]), # weights tensor ), - measure=True + measure=True, ) print(output) diff --git a/shortconv.py b/shortconv.py new file mode 100644 index 00000000..51acb54d --- /dev/null +++ b/shortconv.py @@ -0,0 +1,132 @@ +import torch +from torch import nn +from torch.export import export # PyTorch 2.x에서 사용 + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + +# 분기 1: past_key_value 있고 cache_position[0] > 0 일 때 처리용 모듈 +class TrueBranch(nn.Module): + def __init__(self, conv, conv_cache, layer_idx, L_cache, bias): + super().__init__() + self.conv = conv + self.register_buffer( + "conv_cache", + torch.zeros(10, 32, 1024, 20), + ) + self.layer_idx = layer_idx + self.L_cache = L_cache + self.bias = bias + + def forward(self, Bx, cache_position, seq_len): + conv_state = self.conv_cache[self.layer_idx,:,:,:] + cache_position = torch.clamp(cache_position, 0, self.L_cache - 1) + conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + conv_state[:, :, cache_position] = Bx.to(device=conv_state.device, dtype=conv_state.dtype) + self.conv_cache[self.layer_idx, :, :, :] = conv_state.clone() + conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1) + if self.bias is not None: + conv_out += self.bias + conv_out = conv_out.unsqueeze(-1) + return conv_out + +# 분기 2: 그 외 케이스 처리 모듈 +class FalseBranch(nn.Module): + def __init__(self, conv, conv_cache, layer_idx, L_cache): + super().__init__() + self.conv = conv + self.register_buffer( + "conv_cache", + torch.zeros(10, 32, 1024, 20), + ) + self.layer_idx = layer_idx + self.L_cache = L_cache + + def forward(self, Bx, cache_position, seqlen): + conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) + self.conv_cache[self.layer_idx, :, :, :] = conv_state.clone() + conv_out = self.conv(Bx)[..., :seqlen] + return conv_out + +class ShortConv(nn.Module): + def __init__(self, in_proj, conv, conv_cache, out_proj, layer_idx, L_cache, bias): + super().__init__() + self.in_proj = in_proj + self.conv = conv + self.conv_cache = conv_cache + self.out_proj = out_proj + self.layer_idx = layer_idx + self.L_cache = L_cache + self.bias = bias + + self.true_branch = TrueBranch(conv, conv_cache, layer_idx, L_cache, bias) + self.false_branch = FalseBranch(conv, conv_cache, layer_idx, L_cache) + + def forward(self, x, past_key_value=None, cache_position=None, attention_mask=None): + seqlen = torch.tensor(x.shape[1]) + x = apply_mask_to_padding_states(x, attention_mask) + BCx = self.in_proj(x).transpose(-1, -2) + B, C, x = BCx.chunk(3, dim=-2) + Bx = B * x + + + # 조건: past_key_value가 있고 cache_position[0] > 0 인 경우 + pred = (past_key_value is not None) and (cache_position[0] > 0) + # if pred is True: + # conv_out = true_fn() + # else: + # conv_out = false_fn() + conv_out = torch.cond(pred, self.true_branch, self.false_branch, (Bx, cache_position, seqlen,)) + + y = C * conv_out + y = y.transpose(-1, -2).contiguous() + y = self.out_proj(y) + return y + +import torch +from torch import nn +from torch.export import export + +# 앞서 정의한 TrueBranch, FalseBranch, MoEModel 클래스가 있다고 가정 + +# 모듈 인스턴스화 시 필요한 임의 파라미터 초기화 (예시) +in_proj = nn.Linear(1024, 1024 * 3) # 임베딩 크기 1024 가정 +conv = nn.Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(2,), groups=1024, bias=False) # 1D conv, feature 채널 64 +conv_cache = [torch.zeros(32, 1024, 20)] * 10 # 배치 32, feature 64, 캐시 크기 20, 레이어 10개 가정 +out_proj = nn.Linear(1024, 1024) # 출력 임베딩 크기 1024 + +layer_idx = 0 +L_cache = 20 +bias = conv.bias + +# MoEModel 생성 +model = MoEModel(in_proj, conv, conv_cache, out_proj, layer_idx, L_cache, bias) + +# 예시 입력 생성 (배치 32, 시퀀스 길이 10, 임베딩 1024) +x = torch.randn(32, 10, 1024) + +# past_key_value와 cache_position도 필요한 경우 생성 (None 가능) +past_key_value = type('', (), {})() +past_key_value.conv_cache = conv_cache +cache_position = torch.tensor([5]) + +model.forward(x, past_key_value, cache_position, None) #ADDED +# model.eval() +# # torch.export로 모델 export 예시 +# exported_model = export(model, (x, past_key_value.conv_cache, cache_position, None)) + +# # ExportedProgram 타입 출력 확인 +# print(type(exported_model)) +# print(exported_model) + +# import tico + +# tico.convert_from_exported_program(exported_model).save("shortconv.circle") diff --git a/signature.py b/signature.py index d4ad0f89..5c858da6 100644 --- a/signature.py +++ b/signature.py @@ -3,9 +3,9 @@ import pycircle from pycircle.circleir.model import Model +from pycircle.circleir.operators import CircleAdd from pycircle.circleir.subgraph import Subgraph from pycircle.circleir.tensor import Tensor -from pycircle.circleir.operators import CircleAdd from pycircle.util.alias import TensorType subgraph1 = Subgraph() @@ -40,22 +40,19 @@ # circle_model.subgraphs = [subgraph2, subgraph1] circle_model.subgraphs = [subgraph1, subgraph2] circle_model.signature_defs = { - "add_constant": { - "subgraph_index": 0 - }, - "add_two_inputs": { - "subgraph_index": 1 - }, + "add_constant": {"subgraph_index": 0}, + "add_two_inputs": {"subgraph_index": 1}, } pycircle.export_circle_model(circle_model, "signature_def_original.circle") import torch + try: from onert import infer except ImportError: raise RuntimeError("The 'onert' package is required to run this function.") session_float = infer.session("signature_def_original.circle") -output = session_float.infer((torch.randn(1,3),)) -breakpoint() \ No newline at end of file +output = session_float.infer((torch.randn(1, 3),)) +breakpoint() diff --git a/test.py b/test.py index 6f0f4e4b..d785a88a 100644 --- a/test.py +++ b/test.py @@ -1,8 +1,8 @@ - def test(): if 1 is 1: pass - + print("HI") - -test() \ No newline at end of file + + +test() diff --git a/test/modules/net/Lfm2ShortConv.py b/test/modules/net/Lfm2ShortConv.py new file mode 100644 index 00000000..19e44869 --- /dev/null +++ b/test/modules/net/Lfm2ShortConv.py @@ -0,0 +1,171 @@ + +from typing import Optional + +import torch +from torch import nn +from torch.export import Dim +from test.utils.tag import use_onert + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + + +@use_onert +class Lfm2ShortConv(nn.Module): + # 분기 1: past_key_value 있고 cache_position[0] > 0 일 때 처리용 모듈 + class WithCache(nn.Module): + """ + One-by-one Decoding Stage + + Condition; past_key_value (conv_cache) exists and cache_position[0] > 0 + """ + def __init__(self, conv, layer_idx, L_cache, bias): + super().__init__() + self.conv = conv + self.layer_idx = layer_idx + self.L_cache = L_cache + self.bias = bias + + def forward(self, Bx, cache_position, seq_len, conv_state): + cache_position = cache_position.clamp(0, self.L_cache - 1) + new_conv_state = conv_state.roll(shifts=-1, dims=-1) + new_conv_state[:, :, cache_position] = Bx.to(device=conv_state.device, dtype=conv_state.dtype) + + conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1) + if self.bias: + conv_out += self.conv.bias + + conv_out = conv_out.unsqueeze(-1) + return conv_out #, new_conv_state + + # 분기 2: 그 외 케이스 처리 모듈 + class WithoutCache(nn.Module): + """ + Pure convolution + """ + def __init__(self, conv, layer_idx, L_cache): + super().__init__() + self.conv = conv + self.layer_idx = layer_idx + self.L_cache = L_cache + + def forward(self, Bx, cache_position, seqlen, conv_state): + new_conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) + FIXED_SEQ_LEN = 20 # instead of using seqlen + conv_out = self.conv(Bx)[..., :FIXED_SEQ_LEN] # for compilation + return conv_out #, new_conv_state + + def __init__( + self, + layer_idx: int = 0, + ): + super().__init__() + self.layer_idx = layer_idx + + self.L_cache = 3 + self.bias = False + + self.conv = nn.Conv1d( + in_channels=1024, + out_channels=1024, + kernel_size=self.L_cache, + groups=1024, + bias=self.bias, + padding=self.L_cache - 1, + ) + self.in_proj = nn.Linear( + 1024, + 3 * 1024, + bias=self.bias, + ) + self.out_proj = nn.Linear( + 1024, + 1024, + bias=self.bias, + ) + # # Sharing self.conv is not supported by torch.export + # self.conv_with_cache = self.WithCache(self.conv, self.layer_idx, self.L_cache, self.bias) + # self.conv_without_cache = self.WithoutCache(self.conv, self.layer_idx, self.L_cache) + self.conv_with_cache = self.WithCache(nn.Conv1d( + in_channels=1024, + out_channels=1024, + kernel_size=self.L_cache, + groups=1024, + bias=self.bias, + padding=self.L_cache - 1, + ), self.layer_idx, self.L_cache, self.bias) + self.conv_without_cache = self.WithoutCache(nn.Conv1d( + in_channels=1024, + out_channels=1024, + kernel_size=self.L_cache, + groups=1024, + bias=self.bias, + padding=self.L_cache - 1, + ), self.layer_idx, self.L_cache) + + # def slow_forward( + def forward( + self, + x: torch.Tensor, + conv_cache = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + seqlen = torch.tensor(x.shape[1]) + + x = apply_mask_to_padding_states(x, attention_mask) + BCx = self.in_proj(x).transpose(-1, -2) + B, C, x = BCx.chunk(3, dim=-2) + + Bx = B * x + + condition = conv_cache is not None and cache_position[0] > 0 + + conv_state = conv_cache[self.layer_idx] + + # # TODO Enable cache state update after exportation + # conv_out, new_conv_state = torch.cond(condition, + conv_out = torch.cond(condition, + self.conv_with_cache, self.conv_without_cache, + (Bx, cache_position, seqlen, conv_state,)) + # conv_cache[self.layer_idx, :, :, :] = new_conv_state.clone() + + y = C * conv_out + y = y.transpose(-1, -2).contiguous() + y = self.out_proj(y) + return y + + + def get_example_inputs(self): + sequence_length = 1 + x = torch.randn(32, sequence_length, 1024) + max_batch_size = 32 + conv_dim = 1024 + conv_L_cache = 3 + num_hidden_layers = 12 + + conv_cache = torch.zeros(num_hidden_layers, max_batch_size, conv_dim, conv_L_cache, dtype=torch.float32) + cache_position = torch.tensor([5]) + + # assert (conv_cache is not None and sequence_length == 1) or (conv_cache is None) + + return (x, conv_cache, cache_position,), {} + + def get_dynamic_shapes(self): + sequence_length = Dim("sequence_length", min=1, max=128) + dynamic_shapes = { + "x": {1: sequence_length}, + "conv_cache": {}, + "cache_position": {}, + } + + return dynamic_shapes From a788f66106ab058e3ec86d7b8b79f4f39ee19aa3 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Mon, 13 Oct 2025 20:34:40 +0900 Subject: [PATCH 9/9] Pass: op.cond all test, but const prop don't work for subgraphs. This may result in circle file size overflow --- test/modules/op/cond.py | 7 +- tico/passes/const_prop_pass.py | 203 ++++++++++++++++++++------------- tico/utils/convert.py | 7 +- tico/utils/diff_graph.py | 34 +++++- tico/utils/graph.py | 4 +- tico/utils/trace_decorators.py | 6 +- 6 files changed, 170 insertions(+), 91 deletions(-) diff --git a/test/modules/op/cond.py b/test/modules/op/cond.py index 6856ee55..2f5372dc 100644 --- a/test/modules/op/cond.py +++ b/test/modules/op/cond.py @@ -19,14 +19,17 @@ @use_onert -class SimpleCond1(TestModuleBase): +class SimpleCondWithBuffers(TestModuleBase): class Sin(torch.nn.Module): def forward(self, x): return torch.sin(x) + 1 class Cos(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buf", torch.tensor([3])) def forward(self, x): - return torch.cos(x) - 1 + return torch.cos(x) - self.buf def __init__(self): super().__init__() diff --git a/tico/passes/const_prop_pass.py b/tico/passes/const_prop_pass.py index 68973c23..988a0c8c 100644 --- a/tico/passes/const_prop_pass.py +++ b/tico/passes/const_prop_pass.py @@ -45,32 +45,32 @@ trace_graph_diff_on_pass, ) from tico.utils.utils import get_fake_mode +from tico.utils.subgraph import get_all_graph_modules def get_constant_placeholder_to_tensor_dict( exported_program: ExportedProgram, - graph_module, ) -> OrderedDict[torch.fx.Node, torch.Tensor]: """ Returns a dictionary of constant placeholder node to constant tensor. """ const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict() - graph: torch.fx.Graph = graph_module.graph - for node in graph.nodes: - if node.op != "placeholder": - continue - tensor: Optional[torch.Tensor] = None - if is_param(exported_program, node): - tensor = get_param(exported_program, node) - elif is_buffer(exported_program, node): - tensor = get_buffer(exported_program, node) - elif is_lifted_tensor_constant(exported_program, node): - tensor = get_lifted_tensor_constant(exported_program, node) - - if tensor is not None: - assert node not in const_node_to_tensor - const_node_to_tensor[node] = tensor - + + for graph_module, _ in get_all_graph_modules(exported_program): + for node in graph_module.graph.nodes: + if node.op != "placeholder": + continue + tensor: Optional[torch.Tensor] = None + if is_param(exported_program, node): + tensor = get_param(exported_program, node) + elif is_buffer(exported_program, node): + tensor = get_buffer(exported_program, node) + elif is_lifted_tensor_constant(exported_program, node): + tensor = get_lifted_tensor_constant(exported_program, node) + + if tensor is not None: + assert node not in const_node_to_tensor + const_node_to_tensor[node] = tensor return const_node_to_tensor @@ -113,46 +113,47 @@ def get_data( def propagate_constants( - exported_program: ExportedProgram, graph_module + exported_program: ExportedProgram ) -> OrderedDict[torch.fx.Node, torch.Tensor]: """ Propagates constants and returns a dictionary of node to constant tensors of the graph. """ const_node_to_tensor = get_constant_placeholder_to_tensor_dict( - exported_program, graph_module + exported_program ) - graph: torch.fx.Graph = graph_module.graph - for node in graph.nodes: - if node.op != "call_function": - continue - if node.target in [ - torch.ops.quantized_decomposed.dequantize_per_channel.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - ]: - continue - if not has_constant_data( - [node.args, node.kwargs], - const_node_to_tensor, - ): - continue + for graph_module, _ in get_all_graph_modules(exported_program): + graph: torch.fx.Graph = graph_module.graph + for node in graph.nodes: + if node.op != "call_function": + continue + if node.target in [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ]: + continue + if not has_constant_data( + [node.args, node.kwargs], + const_node_to_tensor, + ): + continue - args_data, kwargs_data = pytree.tree_map( - lambda x: get_data(x, exported_program, const_node_to_tensor), - (node.args, node.kwargs), - ) + args_data, kwargs_data = pytree.tree_map( + lambda x: get_data(x, exported_program, const_node_to_tensor), + (node.args, node.kwargs), + ) - # propagate constant because all of its args are constant tensors. - with torch.no_grad(): - prop_constant_tensor = node.target(*args_data, **kwargs_data) - const_node_to_tensor[node] = prop_constant_tensor + # propagate constant because all of its args are constant tensors. + with torch.no_grad(): + prop_constant_tensor = node.target(*args_data, **kwargs_data) + const_node_to_tensor[node] = prop_constant_tensor return const_node_to_tensor def erase_constant_node( exported_program: ExportedProgram, - node: torch.fx.Node, + node_to_erase: torch.fx.Node, ) -> None: """ Remove corresponding tensor from param/constants dict. @@ -162,23 +163,41 @@ def erase_constant_node( A) They internally uses `exported_program.graph_signature.input_specs` and the `input_specs` are updated at the end of the const_prop_pass. """ - signature = exported_program.graph_signature - if name := signature.inputs_to_parameters.get(node.name, None): - exported_program.state_dict.pop(name, None) - elif name := signature.inputs_to_lifted_tensor_constants.get(node.name, None): - exported_program.constants.pop(name, None) - elif name := signature.inputs_to_buffers.get(node.name, None): - exported_program.constants.pop(name, None) - exported_program.state_dict.pop(name, None) - # Remove from graph. - exported_program.graph.erase_node(node) + for graph_module, _ in get_all_graph_modules(exported_program): + if node_to_erase in graph_module.graph.nodes: + graph_module.graph.erase_node(node_to_erase) + + remove_from_program = True + for graph_module, _ in get_all_graph_modules(exported_program): + for node in graph_module.graph.nodes: + if node.name == node_to_erase.name: + remove_from_program = False + + if remove_from_program: + signature = exported_program.graph_signature + if name := signature.inputs_to_parameters.get(node_to_erase.name, None): + exported_program.state_dict.pop(name, None) + elif name := signature.inputs_to_lifted_tensor_constants.get(node_to_erase.name, None): + exported_program.constants.pop(name, None) + elif name := signature.inputs_to_buffers.get(node_to_erase.name, None): + exported_program.constants.pop(name, None) + exported_program.state_dict.pop(name, None) + +def get_first_node(exported_program, graph): + first_node = get_first_user_input(exported_program, graph) + if not first_node: + # Placeholder nodes must be the first N nodes in the nodes list of a graph. + # Therefore, insert the newly created placeholders at the start of the node list. + assert graph.nodes + first_node = list(graph.nodes)[0] + + return first_node def create_constant_placeholder( const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor], exported_program: ExportedProgram, - graph_module, ) -> List[torch.fx.Node]: """ This function creates constant placeholder nodes according to the given constant nodes (`const_node_to_tensor`) and replace it with the original node. @@ -186,22 +205,21 @@ def create_constant_placeholder( placeholders = [] fake_mode = get_fake_mode(exported_program) - first_user_input = get_first_user_input(exported_program) - if not first_user_input: - # Placeholder nodes must be the first N nodes in the nodes list of a graph. - # Therefore, insert the newly created placeholders at the start of the node list. - assert exported_program.graph.nodes - first_node = list(exported_program.graph.nodes)[0] - first_user_input = first_node # Iterate over nodes in reverse order to insert created placeholder before the `first_user_input`. for node, prop_constant_tensor in reversed(const_node_to_tensor.items()): + if node.graph is not exported_program.graph_module.graph: + # Do not propagates constants of subgraphs + # WHY? + # Uplifting constants to placeholder may alter subgraph signature which may break control flow IR's invariant. + # They assumes that control flow ir's `argument` operands' number equals to those of subgraph's input signatures. + continue + # All users of this constant node are also constant, so we don't need to create a new constant node. if all(x in const_node_to_tensor for x in node.users): - # All users of this constant node are also constant, so we don't need to create a new constant node. erase_constant_node(exported_program, node) continue - if node.op == "placeholder": + if node.op == "placeholder":# and graph_module is exported_program.graph_module: continue # Add `prop_constant_tensor` to program.state_dict. @@ -210,17 +228,24 @@ def create_constant_placeholder( ) # Insert a new placeholder node for the propagated constant tensor. - with exported_program.graph.inserting_before(first_user_input): - const_placeholder_node = exported_program.graph.placeholder( + with node.graph.inserting_before(get_first_node(exported_program, node.graph)): + const_placeholder_node = node.graph.placeholder( prop_constant_tensor_fqn ) # The key here should be same with "target" arg of InputSpec when creating input specs. exported_program.constants[prop_constant_tensor_fqn] = prop_constant_tensor + print(f"exported_program.constants[{prop_constant_tensor_fqn}] = {prop_constant_tensor}") # Replace the original node with the new constant node. node.replace_all_uses_with(const_placeholder_node, propagate_meta=True) - exported_program.graph.erase_node(node) + print(node) + for graph_module, _ in get_all_graph_modules(exported_program): + if node in graph_module.graph.nodes: + graph_module.graph.print_tabular() + graph_module.graph.erase_node(node) + graph_module.graph.print_tabular() + breakpoint() # Update the meta data of the new placeholder node. const_placeholder_node.meta["val"] = fake_mode.from_tensor( @@ -266,40 +291,56 @@ class ConstPropPass(PassBase): def __init__(self) -> None: super().__init__() - - def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: + + def call(self, exported_program: ExportedProgram, _) -> PassResult: + from tico.utils.subgraph import get_all_graph_modules logger = logging.getLogger(__name__) - graph: torch.fx.Graph = graph_module.graph - - # [1], [2] - const_node_to_tensor: OrderedDict[ - torch.fx.Node, torch.Tensor - ] = propagate_constants(exported_program, graph_module) + all_placeholders = [] + all_new_name_to_spec = {} + + const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict() + # [1], [2] + const_node_to_tensor.update(propagate_constants(exported_program)) + print(f"const_node_to_tensor: {const_node_to_tensor}") # [3] placeholders = create_constant_placeholder( - const_node_to_tensor, exported_program, graph_module + const_node_to_tensor, exported_program ) + print(f"placeholders: {placeholders}") # [4] new_name_to_spec = create_input_specs(placeholders) - + print(f"new_name_to_spec: {new_name_to_spec}") + + all_placeholders.extend(placeholders) + all_new_name_to_spec.update(new_name_to_spec) + + + for graph_module, _ in get_all_graph_modules(exported_program): + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + # graph_module.graph.print_tabular() + # [5] # Get existing input specs. existing_name_to_spec = { s.arg.name: s for s in exported_program.graph_signature.input_specs } + # Add the new constants to existing input specs dict. - existing_name_to_spec.update(new_name_to_spec) - # Generate new input spec. + existing_name_to_spec.update(all_new_name_to_spec) + + # Generate new input spec. + # I/O for root graph only new_input_specs = [] - for node in exported_program.graph.nodes: - if node.op != "placeholder": + for node in exported_program.graph_module.graph.nodes: + if node.op != "placeholder": # and graph_module is not exported_program.graph_module: continue assert node.name in existing_name_to_spec, node.name new_input_specs.append(existing_name_to_spec[node.name]) exported_program.graph_signature.input_specs = new_input_specs - graph.eliminate_dead_code() - graph_module.recompile() + # graph.eliminate_dead_code() + # graph_module.recompile() logger.debug("Constant nodes are propagated") # Constant folding can be done with only one time run. Let's set `modified` to False. diff --git a/tico/utils/convert.py b/tico/utils/convert.py index e68012c9..bcb0e55c 100644 --- a/tico/utils/convert.py +++ b/tico/utils/convert.py @@ -201,6 +201,7 @@ def convert_exported_module_to_circle( logger = logging.getLogger(__name__) logger.debug("Input ExportedProgram (must be core aten)") logger.debug(exported_program) + # graph_module.graph.print_tabular() # PRE-EDGE PASSES # @@ -235,7 +236,11 @@ def convert_exported_module_to_circle( ] ) reinterpret_pass.run(exported_program, graph_module) + + ConstPropPass().call(exported_program, exported_program.graph_module) + for graph_module, _ in get_all_graph_modules(exported_program): + graph = graph_module.graph # TODO Distinguish legalize and optimize circle_legalize = PassManager( passes=[ @@ -261,7 +266,6 @@ def convert_exported_module_to_circle( RemoveRedundantToCopy(), MergeConsecutiveCat(), CastMixedTypeArgs(preserve_ep_invariant=True), - ConstPropPass(), SegmentIndexSelectConst(), LegalizeCausalMaskValue( enabled=config.get("legalize_causal_mask_value") @@ -293,6 +297,7 @@ def convert_exported_module_to_circle( ] ) circle_legalize.run(exported_program, graph_module) + ConstPropPass().call(exported_program, exported_program.graph_module) # TODO Give an option to enable quantiztion to user enable_quantization = has_quantization_ops(graph) diff --git a/tico/utils/diff_graph.py b/tico/utils/diff_graph.py index 4a38a4cf..bfb7a05f 100644 --- a/tico/utils/diff_graph.py +++ b/tico/utils/diff_graph.py @@ -150,8 +150,40 @@ def capture(graph: torch.fx.Graph): global graph_captured graph_captured = str(graph) +from tico.utils.subgraph import get_all_graph_modules +all_graphs_captured = {} -@disable_when(LOG_LEVEL > DEBUG) +@disable_when(LOG_LEVEL > LOGGER_THRESHOLD) +def capture_all(ep: torch.export.ExportedProgram): + assert isinstance(ep, torch.export.ExportedProgram) + global all_graphs_captured + for graph, name in get_all_graph_modules(ep): + all_graphs_captured[name] = str(graph) + +@disable_when(LOG_LEVEL > LOGGER_THRESHOLD) +def log_all(ep: torch.export.ExportedProgram, title: str, recapture: bool): + assert isinstance(ep, torch.export.ExportedProgram) + global all_graphs_captured + all_graphs_now = {} + logger = getLogger(__name__) + for graph, name in get_all_graph_modules(ep): + all_graphs_now[name] = str(graph) + + for name, graph in all_graphs_now.items(): + graph_captured = all_graphs_captured[name] + diff = strdiff(f"{graph_captured}\n", f"{graph}\n") + prefix = f"[{title}]" if title else "" + if len(diff) > 0: + logger.debug(f"{prefix} Graph({ (name if name != '' else 'root') }) is changed.") + logger.debug(f"\n{diff}") + + if recapture: + all_graphs_captured[name] = deepcopy(graph) + else: + all_graphs_captured[name] = None # reset + + +@disable_when(LOG_LEVEL > LOGGER_THRESHOLD) def log(graph: torch.fx.Graph, title: str, recapture: bool): """ Capture the end-point graph for graph-diff. diff --git a/tico/utils/graph.py b/tico/utils/graph.py index f0905334..a324fef2 100644 --- a/tico/utils/graph.py +++ b/tico/utils/graph.py @@ -67,11 +67,9 @@ def get_torch_buffer_value(node: torch.fx.Node, ep: ExportedProgram): return named_buf[buf_name] -def get_first_user_input(exported_program: ExportedProgram) -> Optional[torch.fx.Node]: +def get_first_user_input(exported_program: ExportedProgram, graph: torch.fx.Graph) -> Optional[torch.fx.Node]: """Returns the first user input node in the graph.""" first_user_input: Optional[torch.fx.Node] = None - graph_module = exported_program.graph_module - graph: torch.fx.Graph = graph_module.graph for node in graph.nodes: if ( node.op == "placeholder" diff --git a/tico/utils/trace_decorators.py b/tico/utils/trace_decorators.py index 62dab48b..ea1577c2 100644 --- a/tico/utils/trace_decorators.py +++ b/tico/utils/trace_decorators.py @@ -17,7 +17,7 @@ import torch from torch.export import ExportedProgram -from tico.utils.diff_graph import capture, capture_const, log, log_const +from tico.utils.diff_graph import capture, capture_const, log, log_const, capture_all, log_all from tico.utils.passes import PassBase @@ -55,9 +55,9 @@ def _call_traced(fn): def wrapped(self, exported_program, graph_module): assert isinstance(exported_program, ExportedProgram) assert isinstance(graph_module, torch.fx.GraphModule), type(graph_module) - capture(graph_module.graph) + capture_all(exported_program) ret = fn(self, exported_program, graph_module) - log(graph_module.graph, title=str(cls.__name__), recapture=False) + log_all(exported_program, title=str(cls.__name__), recapture=False) return ret return wrapped