Skip to content

Commit

Permalink
--heads key for correct multihead finetuning
Browse files Browse the repository at this point in the history
  • Loading branch information
vue1999 committed Feb 18, 2025
1 parent 21e7269 commit 1555d3a
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 57 deletions.
133 changes: 76 additions & 57 deletions mace/cli/plot_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@



plt.rcParams.update({"font.size": 11})
plt.rcParams.update({"font.size": 8})
plt.style.use("seaborn-v0_8-paper")


Expand Down Expand Up @@ -90,25 +90,34 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--output_format", help="What file type to save plot as", default='pdf', type=str, required=False,
)

parser.add_argument(
"--heads", help="Comma-separated name of the heads used for multihead training", default=None, type=str, required=False,
)

return parser.parse_args()


def plot(data: pd.DataFrame,
min_epoch: int,
output_path: str,
output_path: str,
output_format: str,
linear: bool,
start_swa: int,
error_bars: bool,
keys: str) -> None:
keys: str,
heads: str) -> None:

"""
Plots train,validation loss and errors as a function of epoch.
min_epoch: minimum epoch to plot.
output_path: path to save the plot.
output_format: format to save the plot.
start_swa: whether to plot a dashed line to show epoch when stage two loss (swa) begins.
error_bars: whether to plot standard deviation of loss.
linear: whether to plot in linear scale or logscale (default).
keys: Values to plot.
keys: Values to plot.
heads: Heads used for multihead training.
"""

labels={"mae_e":"MAE E [meV]",
Expand All @@ -127,71 +136,81 @@ def plot(data: pd.DataFrame,
"mae_virials": "MAE Virials [meV]",
"rmse_mu_per_atom": "RMSE MU/atom [mDebye]",
}



data = data[data["epoch"] > min_epoch]
if heads is None:
data = data.groupby(["name", "mode", "epoch"]).agg(["mean", "std"]).reset_index()

valid_data = data[data["mode"] == "eval"]
valid_data_dict={"default": valid_data}
train_data = data[data["mode"] == "opt"]
else:
heads = heads.split(",")
# Separate eval and opt data
valid_data = data[data["mode"] == "eval"].groupby(["name", "mode", "epoch", "head"]).agg(["mean", "std"]).reset_index()
train_data = data[data["mode"] == "opt"].groupby(["name", "mode", "epoch"]).agg(["mean", "std"]).reset_index()
valid_data_dict = {
head: valid_data[valid_data["head"] == head]
for head in heads
}

data = data.groupby(["name", "mode", "epoch"]).agg(["mean", "std"]).reset_index()

valid_data = data[data["mode"] == "eval"]
train_data = data[data["mode"] == "opt"]

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(5, 2.5), constrained_layout=True)
for head, valid_data in valid_data_dict.items():
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(5, 2.5), constrained_layout=True)

# ---- Plot loss ----
ax = axes[0]
if not linear:
ax.set_yscale("log")
# ---- Plot loss ----
ax = axes[0]
if not linear:
ax.set_yscale("log")

ax.plot(train_data["epoch"], train_data["loss"]["mean"], color=colors[1], label="Training", linewidth=1)
ax.plot(valid_data["epoch"], valid_data["loss"]["mean"], color=colors[0], label="Validation", linewidth=1)
ax.plot(train_data["epoch"], train_data["loss"]["mean"], color=colors[1], label="Training", linewidth=1)
ax.plot(valid_data["epoch"], valid_data["loss"]["mean"], color=colors[0], label="Validation", linewidth=1)

if error_bars:
ax.fill_between(train_data["epoch"], train_data["loss"]["mean"] - train_data["loss"]["std"],
train_data["loss"]["mean"] + train_data["loss"]["std"], alpha=0.3, color=colors[1])
ax.fill_between(valid_data["epoch"], valid_data["loss"]["mean"] - valid_data["loss"]["std"],
valid_data["loss"]["mean"] + valid_data["loss"]["std"], alpha=0.3, color=colors[0])
if error_bars:
ax.fill_between(train_data["epoch"], train_data["loss"]["mean"] - train_data["loss"]["std"],
train_data["loss"]["mean"] + train_data["loss"]["std"], alpha=0.3, color=colors[1])
ax.fill_between(valid_data["epoch"], valid_data["loss"]["mean"] - valid_data["loss"]["std"],
valid_data["loss"]["mean"] + valid_data["loss"]["std"], alpha=0.3, color=colors[0])

