Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions test/modules/op/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@


class SimpleCopy(TestModuleBase):
"""
Test case: Same shape copy (should be folded away by ConvertCopyToReshape pass)
5x5 -> 5x5
"""

def __init__(self):
super().__init__()

Expand All @@ -30,6 +35,11 @@ def get_example_inputs(self):


class SimpleCopyWithBroadcastTo(TestModuleBase):
"""
Test case: Broadcast from 1x5 to 5x5
This tests the expand + reshape path in ConvertCopyToReshape pass
"""

def __init__(self):
super().__init__()

Expand All @@ -39,3 +49,116 @@ def forward(self, dst, src):

def get_example_inputs(self):
return (torch.randn(5, 5), torch.randn(1, 5)), {}


class CopyWithScalarBroadcast(TestModuleBase):
"""
Test case: Broadcast from 1x1 to 3x3 (scalar-like broadcast)
"""

def __init__(self):
super().__init__()

def forward(self, dst, src):
dst.copy_(src)
return dst

def get_example_inputs(self):
return (torch.randn(3, 3), torch.randn(1, 1)), {}


class CopyWithRowBroadcast(TestModuleBase):
"""
Test case: Broadcast from 1x4 to 3x4 (row broadcast)
"""

def __init__(self):
super().__init__()

def forward(self, dst, src):
dst.copy_(src)
return dst

def get_example_inputs(self):
return (torch.randn(3, 4), torch.randn(1, 4)), {}


class CopyWithColumnBroadcast(TestModuleBase):
"""
Test case: Broadcast from 3x1 to 3x4 (column broadcast)
"""

def __init__(self):
super().__init__()

def forward(self, dst, src):
dst.copy_(src)
return dst

def get_example_inputs(self):
return (torch.randn(3, 4), torch.randn(3, 1)), {}


class CopyWith3DTensor(TestModuleBase):
"""
Test case: 3D tensor copy with same shape
"""

def __init__(self):
super().__init__()

def forward(self, dst, src):
dst.copy_(src)
return dst

def get_example_inputs(self):
return (torch.randn(2, 3, 4), torch.randn(2, 3, 4)), {}


class CopyWith3DBroadcast(TestModuleBase):
"""
Test case: 3D tensor broadcast from 1x3x4 to 2x3x4
"""

def __init__(self):
super().__init__()

def forward(self, dst, src):
dst.copy_(src)
return dst

def get_example_inputs(self):
return (torch.randn(2, 3, 4), torch.randn(1, 3, 4)), {}


class CopyWithMultiDimBroadcast(TestModuleBase):
"""
Test case: Multi-dimensional broadcast from 1x1x4 to 2x3x4
"""

def __init__(self):
super().__init__()

def forward(self, dst, src):
dst.copy_(src)
return dst

def get_example_inputs(self):
return (torch.randn(2, 3, 4), torch.randn(1, 1, 4)), {}


class CopyWith4DTensor(TestModuleBase):
"""
Test case: 4D tensor copy (batch, channel, height, width)
Broadcast from 1x3x1x1 to 2x3x4x4
"""

def __init__(self):
super().__init__()

def forward(self, dst, src):
dst.copy_(src)
return dst

def get_example_inputs(self):
return (torch.randn(2, 3, 4, 4), torch.randn(1, 3, 1, 1)), {}
24 changes: 24 additions & 0 deletions test/modules/op/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
# limitations under the License.

import torch
from torch.export import Dim

from test.modules.base import TestModuleBase

from test.utils import tag

# Note. tests that call `aten.reshape` or `torch.reshape` are exporeted to aten graph that has `aten.view` instead of `aten.reshape`.


Expand Down Expand Up @@ -65,3 +68,24 @@ def forward(self, x):

def get_example_inputs(self):
return (torch.randn(2, 4, 5),), {}


@tag.use_onert
class ReshapeDynamicShape(TestModuleBase):
def __init__(self):
super().__init__()

