Skip to content

Commit

Permalink
fix lbfgs initialization when past scores are available
Browse files Browse the repository at this point in the history
  • Loading branch information
amatissart committed Oct 17, 2024
1 parent 8f1ff44 commit 5da2c22
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
2 changes: 1 addition & 1 deletion backend/ml/management/commands/ml_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_solidago_pipeline(run_trust_propagation: bool = True):
convergence_error=1e-5,
cumulant_generating_function_error=1e-5,
high_likelihood_range_threshold=0.25,
# max_iter=100,
# max_iter=300,
),
scaling=ScalingCompose(
Mehestan(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,15 @@ def comparison_learning(
(comparisons["comparison"] / comparisons["comparison_max"]).to_numpy()
)

solution = torch.normal(
0, 1, (len(entities),), requires_grad=True, dtype=torch.float64, device=self.device
)

solution = np.random.normal(0.0, 1.0, size=len(entities))
if initialization is not None:
for (entity_id, values) in initialization.iter_entities():
entity_coord = entity_coordinates.get(entity_id)
if entity_coord is not None:
score, _left, _right = values
solution[entity_coord] = score

solution = torch.tensor(solution, requires_grad=True, device=self.device)
lbfgs = torch.optim.LBFGS(
(solution,),
max_iter=self.max_iter,
Expand All @@ -114,7 +112,7 @@ def closure():

n_iter = lbfgs.state_dict()["state"][0]["n_iter"]
if n_iter >= self.max_iter:
raise RuntimeError(f"LBFGS failed to converge in {n_iter} iteratiions")
raise RuntimeError(f"LBFGS failed to converge in {n_iter} iterations")

solution = solution.detach()
if solution.isnan().any():
Expand Down

0 comments on commit 5da2c22

Please sign in to comment.