diff --git a/eval.py b/eval.py index 0d31271..8d02e3d 100644 --- a/eval.py +++ b/eval.py @@ -64,82 +64,28 @@ def get_args(): ), formatter_class=argparse.RawTextHelpFormatter, ) + # fmt: off parser.add_argument("--dataset-json", required=True, type=str, help="Json file for datasets.") - parser.add_argument( - "--method-json", required=True, nargs="+", type=str, help="Json file for methods." - ) + parser.add_argument("--method-json", required=True, nargs="+", type=str, help="Json file for methods.") parser.add_argument("--metric-npy", type=str, help="Npy file for saving metric results.") parser.add_argument("--curves-npy", type=str, help="Npy file for saving curve results.") parser.add_argument("--record-txt", type=str, help="Txt file for saving metric results.") parser.add_argument("--to-overwrite", action="store_true", help="To overwrite the txt file.") parser.add_argument("--record-xlsx", type=str, help="Xlsx file for saving metric results.") - parser.add_argument( - "--include-methods", - type=str, - nargs="+", - help="Names of only specific methods you want to evaluate.", - ) - parser.add_argument( - "--exclude-methods", - type=str, - nargs="+", - help="Names of some specific methods you do not want to evaluate.", - ) - parser.add_argument( - "--include-datasets", - type=str, - nargs="+", - help="Names of only specific datasets you want to evaluate.", - ) - parser.add_argument( - "--exclude-datasets", - type=str, - nargs="+", - help="Names of some specific datasets you do not want to evaluate.", - ) - parser.add_argument( - "--num-workers", - type=int, - default=4, - help="Number of workers for multi-threading or multi-processing. Default: 4", - ) - parser.add_argument( - "--num-bits", - type=int, - default=3, - help="Number of decimal places for showing results. Default: 3", - ) - parser.add_argument( - "--metric-names", - type=str, - nargs="+", - default=["sm", "wfm", "mae", "fmeasure", "em", "precision", "recall", "msiou"], - choices=SUPPORTED_METRICS, - help="Names of metrics", - ) - parser.add_argument( - "--data-type", - type=str, - default="image", - choices=["image", "video"], - help="Type of data.", - ) + parser.add_argument("--include-methods", type=str, nargs="+", help="Names of only specific methods you want to evaluate.") + parser.add_argument("--exclude-methods", type=str, nargs="+", help="Names of some specific methods you do not want to evaluate.") + parser.add_argument("--include-datasets", type=str, nargs="+", help="Names of only specific datasets you want to evaluate.") + parser.add_argument("--exclude-datasets", type=str, nargs="+", help="Names of some specific datasets you do not want to evaluate.") + parser.add_argument("--num-workers", type=int, default=4, help="Number of workers for multi-threading or multi-processing. Default: 4") + parser.add_argument("--num-bits", type=int, default=3, help="Number of decimal places for showing results. Default: 3") + parser.add_argument("--metric-names", type=str, nargs="+", default=["sm", "wfm", "mae", "fmeasure", "em", "precision", "recall", "msiou"], choices=SUPPORTED_METRICS, help="Names of metrics") + parser.add_argument("--data-type", type=str, default="image", choices=["image", "video"], help="Type of data.") known_args = parser.parse_known_args()[0] if known_args.data_type == "video": - parser.add_argument( - "--valid-frame-start", - type=int, - default=0, - help="Valid start index of the frame in each gt video. Defaults to 1, it will skip the first frame. If it is set to None, the code will not skip frames.", - ) - parser.add_argument( - "--valid-frame-end", - type=int, - default=0, - help="Valid end index of the frame in each gt video. Defaults to -1, it will skip the last frame. If it is set to 0, the code will not skip frames.", - ) - + parser.add_argument("--valid-frame-start", type=int, default=0, help="Valid start index of the frame in each gt video. Defaults to 1, it will skip the first frame. If it is set to None, the code will not skip frames.") + parser.add_argument("--valid-frame-end", type=int, default=0, help="Valid end index of the frame in each gt video. Defaults to -1, it will skip the last frame. If it is set to 0, the code will not skip frames.") + # fmt: on args = parser.parse_args() if args.data_type == "video": diff --git a/metrics/draw_curves.py b/metrics/draw_curves.py index 07c2507..b1fb568 100644 --- a/metrics/draw_curves.py +++ b/metrics/draw_curves.py @@ -7,6 +7,16 @@ from utils.recorders import CurveDrawer +# Align the mode with those in GRAYSCALE_METRICS +_YX_AXIS_NAMES = { + "pr": ("precision", "recall"), + "fm": ("fmeasure", None), + "fmeasure": ("fmeasure", None), + "em": ("em", None), + "iou": ("iou", None), + "dice": ("dice", None), +} + def draw_curves( mode: str, @@ -40,12 +50,8 @@ def draw_curves( line_width (int, optional): Width of lines. Defaults to 3. save_name (str, optional): Name or path (without the extension format). Defaults to None. """ - assert mode in ["pr", "fm", "em"] save_name = save_name or mode - mode_axes_setting = axes_setting[mode] - - x_label, y_label = mode_axes_setting["x_label"], mode_axes_setting["y_label"] - x_ticks, y_ticks = mode_axes_setting["x_ticks"], mode_axes_setting["y_ticks"] + y_axis_name, x_axis_name = _YX_AXIS_NAMES[mode] assert curves_npy_path if not isinstance(curves_npy_path, (list, tuple)): @@ -137,14 +143,6 @@ def draw_curves( for idx, (dataset_name, dataset_alias) in enumerate(dataset_aliases.items()): dataset_results = curves[dataset_name] - curve_drawer.set_axis_property( - idx=idx, - title=dataset_alias.upper(), - x_label=x_label, - y_label=y_label, - x_ticks=x_ticks, - y_ticks=y_ticks, - ) for method_name in target_unique_method_names: method_setting = unique_method_settings[method_name] @@ -154,30 +152,19 @@ def draw_curves( continue method_results = dataset_results[method_name] - if mode == "pr": - y_data = method_results.get("p") - if y_data is None: - y_data = method_results["precision"] - assert isinstance(y_data, (list, tuple)), (method_name, method_results.keys()) - - x_data = method_results.get("r") - if x_data is None: - x_data = method_results["recall"] - assert isinstance(x_data, (list, tuple)), (method_name, method_results.keys()) - elif mode == "fm": - y_data = method_results.get("fm") - if y_data is None: - y_data = method_results["fmeasure"] - assert isinstance(y_data, (list, tuple)), (method_name, method_results.keys()) - x_data = np.linspace(0, 1, 256) - elif mode == "em": - y_data = method_results["em"] + if y_axis_name is None: + y_data = np.linspace(0, 1, 256) + else: + y_data = method_results[y_axis_name] assert isinstance(y_data, (list, tuple)), (method_name, method_results.keys()) + if x_axis_name is None: x_data = np.linspace(0, 1, 256) + else: + x_data = method_results[x_axis_name] + assert isinstance(x_data, (list, tuple)), (method_name, method_results.keys()) - curve_drawer.plot_at_axis( - idx=idx, method_curve_setting=method_setting, x_data=x_data, y_data=y_data - ) + curve_drawer.plot_at_axis(idx, method_setting, x_data=x_data, y_data=y_data) + curve_drawer.set_axis_property(idx, dataset_alias, **axes_setting[mode]) curve_drawer.save(path=save_name) diff --git a/plot.py b/plot.py index 7bf2437..34d4cf7 100644 --- a/plot.py +++ b/plot.py @@ -47,41 +47,18 @@ def get_args(): ), formatter_class=argparse.RawTextHelpFormatter, ) + # fmt: off parser.add_argument("--alias-yaml", type=str, help="Yaml file for datasets and methods alias.") - parser.add_argument( - "--style-cfg", - type=str, - required=True, - help="Yaml file for plotting curves.", - ) - parser.add_argument( - "--curves-npys", - required=True, - type=str, - nargs="+", - help="Npy file for saving curve results.", - ) - parser.add_argument( - "--our-methods", type=str, nargs="+", help="Names of our methods for highlighting it." - ) - parser.add_argument( - "--num-rows", type=int, default=1, help="Number of rows for subplots. Default: 1" - ) - parser.add_argument( - "--num-col-legend", type=int, default=1, help="Number of columns in the legend. Default: 1" - ) - parser.add_argument( - "--mode", - type=str, - choices=["pr", "fm", "em"], - default="pr", - help="Mode for plotting. Default: pr", - ) - parser.add_argument( - "--separated-legend", action="store_true", help="Use the separated legend." - ) + parser.add_argument("--style-cfg", type=str, required=True, help="Yaml file for plotting curves.") + parser.add_argument("--curves-npys", required=True, type=str, nargs="+", help="Npy file for saving curve results.") + parser.add_argument("--our-methods", type=str, nargs="+", help="Names of our methods for highlighting it.") + parser.add_argument("--num-rows", type=int, default=1, help="Number of rows for subplots. Default: 1") + parser.add_argument("--num-col-legend", type=int, default=1, help="Number of columns in the legend. Default: 1") + parser.add_argument("--mode", type=str, choices=["pr", "fm", "em", "iou", "dice"], default="pr", help="Mode for plotting. Default: pr") + parser.add_argument("--separated-legend", action="store_true", help="Use the separated legend.") parser.add_argument("--sharey", action="store_true", help="Use the shared y-axis.") parser.add_argument("--save-name", type=str, help="the exported file path") + # fmt: on args = parser.parse_args() return args @@ -95,32 +72,48 @@ def main(args): method_aliases = aliases.get("method") dataset_aliases = aliases.get("dataset") + # TODO: Better method to set axes_setting + axes_setting = { + # pr curve + "pr": { + "x_label": "Recall", + "y_label": "Precision", + "x_ticks": np.linspace(0.5, 1, 6), + "y_ticks": np.linspace(0.7, 1, 6), + }, + # fm curve + "fm": { + "x_label": "Threshold", + "y_label": r"F$_{\beta}$", + "x_ticks": np.linspace(0, 1, 6), + "y_ticks": np.linspace(0.6, 1, 6), + }, + # em curve + "em": { + "x_label": "Threshold", + "y_label": r"E$_{m}$", + "x_ticks": np.linspace(0, 1, 6), + "y_ticks": np.linspace(0.7, 1, 6), + }, + # iou curve + "iou": { + "x_label": "Threshold", + "y_label": "IoU", + "x_ticks": np.linspace(0, 1, 6), + "y_ticks": np.linspace(0.4, 1, 6), + }, + # dice curve + "dice": { + "x_label": "Threshold", + "y_label": "Dice", + "x_ticks": np.linspace(0, 1, 6), + "y_ticks": np.linspace(0.4, 1, 6), + }, + } + draw_curves.draw_curves( mode=args.mode, - # 不同曲线的绘图配置 - axes_setting={ - # pr曲线的配置 - "pr": { - "x_label": "Recall", - "y_label": "Precision", - "x_ticks": np.linspace(0.5, 1, 6), - "y_ticks": np.linspace(0.7, 1, 6), - }, - # fm曲线的配置 - "fm": { - "x_label": "Threshold", - "y_label": r"F$_{\beta}$", - "x_ticks": np.linspace(0, 1, 6), - "y_ticks": np.linspace(0.6, 1, 6), - }, - # em曲线的配置 - "em": { - "x_label": "Threshold", - "y_label": r"E$_{m}$", - "x_ticks": np.linspace(0, 1, 6), - "y_ticks": np.linspace(0.7, 1, 6), - }, - }, + axes_setting=axes_setting, curves_npy_path=args.curves_npys, row_num=args.num_rows, method_aliases=method_aliases, diff --git a/tools/converter.py b/tools/converter.py index 03eaf88..8d5a56c 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -10,30 +10,15 @@ import numpy as np import yaml -parser = argparse.ArgumentParser( - description="A useful and convenient tool to convert your .npy results into the table code in latex." -) -parser.add_argument( - "-i", - "--result-file", - required=True, - nargs="+", - action="extend", - help="The path of the *_metrics.npy file.", -) -parser.add_argument( - "-o", "--tex-file", required=True, type=str, help="The path of the exported tex file." -) -parser.add_argument( - "-c", "--config-file", type=str, help="The path of the customized config yaml file." -) -parser.add_argument( - "--contain-table-env", - action="store_true", - help="Whether to containe the table env in the exported code.", -) +# fmt: off +parser = argparse.ArgumentParser(description="A useful and convenient tool to convert your .npy results into the table code in latex.") +parser.add_argument("-i", "--result-file", required=True, nargs="+", action="extend", help="The path of the *_metrics.npy file.") +parser.add_argument("-o", "--tex-file", required=True, type=str, help="The path of the exported tex file.") +parser.add_argument("-c", "--config-file", type=str, help="The path of the customized config yaml file.") +parser.add_argument("--contain-table-env", action="store_true", help="Whether to containe the table env in the exported code.") parser.add_argument("--num-bits", type=int, default=3, help="Number of valid digits.") parser.add_argument("--transpose", action="store_true", help="Whether to transpose the table.") +# fmt: on args = parser.parse_args() arg_head = f"%% Generated by: {vars(args)}" @@ -139,7 +124,7 @@ def update_dict(parent_dict, sub_dict): metric_row_head=" ", metric_column_head="& ", body=[ - "& \\best{{{txt:.03f}}}", # style for top1 + "& \\first{{{txt:.03f}}}", # style for top1 "& \\second{{{txt:.03f}}}", # style for top2 "& \\third{{{txt:.03f}}}", # style for top3 "& {txt:.03f}", # style for other