def forward(self, x):
# Reshape to (batch, -1) where batch is dynamic
return x.reshape(x.shape[0], -1)

def get_example_inputs(self):
return (torch.randn(4, 4, 4),), {}

def get_dynamic_shapes(self):
batch = Dim("batch", min=2, max=128)
dynamic_shapes = {
"x": {0: batch},
}
return dynamic_shapes

80 changes: 80 additions & 0 deletions test/modules/op/sym_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch.export import Dim

from test.modules.base import TestModuleBase
from test.utils import tag


@tag.use_onert
class SymSizeSimple(TestModuleBase):
"""
Simplest test case for sym_size.int generation.
Just returns the batch size (first dimension).
"""

def forward(self, x):
# Accessing x.shape[0] on a dynamic dimension creates sym_size.int
return x.shape[0]

def get_example_inputs(self):
return (torch.randn(2, 4, 4),), {}

def get_dynamic_shapes(self):
batch = Dim("batch", min=1, max=128)
return {"x": {0: batch}}


@tag.use_onert
@tag.skip(reason="Not yet supported")
class SymSizeInReshape(TestModuleBase):
"""
Test case using sym_size.int in a reshape operation.
This is a common pattern in dynamic batch size models.
"""

def forward(self, x):
batch_size = x.shape[0]
# Use the dynamic batch size in reshape
return x.reshape(batch_size, -1)

def get_example_inputs(self):
return (torch.randn(2, 4, 4),), {}

def get_dynamic_shapes(self):
batch = Dim("batch", min=1, max=128)
return {"x": {0: batch}}


@tag.use_onert
class SymSizeMultipleDims(TestModuleBase):
"""
Test case using multiple dynamic dimensions.
"""

def forward(self, x):
h = x.shape[1]
w = x.shape[2]
# Reshape using multiple dynamic dimensions
return x.reshape(-1, h * w)

def get_example_inputs(self):
return (torch.randn(2, 4, 4),), {}

def get_dynamic_shapes(self):
batch = Dim("batch", min=1, max=128)
return {"x": {0: batch}}

10 changes: 8 additions & 2 deletions test/pt2_to_circle_test/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from pathlib import Path
from typing import Optional

import torch

from tico.config.base import CompileConfigBase
from tico.utils.signature import ModelInputSpec

