Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions test/pt2_to_circle_test/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from test.utils.base_builders import TestDictBuilderBase, TestRunnerBase

from test.utils.tag import is_tagged
from test.utils.tag import TestTag


class NNModuleTest(TestRunnerBase):
Expand All @@ -41,9 +41,11 @@ def __init__(self, test_name: str, nnmodule: torch.nn.Module):
self.test_dir = Path(os.path.dirname(os.path.abspath(__file__))) / "artifacts"

# Get tags
self.test_without_pt2: bool = is_tagged(self.nnmodule, "test_without_pt2")
self.test_without_inference: bool = is_tagged(
self.nnmodule, "test_without_inference"
self.test_without_pt2: bool = TestTag.get(
self.nnmodule, "test_without_pt2", False
)
self.test_without_inference: bool = TestTag.get(
self.nnmodule, "test_without_inference", False
)

# Set tolerance
Expand Down
28 changes: 17 additions & 11 deletions test/utils/base_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import inspect
import pkgutil
from abc import abstractmethod
from typing import Optional

import torch

from test.utils.tag import is_tagged
from test.utils.tag import get_tag, has_tag, TestTag


class TestRunnerBase:
Expand All @@ -31,12 +31,17 @@ def __init__(self, test_name: str, nnmodule: torch.nn.Module):
self.nnmodule = nnmodule
self.example_inputs = nnmodule.get_example_inputs() # type: ignore[operator]

# Get tags
self.skip: bool = is_tagged(self.nnmodule, "skip")
self.skip_reason: str = getattr(self.nnmodule, "__tag_skip_reason", "")
self.test_negative: bool = is_tagged(self.nnmodule, "test_negative")
self.expected_err: str = getattr(self.nnmodule, "__tag_expected_err", "")
self.use_onert: bool = is_tagged(self.nnmodule, "use_onert")
skip: Optional[object] = TestTag.get(type(self.nnmodule), "skip")
self.skip: bool = skip is not None
self.skip_reason: str = skip.get("reason") if skip else ""

test_negative: Optional[object] = TestTag.get(
type(self.nnmodule), "test_negative"
)
self.test_negative: bool = test_negative is not None
self.expected_err: str = test_negative.get("reason") if test_negative else ""

self.use_onert: bool = TestTag.get(type(self.nnmodule), "use_onert", False)

@abstractmethod
def make(self):
Expand Down Expand Up @@ -79,16 +84,17 @@ def _get_nnmodules(self, submodule: str):
)
)

# If any of the nnmodule_classes has a tag `__tag_target`, only those nnmodule_classes will be added
# If any of the nnmodule_classes is marked as target, only those will be added
target_only: bool = any(
hasattr(nnmodule_cls, "__tag_target") for nnmodule_cls in nnmodule_classes
TestTag.get(nnmodule_cls, "target", False)
for nnmodule_cls in nnmodule_classes
)

if target_only:
nnmodule_classes = [
nnmodule_cls
for nnmodule_cls in nnmodule_classes
if hasattr(nnmodule_cls, "__tag_target")
if TestTag.get(nnmodule_cls, "target", False)
]

return nnmodule_classes
Expand Down
138 changes: 78 additions & 60 deletions test/utils/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,95 +12,113 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Type

def skip(reason):
def __inner_skip(orig_class):
setattr(orig_class, "__tag_skip", True)
setattr(orig_class, "__tag_skip_reason", reason)

def __init__(self, *args_, **kwargs_):
pass
class TestTag:
"""Central registry for managing test tag"""

# Ignore initialization of skipped modules
orig_class.__init__ = __init__
_registry: Dict[Type, Dict[str, Any]] = {}

return orig_class
@classmethod
def add(cls, test_class: Type, tag_key: str, tag_value: Any = None) -> None:
"""Add test tag object to a class

return __inner_skip
Args:
test_class: The test class to add tag to
tag_key: Name of Tag object to add
tag_value: Tag object to add
"""
if test_class not in cls._registry:
cls._registry[test_class] = {}

cls._registry[test_class][tag_key] = tag_value

def skip_if(predicate, reason):
def __inner_skip(orig_class):
setattr(orig_class, "__tag_skip", True)
setattr(orig_class, "__tag_skip_reason", reason)
@classmethod
def has(cls, test_class: Type, tag_key: str) -> bool:
"""Check if a class has specific tag type

def __init__(self, *args_, **kwargs_):
pass
Args:
test_class: The test class to check
tag_key: Type of tag object to check for

# Ignore initialization of skipped modules
orig_class.__init__ = __init__
Returns:
bool: True if the tag exists, False otherwise
"""
return test_class in cls._registry and tag_key in cls._registry[test_class]

return orig_class
@classmethod
def get(cls, test_class: Type, tag_key: str, default: Any = None) -> Any:
"""Get tag object for a class

if predicate:
return __inner_skip
else:
return lambda x: x
Args:
test_class: The test class to get tag from
tag_key: Type of tag object to retrieve
default: Default value to return if tag not found

Returns:
The tag object or default if not found
"""
return cls._registry.get(test_class, {}).get(tag_key, default)

def test_without_inference(orig_class):
setattr(orig_class, "__tag_test_without_inference", True)
return orig_class

####################################################################
################## Add tag here ##################
####################################################################

def test_without_pt2(orig_class):
setattr(orig_class, "__tag_test_without_pt2", True)
return orig_class

def skip(reason):
"""
Mark a test class to be skipped with a reason

def test_negative(expected_err):
def __inner_test_negative(orig_class):
setattr(orig_class, "__tag_test_negative", True)
setattr(orig_class, "__tag_expected_err", expected_err)
e.g.
@skip(reason="Not implemented yet")
class MyTest(unittest.TestCase): # <-- This test will be skipped
"""

return orig_class
def decorator(cls):
TestTag.add(cls, "skip", {"reason": reason})
return cls

return __inner_test_negative
return decorator


def target(orig_class):
setattr(orig_class, "__tag_target", True)
return orig_class
def skip_if(predicate, reason):
"""Conditionally mark a test class to be skipped with a reason"""
if predicate:
return skip(reason)
return lambda cls: cls


def use_onert(orig_class):
"""
Decorator to mark a test class so that Circle models are executed
with the 'onert' runtime.
def test_negative(expected_err):
"""Mark a test class as negative test case with expected error"""

Useful when the default 'circle-interpreter' cannot run the model
under test.
"""
setattr(orig_class, "__tag_use_onert", True)
return orig_class
def decorator(cls):
TestTag.add(cls, "test_negative", {"expected_err": expected_err})
return cls

return decorator


def init_args(*args, **kwargs):
def __inner_init_args(orig_class):
orig_init = orig_class.__init__
# Make copy of original __init__, so we can call it without recursion
def target(cls):
"""Mark a test class as target test case"""
TestTag.add(cls, "target")
return cls

def __init__(self, *args_, **kwargs_):
args_ = (*args, *args_)
kwargs_ = {**kwargs, **kwargs_}

orig_init(self, *args_, **kwargs_) # Call the original __init__
def use_onert(cls):
"""Mark a test class to use ONERT runtime"""
TestTag.add(cls, "use_onert")
return cls

orig_class.__init__ = __init__
return orig_class

return __inner_init_args
def test_without_pt2(cls):
"""Mark a test class to not convert along pt2 during test execution"""
TestTag.add(cls, "test_without_pt2")
return cls


def is_tagged(cls, tag: str):
return hasattr(cls, f"__tag_{tag}")
def test_without_inference(cls):
"""Mark a test class to not run inference during test execution"""
TestTag.add(cls, "test_without_inference")
return cls