diff --git a/call.py b/call.py new file mode 100644 index 00000000..940502b8 --- /dev/null +++ b/call.py @@ -0,0 +1,73 @@ +""" Example - circle model import/export """ + +import pycircle + +from pycircle.circleir.model import Model +from pycircle.circleir.operators import CircleAdd, CircleCall +from pycircle.circleir.subgraph import Subgraph +from pycircle.circleir.tensor import Tensor +from pycircle.util.alias import TensorType + + +### subgraph 0 +### input0, input1 -> call0 (subgraph 1) -> tensor0 +### tensor0, weights0 -> add0 -> tensor1 +graph0 = Subgraph() +graph0.name = "graph0" +graph0.inputs = [ + Tensor("sub1_input0", [1, 3], TensorType.FLOAT32), + Tensor("sub1_input1", [1, 3], TensorType.FLOAT32), +] + +call0 = CircleCall() +call0.inputs = [graph0.inputs[0], graph0.inputs[1]] +call0.subgraph = 1 +call0.outputs(0).attribute("Call0", [1, 3], TensorType.FLOAT32) + + +add1 = CircleAdd() +weights0 = Tensor("weights0", [1, 3], TensorType.FLOAT32, [100.0, 100.0, 100.0]) +add1.inputs = [call0.outputs(0), weights0] +add1.outputs(0).attribute("add0", [1, 3], TensorType.FLOAT32) + +graph0.outputs = [add1.outputs(0)] + +### subgraph 1 +### input0, input1 -> ADD -> output +graph1 = Subgraph() +graph1.name = "graph1" +graph1.inputs = [ + Tensor("input0", [1, 3], TensorType.FLOAT32), + Tensor("input1", [1, 3], TensorType.FLOAT32, [-100.0, -100.0, -100.0]), +] +sub_add = CircleAdd() +sub_add.inputs = [graph1.inputs[0], graph1.inputs[1]] +sub_add.outputs(0).attribute("SubAdd", [1, 3], TensorType.FLOAT32) +graph1.outputs = [sub_add.outputs(0)] + +### model +circle_model = Model() +circle_model.subgraphs = [graph0, graph1] +circle_model.signature_defs = { + "graph0": {"subgraph_index": 0}, + "graph1": {"subgraph_index": 1}, +} + +pycircle.export_circle_model(circle_model, "call.circle") + +import torch + +try: + from onert import infer +except ImportError: + raise RuntimeError("The 'onert' package is required to run this function.") + +session_float = infer.session("call.circle") +output = session_float.infer( + ( + torch.randn(1, 3), + torch.randn(1, 3), + ), + measure=True, +) +print(output) diff --git a/example_liquid.py b/example_liquid.py new file mode 100644 index 00000000..fb9dd987 --- /dev/null +++ b/example_liquid.py @@ -0,0 +1,37 @@ +from transformers import AutoProcessor, AutoModelForImageTextToText +from transformers.image_utils import load_image +from transformers.integrations.executorch import sdpa_mask_without_vmap +from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS +# Load model and processor +model_id = "LiquidAI/LFM2-VL-450M" +model = AutoModelForImageTextToText.from_pretrained( + model_id, + device_map="auto", + torch_dtype="bfloat16", + trust_remote_code=True +) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +# Load image and create conversation +url = "https://www.ilankelman.org/stopsigns/australia.jpg" +image = load_image(url) +conversation = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "What is in this image?"}, + ], + }, +] + +# Generate Answer +inputs = processor.apply_chat_template( + conversation, + add_generation_prompt=True, + return_tensors="pt", + return_dict=True, + tokenize=True, +).to(model.device) +outputs = model.generate(**inputs, max_new_tokens=64) +processor.batch_decode(outputs, skip_special_tokens=True)[0] \ No newline at end of file diff --git a/if.py b/if.py new file mode 100644 index 00000000..ce923760 --- /dev/null +++ b/if.py @@ -0,0 +1,80 @@ +""" Example - circle model import/export """ + +import pycircle +from pycircle.circleir.model import Model +from pycircle.circleir.operators import CircleAdd, CircleIf +from pycircle.circleir.subgraph import Subgraph +from pycircle.circleir.tensor import Tensor +from pycircle.util.alias import TensorType + +# 입력 텐서 및 상수 텐서 정의 +input_tensor0 = Tensor("input0", [1, 3], TensorType.FLOAT32) +input_tensor1 = Tensor("input1", [1, 3], TensorType.FLOAT32) +weight_add_100 = Tensor("constant0", [1, 3], TensorType.FLOAT32, [100, 100, 100]) +weight_sub_100 = Tensor("constant1", [1, 3], TensorType.FLOAT32, [-100, -100, -100]) + +### then_subgraph ### +then_subgraph = Subgraph() +then_subgraph.inputs = [Tensor("input0", [1, 3], TensorType.FLOAT32), weight_add_100] + +add_op_then = CircleAdd() +add_op_then.inputs = [then_subgraph.inputs[0], then_subgraph.inputs[1]] +add_op_then.outputs(0).attribute("add_output_then", [1, 3], TensorType.FLOAT32) +then_subgraph.outputs = [add_op_then.outputs(0)] + +### else_subgraph ### +else_subgraph = Subgraph() +else_subgraph.inputs = [Tensor("input0", [1, 3], TensorType.FLOAT32), weight_sub_100] + +add_op_else = CircleAdd() +add_op_else.inputs = [ + else_subgraph.inputs[0], + Tensor("input0", [1, 3], TensorType.FLOAT32), +] +add_op_else.outputs(0).attribute("add_output_else", [1, 3], TensorType.FLOAT32) +else_subgraph.outputs = [add_op_else.outputs(0)] + +### root_subgraph with CircleIf ### +root_subgraph = Subgraph() +root_subgraph.name = "root_subgraph" +condition_tensor = Tensor("condition", [1], TensorType.BOOL) +root_subgraph.inputs = [condition_tensor, input_tensor0, input_tensor1] + +circle_if_op = CircleIf(1, 2) +circle_if_op.inputs = [condition_tensor, input_tensor0, input_tensor1] +circle_if_op.outputs(0).attribute("output_tensor", [1, 3], TensorType.FLOAT32) +circle_if_op.then_subgraph_index = 1 +circle_if_op.else_subgraph_index = 2 +root_subgraph.outputs = [circle_if_op.outputs(0)] + +# 모델 구성 +circle_model = Model() +circle_model.description = "pycircle example : signature_def" +circle_model.subgraphs = [root_subgraph, then_subgraph, else_subgraph] +circle_model.signature_defs = { + "root_graph": {"subgraph_index": 0}, + "then_graph": {"subgraph_index": 1}, + "else_graph": {"subgraph_index": 2}, +} + +# 모델 export +pycircle.export_circle_model(circle_model, "signature_def.circle") + +# onert를 통한 추론 (Inference) +import torch + +try: + from onert import infer +except ImportError: + raise RuntimeError("The 'onert' package is required to run this function.") + +session = infer.session("signature_def.circle") +output = session.infer( + ( + torch.tensor([True]), # condition tensor + torch.randn(1, 3), # input tensor 0 + torch.tensor([[100.0, 100.0, 100.0]]), # weights tensor + ), + measure=True, +) +print(output) diff --git a/shortconv.py b/shortconv.py new file mode 100644 index 00000000..51acb54d --- /dev/null +++ b/shortconv.py @@ -0,0 +1,132 @@ +import torch +from torch import nn +from torch.export import export # PyTorch 2.x에서 사용 + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + +# 분기 1: past_key_value 있고 cache_position[0] > 0 일 때 처리용 모듈 +class TrueBranch(nn.Module): + def __init__(self, conv, conv_cache, layer_idx, L_cache, bias): + super().__init__() + self.conv = conv + self.register_buffer( + "conv_cache", + torch.zeros(10, 32, 1024, 20), + ) + self.layer_idx = layer_idx + self.L_cache = L_cache + self.bias = bias + + def forward(self, Bx, cache_position, seq_len): + conv_state = self.conv_cache[self.layer_idx,:,:,:] + cache_position = torch.clamp(cache_position, 0, self.L_cache - 1) + conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + conv_state[:, :, cache_position] = Bx.to(device=conv_state.device, dtype=conv_state.dtype) + self.conv_cache[self.layer_idx, :, :, :] = conv_state.clone() + conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1) + if self.bias is not None: + conv_out += self.bias + conv_out = conv_out.unsqueeze(-1) + return conv_out + +# 분기 2: 그 외 케이스 처리 모듈 +class FalseBranch(nn.Module): + def __init__(self, conv, conv_cache, layer_idx, L_cache): + super().__init__() + self.conv = conv + self.register_buffer( + "conv_cache", + torch.zeros(10, 32, 1024, 20), + ) + self.layer_idx = layer_idx + self.L_cache = L_cache + + def forward(self, Bx, cache_position, seqlen): + conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) + self.conv_cache[self.layer_idx, :, :, :] = conv_state.clone() + conv_out = self.conv(Bx)[..., :seqlen] + return conv_out + +class ShortConv(nn.Module): + def __init__(self, in_proj, conv, conv_cache, out_proj, layer_idx, L_cache, bias): + super().__init__() + self.in_proj = in_proj + self.conv = conv + self.conv_cache = conv_cache + self.out_proj = out_proj + self.layer_idx = layer_idx + self.L_cache = L_cache + self.bias = bias + + self.true_branch = TrueBranch(conv, conv_cache, layer_idx, L_cache, bias) + self.false_branch = FalseBranch(conv, conv_cache, layer_idx, L_cache) + + def forward(self, x, past_key_value=None, cache_position=None, attention_mask=None): + seqlen = torch.tensor(x.shape[1]) + x = apply_mask_to_padding_states(x, attention_mask) + BCx = self.in_proj(x).transpose(-1, -2) + B, C, x = BCx.chunk(3, dim=-2) + Bx = B * x + + + # 조건: past_key_value가 있고 cache_position[0] > 0 인 경우 + pred = (past_key_value is not None) and (cache_position[0] > 0) + # if pred is True: + # conv_out = true_fn() + # else: + # conv_out = false_fn() + conv_out = torch.cond(pred, self.true_branch, self.false_branch, (Bx, cache_position, seqlen,)) + + y = C * conv_out + y = y.transpose(-1, -2).contiguous() + y = self.out_proj(y) + return y + +import torch +from torch import nn +from torch.export import export + +# 앞서 정의한 TrueBranch, FalseBranch, MoEModel 클래스가 있다고 가정 + +# 모듈 인스턴스화 시 필요한 임의 파라미터 초기화 (예시) +in_proj = nn.Linear(1024, 1024 * 3) # 임베딩 크기 1024 가정 +conv = nn.Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(2,), groups=1024, bias=False) # 1D conv, feature 채널 64 +conv_cache = [torch.zeros(32, 1024, 20)] * 10 # 배치 32, feature 64, 캐시 크기 20, 레이어 10개 가정 +out_proj = nn.Linear(1024, 1024) # 출력 임베딩 크기 1024 + +layer_idx = 0 +L_cache = 20 +bias = conv.bias + +# MoEModel 생성 +model = MoEModel(in_proj, conv, conv_cache, out_proj, layer_idx, L_cache, bias) + +# 예시 입력 생성 (배치 32, 시퀀스 길이 10, 임베딩 1024) +x = torch.randn(32, 10, 1024) + +# past_key_value와 cache_position도 필요한 경우 생성 (None 가능) +past_key_value = type('', (), {})() +past_key_value.conv_cache = conv_cache +cache_position = torch.tensor([5]) + +model.forward(x, past_key_value, cache_position, None) #ADDED +# model.eval() +# # torch.export로 모델 export 예시 +# exported_model = export(model, (x, past_key_value.conv_cache, cache_position, None)) + +# # ExportedProgram 타입 출력 확인 +# print(type(exported_model)) +# print(exported_model) + +# import tico + +# tico.convert_from_exported_program(exported_model).save("shortconv.circle") diff --git a/signature.py b/signature.py new file mode 100644 index 00000000..5c858da6 --- /dev/null +++ b/signature.py @@ -0,0 +1,58 @@ +""" Example - circle model import/export """ + +import pycircle + +from pycircle.circleir.model import Model +from pycircle.circleir.operators import CircleAdd +from pycircle.circleir.subgraph import Subgraph +from pycircle.circleir.tensor import Tensor +from pycircle.util.alias import TensorType + +subgraph1 = Subgraph() +subgraph1.name = "subgraph1" +subgraph1.inputs = [ + Tensor("subgraph1_input", [1, 3], TensorType.FLOAT32), +] + +weights1 = Tensor("constant1", [1, 3], TensorType.FLOAT32, [0.1, 0.2, 0.3]) + +add1 = CircleAdd() +add1.inputs = [subgraph1.inputs[0], weights1] +add1.outputs(0).attribute("Add1", [1, 3], TensorType.FLOAT32) + +subgraph1.outputs = [add1.outputs(0)] + +subgraph2 = Subgraph() +subgraph2.name = "subgraph2" +subgraph2.inputs = [ + Tensor("subgraph2_input1", [1, 3], TensorType.FLOAT32), + Tensor("subgraph2_input2", [1, 3], TensorType.FLOAT32), +] + +add2 = CircleAdd() +add2.inputs = [subgraph2.inputs[0], subgraph2.inputs[1]] +add2.outputs(0).attribute("Add2", [1, 3], TensorType.FLOAT32) + +subgraph2.outputs = [add2.outputs(0)] + +circle_model = Model() +circle_model.description = "pycircle example : signature_def" +# circle_model.subgraphs = [subgraph2, subgraph1] +circle_model.subgraphs = [subgraph1, subgraph2] +circle_model.signature_defs = { + "add_constant": {"subgraph_index": 0}, + "add_two_inputs": {"subgraph_index": 1}, +} + +pycircle.export_circle_model(circle_model, "signature_def_original.circle") + +import torch + +try: + from onert import infer +except ImportError: + raise RuntimeError("The 'onert' package is required to run this function.") + +session_float = infer.session("signature_def_original.circle") +output = session_float.infer((torch.randn(1, 3),)) +breakpoint() diff --git a/test.py b/test.py new file mode 100644 index 00000000..d785a88a --- /dev/null +++ b/test.py @@ -0,0 +1,8 @@ +def test(): + if 1 is 1: + pass + + print("HI") + + +test() diff --git a/test/modules/net/Lfm2ShortConv.py b/test/modules/net/Lfm2ShortConv.py new file mode 100644 index 00000000..19e44869 --- /dev/null +++ b/test/modules/net/Lfm2ShortConv.py @@ -0,0 +1,171 @@ + +from typing import Optional + +import torch +from torch import nn +from torch.export import Dim +from test.utils.tag import use_onert + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + + +@use_onert +class Lfm2ShortConv(nn.Module): + # 분기 1: past_key_value 있고 cache_position[0] > 0 일 때 처리용 모듈 + class WithCache(nn.Module): + """ + One-by-one Decoding Stage + + Condition; past_key_value (conv_cache) exists and cache_position[0] > 0 + """ + def __init__(self, conv, layer_idx, L_cache, bias): + super().__init__() + self.conv = conv + self.layer_idx = layer_idx + self.L_cache = L_cache + self.bias = bias + + def forward(self, Bx, cache_position, seq_len, conv_state): + cache_position = cache_position.clamp(0, self.L_cache - 1) + new_conv_state = conv_state.roll(shifts=-1, dims=-1) + new_conv_state[:, :, cache_position] = Bx.to(device=conv_state.device, dtype=conv_state.dtype) + + conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1) + if self.bias: + conv_out += self.conv.bias + + conv_out = conv_out.unsqueeze(-1) + return conv_out #, new_conv_state + + # 분기 2: 그 외 케이스 처리 모듈 + class WithoutCache(nn.Module): + """ + Pure convolution + """ + def __init__(self, conv, layer_idx, L_cache): + super().__init__() + self.conv = conv + self.layer_idx = layer_idx + self.L_cache = L_cache + + def forward(self, Bx, cache_position, seqlen, conv_state): + new_conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) + FIXED_SEQ_LEN = 20 # instead of using seqlen + conv_out = self.conv(Bx)[..., :FIXED_SEQ_LEN] # for compilation + return conv_out #, new_conv_state + + def __init__( + self, + layer_idx: int = 0, + ): + super().__init__() + self.layer_idx = layer_idx + + self.L_cache = 3 + self.bias = False + + self.conv = nn.Conv1d( + in_channels=1024, + out_channels=1024, + kernel_size=self.L_cache, + groups=1024, + bias=self.bias, + padding=self.L_cache - 1, + ) + self.in_proj = nn.Linear( + 1024, + 3 * 1024, + bias=self.bias, + ) + self.out_proj = nn.Linear( + 1024, + 1024, + bias=self.bias, + ) + # # Sharing self.conv is not supported by torch.export + # self.conv_with_cache = self.WithCache(self.conv, self.layer_idx, self.L_cache, self.bias) + # self.conv_without_cache = self.WithoutCache(self.conv, self.layer_idx, self.L_cache) + self.conv_with_cache = self.WithCache(nn.Conv1d( + in_channels=1024, + out_channels=1024, + kernel_size=self.L_cache, + groups=1024, + bias=self.bias, + padding=self.L_cache - 1, + ), self.layer_idx, self.L_cache, self.bias) + self.conv_without_cache = self.WithoutCache(nn.Conv1d( + in_channels=1024, + out_channels=1024, + kernel_size=self.L_cache, + groups=1024, + bias=self.bias, + padding=self.L_cache - 1, + ), self.layer_idx, self.L_cache) + + # def slow_forward( + def forward( + self, + x: torch.Tensor, + conv_cache = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + seqlen = torch.tensor(x.shape[1]) + + x = apply_mask_to_padding_states(x, attention_mask) + BCx = self.in_proj(x).transpose(-1, -2) + B, C, x = BCx.chunk(3, dim=-2) + + Bx = B * x + + condition = conv_cache is not None and cache_position[0] > 0 + + conv_state = conv_cache[self.layer_idx] + + # # TODO Enable cache state update after exportation + # conv_out, new_conv_state = torch.cond(condition, + conv_out = torch.cond(condition, + self.conv_with_cache, self.conv_without_cache, + (Bx, cache_position, seqlen, conv_state,)) + # conv_cache[self.layer_idx, :, :, :] = new_conv_state.clone() + + y = C * conv_out + y = y.transpose(-1, -2).contiguous() + y = self.out_proj(y) + return y + + + def get_example_inputs(self): + sequence_length = 1 + x = torch.randn(32, sequence_length, 1024) + max_batch_size = 32 + conv_dim = 1024 + conv_L_cache = 3 + num_hidden_layers = 12 + + conv_cache = torch.zeros(num_hidden_layers, max_batch_size, conv_dim, conv_L_cache, dtype=torch.float32) + cache_position = torch.tensor([5]) + + # assert (conv_cache is not None and sequence_length == 1) or (conv_cache is None) + + return (x, conv_cache, cache_position,), {} + + def get_dynamic_shapes(self): + sequence_length = Dim("sequence_length", min=1, max=128) + dynamic_shapes = { + "x": {1: sequence_length}, + "conv_cache": {}, + "cache_position": {}, + } + + return dynamic_shapes diff --git a/test/modules/op/cond.py b/test/modules/op/cond.py new file mode 100644 index 00000000..2f5372dc --- /dev/null +++ b/test/modules/op/cond.py @@ -0,0 +1,119 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from test.modules.base import TestModuleBase +from test.utils.tag import use_onert + + +@use_onert +class SimpleCondWithBuffers(TestModuleBase): + class Sin(torch.nn.Module): + def forward(self, x): + return torch.sin(x) + 1 + + class Cos(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buf", torch.tensor([3])) + def forward(self, x): + return torch.cos(x) - self.buf + + def __init__(self): + super().__init__() + self.sin = self.Sin() + self.cos = self.Cos() + + def forward(self, x, y): + return torch.cond( + x.sum() + y.sum() > 0, + lambda x_: self.sin(x_), + lambda x_: self.cos(x_), + operands=(x,), + ) + + def get_example_inputs(self): + return (torch.randn(3, 3), torch.randn(3, 3)), {} + + +@use_onert +class SimpleCond2(TestModuleBase): + class Sin(torch.nn.Module): + def forward(self, x, y): + return torch.sin(x) + 1 + + class Cos(torch.nn.Module): + def forward(self, x, y): + return torch.cos(x) - 1 + + def __init__(self): + super().__init__() + self.sin = self.Sin() + self.cos = self.Cos() + + def forward(self, x, y): + return torch.cond( + x.sum() + y.sum() > 0, + lambda x, y: self.sin(x, y), + lambda x, y: self.cos(x, y), + operands=(x, y), + ) + + def get_example_inputs(self): + return (torch.randn(3, 3), torch.randn(3, 3)), {} + + +@use_onert +class SimpleCond3(TestModuleBase): + class Sin(torch.nn.Module): + def forward(self, x, y): + return torch.sin(x) + torch.sin(y) + + class Cos(torch.nn.Module): + def forward(self, x, y): + return torch.cos(x) - torch.cos(y) + + def __init__(self): + super().__init__() + self.sin = self.Sin() + self.cos = self.Cos() + + def forward(self, x, y): + return torch.cond( + x.sum() + y.sum() > 0, + lambda x, y: self.sin(x, y), + lambda x, y: self.cos(x, y), + operands=(x, y), + ) + + def get_example_inputs(self): + return (torch.randn(3, 3), torch.randn(3, 3)), {} + + +# if __name__ == "__main__": +# model = SimpleCond2() +# x = torch.randn(3, 3) +# y = torch.randn(3, 3) + +# # export (그래프 생성) +# exported_model = torch.export.export(model, (x, y)) + +# # export된 모델 호출 테스트 +# output = exported_model.module()(x, y) +# exported_model.graph.print_tabular() +# print(exported_model.graph_signature.user_inputs) +# print(exported_model.graph_signature.user_outputs) +# print(output) +# breakpoint() diff --git a/tico/experimental/controlflow/passes/lower_cond.py b/tico/experimental/controlflow/passes/lower_cond.py new file mode 100644 index 00000000..09feba4a --- /dev/null +++ b/tico/experimental/controlflow/passes/lower_cond.py @@ -0,0 +1,144 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch.fx + +import operator + +import torch + +from tico.utils import logging +from tico.utils.graph import create_node +from tico.utils.passes import PassBase, PassResult +from tico.utils.subgraph import get_all_graph_modules +from tico.utils.trace_decorators import trace_graph_diff_on_pass +from tico.utils.validate_args_kwargs import CondArgs +from torch.export import ExportedProgram +from torch.utils import _pytree as pytree + + +@trace_graph_diff_on_pass +class LowerCond(PassBase): + """ + To support torch.cond, with Circle If, translate into a custom IR. + Note that the custom IR must include the information of both graph node and graph index. + `graph node` is required to carry the graph until serialization step alive. + `graph index` is required to create the corresponding circle ir, because circle ir requires graph numbering. + + (1) fill in the meta values, which requires specific subgraph inference. (this process differs from that of filling meta of other tensors) + (2) get the subgraph index + (3) translate the information into a custom intermediate representation (IR) + """ + + def __init__(self): + super().__init__() + + def call(self, exported_program: ExportedProgram, _) -> PassResult: + logger = logging.getLogger(__name__) + + graph_module = exported_program.graph_module + graph: torch.fx.Graph = graph_module.graph + for node in graph.nodes: + if node.op != "call_function": + continue + if node.target != torch.ops.higher_order.cond: + continue + + cond_args = CondArgs(*node.args, **node.kwargs) + true_graph = cond_args.true_graph + false_graph = cond_args.false_graph + graph_args = cond_args.cond_args + + def _set_meta_val(graph_node, graph_module, graph_args): + def _get_meta_val(node): + assert hasattr( + node, "meta" + ), f"'node' has no attribute named 'meta' (node: {node})" + assert ( + "val" in node.meta + ), f"val key not in node.meta (node: {node}, meta: {node.meta})" + return node.meta["val"] + + args, kwargs = pytree.tree_map_only( + torch.fx.Node, + _get_meta_val, + (graph_args, {}), + ) + + new_val = graph_module(*args, **kwargs) # type: ignore[operator] + graph_node.meta["val"] = new_val + + # [1] Fill in the meta values + # [2] Get the subgraph indices + true_graph_idx = -1 + false_graph_idx = -1 + for idx, (graph_module, name) in enumerate( + get_all_graph_modules(exported_program, subgraph_only=True), start=1 + ): + if true_graph.name == name: + _set_meta_val(true_graph, graph_module, graph_args) + true_graph_idx = idx + if false_graph.name == name: + _set_meta_val(false_graph, graph_module, graph_args) + false_graph_idx = idx + + assert "val" in true_graph.meta, f"{true_graph} has no node.meta['val']" + assert "val" in false_graph.meta, f"{false_graph} has no node.meta['val']" + assert true_graph_idx != -1 + assert false_graph_idx != -1 + + # [3] Create the translated IR (circle_custom.if_) + with graph.inserting_before(node): + circle_if = create_node( + graph, + torch.ops.circle_custom.if_, + args=( + cond_args.condition, + cond_args.true_graph, + cond_args.false_graph, + true_graph_idx, + false_graph_idx, + cond_args.cond_args, + ), + kwargs={}, + origin=node, + ) + + for t, f in zip(true_graph.meta["val"], false_graph.meta["val"]): + # Ensure the true and false branches produce compatible tensors. + assert type(t) == type(f) + assert t.shape == f.shape, f"{t.shape} != {f.shape}" + assert t.dtype == f.dtype, f"{t.dtype} != {f.dtype}" + + circle_if.meta["val"] = true_graph.meta["val"][0] + + # FIX ME UNLESS torch.ops.higher_order.cond generates this pattern + assert len(node.users) == 1 + # The original cond node should have exactly one user: the getitem extracting the result. + getitem_node = list(node.users.items())[0][0] + assert getitem_node.target == operator.getitem + getitem_node.replace_all_uses_with(circle_if) + + graph.eliminate_dead_code() + # Clean up any nodes that are no longer reachable after replacement. + graph.lint() + # Verify graph consistency. + graph_module.recompile() + # Recompile the graph module to reflect the updated graph. + + # Run only once. + return PassResult(False) diff --git a/tico/experimental/quantization/passes/fold_quant_ops.py b/tico/experimental/quantization/passes/fold_quant_ops.py index 48afa7d0..5a4fe9f7 100644 --- a/tico/experimental/quantization/passes/fold_quant_ops.py +++ b/tico/experimental/quantization/passes/fold_quant_ops.py @@ -72,10 +72,9 @@ class FoldQuantOps(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph for dq in graph.nodes: if dq.op != "call_function": diff --git a/tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py b/tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py index 2a442987..81e23e87 100644 --- a/tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +++ b/tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py @@ -437,10 +437,8 @@ class InsertQuantizeOnDtypeMismatch(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph for node in graph.nodes: diff --git a/tico/experimental/quantization/passes/propagate_qparam_backward.py b/tico/experimental/quantization/passes/propagate_qparam_backward.py index 1d7f94b9..b77cea82 100644 --- a/tico/experimental/quantization/passes/propagate_qparam_backward.py +++ b/tico/experimental/quantization/passes/propagate_qparam_backward.py @@ -45,10 +45,8 @@ class PropagateQParamBackward(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node): diff --git a/tico/experimental/quantization/passes/propagate_qparam_forward.py b/tico/experimental/quantization/passes/propagate_qparam_forward.py index eb3c8bee..a900cf0f 100644 --- a/tico/experimental/quantization/passes/propagate_qparam_forward.py +++ b/tico/experimental/quantization/passes/propagate_qparam_forward.py @@ -50,7 +50,7 @@ class PropagateQParamForward(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node): @@ -67,7 +67,6 @@ def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node): logger.debug(f"{src.name}'s quantparam is propagated to {dst.name}.") - graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph for node in graph.nodes: if node.op != "call_function": diff --git a/tico/experimental/quantization/passes/quantize_bias.py b/tico/experimental/quantization/passes/quantize_bias.py index 1b826093..40cb9104 100644 --- a/tico/experimental/quantization/passes/quantize_bias.py +++ b/tico/experimental/quantization/passes/quantize_bias.py @@ -41,10 +41,8 @@ class QuantizeBias(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph for node in graph.nodes: if node.op != "call_function": diff --git a/tico/experimental/quantization/passes/remove_weight_dequant_op.py b/tico/experimental/quantization/passes/remove_weight_dequant_op.py index 35fecc2b..b17ed0fa 100644 --- a/tico/experimental/quantization/passes/remove_weight_dequant_op.py +++ b/tico/experimental/quantization/passes/remove_weight_dequant_op.py @@ -96,10 +96,8 @@ class RemoveWeightDequantOp(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph for dq in graph.nodes: if not dq.op == "call_function": diff --git a/tico/interpreter/infer.py b/tico/interpreter/infer.py index 792699b7..1816395d 100644 --- a/tico/interpreter/infer.py +++ b/tico/interpreter/infer.py @@ -63,7 +63,7 @@ def infer(circle_binary: bytes, *args: Any, **kwargs: Any) -> Any: # Get input spec from circle binary. model = circle.Model.Model.GetRootAsModel(circle_binary, 0) - assert model.SubgraphsLength() == 1 + # assert model.SubgraphsLength() == 1 graph = model.Subgraphs(0) model_input_tensors = [ graph.Tensors(graph.Inputs(o)) for o in range(graph.InputsLength()) diff --git a/tico/passes/cast_aten_where_arg_type.py b/tico/passes/cast_aten_where_arg_type.py index f1d1e362..386defc2 100644 --- a/tico/passes/cast_aten_where_arg_type.py +++ b/tico/passes/cast_aten_where_arg_type.py @@ -109,9 +109,8 @@ class CastATenWhereArgType(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - graph_module = exported_program.graph_module graph = graph_module.graph modified = False diff --git a/tico/passes/cast_clamp_mixed_type_args.py b/tico/passes/cast_clamp_mixed_type_args.py index 80288534..1c4ac0f7 100644 --- a/tico/passes/cast_clamp_mixed_type_args.py +++ b/tico/passes/cast_clamp_mixed_type_args.py @@ -92,11 +92,12 @@ class CastClampMixedTypeArgs(PassBase): def __init__(self): super().__init__() - def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> bool: + def convert( + self, exported_program: ExportedProgram, node: torch.fx.Node, graph_module + ) -> bool: logger = logging.getLogger(__name__) modified = False - graph_module = exported_program.graph_module graph = graph_module.graph # clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor @@ -150,17 +151,15 @@ def _convert_arg(arg, arg_name: str): return modified - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: target_op = ops.aten.clamp - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: if not is_target_node(node, target_op): continue - modified |= self.convert(exported_program, node) + modified |= self.convert(exported_program, node, graph_module) graph.eliminate_dead_code() graph.lint() diff --git a/tico/passes/cast_mixed_type_args.py b/tico/passes/cast_mixed_type_args.py index 1168db62..010430a1 100644 --- a/tico/passes/cast_mixed_type_args.py +++ b/tico/passes/cast_mixed_type_args.py @@ -92,10 +92,8 @@ def __init__(self, preserve_ep_invariant=True): self.preserve_ep_invariant = preserve_ep_invariant # TODO Folding float and int values before this pass - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/const_prop_pass.py b/tico/passes/const_prop_pass.py index 6d3c96ac..988a0c8c 100644 --- a/tico/passes/const_prop_pass.py +++ b/tico/passes/const_prop_pass.py @@ -45,6 +45,7 @@ trace_graph_diff_on_pass, ) from tico.utils.utils import get_fake_mode +from tico.utils.subgraph import get_all_graph_modules def get_constant_placeholder_to_tensor_dict( @@ -54,23 +55,22 @@ def get_constant_placeholder_to_tensor_dict( Returns a dictionary of constant placeholder node to constant tensor. """ const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict() - graph_module = exported_program.graph_module - graph: torch.fx.Graph = graph_module.graph - for node in graph.nodes: - if node.op != "placeholder": - continue - tensor: Optional[torch.Tensor] = None - if is_param(exported_program, node): - tensor = get_param(exported_program, node) - elif is_buffer(exported_program, node): - tensor = get_buffer(exported_program, node) - elif is_lifted_tensor_constant(exported_program, node): - tensor = get_lifted_tensor_constant(exported_program, node) - - if tensor is not None: - assert node not in const_node_to_tensor - const_node_to_tensor[node] = tensor - + + for graph_module, _ in get_all_graph_modules(exported_program): + for node in graph_module.graph.nodes: + if node.op != "placeholder": + continue + tensor: Optional[torch.Tensor] = None + if is_param(exported_program, node): + tensor = get_param(exported_program, node) + elif is_buffer(exported_program, node): + tensor = get_buffer(exported_program, node) + elif is_lifted_tensor_constant(exported_program, node): + tensor = get_lifted_tensor_constant(exported_program, node) + + if tensor is not None: + assert node not in const_node_to_tensor + const_node_to_tensor[node] = tensor return const_node_to_tensor @@ -113,45 +113,47 @@ def get_data( def propagate_constants( - exported_program: ExportedProgram, + exported_program: ExportedProgram ) -> OrderedDict[torch.fx.Node, torch.Tensor]: """ Propagates constants and returns a dictionary of node to constant tensors of the graph. """ - const_node_to_tensor = get_constant_placeholder_to_tensor_dict(exported_program) + const_node_to_tensor = get_constant_placeholder_to_tensor_dict( + exported_program + ) - graph_module = exported_program.graph_module - graph: torch.fx.Graph = graph_module.graph - for node in graph.nodes: - if node.op != "call_function": - continue - if node.target in [ - torch.ops.quantized_decomposed.dequantize_per_channel.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - ]: - continue - if not has_constant_data( - [node.args, node.kwargs], - const_node_to_tensor, - ): - continue + for graph_module, _ in get_all_graph_modules(exported_program): + graph: torch.fx.Graph = graph_module.graph + for node in graph.nodes: + if node.op != "call_function": + continue + if node.target in [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ]: + continue + if not has_constant_data( + [node.args, node.kwargs], + const_node_to_tensor, + ): + continue - args_data, kwargs_data = pytree.tree_map( - lambda x: get_data(x, exported_program, const_node_to_tensor), - (node.args, node.kwargs), - ) + args_data, kwargs_data = pytree.tree_map( + lambda x: get_data(x, exported_program, const_node_to_tensor), + (node.args, node.kwargs), + ) - # propagate constant because all of its args are constant tensors. - with torch.no_grad(): - prop_constant_tensor = node.target(*args_data, **kwargs_data) - const_node_to_tensor[node] = prop_constant_tensor + # propagate constant because all of its args are constant tensors. + with torch.no_grad(): + prop_constant_tensor = node.target(*args_data, **kwargs_data) + const_node_to_tensor[node] = prop_constant_tensor return const_node_to_tensor def erase_constant_node( exported_program: ExportedProgram, - node: torch.fx.Node, + node_to_erase: torch.fx.Node, ) -> None: """ Remove corresponding tensor from param/constants dict. @@ -161,17 +163,36 @@ def erase_constant_node( A) They internally uses `exported_program.graph_signature.input_specs` and the `input_specs` are updated at the end of the const_prop_pass. """ - signature = exported_program.graph_signature - if name := signature.inputs_to_parameters.get(node.name, None): - exported_program.state_dict.pop(name, None) - elif name := signature.inputs_to_lifted_tensor_constants.get(node.name, None): - exported_program.constants.pop(name, None) - elif name := signature.inputs_to_buffers.get(node.name, None): - exported_program.constants.pop(name, None) - exported_program.state_dict.pop(name, None) - # Remove from graph. - exported_program.graph.erase_node(node) + for graph_module, _ in get_all_graph_modules(exported_program): + if node_to_erase in graph_module.graph.nodes: + graph_module.graph.erase_node(node_to_erase) + + remove_from_program = True + for graph_module, _ in get_all_graph_modules(exported_program): + for node in graph_module.graph.nodes: + if node.name == node_to_erase.name: + remove_from_program = False + + if remove_from_program: + signature = exported_program.graph_signature + if name := signature.inputs_to_parameters.get(node_to_erase.name, None): + exported_program.state_dict.pop(name, None) + elif name := signature.inputs_to_lifted_tensor_constants.get(node_to_erase.name, None): + exported_program.constants.pop(name, None) + elif name := signature.inputs_to_buffers.get(node_to_erase.name, None): + exported_program.constants.pop(name, None) + exported_program.state_dict.pop(name, None) + +def get_first_node(exported_program, graph): + first_node = get_first_user_input(exported_program, graph) + if not first_node: + # Placeholder nodes must be the first N nodes in the nodes list of a graph. + # Therefore, insert the newly created placeholders at the start of the node list. + assert graph.nodes + first_node = list(graph.nodes)[0] + + return first_node def create_constant_placeholder( @@ -184,22 +205,21 @@ def create_constant_placeholder( placeholders = [] fake_mode = get_fake_mode(exported_program) - first_user_input = get_first_user_input(exported_program) - if not first_user_input: - # Placeholder nodes must be the first N nodes in the nodes list of a graph. - # Therefore, insert the newly created placeholders at the start of the node list. - assert exported_program.graph.nodes - first_node = list(exported_program.graph.nodes)[0] - first_user_input = first_node # Iterate over nodes in reverse order to insert created placeholder before the `first_user_input`. for node, prop_constant_tensor in reversed(const_node_to_tensor.items()): + if node.graph is not exported_program.graph_module.graph: + # Do not propagates constants of subgraphs + # WHY? + # Uplifting constants to placeholder may alter subgraph signature which may break control flow IR's invariant. + # They assumes that control flow ir's `argument` operands' number equals to those of subgraph's input signatures. + continue + # All users of this constant node are also constant, so we don't need to create a new constant node. if all(x in const_node_to_tensor for x in node.users): - # All users of this constant node are also constant, so we don't need to create a new constant node. erase_constant_node(exported_program, node) continue - if node.op == "placeholder": + if node.op == "placeholder":# and graph_module is exported_program.graph_module: continue # Add `prop_constant_tensor` to program.state_dict. @@ -208,17 +228,24 @@ def create_constant_placeholder( ) # Insert a new placeholder node for the propagated constant tensor. - with exported_program.graph.inserting_before(first_user_input): - const_placeholder_node = exported_program.graph.placeholder( + with node.graph.inserting_before(get_first_node(exported_program, node.graph)): + const_placeholder_node = node.graph.placeholder( prop_constant_tensor_fqn ) # The key here should be same with "target" arg of InputSpec when creating input specs. exported_program.constants[prop_constant_tensor_fqn] = prop_constant_tensor + print(f"exported_program.constants[{prop_constant_tensor_fqn}] = {prop_constant_tensor}") # Replace the original node with the new constant node. node.replace_all_uses_with(const_placeholder_node, propagate_meta=True) - exported_program.graph.erase_node(node) + print(node) + for graph_module, _ in get_all_graph_modules(exported_program): + if node in graph_module.graph.nodes: + graph_module.graph.print_tabular() + graph_module.graph.erase_node(node) + graph_module.graph.print_tabular() + breakpoint() # Update the meta data of the new placeholder node. const_placeholder_node.meta["val"] = fake_mode.from_tensor( @@ -264,42 +291,56 @@ class ConstPropPass(PassBase): def __init__(self) -> None: super().__init__() - - def call(self, exported_program: ExportedProgram) -> PassResult: + + def call(self, exported_program: ExportedProgram, _) -> PassResult: + from tico.utils.subgraph import get_all_graph_modules logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module - graph: torch.fx.Graph = graph_module.graph - - # [1], [2] - const_node_to_tensor: OrderedDict[ - torch.fx.Node, torch.Tensor - ] = propagate_constants(exported_program) + all_placeholders = [] + all_new_name_to_spec = {} + + const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict() + # [1], [2] + const_node_to_tensor.update(propagate_constants(exported_program)) + print(f"const_node_to_tensor: {const_node_to_tensor}") # [3] placeholders = create_constant_placeholder( const_node_to_tensor, exported_program ) + print(f"placeholders: {placeholders}") # [4] new_name_to_spec = create_input_specs(placeholders) - + print(f"new_name_to_spec: {new_name_to_spec}") + + all_placeholders.extend(placeholders) + all_new_name_to_spec.update(new_name_to_spec) + + + for graph_module, _ in get_all_graph_modules(exported_program): + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + # graph_module.graph.print_tabular() + # [5] # Get existing input specs. existing_name_to_spec = { s.arg.name: s for s in exported_program.graph_signature.input_specs } + # Add the new constants to existing input specs dict. - existing_name_to_spec.update(new_name_to_spec) - # Generate new input spec. + existing_name_to_spec.update(all_new_name_to_spec) + + # Generate new input spec. + # I/O for root graph only new_input_specs = [] - for node in exported_program.graph.nodes: - if node.op != "placeholder": + for node in exported_program.graph_module.graph.nodes: + if node.op != "placeholder": # and graph_module is not exported_program.graph_module: continue assert node.name in existing_name_to_spec, node.name new_input_specs.append(existing_name_to_spec[node.name]) exported_program.graph_signature.input_specs = new_input_specs - graph.eliminate_dead_code() - graph_module.recompile() + # graph.eliminate_dead_code() + # graph_module.recompile() logger.debug("Constant nodes are propagated") # Constant folding can be done with only one time run. Let's set `modified` to False. diff --git a/tico/passes/convert_conv1d_to_conv2d.py b/tico/passes/convert_conv1d_to_conv2d.py index d7f511f1..1e9cc9aa 100644 --- a/tico/passes/convert_conv1d_to_conv2d.py +++ b/tico/passes/convert_conv1d_to_conv2d.py @@ -141,10 +141,8 @@ def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> boo modified = True return modified - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: target_conv_op = [torch.ops.aten.conv1d.default, torch.ops.aten.conv1d.padding] - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/convert_layout_op_to_reshape.py b/tico/passes/convert_layout_op_to_reshape.py index 443f37a4..bdaf1968 100644 --- a/tico/passes/convert_layout_op_to_reshape.py +++ b/tico/passes/convert_layout_op_to_reshape.py @@ -38,10 +38,8 @@ class ConvertLayoutOpToReshape(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False diff --git a/tico/passes/convert_matmul_to_linear.py b/tico/passes/convert_matmul_to_linear.py index 2db06fa4..d0337eaa 100644 --- a/tico/passes/convert_matmul_to_linear.py +++ b/tico/passes/convert_matmul_to_linear.py @@ -44,8 +44,7 @@ class MatmulToLinearConverter(Converter): def __init__(self): super().__init__() - def convert(self, exported_program, node) -> torch.fx.Node: - graph_module = exported_program.graph_module + def convert(self, exported_program, graph_module, node) -> torch.fx.Node: graph = graph_module.graph mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type] @@ -73,7 +72,7 @@ class RhsConstMatmulToLinearConverter(MatmulToLinearConverter): def __init__(self): super().__init__() - def match(self, exported_program, node) -> bool: + def match(self, exported_program, graph_module, node) -> bool: if not node.target == torch.ops.aten.mm.default: return False @@ -99,7 +98,7 @@ class LhsConstMatmulToLinearConverter(MatmulToLinearConverter): def __init__(self): super().__init__() - def match(self, exported_program, node) -> bool: + def match(self, exported_program, graph_module, node) -> bool: if not node.target == torch.ops.aten.mm.default: return False @@ -114,8 +113,8 @@ def match(self, exported_program, node) -> bool: return True return False - def convert(self, exported_program, node) -> torch.fx.Node: - return super().convert(exported_program, node) + def convert(self, exported_program, graph_module, node) -> torch.fx.Node: + return super().convert(exported_program, graph_module, node) class SingleBatchLhsConstBmmToLinearConverter(Converter): @@ -154,7 +153,7 @@ class SingleBatchLhsConstBmmToLinearConverter(Converter): def __init__(self): super().__init__() - def match(self, exported_program, node) -> bool: + def match(self, exported_program, graph_module, node) -> bool: if not node.target == torch.ops.aten.bmm.default: return False @@ -185,7 +184,7 @@ def match(self, exported_program, node) -> bool: return True - def convert(self, exported_program, node) -> torch.fx.Node: + def convert(self, exported_program, graph_module, node) -> torch.fx.Node: graph_module = exported_program.graph_module graph = graph_module.graph @@ -284,10 +283,9 @@ def __init__( if enable_single_batch_lhs_const_bmm: self.converters.append(SingleBatchLhsConstBmmToLinearConverter()) - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: @@ -295,10 +293,10 @@ def call(self, exported_program: ExportedProgram) -> PassResult: continue for converter in self.converters: - if not converter.match(exported_program, node): + if not converter.match(exported_program, graph_module, node): continue - new_node = converter.convert(exported_program, node) + new_node = converter.convert(exported_program, graph_module, node) modified = True logger.debug( f"{node.name} is replaced with {new_node.name} operator (permute + linear)" diff --git a/tico/passes/convert_repeat_to_expand_copy.py b/tico/passes/convert_repeat_to_expand_copy.py index 82d6f93f..a8183748 100644 --- a/tico/passes/convert_repeat_to_expand_copy.py +++ b/tico/passes/convert_repeat_to_expand_copy.py @@ -38,10 +38,8 @@ class ConvertRepeatToExpandCopy(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/convert_to_relu6.py b/tico/passes/convert_to_relu6.py index 76d2a576..fec1bfa5 100644 --- a/tico/passes/convert_to_relu6.py +++ b/tico/passes/convert_to_relu6.py @@ -155,10 +155,8 @@ def __init__(self): ConvertDoubleClampsToReLU6(), ] - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/decompose_addmm.py b/tico/passes/decompose_addmm.py index 67793fc7..d6d8fe43 100644 --- a/tico/passes/decompose_addmm.py +++ b/tico/passes/decompose_addmm.py @@ -57,8 +57,8 @@ class DecomposeAddmm(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: - gm = exported_program.graph_module + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: + gm = graph_module graph: torch.fx.Graph = gm.graph modified = False diff --git a/tico/passes/decompose_batch_norm.py b/tico/passes/decompose_batch_norm.py index eccae901..a6586040 100644 --- a/tico/passes/decompose_batch_norm.py +++ b/tico/passes/decompose_batch_norm.py @@ -74,10 +74,10 @@ class DecomposeBatchNorm(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - gm = exported_program.graph_module + gm = graph_module graph: torch.fx.Graph = gm.graph modified = False diff --git a/tico/passes/decompose_fake_quantize.py b/tico/passes/decompose_fake_quantize.py index e26dda3d..b3deb590 100644 --- a/tico/passes/decompose_fake_quantize.py +++ b/tico/passes/decompose_fake_quantize.py @@ -63,10 +63,10 @@ def forward(self, x): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: modified = False - gm = exported_program.graph_module + gm = graph_module g = gm.graph qd = torch.ops.quantized_decomposed # type: ignore[return] for node in gm.graph.nodes: diff --git a/tico/passes/decompose_fake_quantize_tensor_qparams.py b/tico/passes/decompose_fake_quantize_tensor_qparams.py index c263c91f..7d0987bd 100644 --- a/tico/passes/decompose_fake_quantize_tensor_qparams.py +++ b/tico/passes/decompose_fake_quantize_tensor_qparams.py @@ -195,10 +195,10 @@ def forward(self, x: "f32[4]"): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: modified = False - gm = exported_program.graph_module + gm = graph_module g = gm.graph qd = torch.ops.quantized_decomposed # type: ignore[return] for node in gm.graph.nodes: diff --git a/tico/passes/decompose_group_norm.py b/tico/passes/decompose_group_norm.py index d184644f..2ba73acc 100644 --- a/tico/passes/decompose_group_norm.py +++ b/tico/passes/decompose_group_norm.py @@ -124,8 +124,8 @@ def _insert_norm(self, graph, tensor, eps, origin): graph, torch.ops.aten.mul.Tensor, (deviation, inverse_std), origin=origin ) - def call(self, exported_program: ExportedProgram) -> PassResult: - gm = exported_program.graph_module + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: + gm = graph_module graph: torch.fx.Graph = gm.graph modified = False diff --git a/tico/passes/decompose_grouped_conv2d.py b/tico/passes/decompose_grouped_conv2d.py index a2082eb8..41215e54 100644 --- a/tico/passes/decompose_grouped_conv2d.py +++ b/tico/passes/decompose_grouped_conv2d.py @@ -81,10 +81,10 @@ class DecomposeGroupedConv2d(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - gm = exported_program.graph_module + gm = graph_module graph: torch.fx.Graph = gm.graph modified = False diff --git a/tico/passes/decompose_slice_scatter.py b/tico/passes/decompose_slice_scatter.py index e39242be..0e908c5b 100644 --- a/tico/passes/decompose_slice_scatter.py +++ b/tico/passes/decompose_slice_scatter.py @@ -75,10 +75,8 @@ class DecomposeSliceScatter(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph: torch.fx.Graph = graph_module.graph modified = False diff --git a/tico/passes/extract_dtype_kwargs.py b/tico/passes/extract_dtype_kwargs.py index 19357e89..3cab0f3b 100644 --- a/tico/passes/extract_dtype_kwargs.py +++ b/tico/passes/extract_dtype_kwargs.py @@ -103,8 +103,7 @@ def __init__(self): self.target_ops = dict() self.target_ops[torch.ops.aten.full_like.default] = _extract_to_output - def call(self, exported_program: ExportedProgram) -> PassResult: - graph_module = exported_program.graph_module + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: graph: torch.fx.Graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/fill_meta_val.py b/tico/passes/fill_meta_val.py index b35631b8..27db1629 100644 --- a/tico/passes/fill_meta_val.py +++ b/tico/passes/fill_meta_val.py @@ -29,10 +29,8 @@ class FillMetaVal(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False # To make sure graph is topologically sorted diff --git a/tico/passes/fuse_leading_unsqueeze_reshape.py b/tico/passes/fuse_leading_unsqueeze_reshape.py index 532f5f90..03678747 100644 --- a/tico/passes/fuse_leading_unsqueeze_reshape.py +++ b/tico/passes/fuse_leading_unsqueeze_reshape.py @@ -49,11 +49,10 @@ class FuseLeadingUnsqueezeReshape(PassBase): x - aten.reshape([1]*k + s1) - aten.permute(list(range(k)) + [d+k for d in p]) """ - def call(self, ep: ExportedProgram) -> PassResult: + def call(self, ep: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - gm = ep.graph_module - graph = gm.graph + graph = graph_module.graph modified = False for reshape_back in graph.nodes: if not is_target_node(reshape_back, ops.aten.reshape): @@ -107,6 +106,6 @@ def call(self, ep: ExportedProgram) -> PassResult: if modified: graph.eliminate_dead_code() graph.lint() - gm.recompile() + graph_module.recompile() return PassResult(modified) diff --git a/tico/passes/fuse_redundant_reshape_to_mean.py b/tico/passes/fuse_redundant_reshape_to_mean.py index 9d9a88e9..1dfbb03a 100644 --- a/tico/passes/fuse_redundant_reshape_to_mean.py +++ b/tico/passes/fuse_redundant_reshape_to_mean.py @@ -39,10 +39,8 @@ class FuseRedundantReshapeToMean(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/legalize_causal_mask_value.py b/tico/passes/legalize_causal_mask_value.py index 01a99222..b3738238 100644 --- a/tico/passes/legalize_causal_mask_value.py +++ b/tico/passes/legalize_causal_mask_value.py @@ -43,14 +43,12 @@ def __init__(self, enabled: bool = False): super().__init__() self.enabled = enabled - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: if not self.enabled: return PassResult(False) new_mask = -120 # Make it configurable logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False diff --git a/tico/passes/legalize_predefined_layout_operators.py b/tico/passes/legalize_predefined_layout_operators.py index 5e9bd690..250797cf 100644 --- a/tico/passes/legalize_predefined_layout_operators.py +++ b/tico/passes/legalize_predefined_layout_operators.py @@ -434,7 +434,7 @@ def legalize_avg_pool2d(self, exported_program, node) -> bool: modified = True return modified - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: target_to_legalize_func = { torch.ops.aten.conv2d.default: self.legalize_conv2d, torch.ops.aten.conv2d.padding: self.legalize_conv2d, @@ -443,8 +443,6 @@ def call(self, exported_program: ExportedProgram) -> PassResult: torch.ops.aten.avg_pool2d.default: self.legalize_avg_pool2d, torch.ops.aten.instance_norm.default: self.legalize_instance_norm, } - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/lower_pow2_to_mul.py b/tico/passes/lower_pow2_to_mul.py index 70074dcc..7227a563 100644 --- a/tico/passes/lower_pow2_to_mul.py +++ b/tico/passes/lower_pow2_to_mul.py @@ -38,10 +38,8 @@ class LowerPow2ToMul(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/lower_to_resize_nearest_neighbor.py b/tico/passes/lower_to_resize_nearest_neighbor.py index ae0d6c1a..039a518a 100644 --- a/tico/passes/lower_to_resize_nearest_neighbor.py +++ b/tico/passes/lower_to_resize_nearest_neighbor.py @@ -197,11 +197,10 @@ def close_enough(x, y, epsilon=1e-10): node.replace_all_uses_with(nhwc_to_nchw, propagate_meta=True) return resize_nearest_neighbor - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) modified = False - graph_module = exported_program.graph_module graph = graph_module.graph for node in graph.nodes: if not is_target_node( diff --git a/tico/passes/lower_to_slice.py b/tico/passes/lower_to_slice.py index 20bb4afc..240fed4a 100644 --- a/tico/passes/lower_to_slice.py +++ b/tico/passes/lower_to_slice.py @@ -77,10 +77,8 @@ class LowerSelectCopyToSlice(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: @@ -157,10 +155,8 @@ class LowerIndexSelectToSlice(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/merge_consecutive_cat.py b/tico/passes/merge_consecutive_cat.py index 110f1419..bb825b64 100644 --- a/tico/passes/merge_consecutive_cat.py +++ b/tico/passes/merge_consecutive_cat.py @@ -31,10 +31,8 @@ class MergeConsecutiveCat(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for cat in graph.nodes: diff --git a/tico/passes/remove_nop.py b/tico/passes/remove_nop.py index 57d686f3..da434ee0 100644 --- a/tico/passes/remove_nop.py +++ b/tico/passes/remove_nop.py @@ -45,10 +45,8 @@ class RemoveNop(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/remove_redundant_assert_nodes.py b/tico/passes/remove_redundant_assert_nodes.py index a12abdc9..f0a96971 100644 --- a/tico/passes/remove_redundant_assert_nodes.py +++ b/tico/passes/remove_redundant_assert_nodes.py @@ -37,8 +37,7 @@ class RemoveRedundantAssertionNodes(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: - graph_module = exported_program.graph_module + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/remove_redundant_expand.py b/tico/passes/remove_redundant_expand.py index 2402af5e..c862fa3d 100644 --- a/tico/passes/remove_redundant_expand.py +++ b/tico/passes/remove_redundant_expand.py @@ -32,10 +32,8 @@ class RemoveRedundantExpand(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/remove_redundant_permute.py b/tico/passes/remove_redundant_permute.py index ab806bec..a7b73505 100644 --- a/tico/passes/remove_redundant_permute.py +++ b/tico/passes/remove_redundant_permute.py @@ -60,7 +60,7 @@ class RemoveRedundantPermutePattern1(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: """ [BEFORE] (AxBxC) - aten.permute_1 - aten.permute_2 - (OUT_SHAPE) @@ -72,8 +72,6 @@ def call(self, exported_program: ExportedProgram) -> PassResult: """ logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for permute2 in graph.nodes: diff --git a/tico/passes/remove_redundant_reshape.py b/tico/passes/remove_redundant_reshape.py index 87f4d83d..76501f68 100644 --- a/tico/passes/remove_redundant_reshape.py +++ b/tico/passes/remove_redundant_reshape.py @@ -55,7 +55,7 @@ class RemoveRedundantReshapePattern1(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: """ [BEFORE] `(AxBxC) - aten.reshape` - (1xAxBxC) - `aten.permute` - (1xAxCxB) - `aten.mul` - (1xAxCxB) - `aten.reshape - (AxCxB)` @@ -63,8 +63,6 @@ def call(self, exported_program: ExportedProgram) -> PassResult: `(AxBxC) - `aten.permute` - (AxCxB) - `aten.mul` - (AxCxB)` """ logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for reshape1 in graph.nodes: @@ -139,7 +137,7 @@ class RemoveRedundantReshapePattern2(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: """ [BEFORE] `(AxBxC) - aten.reshape` - (1xAxBxC) - `aten.permute` - (Bx1xAxC) - `aten.reshape - (Bx(A*C))` @@ -147,8 +145,6 @@ def call(self, exported_program: ExportedProgram) -> PassResult: `(AxBxC) - `aten.permute` - (BxAxC) - `aten.reshape` - (Bx(A*C))` """ logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for reshape1 in graph.nodes: @@ -218,7 +214,7 @@ class RemoveRedundantReshapePattern3(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: """ [BEFORE] (AxBxC) - aten.reshape - (1xAxBxC) - aten.add - (1xAxBxC) - aten.softmax - (1xAxBxC) - aten.reshape - (AxBxC) @@ -230,8 +226,6 @@ def call(self, exported_program: ExportedProgram) -> PassResult: (AxBxC) / (add) (softmax) """ logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for reshape_1 in graph.nodes: @@ -332,18 +326,16 @@ class RemoveRedundantReshapePattern4(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: """ NOTE: Below graph is just an example. This pattern matches not only for the 3D tensors. What this pattern aims to remove is that the consecutive `aten.reshape` ops. [BEFORE] - (AxBxC) - aten.reshape - (AxB'xC') - aten.reshape - (A'xB''xC') + (AxBxC) - aten.reshape - (AxB'xC') - aten.reshape - (A'xB''xC) [AFTER] - (AxBxC) - aten.reshape - (A'xB''xC') + (AxBxC) - aten.reshape - (A'xB''xC) """ logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for reshape1 in graph.nodes: @@ -399,7 +391,7 @@ class RemoveRedundantReshapePattern5(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: """ [BEFORE] (AxBxC) - aten.reshape - (AxBxC) @@ -407,8 +399,6 @@ def call(self, exported_program: ExportedProgram) -> PassResult: (AxBxC) """ logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False diff --git a/tico/passes/remove_redundant_slice.py b/tico/passes/remove_redundant_slice.py index a71f4c85..d8e14c38 100644 --- a/tico/passes/remove_redundant_slice.py +++ b/tico/passes/remove_redundant_slice.py @@ -32,10 +32,8 @@ class RemoveRedundantSlice(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/remove_redundant_to_copy.py b/tico/passes/remove_redundant_to_copy.py index 375ffb57..9784256f 100644 --- a/tico/passes/remove_redundant_to_copy.py +++ b/tico/passes/remove_redundant_to_copy.py @@ -35,10 +35,8 @@ class RemoveRedundantToCopy(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/restore_linear.py b/tico/passes/restore_linear.py index f9ed5351..fe9ba4d9 100644 --- a/tico/passes/restore_linear.py +++ b/tico/passes/restore_linear.py @@ -48,10 +48,8 @@ class RestoreLinear(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/passes/segment_index_select.py b/tico/passes/segment_index_select.py index 31411043..212fa6f0 100644 --- a/tico/passes/segment_index_select.py +++ b/tico/passes/segment_index_select.py @@ -72,10 +72,8 @@ class SegmentIndexSelectConst(PassBase): def __init__(self): super().__init__() - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, graph_module) -> PassResult: logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module graph = graph_module.graph modified = False for node in graph.nodes: diff --git a/tico/serialize/circle_graph.py b/tico/serialize/circle_graph.py index aacba4f2..f4a5e18e 100644 --- a/tico/serialize/circle_graph.py +++ b/tico/serialize/circle_graph.py @@ -69,7 +69,9 @@ class CircleModel(circle.Model.ModelT): def __init__(self): super().__init__() self.subgraphs: List[circle.SubGraph.SubGraphT] = [] - self.buffers: List[circle.Buffer.BufferT] = [] + self.buffers: List[circle.Buffer.BufferT] = [ + circle.Buffer.BufferT() + ] # Add empty buffer at the front def add_subgraph(self, graph: circle.SubGraph.SubGraphT) -> None: self.subgraphs.append(graph) @@ -129,7 +131,7 @@ def add_input(self, input_name: str) -> None: def add_output(self, output: Any) -> None: if isinstance(output, str): - assert output in self.name_to_tid + assert output in self.name_to_tid, f"{output} is not in {self.name_to_tid}" output_name = output elif isinstance(output, int | float): # output is built-in type. diff --git a/tico/serialize/circle_serializer.py b/tico/serialize/circle_serializer.py index 9ac21cc1..5f70d363 100644 --- a/tico/serialize/circle_serializer.py +++ b/tico/serialize/circle_serializer.py @@ -28,11 +28,14 @@ from tico.serialize.operators.node_visitor import get_node_visitors from tico.utils import logging from tico.utils.serialize import finalise_tensor_names, validate_tensor_shapes +from tico.utils.subgraph import get_all_graph_modules multiple_output_ops = [ torch.ops.aten.split_with_sizes.default, torch.ops.aten.max.dim, + # torch.ops.circle_custom.if_.default, + # torch.ops.circle_custom.if_, ] @@ -61,61 +64,68 @@ def build_circle( """ logger = logging.getLogger(__name__) builder = flatbuffers.Builder() - model, graph = _initialize_model() + model = CircleModel() - # Export tensors - _export_tensors(graph, ep) + op_codes: Dict[OpCode, int] = {} + for graph_module, name in get_all_graph_modules(ep): + ep_graph = graph_module.graph + graph = CircleSubgraph(model) - # Register inputs - logger.debug("---------------Register inputs--------------") - for in_spec in ep.graph_signature.input_specs: - if in_spec.kind != InputKind.USER_INPUT: - continue - if isinstance(in_spec.arg, ConstantArgument): - # ConstantArgument is ignored when option is given - if config.get("remove_constant_input"): - continue - # NoneType ConstantArgument is ignored. - if in_spec.arg.value == None: + # Export tensors + if name == "": # root graph + _export_tensors(graph, ep_graph, ep) + else: + _export_tensors_for_subgraph(graph, ep_graph, ep) + + if name == "": # root graph + # Register inputs + logger.debug("---------------Register inputs--------------") + for in_spec in ep.graph_signature.input_specs: + if in_spec.kind != InputKind.USER_INPUT: + continue + if isinstance(in_spec.arg, ConstantArgument): + # ConstantArgument is ignored when option is given + if config.get("remove_constant_input"): + continue + # NoneType ConstantArgument is ignored. + if in_spec.arg.value == None: + continue + arg_name = in_spec.arg.name + graph.add_input(arg_name) + logger.debug(f"Registered input: {arg_name}") + + # Register outputs + logger.debug("---------------Register outputs--------------") + for user_output in ep.graph_signature.user_outputs: + if user_output == None: + logger.debug("Ignore 'None' output") + continue + + graph.add_output(user_output) + logger.debug(f"Registered output: {user_output}") + + # Export operators + logger.debug("---------------Export operators--------------") + visitors = get_node_visitors(op_codes, graph) + for node in ep_graph.nodes: + if node.op != "call_function": continue - arg_name = in_spec.arg.name - graph.add_input(arg_name) - logger.debug(f"Registered input: {arg_name}") - - # Register outputs - logger.debug("---------------Register outputs--------------") - for user_output in ep.graph_signature.user_outputs: - if user_output == None: - logger.debug("Ignore 'None' output") - continue - graph.add_output(user_output) - logger.debug(f"Registered output: {user_output}") - - # Export operators - logger.debug("---------------Export operators--------------") - op_codes: Dict[OpCode, int] = {} - visitors = get_node_visitors(op_codes, graph) - for node in ep.graph.nodes: - if node.op != "call_function": - continue - - opcode = node.target - if opcode == operator.getitem: - continue - if opcode not in visitors: - raise RuntimeError(f"{opcode} is not yet supported") - circle_op = visitors[opcode].define_node(node) + opcode = node.target + if opcode == operator.getitem: + continue + if opcode not in visitors: + raise RuntimeError(f"{opcode} is not yet supported") + circle_op = visitors[opcode].define_node(node) - if circle_op: - graph.add_operator(circle_op) - logger.debug(f"call_function: {node.name} ({opcode}) Op exported.") + if circle_op: + graph.add_operator(circle_op) + logger.debug(f"call_function: {node.name} ({opcode}) Op exported.") - finalise_tensor_names(graph) - validate_tensor_shapes(graph) + finalise_tensor_names(graph) + validate_tensor_shapes(graph) - # Register subgraph - model.subgraphs.append(graph) + model.subgraphs.append(graph) # Encode operator codes model.operatorCodes = [ @@ -133,7 +143,7 @@ def build_circle( return bytes(buf) -def _export_tensors(graph: CircleSubgraph, ep: ExportedProgram) -> None: +def _export_tensors(graph: CircleSubgraph, ep_graph, ep: ExportedProgram) -> None: """Export all tensors from the exported program to the circle graph. Args: @@ -144,7 +154,7 @@ def _export_tensors(graph: CircleSubgraph, ep: ExportedProgram) -> None: logger.debug("---------------Export tensors--------------") buf_name_to_data = {name: buf for name, buf in ep.named_buffers()} - for node in ep.graph.nodes: + for node in ep_graph.nodes: if node.op == "call_function": if node.target in multiple_output_ops: continue @@ -157,15 +167,67 @@ def _export_tensors(graph: CircleSubgraph, ep: ExportedProgram) -> None: logger.debug(f"call_function: {node.name} tensor exported.") elif node.op == "placeholder": - _handle_placeholder_node(graph, node, ep, buf_name_to_data) + _handle_placeholder_node(graph, node, ep_graph, ep, buf_name_to_data) elif node.op == "get_attr": _handle_get_attr_node(graph, node) + elif node.op == "output": + for output in node.args[0]: + if isinstance(output, torch.fx.Node): + assert graph.has_tensor(output.name), f"{output}" + continue + + elif node.op == "call_method": + raise AssertionError("Not yet implemented") + + elif node.op == "call_module": + raise AssertionError("Not yet implemented") + + else: + raise AssertionError(f"Unknown fx.Node op {node.op}") + + +def _export_tensors_for_subgraph( + graph: CircleSubgraph, ep_graph, ep: ExportedProgram +) -> None: + """Export all tensors from the exported program to the circle graph. + + Args: + graph: The CircleSubgraph to add tensors to + ep: The exported PyTorch program + """ + logger = logging.getLogger(__name__) + logger.debug("---------------Export tensors--------------") + buf_name_to_data = { + name: buf for name, buf in ep.named_buffers() + } # model-wise context + + for node in ep_graph.nodes: + if node.op == "call_function": + if node.target in multiple_output_ops: + continue + node_val = node.meta["val"] + if node.name == "cond": + continue + if node_val.layout != torch.strided: + raise RuntimeError( + f"Only support dense tensors (node layout: {node_val.layout})" + ) + graph.add_tensor_from_node(node) + logger.debug(f"call_function: {node.name} tensor exported.") + + elif node.op == "placeholder": + _handle_placeholder_node(graph, node, ep_graph, ep, buf_name_to_data) + graph.add_input(node.name) # This is added for subgraph + + elif node.op == "get_attr": + _handle_get_attr_node(graph, node) elif node.op == "output": for output in node.args[0]: if isinstance(output, torch.fx.Node): assert graph.has_tensor(output.name) + graph.add_output(output.name) # This is added for subgraph continue elif node.op == "call_method": @@ -181,6 +243,7 @@ def _export_tensors(graph: CircleSubgraph, ep: ExportedProgram) -> None: def _handle_placeholder_node( graph: CircleSubgraph, node: torch.fx.Node, + ep_graph, ep: ExportedProgram, buf_name_to_data: dict, ) -> None: @@ -326,20 +389,22 @@ def _handle_get_attr_node( node: The get_attr node to process """ assert isinstance(node.target, str) - attr_tensor = getattr(node.graph.owning_module, node.target) - - if not isinstance(attr_tensor, torch.Tensor): - raise ValueError(f"Attribute {node.target} is not a tensor") - - attr_shape, attr_shape_signature = to_circle_shape(attr_tensor.shape) - - graph.add_tensor_from_scratch( - prefix=node.name, - shape=attr_shape, - shape_signature=attr_shape_signature, - dtype=to_circle_dtype(attr_tensor.dtype), - source_node=node, - ) - - logger = logging.getLogger(__name__) - logger.debug(f"Exported attribute tensor: {node.name}") + attr = getattr(node.graph.owning_module, node.target) + + if isinstance(attr, torch.fx.graph_module.GraphModule): + pass + elif isinstance(attr, torch.Tensor): + attr_shape, attr_shape_signature = to_circle_shape(attr.shape) + + graph.add_tensor_from_scratch( + prefix=node.name, + shape=attr_shape, + shape_signature=attr_shape_signature, + dtype=to_circle_dtype(attr.dtype), + source_node=node, + ) + + logger = logging.getLogger(__name__) + logger.debug(f"Exported attribute tensor: {node.name}") + else: + raise ValueError(f"Unsupported get_attr target type {type(attr)}") diff --git a/tico/serialize/operators/op_circle_if.py b/tico/serialize/operators/op_circle_if.py new file mode 100644 index 00000000..f92eedb8 --- /dev/null +++ b/tico/serialize/operators/op_circle_if.py @@ -0,0 +1,63 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, TYPE_CHECKING + +if TYPE_CHECKING: + import torch._ops + import torch.fx +import torch +from circle_schema import circle + +from tico.serialize.circle_graph import CircleSubgraph +from tico.serialize.operators.hashable_opcode import OpCode +from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor +from tico.serialize.operators.utils import create_builtin_operator, get_op_index +from tico.utils.validate_args_kwargs import CircleIfArgs + + +@register_node_visitor +class CircleIfVisitor(NodeVisitor): + target: List[torch._ops.OpOverload] = [ + torch.ops.circle_custom.if_, + torch.ops.circle_custom.if_.default, + ] + + def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): + super().__init__(op_codes, graph) + + def define_node( + self, + node: torch.fx.Node, + ) -> circle.Operator.OperatorT: + op_index = get_op_index( + circle.BuiltinOperator.BuiltinOperator.IF, self._op_codes + ) + if_args = CircleIfArgs(*node.args, **node.kwargs) + + pred = if_args.pred + arguments = if_args.if_args + then_graph_idx = if_args.then_graph_idx + else_graph_idx = if_args.else_graph_idx + + inputs = [pred, *arguments] + outputs = [node] + # outputs = [i for i in node.users.keys()] + operator = create_builtin_operator(self.graph, op_index, inputs, outputs) + operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.IfOptions + operator.builtinOptions = circle.IfOptions.IfOptionsT() + operator.builtinOptions.thenSubgraphIndex = then_graph_idx + operator.builtinOptions.elseSubgraphIndex = else_graph_idx + + return operator diff --git a/tico/serialize/operators/utils.py b/tico/serialize/operators/utils.py index 462001d6..1b87f5b3 100644 --- a/tico/serialize/operators/utils.py +++ b/tico/serialize/operators/utils.py @@ -41,14 +41,33 @@ def get_op_index(opcode: int, opcode_map: Dict[OpCode, int]) -> int: return op_index +import torch + # TODO Move this to CircleSubGraph def create_builtin_operator( graph, op_index: int, inputs: List, outputs: List ) -> circle.Operator.OperatorT: operator = circle.Operator.OperatorT() operator.opcodeIndex = op_index - operator.inputs = [graph.get_tid(input) for input in inputs] - operator.outputs = [graph.get_tid(output) for output in outputs] + + operator.inputs = [] + for inp in inputs: + if isinstance(inp, torch.fx.immutable_collections.immutable_list): + operator.inputs.append( + tuple(graph.get_tid(inp_item) for inp_item in inp) + ) # TODO: extend to multiple tuple processing + print(f"input: {inp}") + else: + operator.inputs.append(graph.get_tid(inp)) + operator.outputs = [] + for outp in outputs: + if isinstance(outp, torch.fx.immutable_collections.immutable_list): + print(f"output: {outp}") + operator.outputs.append( + tuple(graph.get_tid(outp_item) for outp_item in outp) + ) # TODO: extend to multiple tuple processing + else: + operator.outputs.append(graph.get_tid(outp)) return operator diff --git a/tico/utils/convert.py b/tico/utils/convert.py index d13551e3..bcb0e55c 100644 --- a/tico/utils/convert.py +++ b/tico/utils/convert.py @@ -20,6 +20,7 @@ from torch.export import export, ExportedProgram from tico.config import CompileConfigBase, get_default_config +from tico.experimental.controlflow.passes.lower_cond import LowerCond from tico.experimental.quantization.passes.fold_quant_ops import FoldQuantOps from tico.experimental.quantization.passes.insert_quantize_on_dtype_mismatch import ( InsertQuantizeOnDtypeMismatch, @@ -79,6 +80,7 @@ from tico.utils.errors import NotYetSupportedError from tico.utils.model import CircleModel from tico.utils.passes import PassManager +from tico.utils.subgraph import get_all_graph_modules from tico.utils.trace_decorators import ( trace_const_diff_on_func, trace_graph_diff_on_func, @@ -154,6 +156,7 @@ def check_unsupported_target(exported_program: ExportedProgram): supported_target = list(get_support_targets()) # Ignore `getitem` since it is no-op for multiple outputs. supported_target.append(operator.getitem) + supported_target.append(torch.ops.higher_order.cond) unsupported = [] for n in exported_program.graph.nodes: if n.op != "call_function": @@ -194,21 +197,23 @@ def convert_exported_module_to_circle( assert isinstance(config, CompileConfigBase) - logger = logging.getLogger(__name__) - logger.debug("Input ExportedProgram (must be core aten)") - logger.debug(exported_program) - - # PRE-EDGE PASSES - # - # Here are the passes that run before to_edge() conversion. - # Let's decompose nodes that are not Aten Canonical, which can't be converted to the edge IR. - decompose_quantize_op = PassManager( - passes=[ - DecomposeFakeQuantize(), - DecomposeFakeQuantizeTensorQParams(), - ] - ) - decompose_quantize_op.run(exported_program) + for graph_module, _ in get_all_graph_modules(exported_program): + logger = logging.getLogger(__name__) + logger.debug("Input ExportedProgram (must be core aten)") + logger.debug(exported_program) + # graph_module.graph.print_tabular() + + # PRE-EDGE PASSES + # + # Here are the passes that run before to_edge() conversion. + # Let's decompose nodes that are not Aten Canonical, which can't be converted to the edge IR. + decompose_quantize_op = PassManager( + passes=[ + DecomposeFakeQuantize(), + DecomposeFakeQuantizeTensorQParams(), + ] + ) + decompose_quantize_op.run(exported_program, graph_module) # This pass should be run before 'RestoreLinear' and after 'decompose_quantize_op'. # TODO run pass regardless of the orders. @@ -222,76 +227,92 @@ def convert_exported_module_to_circle( # CompositeImplicitAutograd and have functional schema are safe to not decompose. exported_program = traced_run_decompositions(exported_program) - # TODO Distinguish legalize and optimize - circle_legalize = PassManager( - passes=[ - FillMetaVal(), - ExtractDtypeKwargsPass(), - RemoveNop(), - ConvertLayoutOpToReshape(), - RestoreLinear(), - ConvertToReLU6(), - DecomposeAddmm(), - DecomposeSliceScatter(), - DecomposeGroupNorm(), - DecomposeBatchNorm(), - DecomposeGroupedConv2d(), - CastATenWhereArgType(), - ConvertRepeatToExpandCopy(), - *RemoveRedundantPermutePasses(), - RemoveRedundantAssertionNodes(), - RemoveRedundantExpand(), - RemoveRedundantSlice(), - FuseRedundantReshapeToMean(), - *RemoveRedundantViewPasses(), - RemoveRedundantToCopy(), - MergeConsecutiveCat(), - CastMixedTypeArgs(preserve_ep_invariant=True), - ConstPropPass(), - SegmentIndexSelectConst(), - LegalizeCausalMaskValue(enabled=config.get("legalize_causal_mask_value")), - ConvertMatmulToLinear( - enable_lhs_const=config.get("convert_lhs_const_mm_to_fc"), - enable_rhs_const=config.get("convert_rhs_const_mm_to_fc"), - enable_single_batch_lhs_const_bmm=config.get( - "convert_single_batch_lhs_const_bmm_to_fc" + for graph_module, _ in get_all_graph_modules(exported_program): + graph = graph_module.graph + + reinterpret_pass = PassManager( + passes=[ + LowerCond(), + ] + ) + reinterpret_pass.run(exported_program, graph_module) + + ConstPropPass().call(exported_program, exported_program.graph_module) + + for graph_module, _ in get_all_graph_modules(exported_program): + graph = graph_module.graph + # TODO Distinguish legalize and optimize + circle_legalize = PassManager( + passes=[ + FillMetaVal(), + ExtractDtypeKwargsPass(), + RemoveNop(), + ConvertLayoutOpToReshape(), + RestoreLinear(), + ConvertToReLU6(), + DecomposeAddmm(), + DecomposeSliceScatter(), + DecomposeGroupNorm(), + DecomposeBatchNorm(), + DecomposeGroupedConv2d(), + CastATenWhereArgType(), + ConvertRepeatToExpandCopy(), + *RemoveRedundantPermutePasses(), + RemoveRedundantAssertionNodes(), + RemoveRedundantExpand(), + RemoveRedundantSlice(), + FuseRedundantReshapeToMean(), + *RemoveRedundantViewPasses(), + RemoveRedundantToCopy(), + MergeConsecutiveCat(), + CastMixedTypeArgs(preserve_ep_invariant=True), + SegmentIndexSelectConst(), + LegalizeCausalMaskValue( + enabled=config.get("legalize_causal_mask_value") ), - ), - LowerToResizeNearestNeighbor(), - LegalizePreDefinedLayoutOperators(), - LowerPow2ToMul(), - ConvertConv1dToConv2d(), - *LowerToSlicePasses(), - FuseLeadingUnsqueezeReshape(), - CastClampMixedTypeArgs(), - ] - ) - circle_legalize.run(exported_program) - - # After this stage, ExportedProgram invariant is broken, i.e., - # graph can have a constant torch.tensor not lifted to a placeholder - circle_legalize = PassManager( - passes=[ - FillMetaVal(), - CastMixedTypeArgs(preserve_ep_invariant=False), - ] - ) - circle_legalize.run(exported_program) - - # TODO Give an option to enable quantiztion to user - enable_quantization = has_quantization_ops(exported_program.graph) - if enable_quantization: - quantize_graph = PassManager( + ConvertMatmulToLinear( + enable_lhs_const=config.get("convert_lhs_const_mm_to_fc"), + enable_rhs_const=config.get("convert_rhs_const_mm_to_fc"), + enable_single_batch_lhs_const_bmm=config.get( + "convert_single_batch_lhs_const_bmm_to_fc" + ), + ), + LowerToResizeNearestNeighbor(), + LegalizePreDefinedLayoutOperators(), + LowerPow2ToMul(), + ConvertConv1dToConv2d(), + *LowerToSlicePasses(), + FuseLeadingUnsqueezeReshape(), + CastClampMixedTypeArgs(), + ] + ) + circle_legalize.run(exported_program, graph_module) + + # After this stage, ExportedProgram invariant is broken, i.e., + # graph can have a constant torch.tensor not lifted to a placeholder + circle_legalize = PassManager( passes=[ - FoldQuantOps(), - RemoveWeightDequantOp(), - PropagateQParamForward(), - PropagateQParamBackward(), - QuantizeBias(), - InsertQuantizeOnDtypeMismatch(), + FillMetaVal(), + CastMixedTypeArgs(preserve_ep_invariant=False), ] ) - quantize_graph.run(exported_program) + circle_legalize.run(exported_program, graph_module) + ConstPropPass().call(exported_program, exported_program.graph_module) + + # TODO Give an option to enable quantiztion to user + enable_quantization = has_quantization_ops(graph) + if enable_quantization: + quantize_graph = PassManager( + passes=[ + FoldQuantOps(), + RemoveWeightDequantOp(), + PropagateQParamForward(), + PropagateQParamBackward(), + QuantizeBias(), + InsertQuantizeOnDtypeMismatch(), + ] + ) + quantize_graph.run(exported_program) check_unsupported_target(exported_program) check_training_ops(exported_program) diff --git a/tico/utils/diff_graph.py b/tico/utils/diff_graph.py index 4a38a4cf..bfb7a05f 100644 --- a/tico/utils/diff_graph.py +++ b/tico/utils/diff_graph.py @@ -150,8 +150,40 @@ def capture(graph: torch.fx.Graph): global graph_captured graph_captured = str(graph) +from tico.utils.subgraph import get_all_graph_modules +all_graphs_captured = {} -@disable_when(LOG_LEVEL > DEBUG) +@disable_when(LOG_LEVEL > LOGGER_THRESHOLD) +def capture_all(ep: torch.export.ExportedProgram): + assert isinstance(ep, torch.export.ExportedProgram) + global all_graphs_captured + for graph, name in get_all_graph_modules(ep): + all_graphs_captured[name] = str(graph) + +@disable_when(LOG_LEVEL > LOGGER_THRESHOLD) +def log_all(ep: torch.export.ExportedProgram, title: str, recapture: bool): + assert isinstance(ep, torch.export.ExportedProgram) + global all_graphs_captured + all_graphs_now = {} + logger = getLogger(__name__) + for graph, name in get_all_graph_modules(ep): + all_graphs_now[name] = str(graph) + + for name, graph in all_graphs_now.items(): + graph_captured = all_graphs_captured[name] + diff = strdiff(f"{graph_captured}\n", f"{graph}\n") + prefix = f"[{title}]" if title else "" + if len(diff) > 0: + logger.debug(f"{prefix} Graph({ (name if name != '' else 'root') }) is changed.") + logger.debug(f"\n{diff}") + + if recapture: + all_graphs_captured[name] = deepcopy(graph) + else: + all_graphs_captured[name] = None # reset + + +@disable_when(LOG_LEVEL > LOGGER_THRESHOLD) def log(graph: torch.fx.Graph, title: str, recapture: bool): """ Capture the end-point graph for graph-diff. diff --git a/tico/utils/graph.py b/tico/utils/graph.py index f0905334..a324fef2 100644 --- a/tico/utils/graph.py +++ b/tico/utils/graph.py @@ -67,11 +67,9 @@ def get_torch_buffer_value(node: torch.fx.Node, ep: ExportedProgram): return named_buf[buf_name] -def get_first_user_input(exported_program: ExportedProgram) -> Optional[torch.fx.Node]: +def get_first_user_input(exported_program: ExportedProgram, graph: torch.fx.Graph) -> Optional[torch.fx.Node]: """Returns the first user input node in the graph.""" first_user_input: Optional[torch.fx.Node] = None - graph_module = exported_program.graph_module - graph: torch.fx.Graph = graph_module.graph for node in graph.nodes: if ( node.op == "placeholder" diff --git a/tico/utils/passes.py b/tico/utils/passes.py index 0a59d9bf..cee92c75 100644 --- a/tico/utils/passes.py +++ b/tico/utils/passes.py @@ -31,7 +31,7 @@ class PassBase(ABC): """ @abstractmethod - def call(self, exported_program: ExportedProgram) -> PassResult: + def call(self, exported_program: ExportedProgram, gm) -> PassResult: pass @@ -51,7 +51,7 @@ def __init__( self.passes: List[PassBase] = passes self.strategy: PassStrategy = strategy - def run(self, exported_program: ExportedProgram): + def run(self, exported_program: ExportedProgram, graph_module): MAXIMUM_STEP_COUNT = 1000 step = 0 while True: @@ -59,10 +59,10 @@ def run(self, exported_program: ExportedProgram): for _pass in self.passes: # Automatically update the signatures of the input and output. # https://github.com/pytorch/executorch/issues/4013#issuecomment-2187161844 - with exported_program.graph_module._set_replace_hook( + with graph_module._set_replace_hook( exported_program.graph_signature.get_replace_hook() ): - result = _pass.call(exported_program) + result = _pass.call(exported_program, graph_module) modified = modified or result.modified if modified and self.strategy == PassStrategy.RESTART: break diff --git a/tico/utils/register_custom_op.py b/tico/utils/register_custom_op.py index 095cb6a5..023de45d 100644 --- a/tico/utils/register_custom_op.py +++ b/tico/utils/register_custom_op.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List, Optional, Union import torch from torch._subclasses.fake_tensor import FakeTensor @@ -20,6 +20,41 @@ from tico.utils.mx.mx_ops import _quantize_mx + +def CircleIf(): + @custom_op("circle_custom::if_", mutates_args=()) + def if_( + pred: torch.Tensor, + then_graph: torch.Tensor, + else_graph: torch.Tensor, + then_graph_idx: int, + else_graph_idx: int, + if_args: List[torch.Tensor], + ) -> torch.Tensor: + if pred: + result = then_graph(*if_args) + assert len(result) == 1 # TODO: Support tuple of result + return result[0] + else: + result = else_graph(*if_args) + assert len(result) == 1 # TODO: Support tuple of result + return result[0] + + @register_fake("circle_custom::if_") + def _( + pred: torch.Tensor, + then_graph: torch.Tensor, + else_graph: torch.Tensor, + then_graph_idx: int, + else_graph_idx: int, + if_args: List[torch.Tensor], + ): + result = then_graph(*if_args) + assert len(result) == 1 # TODO: Support tuple of result + + return result[0] + + # Note that an operator assumes input tensor has NHWC format. def CircleResizeNearestNeighbor(): @custom_op("circle_custom::resize_nearest_neighbor", mutates_args=()) @@ -740,3 +775,4 @@ def RegisterOps(): CircleInstanceNorm() CircleQuantizeMX() CircleRMSNorm() + CircleIf() diff --git a/tico/utils/signature.py b/tico/utils/signature.py index acabaf89..a69c0796 100644 --- a/tico/utils/signature.py +++ b/tico/utils/signature.py @@ -117,8 +117,8 @@ def load(circle_path: str) -> bytes: def __init__(self, circle_binary): model = circle.Model.Model.GetRootAsModel(circle_binary, 0) - assert model.SubgraphsLength() == 1, "Only one subgraph is supported" + # Assumption; Circle model's user IO signature is defined in the first subgraph graph = model.Subgraphs(0) tensors = [graph.Tensors(graph.Inputs(o)) for o in range(graph.InputsLength())] diff --git a/tico/utils/subgraph.py b/tico/utils/subgraph.py new file mode 100644 index 00000000..ae33066c --- /dev/null +++ b/tico/utils/subgraph.py @@ -0,0 +1,24 @@ +from typing import Iterator + +import torch +from torch.export import ExportedProgram + + +def get_all_graph_modules( + ep: ExportedProgram, subgraph_only: bool = False +) -> Iterator[tuple[torch.fx.GraphModule, str]]: + """ + Get all graph modules and its name + """ + if not subgraph_only: + yield ep.graph_module, "" # root has no name + + # yield subgraphs + for node in ep.graph.nodes: + if node.op == "get_attr": + graph_module = getattr(node.graph.owning_module, node.target) + + # TODO: Enable recursion (n-depth) + if isinstance(graph_module, torch.fx.graph_module.GraphModule): + assert hasattr(graph_module, "meta") + yield graph_module, getattr(node, "name") diff --git a/tico/utils/trace_decorators.py b/tico/utils/trace_decorators.py index c47c180a..ea1577c2 100644 --- a/tico/utils/trace_decorators.py +++ b/tico/utils/trace_decorators.py @@ -17,7 +17,7 @@ import torch from torch.export import ExportedProgram -from tico.utils.diff_graph import capture, capture_const, log, log_const +from tico.utils.diff_graph import capture, capture_const, log, log_const, capture_all, log_all from tico.utils.passes import PassBase @@ -28,13 +28,11 @@ def trace_const_diff_on_pass(cls): def _call_traced(fn): @wraps(fn) - def wrapped(*args): - _, exported_program = args + def wrapped(self, exported_program, graph_module): assert isinstance(exported_program, ExportedProgram) - graph_module = exported_program.graph_module assert isinstance(graph_module, torch.fx.GraphModule), type(graph_module) capture_const(exported_program) - ret = fn(*args) + ret = fn(self, exported_program, graph_module) log_const(exported_program, title=str(cls.__name__), recapture=False) return ret @@ -54,14 +52,12 @@ def trace_graph_diff_on_pass(cls): def _call_traced(fn): @wraps(fn) - def wrapped(*args): - _, exported_program = args + def wrapped(self, exported_program, graph_module): assert isinstance(exported_program, ExportedProgram) - graph_module = exported_program.graph_module assert isinstance(graph_module, torch.fx.GraphModule), type(graph_module) - capture(graph_module.graph) - ret = fn(*args) - log(graph_module.graph, title=str(cls.__name__), recapture=False) + capture_all(exported_program) + ret = fn(self, exported_program, graph_module) + log_all(exported_program, title=str(cls.__name__), recapture=False) return ret return wrapped diff --git a/tico/utils/utils.py b/tico/utils/utils.py index 3848e9b2..da88d9c4 100644 --- a/tico/utils/utils.py +++ b/tico/utils/utils.py @@ -185,13 +185,23 @@ def set_new_meta_val(node: torch.fx.node.Node): """ assert isinstance(node, torch.fx.node.Node) + def _get_meta_val(node): + assert hasattr( + node, "meta" + ), f"'node' has no attribute named 'meta' (node: {node})" + assert ( + "val" in node.meta + ), f"val key not in node.meta (node: {node}, meta: {node.meta})" + return node.meta["val"] + # `node.target()` needs only `Tensor` for its arguments. # Therefore, let's retrieve `FakeTensor` if it is `torch.fx.Node`. args, kwargs = pytree.tree_map_only( torch.fx.Node, - lambda n: n.meta["val"], + _get_meta_val, (node.args, node.kwargs), ) + new_val = node.target(*args, **kwargs) # type: ignore[operator] node.meta["val"] = new_val diff --git a/tico/utils/validate_args_kwargs.py b/tico/utils/validate_args_kwargs.py index 8a5feb21..646b1246 100644 --- a/tico/utils/validate_args_kwargs.py +++ b/tico/utils/validate_args_kwargs.py @@ -206,6 +206,36 @@ class CloneArgs: memory_format: Optional[torch.memory_format] = None +@enforce_type +@dataclass +class CircleIfArgs: + """ + Why carry both `graph` and `graph_idx`? + [1] `graph` is required to get the proper meta value while processing FakeTensor for torch nodes. Plus, carrying this information until `serialize` leaves the graph until that process, otherwise the graph will be cleaned-up by dead code elimination of graph_module. + [2] `graph_idx` is required to map to circle ir. + """ + + pred: torch.fx.Node + then_graph: torch.fx.Node + else_graph: torch.fx.Node + then_graph_idx: int + else_graph_idx: int + if_args: torch.fx.immutable_collections.immutable_list + + +@enforce_type +@dataclass +class CondArgs: + """ + # This is not aten operator but `torch.ops.higher_order_op.cond` + """ + + condition: torch.fx.Node + true_graph: torch.fx.Node + false_graph: torch.fx.Node + cond_args: torch.fx.immutable_collections.immutable_list + + @enforce_type @dataclass class ConstantPadNdArgs: