Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions example_dynamic_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
from torch.export import Dim
class SimpleCopyWithBroadcastToDynamicShape(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, dst, src):
dst.copy_(src)
return dst

def get_example_inputs(self):
return (torch.randn(5, 5), torch.randn(1, 5)), {}

def get_dynamic_shapes(self):
dim = Dim("dim", min=1, max=128)
dynamic_shapes = {
"dst": {0: dim},
"src": {},
}
return dynamic_shapes

model = SimpleCopyWithBroadcastToDynamicShape()

ep = torch.export.export(
model,
args=(torch.randn(5, 5), torch.randn(1, 5)),
dynamic_shapes={"dst": {0: Dim("dim", min=1, max=128)}, "src": {}}
)

breakpoint()
22 changes: 22 additions & 0 deletions test/modules/op/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import torch

from test.modules.base import TestModuleBase
from torch.export import Dim
from test.utils.tag import use_onert


class SimpleCopy(TestModuleBase):
Expand All @@ -39,3 +41,23 @@ def forward(self, dst, src):

def get_example_inputs(self):
return (torch.randn(5, 5), torch.randn(1, 5)), {}

@use_onert
class SimpleCopyWithBroadcastToDynamicShape(TestModuleBase):
def __init__(self):
super().__init__()

def forward(self, dst, src):
dst.copy_(src)
return dst

def get_example_inputs(self):
return (torch.randn(5, 5), torch.randn(1, 5)), {}

