diff --git a/gbmi/exp_group_finetuning/groups.py b/gbmi/exp_group_finetuning/groups.py index 6a0d9862..ecfb258e 100644 --- a/gbmi/exp_group_finetuning/groups.py +++ b/gbmi/exp_group_finetuning/groups.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, cast, Literal, Generic, TypeVar - +import json +import torch T = TypeVar("T") @@ -11,6 +12,10 @@ class Group(ABC, Generic[T]): def id() -> T: ... + @abstractmethod + def toJSON(self): + ... + @abstractmethod def name(self) -> str: ... @@ -19,29 +24,72 @@ def name(self) -> str: def size(self) -> int: ... + @abstractmethod + def index(self) -> int: + ... + @staticmethod @abstractmethod def parameternames() -> List[str]: ... - @staticmethod @abstractmethod - def op(a: T, b: T) -> T: + def op(self, a: T, b: T) -> T: ... - @classmethod - def reduce(cls, xs: T) -> T: - accumulator = cls.id() + def reduce(self, xs: T) -> T: + accumulator = self.__class__.id() for x in xs: - accumulator = cls.op(accumulator, x) + accumulator = self.op(accumulator, x) return accumulator +class DihedralGroup(Group): + def __init__(self, n: int): + self.n = n + self.lookup = [] + for x in range(2 * n): + self.lookup.append([]) + for y in range(2 * n): + j = x % 2 + if j == 0: + result = (y % 2 + (2 * ((x // 2 + y // 2) % n))) % (2 * n) + else: + result = ((y % 2 + 1) % 2 + (2 * ((x // 2 - y // 2) % n))) % (2 * n) + + self.lookup[x].append(result) + self.lookup = torch.tensor(self.lookup).to("cuda") + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + def name(self) -> str: + return "DihedralGroup" + str(2 * self.n) + + def size(self) -> int: + return 2 * self.n + + def index(self) -> int: + return self.n + + def parameternames() -> List[str]: + return ["modulus"] + + def id(): + return 0 + + def op(self, x, y): + return self.lookup[x][:, y] + + class CyclicGroup(Group): def __init__(self, n: int): self.n = n + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + def name(self) -> str: return "CyclicGroup" + str(self.n) @@ -51,6 +99,9 @@ def size(self) -> int: def parameternames() -> List[str]: return ["modulus"] + def index(self) -> int: + return self.n + def id(): return 0 @@ -58,5 +109,6 @@ def op(self, x, y): return (x + y) % self.n -GroupDict = {"Cyclic": CyclicGroup} +GroupDict = {"CyclicGroup": CyclicGroup, "DihedralGroup": DihedralGroup} cycle = CyclicGroup(5) +dihedral = DihedralGroup(4) diff --git a/gbmi/exp_group_finetuning/train.py b/gbmi/exp_group_finetuning/train.py index 674d556c..79b5183b 100644 --- a/gbmi/exp_group_finetuning/train.py +++ b/gbmi/exp_group_finetuning/train.py @@ -4,9 +4,14 @@ from dataclasses import field from collections.abc import Callable -from groups import Group, GroupDict, CyclicGroup +from gbmi.exp_group_finetuning.groups import ( + Group, + GroupDict, + CyclicGroup, + DihedralGroup, +) import sys -from typing import Any, Dict, List, Optional, cast, Literal, Generic, TypeVar +from typing import Any, Dict, List, Optional, cast, Literal, Generic, TypeVar, Type from gbmi import utils import numpy as np @@ -42,7 +47,9 @@ class ModularFineTuning(ExperimentConfig): model_config: HookedTransformerConfig # using int instead of abstract class because i'm clueless what's going on with typing - group: Group + group_family: str + group_index: int + group_size: int group_name: str zero_biases: bool = True attention_rate: float = 0 # 0 is use attention, 1 is uniformly constant attention @@ -62,7 +69,7 @@ def get_datamodule(self): def get_summary_slug(self, config: Config[ModularFineTuning]) -> str: return ( - f"GroupFineTuning-{config.experiment.model_config.n_ctx}-{config.train_for[0]}-" + f"GroupFineTuning-{config.experiment.group_family+str(config.experiment.group_index)}-{config.experiment.model_config.n_ctx}-{config.train_for[0]}-" f"{config.train_for[1]}-attention-rate-{config.experiment.attention_rate}" f"{'-nondeterministic' if not config.deterministic else ''}" ) @@ -83,7 +90,9 @@ def modular_addition_config(attn_rate: float, group: Group, elements: int): attn_only=False, normalization_type=None, ), - group=group, + group_family=type(group).__name__, + group_index=group.index(), + group_size=group.size(), group_name=group.name(), zero_biases=True, attention_rate=attn_rate, @@ -105,6 +114,12 @@ def modular_addition_config(attn_rate: float, group: Group, elements: int): MODULAR_ADDITION_113_PIZZA_CONFIG = modular_addition_config( attn_rate=1, group=CyclicGroup(113), elements=2 ) +DIHEDRAL_100_CLOCK_CONFIG = modular_addition_config( + attn_rate=0, group=DihedralGroup(104), elements=2 +) +DIHEDRAL_100_PIZZA_CONFIG = modular_addition_config( + attn_rate=1, group=DihedralGroup(104), elements=2 +) class ModularFineTuningTrainingWrapper(TrainingWrapper[ModularFineTuning]): @@ -120,8 +135,8 @@ def build_model(config: Config[ModularFineTuning]) -> HookedTransformer: model_config, { "seed": reseed(config.seed, "model"), - "d_vocab": config.experiment.group.size() + 1, - "d_vocab_out": config.experiment.group.size(), + "d_vocab": config.experiment.group_size + 1, + "d_vocab_out": config.experiment.group_size, }, warn_if_not_default=False, ) @@ -138,9 +153,14 @@ def loss_fn( logits: Float[Tensor, "batch pos d_vocab"], # noqa: F722 labels: Integer[Tensor, "batch"], # noqa: F821 ) -> Float[Tensor, ""]: # noqa: F722 + logits = logits + labels = labels logits = logits[:, -1, :].to(torch.float64) + log_probs = utils.log_softmax(logits, dim=-1) + correct_log_probs = log_probs.gather(-1, labels.unsqueeze(-1))[:, 0] + return -correct_log_probs.mean() @staticmethod @@ -162,7 +182,10 @@ def run_batch( self, x: Float[Tensor, "batch pos"], prefix: str # noqa: F722 ) -> Float[Tensor, ""]: # noqa: F722 self.model.to(x.device, print_details=False) - labels = self.config.experiment.group.reduce(list(x[:, :-1])) + + labels = GroupDict[self.config.experiment.group_family]( + self.config.experiment.group_index + ).reduce(list(x[:, :-1].T)) assert ( len(labels.shape) == 1 ), f"labels.shape == {labels.shape} != 1 (from x.shape == {x.shape})" @@ -170,6 +193,7 @@ def run_batch( x, fwd_hooks=[("blocks.0.attn.hook_pattern", self.attention_hook)] ) loss = self.loss_fn(y_preds, labels) + self.log(f"{prefix}loss", loss, prog_bar=True) acc = self.acc_fn(y_preds, labels) self.log(f"{prefix}acc", acc, prog_bar=True) @@ -206,10 +230,11 @@ def setup(self, stage: str): # Full dataset rng = np.random.default_rng(self.dataset_seed) pairs = generate_all_sequences( - self.config.experiment.group.size(), self.model_config.n_ctx - 1 + self.config.experiment.group_size, + self.model_config.n_ctx - 1, ) # concat a special token of value self.config.experiment.p to the end of each sequence for '=' - equals_token = self.config.experiment.group.size() + equals_token = self.config.experiment.group_size data = torch.cat( [pairs, equals_token * torch.ones((len(pairs), 1))], dim=1 ).long() @@ -296,12 +321,7 @@ def main(argv: List[str] = sys.argv): add_force_argument(parser) add_no_save_argument(parser) - HOOKED_TRANSFORMER_CONFIG_EXCLUDE_ARGS = set( - ( - "d_vocab", - "d_vocab_out", - ) - ) + HOOKED_TRANSFORMER_CONFIG_EXCLUDE_ARGS = set(("d_vocab", "d_vocab_out", "group")) Config.add_arguments(parser) add_HookedTransformerConfig_arguments( parser, exclude_arguments=HOOKED_TRANSFORMER_CONFIG_EXCLUDE_ARGS diff --git a/notebooks_alex/pizzaclock.py b/notebooks_alex/pizzaclock.py index cdef1ba3..4651e526 100644 --- a/notebooks_alex/pizzaclock.py +++ b/notebooks_alex/pizzaclock.py @@ -1,5 +1,14 @@ -from gbmi.exp_modular_fine_tuning.train import MODULAR_ADDITION_113_CLOCK_CONFIG -from gbmi.exp_modular_fine_tuning.train import MODULAR_ADDITION_113_PIZZA_CONFIG +from gbmi.exp_group_finetuning.train import MODULAR_ADDITION_113_CLOCK_CONFIG +from gbmi.exp_group_finetuning.train import MODULAR_ADDITION_113_PIZZA_CONFIG +from gbmi.exp_group_finetuning.train import DIHEDRAL_100_CLOCK_CONFIG +from gbmi.exp_group_finetuning.train import DIHEDRAL_100_PIZZA_CONFIG + +from gbmi.exp_group_finetuning.groups import ( + Group, + GroupDict, + CyclicGroup, + DihedralGroup, +) from gbmi.model import train_or_load_model import torch from math import sqrt @@ -8,11 +17,11 @@ import tqdm device = "cuda" -p = 113 -q = p -freeze_model = False -config = MODULAR_ADDITION_113_PIZZA_CONFIG +freeze_model = False +config = DIHEDRAL_100_PIZZA_CONFIG +p = config.experiment.group_index +q = p * 2 frac_train = 0.3 seed = 999 num_epochs = 5000 @@ -92,6 +101,7 @@ def loss_fn(logits, labels, softmax=True): logits = logits[:, :, -1].squeeze(-1) else: logits = logits[:, -1, :] + logits = logits.to(torch.float64) log_probs = logits.log_softmax(dim=-1) correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0] @@ -116,9 +126,8 @@ def loss_fn(logits, labels, softmax=True): b_vector = einops.repeat(torch.arange(q), "j -> (i j)", i=q) equals_vector = einops.repeat(torch.tensor(q), " -> (i j)", i=q, j=q) dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).to(device) - - -labels = (dataset[:, 0] - dataset[:, 1]) % q +labels = DihedralGroup(104).op(dataset[:, 0], dataset[:, 1]).flatten() +print(labels) optimizer = torch.optim.AdamW( full_model.parameters(), lr=1e-3, weight_decay=1, betas=(0.9, 0.98) )