diff --git a/notebooks_alex/pizzaclock.py b/notebooks_alex/pizzaclock.py index 44c45a6f..73357177 100644 --- a/notebooks_alex/pizzaclock.py +++ b/notebooks_alex/pizzaclock.py @@ -65,7 +65,6 @@ def forward(self, x): ) -<<<<<<< HEAD def loss_fn(logits, labels): print(logits.shape) if len(logits.shape) == 3: @@ -74,7 +73,6 @@ def loss_fn(logits, labels): else: logits = logits[:, -1, :] logits = logits.to(torch.float64) -||||||| constructed merge base def loss_fn(logits, labels): print(logits.shape) if len(logits.shape) == 3: @@ -83,7 +81,6 @@ def loss_fn(logits, labels): else: logits = logits[:,-1,:] logits = logits.to(torch.float64) -======= def loss_fn(logits, labels, softmax=True): if softmax: if len(logits.shape) == 3: @@ -94,7 +91,6 @@ def loss_fn(logits, labels, softmax=True): logits = logits.to(torch.float64) log_probs = logits.log_softmax(dim=-1) correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0] ->>>>>>> changed pizzaclock slightly return -correct_log_probs.mean() else: @@ -118,33 +114,9 @@ def loss_fn(logits, labels, softmax=True): equals_vector = einops.repeat(torch.tensor(q), " -> (i j)", i=q, j=q) dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).to(device) -<<<<<<< HEAD -if subtracting: - subtractedset = (dataset[:, 0] - dataset[:, 1]) % q - if freeze_model: - labels = (subtractedset - ((subtractedset > (q // 2)) * subtractedset) * 2) % q - else: - labels = subtractedset # Finds either a-b or b-a depending on which one is lower than q//2. Symmetric in a and b. -else: - labels = (dataset[:, 0] + dataset[:, 1]) % q -||||||| constructed merge base - -if subtracting: - - subtractedset = (dataset[:, 0] - dataset[:, 1]) % q - if freeze_model: - labels = ( - subtractedset - ((subtractedset > (q // 2)) * subtractedset) * 2 - ) % q - else: - labels = subtractedset # Finds either a-b or b-a depending on which one is lower than q//2. Symmetric in a and b. -else: - labels = (dataset[:, 0] + dataset[:, 1]) % q -======= -labels = ((dataset[:, 0] - dataset[:, 1]) ) % q ->>>>>>> changed pizzaclock slightly -optimizer = torch.optim.AdamW( +labels = (dataset[:, 0] + dataset[:, 1]) % q + full_model.parameters(), lr=1e-3, weight_decay=1, betas=(0.9, 0.98) ) @@ -159,41 +131,34 @@ def loss_fn(logits, labels, softmax=True): test_labels = labels[test_indices] for epoch in tqdm.tqdm(range(num_epochs)): -<<<<<<< HEAD if freeze_model: train_logits = full_model(train_data) else: train_logits = full_model.run_with_hooks( train_data, fwd_hooks=[("blocks.0.attn.hook_pattern", hook_fn)] ) -||||||| constructed merge base if freeze_model: train_logits = full_model(train_data) else: train_logits = full_model.run_with_hooks(train_data,fwd_hooks=[("blocks.0.attn.hook_pattern", hook_fn)]) -======= train_logits = full_model.run_with_hooks( train_data, fwd_hooks=[("blocks.0.attn.hook_pattern", hook_fn)] ) ->>>>>>> changed pizzaclock slightly train_loss = loss_fn(train_logits, train_labels) train_loss.backward() optimizer.step() optimizer.zero_grad() with torch.inference_mode(): -<<<<<<< HEAD if freeze_model: test_logits = full_model(test_data) else: test_logits = full_model.run_with_hooks( test_data, fwd_hooks=[("blocks.0.attn.hook_pattern", hook_fn)] ) -||||||| constructed merge base if freeze_model: test_logits = full_model(test_data) else: test_logits = full_model.run_with_hooks(test_data,fwd_hooks=[("blocks.0.attn.hook_pattern", hook_fn)]) -======= test_logits = full_model.run_with_hooks( test_data, fwd_hooks=[("blocks.0.attn.hook_pattern", hook_fn)] ) @@ -201,7 +166,6 @@ def loss_fn(logits, labels, softmax=True): test_logits = full_model.run_with_hooks( test_data, fwd_hooks=[("blocks.0.attn.hook_pattern", hook_fn)] ) ->>>>>>> changed pizzaclock slightly test_loss = loss_fn(test_logits, test_labels) if ((epoch + 1) % 10) == 0: