Skip to content

Commit

Permalink
WIP Debug Preference Learning
Browse files Browse the repository at this point in the history
  • Loading branch information
lenhoanglnh committed Jan 10, 2025
1 parent 6acf0b5 commit 977cf57
Show file tree
Hide file tree
Showing 17 changed files with 1,485 additions and 1,378 deletions.
63 changes: 61 additions & 2 deletions solidago/experiments/toy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@

from pandas import DataFrame, Series
from pathlib import Path
from numba import njit

from solidago import *
from solidago.primitives.optimize import coordinate_descent, njit_brentq


# t = TournesolExport("tests/tiny_tournesol.zip")

generative_model = GenerativeModel.load("tests/generative_model/test_generative_model.json")
# generative_model = GenerativeModel.load("tests/generative_model/test_generative_model.json")
# s = generative_model()
# s.save("tests/pipeline/saved")

# s = State.load("tests/pipeline/saved")
s = State.load("tests/pipeline/saved")
pipeline= Sequential.load("tests/pipeline/test_pipeline.json")

# s = pipeline(s, "tests/pipeline/saved")
Expand All @@ -27,3 +31,58 @@
# s = pipeline.scaling.state2state_function(s, save_directory="tests/pipeline/saved")
# s = pipeline.aggregation.state2state_function(s, save_directory="tests/pipeline/saved")
# s = pipeline.post_process.state2state_function(s, save_directory="tests/pipeline/saved")

self = NumbaUniformGBT()

assessments = s.assessments.reorder_keys(["username", "criterion", "entity_name"])
comparisons = s.comparisons.reorder_keys(["username", "criterion", "left_name", "right_name"])
user = next(iter(s.users))
assessments = assessments[user]
comparisons = comparisons[user]

compared_entity_names = comparisons.get_set("left_name") | comparisons.get_set("right_name")
entities = s.entities.get(compared_entity_names)
init = s.user_models[user](entities).reorder_keys(["criterion", "entity_name"])
comparisons = comparisons.reorder_keys(["criterion", "left_name", "right_name"])
criteria = comparisons.get_set("criterion") | init.get_set("criterion")
criterion = next(iter(criteria))
comparisons = comparisons[criterion]
init = init[criterion]

entity_name2index = { str(entity): index for index, entity in enumerate(entities) }
comparisons = comparisons.order_by_entities()

entity_index = np.random.randint(len(entities))
entity_name = entities.iloc[entity_index].name
scores = np.arange(len(entities), dtype=np.float64)

def get_partial_derivative_args(entity_index: int, scores: np.ndarray) -> tuple:
entity_name = entities.iloc[entity_index].name
normalized_comparisons = comparisons[entity_name].normalized_comparisons(self.last_comparison_only)
df = comparisons[entity_name].to_df(last_row_only=self.last_comparison_only)
indices = df["other_name"].map(entity_name2index)
return scores[indices], np.array(normalized_comparisons)

get_partial_derivative_args(entity_index, scores)

empty_function = lambda coordinate, variable: tuple()
get_update_coordinate_function_args = empty_function

def coordinate_function(coordinate: int, variable: np.ndarray[np.float64]):
@njit
def f(value: np.float64, *partial_derivative_args) -> np.float64:
return self.partial_derivative(coordinate, np.array([
variable[i] if i != coordinate else value
for i in range(len(variable))
], dtype=np.float64), *partial_derivative_args)
return f