if start_swa is not None:
ax.axvline(start_swa, color="black", linestyle="dashed", linewidth=1, alpha=0.6, label="Stage Two Starts")
if start_swa is not None:
ax.axvline(start_swa, color="black", linestyle="dashed", linewidth=1, alpha=0.6, label="Stage Two Starts")

ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.legend(loc="upper right", fontsize=4)
ax.grid(True, linestyle="--", alpha=0.5)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.legend(loc="upper right", fontsize=4)
ax.grid(True, linestyle="--", alpha=0.5)

# ---- Plot selected keys ----
ax = axes[1]
twin_axes = [] # Store twin axes for multiple y-scales
for i, key in enumerate(keys.split(",")):
color = colors[(i + 2)]
label = labels.get(key, key)
# ---- Plot selected keys ----
ax = axes[1]
twin_axes = []
for i, key in enumerate(keys.split(",")):
color = colors[(i + 2)]
label = labels.get(key, key)

if i == 0:
main_ax = ax
else:
main_ax = ax.twinx()
main_ax.spines.right.set_position(("outward", 40 * (i - 1)))
twin_axes.append(main_ax)
if i == 0:
main_ax = ax
else:
main_ax = ax.twinx()
main_ax.spines.right.set_position(("outward", 40 * (i - 1)))
twin_axes.append(main_ax)

main_ax.plot(valid_data["epoch"], valid_data[key]["mean"] * 1e3, color=color, label=label, linewidth=1)
main_ax.plot(valid_data["epoch"], valid_data[key]["mean"] * 1e3, color=color, label=label, linewidth=1)

if error_bars:
main_ax.fill_between(valid_data["epoch"], (valid_data[key]["mean"] - valid_data[key]["std"]) * 1e3,
(valid_data[key]["mean"] + valid_data[key]["std"]) * 1e3, alpha=0.3, color=color)
if error_bars:
main_ax.fill_between(valid_data["epoch"], (valid_data[key]["mean"] - valid_data[key]["std"]) * 1e3,
(valid_data[key]["mean"] + valid_data[key]["std"]) * 1e3, alpha=0.3, color=color)

main_ax.set_ylabel(label, color=color)
main_ax.tick_params(axis="y", colors=color)

if start_swa is not None:
ax.axvline(start_swa, color="black", linestyle="dashed", linewidth=1, alpha=0.6, label="Stage Two Starts")
main_ax.set_ylabel(label, color=color)
main_ax.tick_params(axis="y", colors=color)
if start_swa is not None:
ax.axvline(start_swa, color="black", linestyle="dashed", linewidth=1, alpha=0.6, label="Stage Two Starts")

ax.set_xlabel("Epoch")
ax.set_xlim(left=min_epoch)
ax.grid(True, linestyle="--", alpha=0.5)
ax.set_xlabel("Epoch")
ax.set_xlim(left=min_epoch)
ax.grid(True, linestyle="--", alpha=0.5)

fig.savefig(output_path, dpi=300, bbox_inches="tight")
plt.close(fig)
fig.savefig(f"{output_path}_{head}.{output_format}", dpi=300, bbox_inches="tight")
plt.close(fig)



Expand Down Expand Up @@ -219,9 +238,9 @@ def run(args: argparse.Namespace) -> None:
)

for name, group in data.groupby("name"):
plot(group, min_epoch=args.min_epoch, output_path=f"{name}.{args.output_format}",
plot(group, min_epoch=args.min_epoch, output_path=name, output_format=args.output_format,
linear=args.linear, start_swa=args.start_swa, error_bars=args.error_bars,
keys=args.keys)
keys=args.keys, heads=args.heads)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def valid_err_log(
):
eval_metrics["mode"] = "eval"
eval_metrics["epoch"] = epoch
eval_metrics["head"] = valid_loader_name
logger.log(eval_metrics)
if epoch is None:
inintial_phrase = "Initial"
Expand Down

0 comments on commit 1555d3a

Please sign in to comment.