Skip to content

Commit

Permalink
0. 添加iou和dice曲线的绘制功能。
Browse files Browse the repository at this point in the history
1. 简化绘图代码中关于轴对应的指标的选择逻辑
2. 基于ruff的格式,屏蔽对终端参数选项的格式化。
3. 调整converter.py中关于最优指标使用的指令(`\best` -> `\first`),以和两外两个位次的名称对应。
  • Loading branch information
lartpang committed Sep 26, 2024
1 parent 8d8c86e commit ad3bfd8
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 180 deletions.
80 changes: 13 additions & 67 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,82 +64,28 @@ def get_args():
),
formatter_class=argparse.RawTextHelpFormatter,
)
# fmt: off
parser.add_argument("--dataset-json", required=True, type=str, help="Json file for datasets.")
parser.add_argument(
"--method-json", required=True, nargs="+", type=str, help="Json file for methods."
)
parser.add_argument("--method-json", required=True, nargs="+", type=str, help="Json file for methods.")
parser.add_argument("--metric-npy", type=str, help="Npy file for saving metric results.")
parser.add_argument("--curves-npy", type=str, help="Npy file for saving curve results.")
parser.add_argument("--record-txt", type=str, help="Txt file for saving metric results.")
parser.add_argument("--to-overwrite", action="store_true", help="To overwrite the txt file.")
parser.add_argument("--record-xlsx", type=str, help="Xlsx file for saving metric results.")
parser.add_argument(
"--include-methods",
type=str,
nargs="+",
help="Names of only specific methods you want to evaluate.",
)
parser.add_argument(
"--exclude-methods",
type=str,
nargs="+",
help="Names of some specific methods you do not want to evaluate.",
)
parser.add_argument(
"--include-datasets",
type=str,
nargs="+",
help="Names of only specific datasets you want to evaluate.",
)
parser.add_argument(
"--exclude-datasets",
type=str,
nargs="+",
help="Names of some specific datasets you do not want to evaluate.",
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of workers for multi-threading or multi-processing. Default: 4",
)
parser.add_argument(
"--num-bits",
type=int,
default=3,
help="Number of decimal places for showing results. Default: 3",
)
parser.add_argument(
"--metric-names",
type=str,
nargs="+",
default=["sm", "wfm", "mae", "fmeasure", "em", "precision", "recall", "msiou"],
choices=SUPPORTED_METRICS,
help="Names of metrics",
)
parser.add_argument(
"--data-type",
type=str,
default="image",
choices=["image", "video"],
help="Type of data.",
)
parser.add_argument("--include-methods", type=str, nargs="+", help="Names of only specific methods you want to evaluate.")
parser.add_argument("--exclude-methods", type=str, nargs="+", help="Names of some specific methods you do not want to evaluate.")
parser.add_argument("--include-datasets", type=str, nargs="+", help="Names of only specific datasets you want to evaluate.")
parser.add_argument("--exclude-datasets", type=str, nargs="+", help="Names of some specific datasets you do not want to evaluate.")
parser.add_argument("--num-workers", type=int, default=4, help="Number of workers for multi-threading or multi-processing. Default: 4")
parser.add_argument("--num-bits", type=int, default=3, help="Number of decimal places for showing results. Default: 3")
parser.add_argument("--metric-names", type=str, nargs="+", default=["sm", "wfm", "mae", "fmeasure", "em", "precision", "recall", "msiou"], choices=SUPPORTED_METRICS, help="Names of metrics")
parser.add_argument("--data-type", type=str, default="image", choices=["image", "video"], help="Type of data.")

known_args = parser.parse_known_args()[0]
if known_args.data_type == "video":
parser.add_argument(
"--valid-frame-start",
type=int,
default=0,
help="Valid start index of the frame in each gt video. Defaults to 1, it will skip the first frame. If it is set to None, the code will not skip frames.",
)
parser.add_argument(
"--valid-frame-end",
type=int,
default=0,
help="Valid end index of the frame in each gt video. Defaults to -1, it will skip the last frame. If it is set to 0, the code will not skip frames.",
)

parser.add_argument("--valid-frame-start", type=int, default=0, help="Valid start index of the frame in each gt video. Defaults to 1, it will skip the first frame. If it is set to None, the code will not skip frames.")
parser.add_argument("--valid-frame-end", type=int, default=0, help="Valid end index of the frame in each gt video. Defaults to -1, it will skip the last frame. If it is set to 0, the code will not skip frames.")
# fmt: on
args = parser.parse_args()

if args.data_type == "video":
Expand Down
55 changes: 21 additions & 34 deletions metrics/draw_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@

from utils.recorders import CurveDrawer