coordinate_optimization_xtol = 1e-5
def update_coordinate_function(coordinate: int, variable: np.ndarray[np.float64], *coordinate_update_args) -> float:
return njit_brentq(
f=coordinate_function(coordinate, variable),
args=get_partial_derivative_args(coordinate, variable, *coordinate_update_args),
xtol=coordinate_optimization_xtol,
a=variable[coordinate] - 1.0,
b=variable[coordinate] + 1.0
)
22 changes: 9 additions & 13 deletions solidago/src/solidago/_pipeline/_preference_learning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,21 @@ def __call__(self,
user_models: UserModels
) -> UserModels:
""" Learns a scoring model, given user judgments of entities """
learned_models = UserModels()
comparison_key_names = ["username", "criterion", "left_name", "right_name"]
reordered_comparisons = comparisons.reorder_keys(comparison_key_names)
result = UserModels()
assessments = assessments.reorder_keys(["username", "criterion", "entity_name"])
comparisons = comparisons.reorder_keys(["username", "criterion", "left_name", "right_name"])
for user in users:
learned_models[user] = self.user_learn(
user,
entities,
assessments[user],
comparisons[user],
user_models[user].base_model()[0] if user in user_models else DirectScoring()
)
return learned_models
logger.info(f" Learning user {user}'s base model")
result[user] = self.user_learn(user, entities, assessments[user], comparisons[user],
user_models[user].base_model()[0])
return result

@abstractmethod
def user_learn(self,
user: User,
entities: Entities,
assessments: Assessments,
comparisons: Comparisons,
assessments: Assessments, # key_names == ["criterion", "entity_name"]
comparisons: Comparisons, # key_names == ["criterion", "left_name", "right_name"]
base_model: BaseModel
) -> BaseModel:
"""Learns a scoring model, given user judgments of entities """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def compute_scores(self,
entities: Entities,
entity_name2index: dict[str, int],
comparisons: Comparisons, # key_names == ["left_name, right_name"]
init_multiscores : MultiScore, # key_names == "entity_name"
init_multiscores : MultiScore, # key_names == ["entity_name"]
) -> npt.NDArray:
""" Computes the scores given comparisons """
scores = self.init_scores(entity_name2index, init_multiscores)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def compute_scores(self,
init_multiscores : MultiScore, # key_names == "entity_name"
) -> npt.NDArray:
""" Computes the scores given comparisons """
entity_ordered_comparisons = comparisons.order_by_entities()
comparisons = comparisons.order_by_entities()
def get_partial_derivative_args(entity_index: int, scores: np.ndarray) -> tuple:
entity_name = entities.iloc[entity_index].name
normalized_comparisons = comparisons.normalized_comparisons(self.last_comparison_only)
df = entity_ordered_comparisons[entity_name].to_df(last_row_only=self.last_comparison_only)
normalized_comparisons = comparisons[entity_name].normalized_comparisons(self.last_comparison_only)
df = comparisons[entity_name].to_df(last_row_only=self.last_comparison_only)
indices = df["other_name"].map(entity_name2index)
return scores[indices], np.array(normalized_comparisons)

