Skip to content

Commit

Permalink
implemented univariate selector methods (from sci-learn) and added te…
Browse files Browse the repository at this point in the history
…sts.
  • Loading branch information
enriquea committed Sep 23, 2024
1 parent b6e8eab commit 07a9dc5
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 68 deletions.
6 changes: 4 additions & 2 deletions fslite/fs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
"treating each feature independently and assessing its contribution to the predictive "
"performance of the model.",
"methods": [
{"name": "anova","description": "Univariate ANOVA feature selection (f-classification)"},
{"name": "u_corr", "description": "Univariate correlation"},
{"name": "anova", "description": "Univariate ANOVA feature selection (f-classification)"},
{"name": "u_corr", "description": "Univariate Pearson's correlation"},
{"name": "f_regression", "description": "Univariate f-regression"},
{"name": "mutual_info_regression", "description": "Univariate mutual information regression"},
{"name": "mutual_info_classification", "description": "Univariate mutual information classification"},
],
},
"multivariate": {
Expand Down
139 changes: 74 additions & 65 deletions fslite/fs/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
from typing import Dict, List

import numpy as np
import pandas as pd
from sklearn.feature_selection import SelectKBest, f_classif, f_regression
from sklearn.feature_selection import (GenericUnivariateSelect,
f_classif,
f_regression,
mutual_info_classif,
mutual_info_regression)

from fslite.fs.constants import get_fs_univariate_methods, is_valid_univariate_method
from fslite.fs.fdataframe import FSDataFrame
Expand Down Expand Up @@ -62,14 +65,65 @@ def __str__(self):
def __repr__(self):
return self.__str__()

def univariate_feature_selector(
self,
df: FSDataFrame,
score_func: str = "f_classif",
selection_mode: str = "percentile",
selection_threshold: float = 0.8
) -> List[int]:
"""
Wrapper for scikit-learn's `GenericUnivariateSelect` feature selector, supporting multiple scoring functions.
:param df: Input FSDataFrame
:param score_func: The score function to use for feature selection. Options are:
- 'f_classif': ANOVA F-value for classification tasks.
- 'f_regression': F-value for regression tasks.
- 'mutual_info_classif': Mutual information for classification.
- 'mutual_info_regression': Mutual information for regression.
:param selection_mode: Mode for feature selection (e.g. 'percentile' or 'k_best').
:param selection_threshold: The percentage or number of features to select based on the selection mode.
:return: List of selected feature indices.
"""
# Define the score function based on input
score_func_mapping = {
"f_classif": f_classif,
"f_regression": f_regression,
"mutual_info_classif": mutual_info_classif,
"mutual_info_regression": mutual_info_regression,
}

if score_func not in score_func_mapping:
raise ValueError(f"Invalid score_func '{score_func}'. Valid options are: {list(score_func_mapping.keys())}")

# Extract the score function
selected_score_func = score_func_mapping[score_func]

# Get the feature matrix and label vector from the FSDataFrame
f_matrix = df.get_feature_matrix()
y = df.get_label_vector()

# Configure the selector using the provided score function and selection mode/threshold
selector = GenericUnivariateSelect(score_func=selected_score_func,
mode=selection_mode,
param=selection_threshold)

# Fit the selector and get only the selected feature indices (not the transformed matrix)
_ = selector.fit_transform(f_matrix, y)
selected_features = selector.get_support(indices=True)

return list(selected_features)

def univariate_filter(
self, df: FSDataFrame, univariate_method: str = "u_corr", **kwargs
self, df: FSDataFrame, univariate_method: str = "u_corr", **kwargs
) -> FSDataFrame:
"""
Filter features after applying a univariate feature selector method.
:param df: Input DataFrame
:param univariate_method: Univariate selector method ('u_corr', 'anova', 'f_regression')
:param univariate_method: Univariate selector method ('u_corr', 'anova', 'f_regression',
'mutual_info_classification', 'mutual_info_regression')
:return: Filtered DataFrame with selected features
"""

Expand All @@ -83,17 +137,21 @@ def univariate_filter(
selected_features = []

if univariate_method == "anova":
# TODO: Implement ANOVA selector
# selected_features = univariate_selector(df, features, label, label_type='categorical', **kwargs)
pass
selected_features = self.univariate_feature_selector(df, score_func="f_classif", **kwargs)
elif univariate_method == "f_regression":
# TODO: Implement F-regression selector
# selected_features = univariate_selector(df, features, label, label_type='continuous', **kwargs)
pass
selected_features = self.univariate_feature_selector(df, score_func="f_regression", **kwargs)
elif univariate_method == "u_corr":
selected_features = univariate_correlation_selector(df, **kwargs)
elif univariate_method == "mutual_info_classification":
selected_features = self.univariate_feature_selector(df, score_func="mutual_info_classif", **kwargs)
elif univariate_method == "mutual_info_regression":
selected_features = self.univariate_feature_selector(df, score_func="mutual_info_regression", **kwargs)

logger.info(f"Applying univariate filter using method: {univariate_method}")
logger.info(
f"Applying univariate filter using method: {univariate_method} \n"
f" with selection mode: {kwargs.get('selection_mode')} \n"
f" and selection threshold: {kwargs.get('selection_threshold')}"
)

if len(selected_features) == 0:
logger.warning("No features selected. Returning original DataFrame.")
Expand All @@ -104,14 +162,16 @@ def univariate_filter(


def univariate_correlation_selector(
df: FSDataFrame, corr_threshold: float = 0.3
df: FSDataFrame,
selection_threshold: float = 0.3
) -> List[int]:
"""
TODO: Replace this implementation with sci-learn's GenericUnivariateSelect with score_func='f_regression'
Select features based on their correlation with a label (class), if the correlation value is less than the specified
threshold.
:param df: Input DataFrame
:param corr_threshold: Maximum allowed correlation threshold
:param selection_threshold: Maximum allowed correlation threshold
:return: List of selected feature indices
"""
Expand Down Expand Up @@ -139,58 +199,7 @@ def compute_univariate_corr(df: FSDataFrame) -> Dict[int, float]:
selected_features = [
feature_index
for feature_index, corr in correlations.items()
if corr <= corr_threshold
if corr <= selection_threshold
]
return selected_features


def univariate_selector(
df: pd.DataFrame,
features: List[str],
label: str,
label_type: str = "categorical",
selection_mode: str = "percentile",
selection_threshold: float = 0.8,
) -> List[str]:
"""
Wrapper for scikit-learn's `SelectKBest` feature selector.
If the label is categorical, ANOVA test is used; if continuous, F-regression test is used.
:param df: Input DataFrame
:param features: List of feature column names
:param label: Label column name
:param label_type: Type of label ('categorical' or 'continuous')
:param selection_mode: Mode for feature selection ('percentile' or 'k_best')
:param selection_threshold: Number of features to select or the percentage of features
:return: List of selected feature names
"""

X = df[features].values
y = df[label].values

if label_type == "categorical":
logger.info("ANOVA (F-classification) univariate feature selection")
selector = SelectKBest(score_func=f_classif)
elif label_type == "continuous":
logger.info("F-value (F-regression) univariate feature selection")
selector = SelectKBest(score_func=f_regression)
else:
raise ValueError("`label_type` must be one of 'categorical' or 'continuous'")

if selection_mode == "percentile":
selector.set_params(k="all") # We'll handle the percentile threshold manually
selector.fit(X, y)
scores = selector.scores_
selected_indices = [
i
for i, score in enumerate(scores)
if score >= selection_threshold * max(scores)
]
else:
selector.set_params(k=int(selection_threshold))
selector.fit(X, y)
selected_indices = selector.get_support(indices=True)

selected_features = [features[i] for i in selected_indices]
return selected_features
115 changes: 114 additions & 1 deletion fslite/tests/test_univariate_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def test_univariate_filter_corr():
fs_df = FSDataFrame(df=df, sample_col="Sample", label_col="label")

# create FSUnivariate instance
fs_univariate = FSUnivariate(fs_method="u_corr", corr_threshold=0.3)
fs_univariate = FSUnivariate(fs_method="u_corr",
selection_threshold=0.3)

fsdf_filtered = fs_univariate.select_features(fs_df)

Expand All @@ -28,3 +29,115 @@ def test_univariate_filter_corr():
# Export the filtered DataFrame as Pandas DataFrame
df_filtered = fsdf_filtered.to_pandas()
df_filtered.to_csv("filtered_tnbc_data.csv", index=False)


# test the univariate_filter method with 'anova' method
def test_univariate_filter_anova():
"""
Test univariate_filter method with 'anova' method.
:return: None
"""

# import tsv as pandas DataFrame
df = pd.read_csv(get_tnbc_data_path(), sep="\t")

# create FSDataFrame instance
fs_df = FSDataFrame(df=df, sample_col="Sample", label_col="label")

# create FSUnivariate instance
fs_univariate = FSUnivariate(fs_method="anova",
selection_mode="percentile",
selection_threshold=0.8)

fsdf_filtered = fs_univariate.select_features(fs_df)

assert fs_df.count_features() == 500
assert fsdf_filtered.count_features() == 4

# Export the filtered DataFrame as Pandas DataFrame
df_filtered = fsdf_filtered.to_pandas()
df_filtered.to_csv("filtered_tnbc_data.csv", index=False)


# test the univariate_filter method with 'mutual_info_classification' method
def test_univariate_filter_mutual_info_classification():
"""
Test univariate_filter method with 'mutual_info_classification' method.
:return: None
"""

# import tsv as pandas DataFrame
df = pd.read_csv(get_tnbc_data_path(), sep="\t")

# create FSDataFrame instance
fs_df = FSDataFrame(df=df, sample_col="Sample", label_col="label")

# create FSUnivariate instance
fs_univariate = FSUnivariate(fs_method="mutual_info_classification",
selection_mode="k_best",
selection_threshold=30)

fsdf_filtered = fs_univariate.select_features(fs_df)

assert fs_df.count_features() == 500
assert fsdf_filtered.count_features() == 30

# Export the filtered DataFrame as Pandas DataFrame
df_filtered = fsdf_filtered.to_pandas()
df_filtered.to_csv("filtered_tnbc_data.csv", index=False)


# test the univariate_filter method with 'mutual_info_regression' method
def test_univariate_filter_mutual_info_regression():
"""
Test univariate_filter method with 'mutual_info_regression' method.
:return: None
"""

# import tsv as pandas DataFrame
df = pd.read_csv(get_tnbc_data_path(), sep="\t")

# create FSDataFrame instance
fs_df = FSDataFrame(df=df, sample_col="Sample", label_col="label")

# create FSUnivariate instance
fs_univariate = FSUnivariate(fs_method="mutual_info_regression",
selection_mode="percentile",
selection_threshold=0.8)

fsdf_filtered = fs_univariate.select_features(fs_df)

assert fs_df.count_features() == 500
assert fsdf_filtered.count_features() == 4

# Export the filtered DataFrame as Pandas DataFrame
df_filtered = fsdf_filtered.to_pandas()
df_filtered.to_csv("filtered_tnbc_data.csv", index=False)


# test the univariate_filter method with f-regression method
def test_univariate_filter_f_regression():
"""
Test univariate_filter method with f_regression method.
:return: None
"""

# import tsv as pandas DataFrame
df = pd.read_csv(get_tnbc_data_path(), sep="\t")

# create FSDataFrame instance
fs_df = FSDataFrame(df=df, sample_col="Sample", label_col="label")

# create FSUnivariate instance
fs_univariate = FSUnivariate(fs_method="f_regression",
selection_mode="percentile",
selection_threshold=0.8)

fsdf_filtered = fs_univariate.select_features(fs_df)

assert fs_df.count_features() == 500
assert fsdf_filtered.count_features() == 4

# Export the filtered DataFrame as Pandas DataFrame
df_filtered = fsdf_filtered.to_pandas()
df_filtered.to_csv("filtered_tnbc_data.csv", index=False)

0 comments on commit 07a9dc5

Please sign in to comment.