Skip to content

Commit

Permalink
pizzaclock stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexGibson0 committed Jan 20, 2024
1 parent f857fae commit 8c4a673
Showing 1 changed file with 2 additions and 38 deletions.
40 changes: 2 additions & 38 deletions notebooks_alex/pizzaclock.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def forward(self, x):
)


<<<<<<< HEAD
def loss_fn(logits, labels):
print(logits.shape)
if len(logits.shape) == 3:
Expand All @@ -74,7 +73,6 @@ def loss_fn(logits, labels):
else:
logits = logits[:, -1, :]
logits = logits.to(torch.float64)
||||||| constructed merge base
def loss_fn(logits, labels):
print(logits.shape)
if len(logits.shape) == 3:
Expand All @@ -83,7 +81,6 @@ def loss_fn(logits, labels):
else:
logits = logits[:,-1,:]
logits = logits.to(torch.float64)
=======
def loss_fn(logits, labels, softmax=True):
if softmax:
if len(logits.shape) == 3:
Expand All @@ -94,7 +91,6 @@ def loss_fn(logits, labels, softmax=True):
logits = logits.to(torch.float64)
log_probs = logits.log_softmax(dim=-1)
correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
>>>>>>> changed pizzaclock slightly

return -correct_log_probs.mean()
else:
Expand All @@ -118,33 +114,9 @@ def loss_fn(logits, labels, softmax=True):
equals_vector = einops.repeat(torch.tensor(q), " -> (i j)", i=q, j=q)
dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).to(device)

<<<<<<< HEAD

if subtracting:
subtractedset = (dataset[:, 0] - dataset[:, 1]) % q
if freeze_model:
labels = (subtractedset - ((subtractedset > (q // 2)) * subtractedset) * 2) % q
else:
labels = subtractedset # Finds either a-b or b-a depending on which one is lower than q//2. Symmetric in a and b.
else:
labels = (dataset[:, 0] + dataset[:, 1]) % q
||||||| constructed merge base

if subtracting:

subtractedset = (dataset[:, 0] - dataset[:, 1]) % q
if freeze_model:
labels = (
subtractedset - ((subtractedset > (q // 2)) * subtractedset) * 2
) % q
else:
labels = subtractedset # Finds either a-b or b-a depending on which one is lower than q//2. Symmetric in a and b.
else:
labels = (dataset[:, 0] + dataset[:, 1]) % q
=======
labels = ((dataset[:, 0] - dataset[:, 1]) ) % q
>>>>>>> changed pizzaclock slightly
optimizer = torch.optim.AdamW(
labels = (dataset[:, 0] + dataset[:, 1]) % q

full_model.parameters(), lr=1e-3, weight_decay=1, betas=(0.9, 0.98)

Check failure on line 120 in notebooks_alex/pizzaclock.py

View workflow job for this annotation

GitHub Actions / ci (3.11, 1.7.1, ubuntu-latest)

IndentationError: unexpected indent
)

Expand All @@ -159,49 +131,41 @@ def loss_fn(logits, labels, softmax=True):
test_labels = labels[test_indices]

for epoch in tqdm.tqdm(range(num_epochs)):
<<<<<<< HEAD
if freeze_model:
train_logits = full_model(train_data)
else:
train_logits = full_model.run_with_hooks(
train_data, fwd_hooks=[("blocks.0.attn.hook_pattern", hook_fn)]
)
||||||| constructed merge base
if freeze_model:
train_logits = full_model(train_data)
else:
train_logits = full_model.run_with_hooks(train_data,fwd_hooks=[("blocks.0.attn.hook_pattern", hook_fn)])
=======
train_logits = full_model.run_with_hooks(
train_data, fwd_hooks=[("blocks.0.attn.hook_pattern", hook_fn)]
)
>>>>>>> changed pizzaclock slightly
train_loss = loss_fn(train_logits, train_labels)
train_loss.backward()
optimizer.step()
optimizer.zero_grad()
with torch.inference_mode():
<<<<<<< HEAD
if freeze_model:
test_logits = full_model(test_data)
else:
test_logits = full_model.run_with_hooks(
test_data, fwd_hooks=[("blocks.0.attn.hook_pattern", hook_fn)]
)
||||||| constructed merge base
if freeze_model:
test_logits = full_model(test_data)
else:
test_logits = full_model.run_with_hooks(test_data,fwd_hooks=[("blocks.0.attn.hook_pattern", hook_fn)])
=======
test_logits = full_model.run_with_hooks(
test_data, fwd_hooks=[("blocks.0.attn.hook_pattern", hook_fn)]
)

test_logits = full_model.run_with_hooks(
test_data, fwd_hooks=[("blocks.0.attn.hook_pattern", hook_fn)]
)
>>>>>>> changed pizzaclock slightly
test_loss = loss_fn(test_logits, test_labels)

if ((epoch + 1) % 10) == 0:
Expand Down

0 comments on commit 8c4a673

Please sign in to comment.