Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 9 additions & 35 deletions orca_python/utilities/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
import pandas as pd
from pkg_resources import get_distribution, parse_version
from sklearn import preprocessing
from sklearn.metrics import make_scorer
from sklearn.model_selection import GridSearchCV, StratifiedKFold

from orca_python.metrics import compute_metric, load_metric_as_scorer
from orca_python.results import Results


Expand Down Expand Up @@ -169,30 +169,19 @@ def run_experiment(self):
train_metrics = OrderedDict()
test_metrics = OrderedDict()
for metric_name in self.general_conf["metrics"]:

try:
# Loading metric from file
module = __import__("orca_python").metrics
metric = getattr(
module, self.general_conf["cv_metric"].lower().strip()
)

except AttributeError:
raise AttributeError(
"No metric named '%s'" % metric_name.strip().lower()
)

# Get train scores
train_score = metric(
partition["train_outputs"], train_predicted_y
train_score = compute_metric(
metric_name,
partition["train_outputs"],
train_predicted_y,
)
train_metrics[metric_name.strip() + "_train"] = train_score

# Get test scores
test_metrics[metric_name.strip() + "_test"] = np.nan
if "test_outputs" in partition:
test_score = metric(
partition["test_outputs"], test_predicted_y
test_score = compute_metric(
metric_name, partition["test_outputs"], test_predicted_y
)
test_metrics[metric_name.strip() + "_test"] = test_score

Expand Down Expand Up @@ -536,23 +525,8 @@ def _get_optimal_estimator(
optimal.refit_time_ = elapsed
return optimal

try:
module = __import__("orca_python").metrics
metric = getattr(module, self.general_conf["cv_metric"].lower().strip())

except AttributeError:

if not isinstance(self.general_conf["cv_metric"], str):
raise AttributeError("cv_metric must be string")

raise AttributeError(
"No metric named '%s' implemented"
% self.general_conf["cv_metric"].strip().lower()
)

# Making custom metrics compatible with sklearn
gib = module.greater_is_better(self.general_conf["cv_metric"].lower().strip())
scoring_function = make_scorer(metric, greater_is_better=gib)
metric_name = self.general_conf["cv_metric"].strip().lower()
scoring_function = load_metric_as_scorer(metric_name)

# Creating object to split train data for cross-validation
# This will make GridSearch have a pseudo-random beheaviour
Expand Down