Expand Down Expand Up @@ -204,8 +206,12 @@ def has_symbolic_input(circle_model_path: str) -> bool:
forward_kwargs=deepcopy(self.forward_kwargs),
runtime="onert",
)
torch_shape = torch_result[0].shape
circle_result[0] = circle_result[0].reshape(torch_shape)
for idx, (tr, cr) in enumerate(zip(torch_result, circle_result)):
if isinstance(tr, torch.Tensor):
circle_result[idx] = circle_result[idx].reshape(tr.shape)
else:
# tr is scalar
torch_result[idx] = torch.tensor([tr], dtype=torch.int32) # TODO Fix properly
else:
circle_result = infer_circle(
circle_model_path,
Expand Down
2 changes: 1 addition & 1 deletion test/pt2_to_circle_test/test_pt2_to_circle.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def validate_result(
else:
raise TypeError("Expected result must be a tensor or scalar value.")

# Check both dypte and value mismatch
# Check both dtype and value mismatch
torch.testing.assert_close(
actual=circle_tensor,
expected=expected_tensor,
Expand Down
42 changes: 41 additions & 1 deletion test/utils/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import tico.utils
import tico.utils.model
from tico.utils.signature import ModelInputSpec

import torch

def infer_with_circle_interpreter(
circle_path: str,
Expand Down Expand Up @@ -51,6 +51,13 @@ def infer_with_circle_interpreter(

return circle_result

from tico.serialize.circle_mapping import (
extract_circle_dtype,
extract_circle_shape,
str_to_circle_dtype,
to_circle_dtype,
to_circle_shape,
)

def infer_with_onert(
circle_path: str,
Expand Down Expand Up @@ -83,6 +90,39 @@ def infer_with_onert(
inputs = ispec.bind(forward_args, forward_kwargs, check=True)

session_float = infer.session(circle_path)

# Handle dynamic shapes: onert cannot execute models with unspecified dimensions
# Check if any input has dynamic dimensions (indicated by -1)
input_tensorinfos = session_float.get_inputs_tensorinfo()
has_dynamic_shapes = any(
-1 in info.dims for info in input_tensorinfos
)

if has_dynamic_shapes:
# Set concrete input shapes based on the actual input data
from onert.native.libnnfw_api_pybind import tensorinfo

for idx, (info, input_data) in enumerate(zip(input_tensorinfos, inputs)):
if -1 in info.dims:
# Create new tensorinfo with concrete shape from input data
new_info = tensorinfo()
new_info.rank = len(input_data.shape)
new_info.dims = list(input_data.shape)

assert input_data.dtype in [torch.float32, torch.float]
new_info.dtype = "float32"

try:
session_float.session.set_input_tensorinfo(idx, new_info)
except Exception as e:
# If setting tensorinfo fails, try to continue anyway
# Some versions of onert might handle this differently
import warnings
warnings.warn(
f"Failed to set input tensorinfo for input {idx}: {e}. "
f"Attempting inference anyway."
)

output = session_float.infer(inputs)

return output
30 changes: 25 additions & 5 deletions tico/interpreter/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def infer(circle_binary: bytes, *args: Any, **kwargs: Any) -> Any:
graph.Tensors(graph.Inputs(o)) for o in range(graph.InputsLength())
]
model_input_shapes_np = [t.ShapeAsNumpy() for t in model_input_tensors]
model_input_shape_sigs_np = [t.ShapeSignatureAsNumpy() for t in model_input_tensors]
model_input_types_cm = [t.Type() for t in model_input_tensors]

# Check if given inputs' dtype and shape from users match the inputs' from model binary.
Expand All @@ -77,11 +78,30 @@ def infer(circle_binary: bytes, *args: Any, **kwargs: Any) -> Any:
f"Mismatch input length: input({len(user_inputs)}) != circle model({len(model_input_shapes_np)})"
)
for input_idx, user_input in enumerate(user_inputs):
# Shape
if list(user_input.shape) != list(model_input_shapes_np[input_idx]):
raise RuntimeError(
f"Mismatch input {input_idx} shape : input({user_input.shape}) != circle model({model_input_shapes_np[input_idx]})"
)
# Shape - check against shape_signature if available (for dynamic shapes)
model_shape = model_input_shapes_np[input_idx]
model_shape_sig = model_input_shape_sigs_np[input_idx]
user_shape = list(user_input.shape)

# If shape_signature exists, validate against it (supports dynamic dimensions)
if model_shape_sig is not None and len(model_shape_sig) > 0:
if len(user_shape) != len(model_shape_sig):
raise RuntimeError(
f"Mismatch input {input_idx} rank: input({len(user_shape)}) != circle model({len(model_shape_sig)})"
)
for dim_idx, (user_dim, sig_dim) in enumerate(zip(user_shape, model_shape_sig)):
# -1 in shape_signature means dynamic dimension, accept any value
if sig_dim != -1 and user_dim != sig_dim:
raise RuntimeError(
f"Mismatch input {input_idx} shape at dimension {dim_idx}: input({user_dim}) != circle model({sig_dim})"
)
else:
# No shape_signature, validate against static shape
if user_shape != list(model_shape):
raise RuntimeError(
f"Mismatch input {input_idx} shape : input({user_input.shape}) != circle model({model_shape})"
)

# Data type
user_input_type_cm = to_circle_dtype(user_input.dtype)
if user_input_type_cm != model_input_types_cm[input_idx]:
Expand Down
Loading