def get_dynamic_shapes(self):
dim = Dim("dim", min=1, max=128)
dynamic_shapes = {
"dst": {0: dim},
"src": {},
}
return dynamic_shapes
1 change: 1 addition & 0 deletions test/pt2_to_circle_test/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _run(
assert (
self.use_onert
), "Dynamic shapes are only supported with onert runtime. Please set 'use_onert' to True."
dynamic_shapes = self.nnmodule.get_dynamic_shapes()

compile_config: Optional[CompileConfigBase] = None
if hasattr(self.nnmodule, "get_compile_config"):
Expand Down
3 changes: 2 additions & 1 deletion tico/serialize/circle_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def build_circle(ep: ExportedProgram) -> bytes:

graph.add_output(user_output)
logger.debug(f"Registered output: {user_output}")

# Export operators
logger.debug("---------------Export operators--------------")
op_codes: Dict[OpCode, int] = {}
Expand Down Expand Up @@ -146,6 +146,7 @@ def _export_tensors(graph: CircleSubgraph, ep: ExportedProgram) -> None:
raise RuntimeError(
f"Only support dense tensors (node layout: {node_val.layout})"
)

graph.add_tensor_from_node(node)
logger.debug(f"call_function: {node.name} tensor exported.")

Expand Down
88 changes: 41 additions & 47 deletions tico/serialize/operators/op_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from typing import Dict, List, Optional, TYPE_CHECKING, Union

from copy import deepcopy
if TYPE_CHECKING:
import torch._ops
import torch.fx
Expand Down Expand Up @@ -59,9 +59,8 @@ def check_to_do_broadcast(
src: List[int],
src_sig: Optional[List[int]],
) -> bool:
assert dst_sig is None
assert src_sig is None
return dst != src
exactly_same = (dst_sig == src_sig) and (dst == src)
return not exactly_same

def define_broadcast_to_node(
self,
Expand Down Expand Up @@ -98,54 +97,43 @@ def define_node(
self,
node: torch.fx.Node,
) -> circle.Operator.OperatorT:
if len(node.args) == 3:
raise NotYetSupportedError("'non_blocking' is not supported yet.")

assert len(node.args) == 2, len(node.args)

args = CopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
args = CopyArgs(*node.args, **node.kwargs)
dst = args.dst
src = args.src

# To connect 'dst' to Reshape node in the graph, 'dst' must be converted to Shape op.
dst_tensor: circle.Tensor.TensorT = self.graph.get_tensor(dst)
dst_shape: List[int] = dst_tensor.shape
dst_shape_signature: Optional[List[int]] = dst_tensor.shapeSignature

if dst_shape_signature is not None:
# TODO: support dynamic shape
raise NotYetSupportedError("Dynamic shape is not supported yet.")

dst_shape_tensor = torch.as_tensor(dst_shape, dtype=torch.int32)

dst_shape_shape = [len(dst_shape)]
dst_name: str = dst.name

shape_output = self.graph.add_tensor_from_scratch(
prefix=f"{dst_name}_shape_output",
shape=dst_shape_shape,
shape_signature=None,
dtype=circle.TensorType.TensorType.INT32,
source_node=node,
)

shape_operator = self.define_shape_node([dst], [shape_output])
self.graph.add_operator(shape_operator)

src_tensor: circle.Tensor.TensorT = self.graph.get_tensor(src)
src_shape: List[int] = src_tensor.shape
src_shape_signature: Optional[List[int]] = src_tensor.shapeSignature

if src_shape_signature is not None:
# TODO: support dynamic shape
raise NotYetSupportedError("Dynamic shape is not supported yet.")

dst_tensor: circle.Tensor.TensorT = self.graph.get_tensor(dst)
dst_shape: List[int] = dst_tensor.shape
dst_shape_signature: Optional[List[int]] = dst_tensor.shapeSignature
# The src tensor must be broadcastable with the dst tensor.
do_broadcast = self.check_to_do_broadcast(
dst_shape, dst_shape_signature, src_shape, src_shape_signature
)
if do_broadcast:
# create braodcastTo output tensor

if not do_broadcast:
# To connect 'dst' to Reshape node in the graph, 'dst' must be converted to Shape op.
dst_shape_tensor = torch.as_tensor(dst_shape, dtype=torch.int32)

dst_shape_shape = [len(dst_shape)]
dst_name: str = dst.name

shape_output = self.graph.add_tensor_from_scratch(
prefix=f"{dst_name}_shape_output",
shape=dst_shape_shape,
shape_signature=None,
dtype=circle.TensorType.TensorType.INT32,
source_node=node,
)

shape_operator = self.define_shape_node([dst], [shape_output])
self.graph.add_operator(shape_operator)
inputs = [src, shape_output]
else:
# create broadcastTo output tensor
src_name: str = src.name
src_type: int = src_tensor.type

Expand All @@ -159,15 +147,21 @@ def define_node(
)
)

dst_shape_merged = deepcopy(dst_shape)
if dst_shape_signature is not None:
for idx, sig in enumerate(dst_shape_signature):
if sig == -1:
dst_shape_merged[idx] = -1

dst_shape_tensor = torch.as_tensor(dst_shape_merged, dtype=torch.int32)
broadcast_to_operator: circle.Operator.OperatorT = (
self.define_broadcast_to_node(
[src_tensor, dst_shape_tensor], [broadcast_to_output]
[src_tensor, dst_shape_tensor], [node]
)
)
self.graph.add_operator(broadcast_to_operator)
inputs: List = [broadcast_to_output, shape_output]
else:
inputs = [src, shape_output]
# self.graph.add_operator(broadcast_to_operator)
# inputs: List = [broadcast_to_output, shape_output]
return broadcast_to_operator

outputs = [node]
op_index = get_op_index(
Expand Down
10 changes: 9 additions & 1 deletion tico/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def convert_exported_module_to_circle(
logger.debug("Input ExportedProgram (must be core aten)")
logger.debug(exported_program)

nodes = list(exported_program.graph.nodes)
print(nodes[0].meta['val'])
# PRE-EDGE PASSES
#
# Here are the passes that run before to_edge() conversion.
Expand All @@ -221,6 +223,8 @@ def convert_exported_module_to_circle(
# CompositeImplicitAutograd and have functional schema are safe to not decompose.
exported_program = traced_run_decompositions(exported_program)

nodes = list(exported_program.graph.nodes)
print(nodes[0].meta['val'])
# TODO Distinguish legalize and optimize
circle_legalize = PassManager(
passes=[
Expand Down Expand Up @@ -259,7 +263,9 @@ def convert_exported_module_to_circle(
]
)
circle_legalize.run(exported_program)


nodes = list(exported_program.graph.nodes)
print(nodes[0].meta['val'])
# After this stage, ExportedProgram invariant is broken, i.e.,
# graph can have a constant torch.tensor not lifted to a placeholder
circle_legalize = PassManager(
Expand All @@ -270,6 +276,8 @@ def convert_exported_module_to_circle(
)
circle_legalize.run(exported_program)

nodes = list(exported_program.graph.nodes)
print(nodes[0].meta['val'])
# TODO Give an option to enable quantiztion to user
enable_quantization = has_quantization_ops(exported_program.graph)
if enable_quantization:
Expand Down
6 changes: 6 additions & 0 deletions tico/utils/validate_args_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,12 @@ class CopyArgs:

dst: torch.fx.Node
src: torch.fx.Node
non_blocking: bool = False

def __post_init__(self):
if self.non_blocking is True:
raise NotImplementedError("non_blocking option is not supported yet.")



@enforce_type
Expand Down
Loading