Skip to content

Commit

Permalink
wraps regression
Browse files Browse the repository at this point in the history
  • Loading branch information
mgraffg committed Feb 13, 2025
1 parent 10d1055 commit 77a3664
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions CompStats/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def explained_variance_score(y_true,
use_tqdm=True,
**kwargs):
"""explained_variance_score"""

@wraps(metrics.explained_variance_score)
def inner(y, hy):
return metrics.explained_variance_score(y, hy,
sample_weight=sample_weight,
Expand All @@ -336,6 +338,8 @@ def max_error(y_true, *y_pred,
use_tqdm=True,
**kwargs):
"""max_error"""

@wraps(metrics.max_error)
def inner(y, hy):
return metrics.max_error(y, hy)
return Perf(y_true, *y_pred, score_func=None, error_func=inner,
Expand All @@ -354,6 +358,8 @@ def mean_absolute_error(y_true,
use_tqdm=True,
**kwargs):
"""mean_absolute_error"""

@wraps(metrics.mean_absolute_error)
def inner(y, hy):
return metrics.mean_absolute_error(y, hy,
sample_weight=sample_weight,
Expand All @@ -375,6 +381,8 @@ def mean_squared_error(y_true,
use_tqdm=True,
**kwargs):
"""mean_squared_error"""

@wraps(metrics.mean_squared_error)
def inner(y, hy):
return metrics.mean_squared_error(y, hy,
sample_weight=sample_weight,
Expand All @@ -396,6 +404,8 @@ def root_mean_squared_error(y_true,
use_tqdm=True,
**kwargs):
"""root_mean_squared_error"""

@wraps(metrics.root_mean_squared_error)
def inner(y, hy):
return metrics.root_mean_squared_error(y, hy,
sample_weight=sample_weight,
Expand All @@ -417,6 +427,8 @@ def mean_squared_log_error(y_true,
use_tqdm=True,
**kwargs):
"""mean_squared_log_error"""

@wraps(metrics.mean_squared_log_error)
def inner(y, hy):
return metrics.mean_squared_log_error(y, hy,
sample_weight=sample_weight,
Expand All @@ -438,6 +450,8 @@ def root_mean_squared_log_error(y_true,
use_tqdm=True,
**kwargs):
"""root_mean_squared_log_error"""

@wraps(metrics.root_mean_squared_log_error)
def inner(y, hy):
return metrics.root_mean_squared_log_error(y, hy,
sample_weight=sample_weight,
Expand All @@ -459,6 +473,8 @@ def median_absolute_error(y_true,
use_tqdm=True,
**kwargs):
"""median_absolute_error"""

@wraps(metrics.median_absolute_error)
def inner(y, hy):
return metrics.median_absolute_error(y, hy,
sample_weight=sample_weight,
Expand All @@ -481,6 +497,8 @@ def r2_score(y_true,
use_tqdm=True,
**kwargs):
"""r2_score"""

@wraps(metrics.r2_score)
def inner(y, hy):
return metrics.r2_score(y, hy,
sample_weight=sample_weight,
Expand All @@ -502,6 +520,8 @@ def mean_poisson_deviance(y_true,
use_tqdm=True,
**kwargs):
"""mean_poisson_deviance"""

@wraps(metrics.mean_poisson_deviance)
def inner(y, hy):
return metrics.mean_poisson_deviance(y, hy,
sample_weight=sample_weight)
Expand All @@ -521,6 +541,8 @@ def mean_gamma_deviance(y_true,
use_tqdm=True,
**kwargs):
"""mean_gamma_deviance"""

@wraps(metrics.mean_gamma_deviance)
def inner(y, hy):
return metrics.mean_gamma_deviance(y, hy,
sample_weight=sample_weight)
Expand All @@ -541,6 +563,8 @@ def mean_absolute_percentage_error(y_true,
use_tqdm=True,
**kwargs):
"""mean_absolute_percentage_error"""

@wraps(metrics.mean_absolute_percentage_error)
def inner(y, hy):
return metrics.mean_absolute_percentage_error(y, hy,
sample_weight=sample_weight,
Expand All @@ -562,6 +586,8 @@ def d2_absolute_error_score(y_true,
use_tqdm=True,
**kwargs):
"""d2_absolute_error_score"""

@wraps(metrics.d2_absolute_error_score)
def inner(y, hy):
return metrics.d2_absolute_error_score(y, hy,
sample_weight=sample_weight,
Expand Down

0 comments on commit 77a3664

Please sign in to comment.