# Align the mode with those in GRAYSCALE_METRICS
_YX_AXIS_NAMES = {
"pr": ("precision", "recall"),
"fm": ("fmeasure", None),
"fmeasure": ("fmeasure", None),
"em": ("em", None),
"iou": ("iou", None),
"dice": ("dice", None),
}


def draw_curves(
mode: str,
Expand Down Expand Up @@ -40,12 +50,8 @@ def draw_curves(
line_width (int, optional): Width of lines. Defaults to 3.
save_name (str, optional): Name or path (without the extension format). Defaults to None.
"""
assert mode in ["pr", "fm", "em"]
save_name = save_name or mode
mode_axes_setting = axes_setting[mode]

x_label, y_label = mode_axes_setting["x_label"], mode_axes_setting["y_label"]
x_ticks, y_ticks = mode_axes_setting["x_ticks"], mode_axes_setting["y_ticks"]
y_axis_name, x_axis_name = _YX_AXIS_NAMES[mode]

assert curves_npy_path
if not isinstance(curves_npy_path, (list, tuple)):
Expand Down Expand Up @@ -137,14 +143,6 @@ def draw_curves(

for idx, (dataset_name, dataset_alias) in enumerate(dataset_aliases.items()):
dataset_results = curves[dataset_name]
curve_drawer.set_axis_property(
idx=idx,
title=dataset_alias.upper(),
x_label=x_label,
y_label=y_label,
x_ticks=x_ticks,
y_ticks=y_ticks,
)

for method_name in target_unique_method_names:
method_setting = unique_method_settings[method_name]
Expand All @@ -154,30 +152,19 @@ def draw_curves(
continue

method_results = dataset_results[method_name]
if mode == "pr":
y_data = method_results.get("p")
if y_data is None:
y_data = method_results["precision"]
assert isinstance(y_data, (list, tuple)), (method_name, method_results.keys())

x_data = method_results.get("r")
if x_data is None:
x_data = method_results["recall"]
assert isinstance(x_data, (list, tuple)), (method_name, method_results.keys())
elif mode == "fm":
y_data = method_results.get("fm")
if y_data is None:
y_data = method_results["fmeasure"]
assert isinstance(y_data, (list, tuple)), (method_name, method_results.keys())

x_data = np.linspace(0, 1, 256)
elif mode == "em":
y_data = method_results["em"]
if y_axis_name is None:
y_data = np.linspace(0, 1, 256)
else:
y_data = method_results[y_axis_name]
assert isinstance(y_data, (list, tuple)), (method_name, method_results.keys())

if x_axis_name is None:
x_data = np.linspace(0, 1, 256)
else:
x_data = method_results[x_axis_name]
assert isinstance(x_data, (list, tuple)), (method_name, method_results.keys())

curve_drawer.plot_at_axis(
idx=idx, method_curve_setting=method_setting, x_data=x_data, y_data=y_data
)
curve_drawer.plot_at_axis(idx, method_setting, x_data=x_data, y_data=y_data)
curve_drawer.set_axis_property(idx, dataset_alias, **axes_setting[mode])
curve_drawer.save(path=save_name)
105 changes: 49 additions & 56 deletions plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,41 +47,18 @@ def get_args():
),
formatter_class=argparse.RawTextHelpFormatter,
)
# fmt: off
parser.add_argument("--alias-yaml", type=str, help="Yaml file for datasets and methods alias.")
parser.add_argument(
"--style-cfg",
type=str,
required=True,
help="Yaml file for plotting curves.",
)
parser.add_argument(
"--curves-npys",
required=True,
type=str,
nargs="+",
help="Npy file for saving curve results.",
)
parser.add_argument(
"--our-methods", type=str, nargs="+", help="Names of our methods for highlighting it."
)
parser.add_argument(
"--num-rows", type=int, default=1, help="Number of rows for subplots. Default: 1"
)
parser.add_argument(
"--num-col-legend", type=int, default=1, help="Number of columns in the legend. Default: 1"
)
parser.add_argument(
"--mode",
type=str,
choices=["pr", "fm", "em"],
default="pr",
help="Mode for plotting. Default: pr",
)
parser.add_argument(
"--separated-legend", action="store_true", help="Use the separated legend."
)
parser.add_argument("--style-cfg", type=str, required=True, help="Yaml file for plotting curves.")
parser.add_argument("--curves-npys", required=True, type=str, nargs="+", help="Npy file for saving curve results.")
parser.add_argument("--our-methods", type=str, nargs="+", help="Names of our methods for highlighting it.")
parser.add_argument("--num-rows", type=int, default=1, help="Number of rows for subplots. Default: 1")
parser.add_argument("--num-col-legend", type=int, default=1, help="Number of columns in the legend. Default: 1")
parser.add_argument("--mode", type=str, choices=["pr", "fm", "em", "iou", "dice"], default="pr", help="Mode for plotting. Default: pr")
parser.add_argument("--separated-legend", action="store_true", help="Use the separated legend.")
parser.add_argument("--sharey", action="store_true", help="Use the shared y-axis.")
parser.add_argument("--save-name", type=str, help="the exported file path")
# fmt: on
args = parser.parse_args()

return args
Expand All @@ -95,32 +72,48 @@ def main(args):
method_aliases = aliases.get("method")
dataset_aliases = aliases.get("dataset")

# TODO: Better method to set axes_setting
axes_setting = {
# pr curve
"pr": {
"x_label": "Recall",
"y_label": "Precision",
"x_ticks": np.linspace(0.5, 1, 6),
"y_ticks": np.linspace(0.7, 1, 6),
},
# fm curve
"fm": {
"x_label": "Threshold",
"y_label": r"F$_{\beta}$",
"x_ticks": np.linspace(0, 1, 6),
"y_ticks": np.linspace(0.6, 1, 6),
},
# em curve
"em": {
"x_label": "Threshold",
"y_label": r"E$_{m}$",
"x_ticks": np.linspace(0, 1, 6),
"y_ticks": np.linspace(0.7, 1, 6),
},
# iou curve
"iou": {
"x_label": "Threshold",
"y_label": "IoU",
"x_ticks": np.linspace(0, 1, 6),
"y_ticks": np.linspace(0.4, 1, 6),
},
# dice curve
"dice": {
"x_label": "Threshold",
"y_label": "Dice",
"x_ticks": np.linspace(0, 1, 6),
"y_ticks": np.linspace(0.4, 1, 6),
},
}

draw_curves.draw_curves(
mode=args.mode,
# 不同曲线的绘图配置
axes_setting={
# pr曲线的配置
"pr": {
"x_label": "Recall",
"y_label": "Precision",
"x_ticks": np.linspace(0.5, 1, 6),
"y_ticks": np.linspace(0.7, 1, 6),
},
# fm曲线的配置
"fm": {
"x_label": "Threshold",
"y_label": r"F$_{\beta}$",
"x_ticks": np.linspace(0, 1, 6),
"y_ticks": np.linspace(0.6, 1, 6),
},
# em曲线的配置
"em": {
"x_label": "Threshold",
"y_label": r"E$_{m}$",
"x_ticks": np.linspace(0, 1, 6),
"y_ticks": np.linspace(0.7, 1, 6),
},
},
axes_setting=axes_setting,
curves_npy_path=args.curves_npys,
row_num=args.num_rows,
method_aliases=method_aliases,
Expand Down
31 changes: 8 additions & 23 deletions tools/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,15 @@
import numpy as np
import yaml

parser = argparse.ArgumentParser(
description="A useful and convenient tool to convert your .npy results into the table code in latex."
)
parser.add_argument(
"-i",
"--result-file",
required=True,
nargs="+",
action="extend",
help="The path of the *_metrics.npy file.",
)
parser.add_argument(
"-o", "--tex-file", required=True, type=str, help="The path of the exported tex file."
)
parser.add_argument(
"-c", "--config-file", type=str, help="The path of the customized config yaml file."
)
parser.add_argument(
"--contain-table-env",
action="store_true",
help="Whether to containe the table env in the exported code.",
)
# fmt: off
parser = argparse.ArgumentParser(description="A useful and convenient tool to convert your .npy results into the table code in latex.")
parser.add_argument("-i", "--result-file", required=True, nargs="+", action="extend", help="The path of the *_metrics.npy file.")
parser.add_argument("-o", "--tex-file", required=True, type=str, help="The path of the exported tex file.")
parser.add_argument("-c", "--config-file", type=str, help="The path of the customized config yaml file.")
parser.add_argument("--contain-table-env", action="store_true", help="Whether to containe the table env in the exported code.")
parser.add_argument("--num-bits", type=int, default=3, help="Number of valid digits.")
parser.add_argument("--transpose", action="store_true", help="Whether to transpose the table.")
# fmt: on
args = parser.parse_args()

arg_head = f"%% Generated by: {vars(args)}"
Expand Down Expand Up @@ -139,7 +124,7 @@ def update_dict(parent_dict, sub_dict):
metric_row_head=" ",
metric_column_head="& ",
body=[
"& \\best{{{txt:.03f}}}", # style for top1
"& \\first{{{txt:.03f}}}", # style for top1
"& \\second{{{txt:.03f}}}", # style for top2
"& \\third{{{txt:.03f}}}", # style for top3
"& {txt:.03f}", # style for other
Expand Down

0 comments on commit ad3bfd8

Please sign in to comment.