Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/modules/op/dyn_shapes/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
It can be accessed as a module as `test.modules.op.dyn_shapes*`

## How to test models with inputs which have dynamic shapes?

The folder contains tests for single-op models that have dynamic inputs.
Such test requires adding additional method `get_input_dynamic_shapes` to a test class inheriting from `nn.Module`.
The format of value returned by `get_input_dynamic_shapes` should match an `dynamic_shapes` argument of [torch.export](https://pytorch.org/docs/stable/export.html) function.


### An example:
```py
from torch.export import Dim

class TwoInputsDynSimpleAdd(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
z = x + y
return z

def get_example_inputs(self):
return (torch.ones(1, 2, 3), torch.ones(1, 2, 3))

def get_input_dynamic_shapes(self):
return (1, Dim("d2"), Dim("d3")), (1, Dim("d2"), Dim("d3"))
```
50 changes: 50 additions & 0 deletions test/modules/op/dyn_shapes/add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
from torch.export import Dim

from test.utils import tag


class SingleInputDynSimpleAdd(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
z = x + y
return z

def get_example_inputs(self):
return (torch.ones(4, 5, 6), torch.ones(1, 1, 1))

def get_input_dynamic_shapes(self):
return (4, Dim("d2"), Dim("d3")), (1, 1, 1)


class TwoInputsDynSimpleAdd(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
z = x + y
return z

def get_example_inputs(self):
return (torch.ones(1, 2, 3), torch.ones(1, 2, 3))

def get_input_dynamic_shapes(self):
return (1, Dim("d2"), Dim("d3")), (1, Dim("d2"), Dim("d3"))


@tag.test_negative(expected_err=f"Failed running call_function")
class DynSimpleAddNotMatchedShape(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
z = x + y
return z

def get_example_inputs(self):
return (torch.ones(1, 2, 3), torch.ones(1, 4, 3))

def get_input_dynamic_shapes(self):
return (1, Dim("d"), 3), (1, Dim("d"), 3)
2 changes: 1 addition & 1 deletion test/pt2_to_circle_test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def load_tests(loader, standard_tests, pattern):

# Add test files to be found by `unittest`
# WHY? Not to include other files by mistake and to make it clear which files are being tested
for testfile in ["test_net.py", "test_op.py"]:
for testfile in ["test_net.py", "test_op.py", "test_dyn_op.py"]:
package_tests = loader.discover(start_dir=this_dir, pattern=testfile)
standard_tests.addTests(package_tests)

Expand Down
18 changes: 15 additions & 3 deletions test/pt2_to_circle_test/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
convert_pt2_to_circle,
infer_circle,
infer_nnmodule,
resize_circle,
validate_result,
verify_circle,
)
Expand Down Expand Up @@ -85,6 +86,7 @@ def _run(self, without_pt2=False):
os.makedirs(os.path.dirname(test_prefix), exist_ok=True)

circle_model_path = str(test_prefix) + ".circle"
resized_circle_model_path = str(test_prefix) + ".resized.circle"
opt_circle_model_path = str(test_prefix) + ".opt.circle"
pt2_model_path = str(test_prefix) + ".pt2"

Expand All @@ -96,13 +98,23 @@ def _run(self, without_pt2=False):
)
else:
# torch.nn.Module --> ExportedProgram ----------------------------------------> circle
convert_nnmodule_to_pt2(self.nnmodule, self.example_inputs, pt2_model_path)
convert_nnmodule_to_pt2(
self.nnmodule, self.example_inputs, pt2_model_path, self.dynamic_shapes
)
convert_pt2_to_circle(pt2_model_path, circle_model_path)

verify_circle(circle_model_path, opt_circle_model_path)
if self.dynamic_shapes:
resize_circle(
circle_model_path, resized_circle_model_path, self.example_inputs
)

verify_circle(
resized_circle_model_path if self.dynamic_shapes else circle_model_path,
opt_circle_model_path,
)

torch_result = infer_nnmodule(self.nnmodule, self.example_inputs)
circle_result = infer_circle(circle_model_path, self.example_inputs)
circle_result = infer_circle(opt_circle_model_path, self.example_inputs)
validate_result(torch_result, circle_result, **self.tolerance)


Expand Down
21 changes: 21 additions & 0 deletions test/pt2_to_circle_test/test_dyn_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from test.pt2_to_circle_test.builder import NormalTestDictBuilder
from test.utils.helper import declare_unittests

# NOTE Thie file's name must start with `test_` to be found by unittest


declare_unittests(globals(), "test.modules.op.dyn_shapes", NormalTestDictBuilder)
43 changes: 39 additions & 4 deletions test/pt2_to_circle_test/test_pt2_to_circle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import subprocess
from functools import wraps
from pathlib import Path
from typing import List, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union

if TYPE_CHECKING:
import numpy as np
Expand All @@ -26,14 +26,16 @@
import tico.utils.model
import torch
from tico.utils.convert import convert_exported_module_to_circle
from tico.utils.utils import SuppressWarning
from tico.utils.utils import run_bash_cmd, SuppressWarning
from torch.export import export
from torch.utils import _pytree as pytree

# TODO Move this to utils or helper

__test_dir = Path(os.path.dirname(os.path.abspath(__file__))) / "artifacts"
__circle2circle_path = "/usr/share/one/bin/circle2circle"
__circle_resizer_path = "/usr/share/one/bin/circle-resizer"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Waiting for Samsung/ONE#14727



# Create empty test directories
if not os.path.exists(__test_dir):
Expand Down Expand Up @@ -66,9 +68,22 @@ def get_args_kwargs(example_inputs: tuple):
return example_inputs, {}


def extract_shapes_from_input_tensors(tensors: tuple[torch.Tensor]) -> str:
shapes = []
for tensor in tensors:
shape = []
for dim in tensor.size():
shape.append(str(dim))
shapes.append("[" + ",".join(shape) + "]")
return ",".join(shapes)


@print_name_on_exception
def convert_nnmodule_to_pt2(
model: torch.nn.Module, example_inputs: tuple, pt2_model_path: str
model: torch.nn.Module,
example_inputs: tuple,
pt2_model_path: str,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
):
# Create .pt2 model
with torch.no_grad(), SuppressWarning(UserWarning, ".*quantize_per_tensor"):
Expand All @@ -77,7 +92,9 @@ def convert_nnmodule_to_pt2(
# UserWarning: At pre-dispatch tracing, we assume that any custom op marked with
# CompositeImplicitAutograd and have functional schema are safe to not decompose.
_args, _kwargs = get_args_kwargs(example_inputs)
exported = export(model.eval(), args=_args, kwargs=_kwargs)
exported = export(
model.eval(), args=_args, kwargs=_kwargs, dynamic_shapes=dynamic_shapes
)
torch.export.save(exported, pt2_model_path)


Expand Down Expand Up @@ -148,6 +165,24 @@ def infer_nnmodule(model: torch.nn.Module, example_inputs: tuple):
return torch_result


@print_name_on_exception
def resize_circle(
circle_model_path: str,
resized_circle_model_str: str,
example_inputs: tuple[torch.Tensor],
):
cmd = [
__circle_resizer_path,
"--input_path",
circle_model_path,
"--output_path",
resized_circle_model_str,
"--input_shapes",
extract_shapes_from_input_tensors(example_inputs),
]
run_bash_cmd(cmd)


@print_name_on_exception
def infer_circle(circle_path: str, example_inputs: tuple):
circle_model = tico.utils.model.CircleModel.load(circle_path)
Expand Down
4 changes: 4 additions & 0 deletions test/utils/base_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def __init__(self, test_name: str, nnmodule: torch.nn.Module):
assert hasattr(nnmodule, "get_example_inputs")
assert isinstance(nnmodule.get_example_inputs(), tuple) # type: ignore[operator]

self.dynamic_shapes = None
if hasattr(nnmodule, "get_input_dynamic_shapes"):
self.dynamic_shapes = nnmodule.get_input_dynamic_shapes() # type: ignore[operator]

self.nnmodule = nnmodule
self.example_inputs = nnmodule.get_example_inputs() # type: ignore[operator]

Expand Down
Loading