Skip to content

Commit

Permalink
Merge pull request #176 from macrocosm-os/avg_loss
Browse files Browse the repository at this point in the history
Use average loss across all batches.
  • Loading branch information
cryptal-mc authored Oct 4, 2024
2 parents 9b11aee + 2017c51 commit 362f603
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 30 deletions.
33 changes: 14 additions & 19 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ class PerUIDEvalState:
# The losses per batch.
losses: typing.List[float] = dataclasses.field(default=None)

def avg_loss(self) -> float:
"""Safely computes the average loss from a list of losses."""
return sum(self.losses) / len(self.losses) if self.losses else math.inf


class Validator:
MODEL_TRACKER_FILENAME = "model_tracker.pickle"
Expand Down Expand Up @@ -795,15 +799,14 @@ async def run_step(self):
tokenizer = pt.model.load_tokenizer(
competition.constraints, cache_dir=self.config.model_dir
)

if cur_block >= constants.sample_pack_block:
pack_samples = True
pages_per_eval = constants.pages_per_eval_pack
else:
pack_samples = False
pages_per_eval = constants.pages_per_eval_unpack



# If the option is set in the config, override
pages_per_eval = (
self.config.pages_per_eval
Expand Down Expand Up @@ -909,12 +912,15 @@ async def run_step(self):
)

# Compute wins and win rates per uid.
losses_per_uid = {uid: state.losses for uid, state in uid_to_state.items()}
# Take the average loss across all batches for comparison of best model.
# Keep it as a list of 1 for later calculations.
losses_per_uid = {
uid: [state.avg_loss()] for uid, state in uid_to_state.items()
}
uid_to_block = {uid: state.block for uid, state in uid_to_state.items()}
wins, win_rate = pt.validation.compute_wins(
uids,
losses_per_uid,
batches,
uid_to_block,
competition.constraints.epsilon_func,
cur_block,
Expand Down Expand Up @@ -1042,29 +1048,18 @@ def _record_eval_results(
curr_block (int): The current block.
uid_to_state (typing.Dict[int, PerUIDEvalState]): A dictionary mapping uids to their eval state.
"""
top_model_loss = self._compute_avg_loss(uid_to_state[top_uid].losses)
top_model_loss = uid_to_state[top_uid].avg_loss()
for _, state in uid_to_state.items():
self.model_tracker.on_model_evaluated(
state.hotkey,
EvalResult(
block=curr_block,
score=self._compute_avg_loss(state.losses),
score=state.avg_loss(),
winning_model_block=uid_to_state[top_uid].block,
winning_model_score=top_model_loss,
),
)

def _compute_avg_loss(self, losses: typing.List[float]) -> float:
"""Safely computes the average loss from a list of losses.
Args:
losses (typing.List[float]): A list of losses.
Returns:
float: The average loss.
"""
return sum(losses) / len(losses) if losses else math.inf

def log_step(
self,
competition_id: CompetitionId,
Expand Down Expand Up @@ -1102,7 +1097,7 @@ def log_step(
"block": uid_to_state[uid].block,
"hf": uid_to_state[uid].repo_name,
"competition_id": int(competition_id),
"average_loss": self._compute_avg_loss(uid_to_state[uid].losses),
"average_loss": uid_to_state[uid].avg_loss(),
"epsilon_adv": competition_epsilon_func.compute_epsilon(
current_block, uid_to_state[uid].block
),
Expand Down
22 changes: 11 additions & 11 deletions pretrain/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def iswin(
def compute_wins(
uids: typing.List[int],
losses_per_uid: typing.Dict[int, typing.List[float]],
batches: typing.List[torch.FloatTensor],
uid_to_block: typing.Dict[int, int],
epsilon_func: EpsilonFunc,
current_block: int,
Expand All @@ -78,7 +77,6 @@ def compute_wins(
Parameters:
uids (list): A list of uids to compare.
losses_per_uid (dict): A dictionary of losses for each uid by batch.
batches (List): A list of data batches.
uid_to_block (dict): A dictionary of blocks for each uid.
epsilon_func (EpsilonFunc): Function that determines how much advantage to give to the earlier block.
current_block: The current block.
Expand All @@ -88,20 +86,22 @@ def compute_wins(
"""
wins = {uid: 0 for uid in uids}
win_rate = {uid: 0 for uid in uids}
for i, uid_i in enumerate(uids):
for uid_i in uids:
total_matches = 0
block_i = uid_to_block[uid_i]
for j, uid_j in enumerate(uids):
if i == j:
for uid_j in uids:
if uid_i == uid_j:
continue
block_j = uid_to_block[uid_j]
for batch_idx, _ in enumerate(batches):
loss_i = losses_per_uid[uid_i][batch_idx]
loss_j = losses_per_uid[uid_j][batch_idx]

for loss_i, loss_j in zip(losses_per_uid[uid_i], losses_per_uid[uid_j]):
wins[uid_i] += (
1
if iswin(
loss_i, loss_j, block_i, block_j, epsilon_func, current_block
loss_i,
loss_j,
uid_to_block[uid_i],
uid_to_block[uid_j],
epsilon_func,
current_block,
)
else 0
)
Expand Down

0 comments on commit 362f603

Please sign in to comment.