Expand All @@ -106,14 +106,14 @@ def partial_derivative(self) -> Callable[[int, np.ndarray[np.float64], dict, dic

@njit
def njit_partial_derivative(
coordinate: int,
entity_index: int,
scores: float,
compared_scores: npt.NDArray,
normalized_comparisons: npt.NDArray,
) -> npt.NDArray:
score_diffs = scores[coordinate] - compared_scores
score_diffs = scores[entity_index] - compared_scores
nll_derivative = np.sum(cfg_deriv(score_diffs) - normalized_comparisons)
prior_derivative = scores[coordinate] / prior_var
prior_derivative = scores[entity_index] / prior_var
return prior_derivative + nll_derivative

return njit_partial_derivative
Expand Down
16 changes: 12 additions & 4 deletions solidago/src/solidago/primitives/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

from typing import Callable, Tuple, Literal, Optional, Any
from functools import cache

import numpy as np
from numba import njit
Expand Down Expand Up @@ -180,6 +181,7 @@ def coordinate_updates(
initialization: np.ndarray,
updated_coordinates: Optional[list[int]]=None,
error: float=1e-5,
max_iter: int=10000,
):
"""Minimize a loss function with coordinate descent,
by leveraging the partial derivatives of the loss
Expand All @@ -197,6 +199,8 @@ def coordinate_updates(
Initialization point of the coordinate descent
error: float
Tolerated error
max_iter: int
Maximum number of iterations
Returns
-------
Expand All @@ -207,6 +211,7 @@ def coordinate_updates(
to_pick = list() if updated_coordinates is None else updated_coordinates
variable = initialization
variable_len = len(variable)
iteration_number = 0

def pick_next_coordinate():
nonlocal to_pick
Expand All @@ -215,7 +220,8 @@ def pick_next_coordinate():
np.random.shuffle(to_pick)
return to_pick.pop()

while len(unchanged) < variable_len:
while len(unchanged) < variable_len and iteration_number < max_iter:
iteration_number += 1
coordinate = pick_next_coordinate()
if coordinate in unchanged:
continue
Expand Down Expand Up @@ -271,12 +277,14 @@ def coordinate_descent(
# First define the update_coordinate_function associated to coordinatewise descent
# by leveraging njit and brentq

empty_function = lambda coordinate, variable: tuple()

if get_partial_derivative_args is None:
get_partial_derivative_args = lambda coordinate, variable: tuple()
get_partial_derivative_args = empty_function

if get_update_coordinate_function_args is None:
get_update_coordinate_function_args = lambda coordinate, variable: tuple()
get_update_coordinate_function_args = empty_function

def coordinate_function(
coordinate: int,
variable: np.ndarray[np.float64],
Expand Down
60 changes: 30 additions & 30 deletions solidago/tests/load_save/generated_state/assessments.csv
Original file line number Diff line number Diff line change
@@ -1,68 +1,68 @@
username,criterion,entity_name,assessment,assessment_min,assessment_max
user_0,default,entity_0,-3.0840321464854243,-inf,inf
user_0,default,entity_0,-3.084032146485425,-inf,inf
user_0,default,entity_16,-4.0878539983787,-inf,inf
user_0,default,entity_15,-1.4902020610943079,-inf,inf
user_0,default,entity_15,-1.490202061094308,-inf,inf
user_0,default,entity_18,-1.0038588539696902,-inf,inf
user_0,default,entity_9,0.15862347293943008,-inf,inf
user_0,default,entity_13,0.6825297084581011,-inf,inf
user_0,default,entity_19,-1.1330286216672034,-inf,inf
user_0,default,entity_5,1.527715830950634,-inf,inf
user_0,default,entity_9,0.15862347293943002,-inf,inf
user_0,default,entity_13,0.6825297084581012,-inf,inf
user_0,default,entity_19,-1.1330286216672032,-inf,inf
user_0,default,entity_5,1.5277158309506345,-inf,inf
user_0,default,entity_4,3.4585969415576345,-inf,inf
user_1,default,entity_3,1.667515895613929,-inf,inf
user_1,default,entity_14,1.006242789479527,-inf,inf
user_1,default,entity_7,2.664075720463755,-inf,inf
user_1,default,entity_18,-0.8301415440372613,-inf,inf
user_1,default,entity_8,-0.17420902398789284,-inf,inf
user_1,default,entity_8,-0.1742090239878929,-inf,inf
user_1,default,entity_10,-1.1835582791092611,-inf,inf
user_1,default,entity_6,-1.1948480725179846,-inf,inf
user_1,default,entity_16,-0.9539951632086848,-inf,inf
user_1,default,entity_12,-2.180286196407139,-inf,inf
user_1,default,entity_16,-0.9539951632086847,-inf,inf
user_1,default,entity_12,-2.1802861964071387,-inf,inf
user_1,default,entity_2,-0.42652666498132275,-inf,inf
user_2,default,entity_11,-1.387325030216633,-inf,inf
user_2,default,entity_9,-0.9284619449437295,-inf,inf
user_2,default,entity_18,-0.4850896006442237,-inf,inf
user_2,default,entity_18,-0.48508960064422346,-inf,inf
user_2,default,entity_14,-0.5877998959704127,-inf,inf
user_2,default,entity_17,-0.8219599742190198,-inf,inf
user_2,default,entity_0,-0.7750717258171586,-inf,inf
user_2,default,entity_17,-0.8219599742190199,-inf,inf
user_2,default,entity_0,-0.7750717258171584,-inf,inf
user_2,default,entity_4,-0.7272800370864463,-inf,inf
user_2,default,entity_13,-0.33567087330509393,-inf,inf
user_2,default,entity_5,1.0737649403543752,-inf,inf
user_2,default,entity_1,1.3235236743850214,-inf,inf
user_2,default,entity_1,1.3235236743850212,-inf,inf
user_2,default,entity_16,2.548852058606136,-inf,inf
user_3,default,entity_19,0.945310811705114,-inf,inf
user_3,default,entity_1,2.280503660781391,-inf,inf
user_3,default,entity_19,0.9453108117051138,-inf,inf
user_3,default,entity_1,2.2805036607813913,-inf,inf
user_3,default,entity_13,-1.5769377112488974,-inf,inf
user_3,default,entity_16,1.5165254267520498,-inf,inf
user_3,default,entity_10,2.072081238298041,-inf,inf
user_3,default,entity_3,-1.1848498093087512,-inf,inf
user_3,default,entity_7,1.3271180042064845,-inf,inf
user_3,default,entity_18,-0.25013042087067283,-inf,inf
user_3,default,entity_12,-1.0216656841229712,-inf,inf
user_3,default,entity_12,-1.021665684122971,-inf,inf
user_3,default,entity_6,-1.0810028748325473,-inf,inf
user_3,default,entity_2,1.203173894386199,-inf,inf
user_4,default,entity_19,0.2639409528014471,-inf,inf
user_4,default,entity_8,0.440512579577302,-inf,inf
user_4,default,entity_2,-0.43449183897222454,-inf,inf
user_4,default,entity_17,0.6740247220648662,-inf,inf
user_4,default,entity_18,1.2408002798143847,-inf,inf
user_4,default,entity_3,0.48288392306535105,-inf,inf
user_4,default,entity_3,0.482883923065351,-inf,inf
user_4,default,entity_15,0.8328232652751022,-inf,inf
user_4,default,entity_12,-0.49299657597255997,-inf,inf
user_4,default,entity_14,-0.5966527881357853,-inf,inf
user_4,default,entity_12,-0.4929965759725601,-inf,inf
user_4,default,entity_14,-0.5966527881357854,-inf,inf
user_4,default,entity_7,0.11951141508835739,-inf,inf
user_4,default,entity_16,-0.5090007368392502,-inf,inf
user_5,default,entity_6,-0.21462643170061715,-inf,inf
user_5,default,entity_17,0.05710431263487674,-inf,inf
user_6,default,entity_13,-0.6865530798790694,-inf,inf
user_5,default,entity_6,-0.21462643170061727,-inf,inf
user_5,default,entity_17,0.05710431263487662,-inf,inf
user_6,default,entity_13,-0.6865530798790695,-inf,inf
user_6,default,entity_16,0.27755782227896963,-inf,inf
user_6,default,entity_4,-0.6587271432487041,-inf,inf
user_6,default,entity_4,-0.6587271432487042,-inf,inf
user_6,default,entity_8,-0.5583548331157635,-inf,inf
user_6,default,entity_18,1.9506928061461883,-inf,inf
user_6,default,entity_1,0.49050679907782496,-inf,inf
user_6,default,entity_2,-1.3252639653101213,-inf,inf
user_6,default,entity_0,-0.3233871818836944,-inf,inf
user_6,default,entity_2,-1.3252639653101215,-inf,inf
user_6,default,entity_0,-0.3233871818836943,-inf,inf
user_7,default,entity_4,0.18732339474126747,-inf,inf
user_7,default,entity_8,1.9008743051155528,-inf,inf
user_7,default,entity_8,1.900874305115553,-inf,inf
user_7,default,entity_19,0.0966132005912641,-inf,inf
user_7,default,entity_3,0.9342862554614455,-inf,inf
user_7,default,entity_11,1.6668385201561917,-inf,inf
Expand All @@ -73,13 +73,13 @@ user_7,default,entity_14,3.0454263676652222,-inf,inf
user_8,default,entity_8,2.4404515835847844,-inf,inf
user_8,default,entity_4,2.1409310166269293,-inf,inf
user_8,default,entity_11,0.8428911156251948,-inf,inf
user_8,default,entity_14,0.11540882613427506,-inf,inf
user_8,default,entity_14,0.11540882613427472,-inf,inf
user_8,default,entity_6,0.45813136118948833,-inf,inf
user_8,default,entity_2,0.31277617960683546,-inf,inf
user_9,default,entity_3,-1.962426062799881,-inf,inf
user_9,default,entity_3,-1.9624260627998813,-inf,inf
user_9,default,entity_12,-0.6855884162192276,-inf,inf
user_9,default,entity_2,-0.885308032694007,-inf,inf
user_9,default,entity_2,-0.8853080326940073,-inf,inf
user_9,default,entity_14,0.5609584765999138,-inf,inf
user_9,default,entity_9,-0.023301797645628702,-inf,inf
user_9,default,entity_5,-0.7796727684945641,-inf,inf
user_9,default,entity_5,-0.779672768494564,-inf,inf
user_9,default,entity_1,-2.172493689682739,-inf,inf
44 changes: 22 additions & 22 deletions solidago/tests/load_save/generated_state/vouches.csv
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
by,to,kind,weight,priority
user_0,user_3,Personhood,0.9833949546677002,0
user_0,user_7,Personhood,0.08528809340315868,0
user_0,user_9,Personhood,0.18281300923230315,0
user_0,user_4,Personhood,0.7912319527696452,0
user_0,user_8,Personhood,0.7896824062501574,0
user_0,user_6,Personhood,0.840778792639736,0
user_0,user_2,Personhood,0.5238654715100963,0
user_0,user_1,Personhood,0.8925993384032407,0
user_3,user_9,Personhood,0.9742272864750473,0
user_3,user_6,Personhood,0.6507377163366674,0
user_3,user_1,Personhood,0.09393498035440484,0
user_8,user_7,Personhood,0.9997564507524814,0
user_8,user_4,Personhood,0.936525738441582,0
user_6,user_0,Personhood,0.9866633771143694,0
user_1,user_0,Personhood,0.820148529887458,0
user_1,user_3,Personhood,0.9176014249118304,0
user_1,user_7,Personhood,0.8278937779667541,0
user_1,user_9,Personhood,0.31332771795896475,0
user_1,user_4,Personhood,0.9978833273449196,0
user_1,user_8,Personhood,0.8785342490885062,0
user_1,user_6,Personhood,0.028806646041575235,0
user_1,user_2,Personhood,0.18106849188645013,0
user_1,user_2,Personhood,0.7896824062501574,0
user_1,user_4,Personhood,0.840778792639736,0
user_1,user_8,Personhood,0.5238654715100963,0
user_1,user_7,Personhood,0.8925993384032407,0
user_1,user_9,Personhood,0.5954263337434686,0
user_1,user_6,Personhood,0.9742272864750473,0
user_1,user_0,Personhood,0.07999942762187195,0
user_1,user_3,Personhood,0.6507377163366674,0
user_8,user_4,Personhood,0.8201966493572297,0
user_8,user_6,Personhood,0.08035213916732376,0
user_6,user_8,Personhood,0.9866633771143694,0
user_0,user_2,Personhood,0.5919834173260232,0
user_0,user_1,Personhood,0.021068292613971695,0
user_0,user_4,Personhood,0.3865438146392457,0
user_0,user_8,Personhood,0.9417170498803228,0
user_0,user_7,Personhood,0.9394529089921024,0
user_0,user_9,Personhood,0.7323918980299693,0
user_0,user_6,Personhood,0.6923214349223931,0
user_0,user_3,Personhood,0.5007521844552536,0
user_3,user_2,Personhood,0.8700068988139138,0
user_3,user_8,Personhood,0.9458846822069802,0
user_3,user_7,Personhood,0.3358296375028642,0
Loading

0 comments on commit 977cf57

Please sign in to comment.