Skip to content

Commit

Permalink
adding averaging for less than 10 samples
Browse files Browse the repository at this point in the history
  • Loading branch information
saanikat committed Sep 12, 2024
1 parent 04024cd commit e534270
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 6 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,13 @@ print(results)
```

To see the available schemas, you can run:
```
schemas = model.show_available_schemas()
print(schemas)
```

This will print the available schemas as a list.

You can use the format provided in the `trial.py` script in this repository as a reference.
14 changes: 13 additions & 1 deletion attribute_standardizer/attr_standardizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,16 @@ def standardize(
try:
csv_file = fetch_from_pephub(pep)

X_values_st, X_headers_st, X_values_bow = data_preprocessing(csv_file)
X_values_st, X_headers_st, X_values_bow, num_rows = data_preprocessing(
csv_file
)
(
X_headers_embeddings_tensor,
X_values_embeddings_tensor,
X_values_bow_tensor,
label_encoder,
) = data_encoding(
num_rows,
X_values_st,
X_headers_st,
X_values_bow,
Expand Down Expand Up @@ -192,3 +195,12 @@ def standardize(
logger.error(
f"Error occured during standardization in standardize function: {str(e)}"
)
@staticmethod
def show_available_schemas()-> list[str]:
"""
Stores a list of available schemas.
:return list: List of available schemas.
"""
schemas = ['ENCODE', 'FAIRTRACKS', 'BEDBASE']
return schemas

28 changes: 25 additions & 3 deletions attribute_standardizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def load_from_huggingface(schema: str) -> Optional[Any]:

def data_preprocessing(
df: pd.DataFrame,
) -> Tuple[List[List[str]], List[str], List[List[str]]]:
) -> Tuple[List[List[str]], List[str], List[List[str]], int]:
"""
Preprocessing the DataFrame by extracting the column values and headers.
Expand All @@ -76,13 +76,16 @@ def data_preprocessing(
- Nested list containing the comma separated values in each column for sentence transformer embeddings.
- List containing the headers of the DataFrame.
- Nested list containing the comma separated values in each column for Bag of Words encoding.
- Number of rows in the metadata csv
"""

X_values_st = [df[column].astype(str).tolist() for column in df.columns]
X_headers_st = df.columns.tolist()
X_values_bow = [df[column].astype(str).tolist() for column in df.columns]

return X_values_st, X_headers_st, X_values_bow
num_rows = df.shape[0]

return X_values_st, X_headers_st, X_values_bow, num_rows


def get_top_k_average(val_embedding: List[np.ndarray], k: int) -> np.ndarray:
Expand Down Expand Up @@ -134,7 +137,21 @@ def get_top_cluster_averaged(embeddings: List[np.ndarray]) -> np.ndarray:
return top_k_average.numpy()


def get_averaged(embeddings: List[np.ndarray]) -> np.ndarray:
"""
Averages the embeddings.
:param list embeddings: List of embeddings, each embedding is a vector of values.
:return np.ndarray: The mean of all the embeddings as a NumPy array.
"""
flattened_embeddings = [embedding.tolist() for embedding in embeddings]
flattened_embeddings_array = np.array(flattened_embeddings)
averaged_embedding = np.mean(flattened_embeddings_array, axis=0)

return averaged_embedding


def data_encoding(
num_rows: int,
X_values_st: List[List[str]],
X_headers_st: List[str],
X_values_bow: List[List[str]],
Expand All @@ -144,6 +161,7 @@ def data_encoding(
"""
Encode input data in accordance with the user-specified schemas.
:param int num_rows: Number of rows in the sample metadata
:param list X_values_st: Nested list containing the comma separated values in each column for sentence transformer embeddings.
:param list X_headers_st: List containing the headers of the DataFrame.
:param list X_values_bow: Nested list containing the comma separated values in each column for Bag of Words encoding.
Expand All @@ -159,7 +177,11 @@ def data_encoding(
embeddings = []
for column in X_values_st:
val_embedding = sentence_encoder.encode(column, show_progress_bar=False)
embedding = get_top_cluster_averaged(val_embedding)
if num_rows >= 10:
embedding = get_top_cluster_averaged(val_embedding)
else:
embedding = get_averaged(val_embedding)

embeddings.append(embedding)
X_values_embeddings = embeddings
if schema == "ENCODE":
Expand Down
9 changes: 7 additions & 2 deletions trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

model = AttrStandardizer("ENCODE")

results = model.standardize(pep="geo/gse178283:default")
schemas = model.show_available_schemas()

print(results)
print(schemas)

#results = model.standardize(pep="geo/gse178283:default")
results = model.standardize(pep="geo/gse228634:default")

print(results)

0 comments on commit e534270

Please sign in to comment.