Skip to content

Commit

Permalink
Unify process_data_set and filter_dataset into a single method (proce…
Browse files Browse the repository at this point in the history
…ss_data_set) which now supports returning a CSV
  • Loading branch information
alexhernandezgarcia committed Jul 13, 2024
1 parent 7d49ced commit e12f3ba
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
43 changes: 27 additions & 16 deletions gflownet/envs/crystals/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def states2proxy(
dim=1,
)

def process_data_set(self, df: pd.DataFrame, progress=False) -> List[List]:
def process_data_set(
self, df: pd.DataFrame, return_type: str = "state", progress=False
) -> List[List]:
"""
Converts a data set passed as a pandas DataFrame into a list of states in
environment format.
Expand All @@ -170,6 +172,10 @@ def process_data_set(self, df: pd.DataFrame, progress=False) -> List[List]:
df : DataFrame
A pandas DataFrame containing the necessary columns to represent a crystal
as described above.
return_type: str
Identifier of the data format to be return. Options:
- state: list of states in environment format (default)
- dataframe: pandas DataFrame
progress : bool
Whether to display a progress bar.
Expand All @@ -178,17 +184,29 @@ def process_data_set(self, df: pd.DataFrame, progress=False) -> List[List]:
list
A list of states in environment format.
"""
data_valid = []
is_valid = []
states_valid = []
for row in tqdm(df.iterrows(), total=len(df), disable=not progress):
# Index 0 is the row index; index 1 is the remaining columns
row = row[1]
if self._is_valid_datarow(row):
# TODO: Consider making stack state a dict which would avoid having to
# do this, among other advantages
state = self._state_from_datarow(row)
state_stack = [2] + [state[stage] for stage in self.subenvs]
data_valid.append(state_stack)
return data_valid
if return_type.lower() == "dataframe":
is_valid.append(self._is_valid_datarow(row))
elif return_type.lower().startswith("state"):
if self._is_valid_datarow(row):
# TODO: Consider making stack state a dict which would avoid having
# to do this, among other advantages
state = self._state_from_datarow(row)
state_stack = [2] + [state[stage] for stage in self.subenvs]
states_valid.append(state_stack)
else:
raise ValueError(
f"Unknown return_type. Received {return_type}, expected state or "
"dataframe"
)
if return_type.lower() == "dataframe":
return df[np.array(is_valid)]
else:
return states_valid

def _state_from_datarow(self, row):
state = {}
Expand All @@ -209,10 +227,3 @@ def _is_valid_datarow(self, row):
subenv.is_valid(state[stage]) for stage, subenv in self.subenvs.items()
]
return all(is_valid_subenvs)

def filter_dataset(self, df: pd.DataFrame) -> pd.DataFrame:
is_valid = []
for row in df.iterrows():
row = row[1]
is_valid.append(self._is_valid_datarow(row))
return df[np.array(is_valid)]
4 changes: 2 additions & 2 deletions scripts/crystal/plots_iclm24.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def load_mb_data(
vdf = add_elements_columns(vdf)
tdf[energy_key] = tdf[names[target]]
vdf[energy_key] = vdf[names[target]]
tdf = env.filter_dataset(tdf)
vdf = env.filter_dataset(vdf)
tdf = env.process_data_set(tdf, return_type="dataframe")
vdf = env.process_data_set(vdf, return_type="dataframe")
tdf = tdf.drop(columns=[names[target], "Formulae"])
vdf = vdf.drop(columns=[names[target], "Formulae"])
print("Filtered data sets:")
Expand Down

0 comments on commit e12f3ba

Please sign in to comment.