diff --git a/pgscatalog_utils/aggregate/aggregate_scores.py b/pgscatalog_utils/aggregate/aggregate_scores.py index d57943d..71fbb43 100644 --- a/pgscatalog_utils/aggregate/aggregate_scores.py +++ b/pgscatalog_utils/aggregate/aggregate_scores.py @@ -16,14 +16,14 @@ def aggregate_scores(): if args.split: logger.debug("Splitting aggregated scores by sampleset") - for sampleset, group in df.groupby('sampleset'): + for sampleset, group in df.groupby("sampleset"): fout = f"{sampleset}_pgs.txt.gz" logger.debug(f"Compressing sampleset {sampleset}, writing to {fout}") - group.to_csv(fout, sep='\t', compression='gzip') + group.to_csv(fout, sep="\t", compression="gzip") else: fout = "aggregated_scores.txt.gz" logger.info(f"Compressing all samplesets and writing combined scores to {fout}") - df.to_csv(fout, sep='\t', compression='gzip') + df.to_csv(fout, sep="\t", compression="gzip") def aggregate(scorefiles: list[str]): @@ -33,11 +33,13 @@ def aggregate(scorefiles: list[str]): for i, path in enumerate(scorefiles): logger.debug(f"Reading {path}") # pandas can automatically detect zst compression, neat! - df = (pd.read_table(path, converters={"#IID": str}, header=0) - .assign(sampleset=path.split('_')[0]) - .set_index(['sampleset', '#IID'])) + df = ( + pd.read_table(path, converters={"#IID": str}, header=0) + .assign(sampleset=path.split("_")[0]) + .set_index(["sampleset", "#IID"]) + ) - df.index.names = ['sampleset', 'IID'] + df.index.names = ["sampleset", "IID"] # Subset to aggregatable columns df = df[_select_agg_cols(df.columns)] @@ -45,31 +47,57 @@ def aggregate(scorefiles: list[str]): # Combine DFs if i == 0: - logger.debug('Initialising combined DF') + logger.debug("Initialising combined DF") combined = df.copy() else: - logger.debug('Adding to combined DF') + logger.debug("Adding to combined DF") combined = combined.add(df, fill_value=0) - assert all([x in combined.columns for x in aggcols]), "All Aggregatable Columns are present in the final DF" + assert all( + [x in combined.columns for x in aggcols] + ), "All Aggregatable Columns are present in the final DF" - return combined.pipe(_calculate_average) + sum_df, avg_df = combined.pipe(_calculate_average) + # need to melt sum and avg separately to give correct value_Name to melt + dfs = [_melt(x, y) for x, y in zip([sum_df, avg_df], ["SUM", "AVG"])] + # add melted average back + combined = pd.concat([dfs[0], dfs[1]["AVG"]], axis=1) + return combined[["accession", "SUM", "DENOM", "AVG"]] + + +def _melt(df, value_name): + df = df.melt( + id_vars=["DENOM"], + value_name=value_name, + var_name="accession", + ignore_index=False, + ) + df["accession"] = df["accession"].str.replace(f"_{value_name}", "") + return df def _calculate_average(combined: pd.DataFrame): logger.debug("Averaging data") - avgs = combined.loc[:, combined.columns.str.endswith('_SUM')].divide(combined['DENOM'], axis=0) - avgs.columns = avgs.columns.str.replace('_SUM', '_AVG') - return pd.concat([combined, avgs], axis=1) + avgs = combined.loc[:, combined.columns.str.endswith("_SUM")].divide( + combined["DENOM"], axis=0 + ) + avgs.columns = avgs.columns.str.replace("_SUM", "_AVG") + avgs["DENOM"] = combined["DENOM"] + return combined, avgs def _select_agg_cols(cols): - keep_cols = ['DENOM'] - return [x for x in cols if (x.endswith('_SUM') and (x != 'NAMED_ALLELE_DOSAGE_SUM')) or (x in keep_cols)] + keep_cols = ["DENOM"] + return [ + x + for x in cols + if (x.endswith("_SUM") and (x != "NAMED_ALLELE_DOSAGE_SUM")) or (x in keep_cols) + ] def _description_text() -> str: - return textwrap.dedent(''' + return textwrap.dedent( + """ Aggregate plink .sscore files into a combined TSV table. This aggregation sums scores that were calculated from plink @@ -80,20 +108,45 @@ def _description_text() -> str: Input .sscore files can be optionally compressed with zstd or gzip. The aggregated output scores are compressed with gzip. - ''') + """ + ) def _parse_args(args=None) -> argparse.Namespace: - parser = argparse.ArgumentParser(description=_description_text(), - formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument('-s', '--scores', dest='scores', required=True, nargs='+', - help=' List of scorefile paths. Use a wildcard (*) to select multiple files.') - parser.add_argument('-o', '--outdir', dest='outdir', required=True, - default='scores/', help=' Output directory to store downloaded files') - parser.add_argument('--split', dest='split', required=False, action=argparse.BooleanOptionalAction, - help=' Make one aggregated file per sampleset') - parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', - help=' Extra logging information') + parser = argparse.ArgumentParser( + description=_description_text(), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "-s", + "--scores", + dest="scores", + required=True, + nargs="+", + help=" List of scorefile paths. Use a wildcard (*) to select multiple files.", + ) + parser.add_argument( + "-o", + "--outdir", + dest="outdir", + required=True, + default="scores/", + help=" Output directory to store downloaded files", + ) + parser.add_argument( + "--split", + dest="split", + required=False, + action=argparse.BooleanOptionalAction, + help=" Make one aggregated file per sampleset", + ) + parser.add_argument( + "-v", + "--verbose", + dest="verbose", + action="store_true", + help=" Extra logging information", + ) return parser.parse_args(args)