Skip to content

Commit

Permalink
WIP Preference Learning
Browse files Browse the repository at this point in the history
  • Loading branch information
lenhoanglnh committed Jan 10, 2025
1 parent c6659f6 commit 6acf0b5
Show file tree
Hide file tree
Showing 14 changed files with 2,801 additions and 1,224 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def init_scores(self,
entity_name2index: dict[str, int],
init_multiscores: MultiScore, # key_names == "entity_name"
) -> npt.NDArray:
scores = np.zeros(len(entity_name2index))
scores = np.zeros(len(entity_name2index), dtype=np.float64)
for entity, init_score in init_multiscores:
if not init_score.isnan():
scores[entity_name2index[str(entity)]] = init_score.value
Expand Down Expand Up @@ -171,10 +171,10 @@ def compute_uncertainties(self,
rights: npt.NDArray
rights[i] is the right uncertainty on scores[i]
"""
compared_entity_indices = comparisons.compared_entity_indices(entity_name2index)
indices = { loc: np.array(compared_entity_indices[loc]) for loc in ("left", "right") }
indices = comparisons.compared_entity_indices(entity_name2index, self.last_comparison_only)
indices = { loc: np.array(indices[loc]) for loc in ("left", "right") }
score_diffs = scores[indices["left"]] - scores[indices["right"]]
normalized_comparisons = comparisons.normalized_comparisons()
normalized_comparisons = comparisons.normalized_comparisons(self.last_comparison_only)
score_negative_log_likelihood = self.negative_log_likelihood(score_diffs, normalized_comparisons)

kwargs = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self,
convergence_error: float=1e-5,
max_iter: int=100,
device: torch.device=default_device,
last_comparison_only: bool=True,
):
""" Generalized Bradley Terry is a class of porbability models of comparisons,
introduced in the paper "Generalized Bradley-Terry Models for Score Estimation
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(self,
prior_std_dev=prior_std_dev,
uncertainty_nll_increase=uncertainty_nll_increase,
max_uncertainty=max_uncertainty,
last_comparison_only=last_comparison_only,
)
self.convergence_error = convergence_error
self.max_iter = max_iter
Expand Down Expand Up @@ -124,9 +126,9 @@ def negative_log_posterior(self,
comparisons: Comparisons,
) -> torch.Tensor:
""" Negative log posterior """
entity_indices = comparisons.compared_entity_indices(entity_name2index)
score_diffs = scores[entity_indices["left"]] - scores[entity_indices["right"]]
normalized_comparisons = comparisons.normalized_comparisons()
indices = comparisons.compared_entity_indices(entity_name2index, self.last_comparison_only)
score_diffs = scores[indices["left"]] - scores[indices["right"]]
normalized_comparisons = comparisons.normalized_comparisons(self.last_comparison_only)
loss = self.cumulant_generating_function(score_diffs).sum()
loss -= (score_diffs * torch.tensor(normalized_comparisons)).sum()
return loss + (scores**2).sum() / (2 * self.prior_std_dev**2)
Expand All @@ -140,6 +142,7 @@ def __init__(self,
convergence_error: float=1e-5,
max_iter: int=100,
device: torch.device=default_device,
last_comparison_only: bool=True,
):
"""
Parameters (TODO)
Expand All @@ -152,6 +155,7 @@ def __init__(self,
convergence_error=convergence_error,
max_iter=max_iter,
device=device,
last_comparison_only=last_comparison_only
)

def cumulant_generating_function(self, score_diffs: torch.Tensor) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self,
uncertainty_nll_increase: float=1.0,
max_uncertainty: float=1e3,
convergence_error: float=1e-5,
last_comparison_only: bool=True,
):
""" Generalized Bradley Terry is a class of porbability models of comparisons,
introduced in the paper "Generalized Bradley-Terry Models for Score Estimation
Expand Down Expand Up @@ -45,6 +46,7 @@ def __init__(self,
prior_std_dev=prior_std_dev,
uncertainty_nll_increase=uncertainty_nll_increase,
max_uncertainty=max_uncertainty,
last_comparison_only=last_comparison_only,
)
self.convergence_error = convergence_error

Expand Down Expand Up @@ -79,10 +81,10 @@ def compute_scores(self,
entity_ordered_comparisons = comparisons.order_by_entities()
def get_partial_derivative_args(entity_index: int, scores: np.ndarray) -> tuple:
entity_name = entities.iloc[entity_index].name
df = entity_ordered_comparisons[entity_name].to_df()
normalized_comparisons = np.array(df["comparison"] / df["comparison_max"])
normalized_comparisons = comparisons.normalized_comparisons(self.last_comparison_only)
df = entity_ordered_comparisons[entity_name].to_df(last_row_only=self.last_comparison_only)
indices = df["other_name"].map(entity_name2index)
return scores[indices], normalized_comparisons
return scores[indices], np.array(normalized_comparisons)

return coordinate_descent(
self.partial_derivative,
Expand Down Expand Up @@ -123,6 +125,7 @@ def __init__(self,
uncertainty_nll_increase: float=1.0,
max_uncertainty: float=1e3,
convergence_error: float=1e-5,
last_comparison_only: bool=True,
):
"""
Expand All @@ -134,6 +137,7 @@ def __init__(self,
uncertainty_nll_increase=uncertainty_nll_increase,
max_uncertainty=max_uncertainty,
convergence_error=convergence_error,
last_comparison_only=last_comparison_only,
)

@cached_property
Expand Down
6 changes: 3 additions & 3 deletions solidago/src/solidago/_state/_comparisons/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,16 @@ def order_by_entities(self) -> "Comparisons": # key_names == ["entity_name", "ot
def compared_entity_indices(self,
entity_name2index: dict[str, int],
last_comparison_only: bool=True,
returns: Literal["rows", "row_list", "last_row"]="rows"
) -> dict[str, list[int]]:
key_indices = { loc: self.key_names.index(f"{loc}_name") for loc in ("left", "right") }
returns = "last_row" if last_comparison_only else "rows"
return {
location: [
entity_name2index[keys[key_indices[location]]]
for keys, _ in self.iter(returns)
] for location in ("left", "right")
}

def normalized_comparisons(self) -> Series:
df = self.to_df()
def normalized_comparisons(self, last_comparison_only: bool) -> Series:
df = self.to_df(last_row_only=last_comparison_only)
return df["comparison"] / df["comparison_max"]
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,18 @@ def __len__(self) -> int:
return sum([len(row_list) for _, row_list in self._dict.items()])
return sum([ len(sub_dicts) for sub_dicts in self._dict.values() ])

def to_rows(self, row_kwargs: Optional[dict]) -> list[dict]:
def to_rows(self, row_kwargs: Optional[dict]=None, last_row_only: bool=False) -> list[dict]:
if row_kwargs is None:
row_kwargs = dict()
returns = "last_row" if last_row_only else "rows"
return [
dict(zip(self.key_names, keys)) | row_kwargs | row
for keys, row in self.iter(returns="rows", value_process=False, key_process=False)
for keys, row in self.iter(returns=returns, value_process=False, key_process=False)
]

def to_df(self, row_kwargs: Optional[dict]=None, last_row_only: bool=False) -> DataFrame:
return DataFrame(self.to_rows(row_kwargs, last_row_only))

def iter(self,
returns: Literal["rows", "row_list", "last_row"]="rows",
value_process: bool=True,
Expand Down
182 changes: 84 additions & 98 deletions solidago/tests/load_save/generated_state/assessments.csv
Original file line number Diff line number Diff line change
@@ -1,99 +1,85 @@
username,criterion,entity_name,assessment,assessment_min,assessment_max
user_0,default,entity_16,-2.0749057611785386,-inf,inf
user_0,default,entity_15,-2.39370352625087,-inf,inf
user_0,default,entity_11,-1.5992836978625538,-inf,inf
user_0,default,entity_18,-2.619031854481638,-inf,inf
user_0,default,entity_6,-1.2481876740748055,-inf,inf
user_0,default,entity_17,-0.8195927297894479,-inf,inf
user_0,default,entity_14,-1.2667207574469999,-inf,inf
user_0,default,entity_19,0.8509795934379412,-inf,inf
user_0,default,entity_4,3.053938668091461,-inf,inf
user_1,default,entity_10,1.666205471029322,-inf,inf
user_1,default,entity_15,1.3982102010500272,-inf,inf
user_1,default,entity_7,0.31743364012724,-inf,inf
user_1,default,entity_2,2.0757722679531048,-inf,inf
user_1,default,entity_1,0.24777798376771804,-inf,inf
user_1,default,entity_17,1.2172607977951482,-inf,inf
user_1,default,entity_14,0.7587810143743123,-inf,inf
user_1,default,entity_16,-0.45322198998446633,-inf,inf
user_1,default,entity_3,-0.6293299841653897,-inf,inf
user_1,default,entity_12,-2.053760441767217,-inf,inf
user_1,default,entity_18,-0.9601193237500203,-inf,inf
user_1,default,entity_11,-0.7764255417543154,-inf,inf
user_2,default,entity_8,-3.1840313925258865,-inf,inf
user_2,default,entity_12,-1.3582673358306874,-inf,inf
user_2,default,entity_4,-0.3920343878647034,-inf,inf
user_2,default,entity_13,-1.0307685022464559,-inf,inf
user_2,default,entity_19,1.2140338299536473,-inf,inf
user_2,default,entity_1,0.34488131370003594,-inf,inf
user_2,default,entity_5,0.4963366907015018,-inf,inf
user_2,default,entity_16,1.3239011724971825,-inf,inf
user_3,default,entity_1,1.7159560105752791,-inf,inf
user_3,default,entity_2,0.900609937541631,-inf,inf
user_3,default,entity_10,0.010155796202459877,-inf,inf
user_3,default,entity_19,1.9939222430641497,-inf,inf
user_3,default,entity_18,-0.2487610576189404,-inf,inf
user_3,default,entity_6,-1.201379294506745,-inf,inf
user_3,default,entity_8,-1.989535014457278,-inf,inf
user_3,default,entity_14,-1.421354527295856,-inf,inf
user_3,default,entity_5,-0.32127542248226026,-inf,inf
user_4,default,entity_8,0.28435161869848846,-inf,inf
user_4,default,entity_6,2.4003825078321817,-inf,inf
user_4,default,entity_19,0.42259228898088025,-inf,inf
user_4,default,entity_11,0.055986495346031506,-inf,inf
user_4,default,entity_17,0.2099984004735524,-inf,inf
user_4,default,entity_13,0.4161872688546774,-inf,inf
user_4,default,entity_9,1.6036019683037805,-inf,inf
user_4,default,entity_4,-0.18158399657965602,-inf,inf
user_4,default,entity_2,2.0454369805034163,-inf,inf
user_4,default,entity_0,0.3854155927801495,-inf,inf
user_4,default,entity_15,-0.14632039422818038,-inf,inf
user_4,default,entity_5,-2.022605231132655,-inf,inf
user_4,default,entity_18,-0.872532836583685,-inf,inf
user_4,default,entity_3,-0.3827672112907701,-inf,inf
user_4,default,entity_12,0.6017254822778608,-inf,inf
user_4,default,entity_7,-2.9180266496509044,-inf,inf
user_4,default,entity_16,-0.7599310550355487,-inf,inf
user_5,default,entity_3,-2.1101274264627663,-inf,inf
user_5,default,entity_14,-2.762781599519231,-inf,inf
user_5,default,entity_2,-0.5100542990101968,-inf,inf
user_5,default,entity_17,0.05710431263487662,-inf,inf
user_6,default,entity_16,-0.06305869415402687,-inf,inf
user_6,default,entity_6,-0.7677539107375391,-inf,inf
user_6,default,entity_3,-2.8643609493870237,-inf,inf
user_6,default,entity_11,-2.1819602163698275,-inf,inf
user_6,default,entity_4,-0.43687483459641485,-inf,inf
user_6,default,entity_9,-0.7354562201117167,-inf,inf
user_6,default,entity_13,-1.6808275477504329,-inf,inf
user_6,default,entity_8,0.9663711729296698,-inf,inf
user_6,default,entity_14,0.621626369112255,-inf,inf
user_6,default,entity_2,-0.35611410923840825,-inf,inf
user_6,default,entity_10,1.408450182481708,-inf,inf
user_6,default,entity_5,0.5043003970870001,-inf,inf
user_6,default,entity_7,2.0966207883047416,-inf,inf
user_7,default,entity_15,-0.5655431776581341,-inf,inf
user_7,default,entity_12,-0.18604211953544048,-inf,inf
user_7,default,entity_8,-0.719456200755322,-inf,inf
user_7,default,entity_10,-0.824253089052783,-inf,inf
user_7,default,entity_1,-1.1268342583835582,-inf,inf
user_7,default,entity_3,1.6072561062359942,-inf,inf
user_7,default,entity_18,0.43807952480223084,-inf,inf
user_7,default,entity_2,1.1942940555235904,-inf,inf
user_7,default,entity_4,-0.4542680865093558,-inf,inf
user_7,default,entity_14,2.5219389650833213,-inf,inf
user_7,default,entity_19,-0.6650799718202678,-inf,inf
user_7,default,entity_9,-0.13580933330195982,-inf,inf
user_7,default,entity_11,-0.4812107308028669,-inf,inf
user_7,default,entity_17,1.96090138537457,-inf,inf
user_8,default,entity_8,2.1127574448260513,-inf,inf
user_8,default,entity_4,2.642406825318763,-inf,inf
user_8,default,entity_9,0.5444086059726662,-inf,inf
user_8,default,entity_11,0.5846999636827297,-inf,inf
user_8,default,entity_2,0.7698291989015807,-inf,inf
user_8,default,entity_18,1.3419798902453495,-inf,inf
user_8,default,entity_5,1.7797980695330238,-inf,inf
user_9,default,entity_3,-0.01253707482118327,-inf,inf
user_9,default,entity_13,1.3216410202684816,-inf,inf
user_9,default,entity_11,1.694370071188104,-inf,inf
user_9,default,entity_14,1.5236251305880073,-inf,inf
user_9,default,entity_9,2.4498918215182184,-inf,inf
user_0,default,entity_0,-3.0840321464854243,-inf,inf
user_0,default,entity_16,-4.0878539983787,-inf,inf
user_0,default,entity_15,-1.4902020610943079,-inf,inf
user_0,default,entity_18,-1.0038588539696902,-inf,inf
user_0,default,entity_9,0.15862347293943008,-inf,inf
user_0,default,entity_13,0.6825297084581011,-inf,inf
user_0,default,entity_19,-1.1330286216672034,-inf,inf
user_0,default,entity_5,1.527715830950634,-inf,inf
user_0,default,entity_4,3.4585969415576345,-inf,inf
user_1,default,entity_3,1.667515895613929,-inf,inf
user_1,default,entity_14,1.006242789479527,-inf,inf
user_1,default,entity_7,2.664075720463755,-inf,inf
user_1,default,entity_18,-0.8301415440372613,-inf,inf
user_1,default,entity_8,-0.17420902398789284,-inf,inf
user_1,default,entity_10,-1.1835582791092611,-inf,inf
user_1,default,entity_6,-1.1948480725179846,-inf,inf
user_1,default,entity_16,-0.9539951632086848,-inf,inf
user_1,default,entity_12,-2.180286196407139,-inf,inf
user_1,default,entity_2,-0.42652666498132275,-inf,inf
user_2,default,entity_11,-1.387325030216633,-inf,inf
user_2,default,entity_9,-0.9284619449437295,-inf,inf
user_2,default,entity_18,-0.4850896006442237,-inf,inf
user_2,default,entity_14,-0.5877998959704127,-inf,inf
user_2,default,entity_17,-0.8219599742190198,-inf,inf
user_2,default,entity_0,-0.7750717258171586,-inf,inf
user_2,default,entity_4,-0.7272800370864463,-inf,inf
user_2,default,entity_13,-0.33567087330509393,-inf,inf
user_2,default,entity_5,1.0737649403543752,-inf,inf
user_2,default,entity_1,1.3235236743850214,-inf,inf
user_2,default,entity_16,2.548852058606136,-inf,inf
user_3,default,entity_19,0.945310811705114,-inf,inf
user_3,default,entity_1,2.280503660781391,-inf,inf
user_3,default,entity_13,-1.5769377112488974,-inf,inf
user_3,default,entity_16,1.5165254267520498,-inf,inf
user_3,default,entity_10,2.072081238298041,-inf,inf
user_3,default,entity_3,-1.1848498093087512,-inf,inf
user_3,default,entity_7,1.3271180042064845,-inf,inf
user_3,default,entity_18,-0.25013042087067283,-inf,inf
user_3,default,entity_12,-1.0216656841229712,-inf,inf
user_3,default,entity_6,-1.0810028748325473,-inf,inf
user_3,default,entity_2,1.203173894386199,-inf,inf
user_4,default,entity_19,0.2639409528014471,-inf,inf
user_4,default,entity_8,0.440512579577302,-inf,inf
user_4,default,entity_2,-0.43449183897222454,-inf,inf
user_4,default,entity_17,0.6740247220648662,-inf,inf
user_4,default,entity_18,1.2408002798143847,-inf,inf
user_4,default,entity_3,0.48288392306535105,-inf,inf
user_4,default,entity_15,0.8328232652751022,-inf,inf
user_4,default,entity_12,-0.49299657597255997,-inf,inf
user_4,default,entity_14,-0.5966527881357853,-inf,inf
user_4,default,entity_7,0.11951141508835739,-inf,inf
user_4,default,entity_16,-0.5090007368392502,-inf,inf
user_5,default,entity_6,-0.21462643170061715,-inf,inf
user_5,default,entity_17,0.05710431263487674,-inf,inf
user_6,default,entity_13,-0.6865530798790694,-inf,inf
user_6,default,entity_16,0.27755782227896963,-inf,inf
user_6,default,entity_4,-0.6587271432487041,-inf,inf
user_6,default,entity_8,-0.5583548331157635,-inf,inf
user_6,default,entity_18,1.9506928061461883,-inf,inf
user_6,default,entity_1,0.49050679907782496,-inf,inf
user_6,default,entity_2,-1.3252639653101213,-inf,inf
user_6,default,entity_0,-0.3233871818836944,-inf,inf
user_7,default,entity_4,0.18732339474126747,-inf,inf
user_7,default,entity_8,1.9008743051155528,-inf,inf
user_7,default,entity_19,0.0966132005912641,-inf,inf
user_7,default,entity_3,0.9342862554614455,-inf,inf
user_7,default,entity_11,1.6668385201561917,-inf,inf
user_7,default,entity_5,0.7660559749152163,-inf,inf
user_7,default,entity_10,-1.394815185465011,-inf,inf
user_7,default,entity_2,-0.5664965551290375,-inf,inf
user_7,default,entity_14,3.0454263676652222,-inf,inf
user_8,default,entity_8,2.4404515835847844,-inf,inf
user_8,default,entity_4,2.1409310166269293,-inf,inf
user_8,default,entity_11,0.8428911156251948,-inf,inf
user_8,default,entity_14,0.11540882613427506,-inf,inf
user_8,default,entity_6,0.45813136118948833,-inf,inf
user_8,default,entity_2,0.31277617960683546,-inf,inf
user_9,default,entity_3,-1.962426062799881,-inf,inf
user_9,default,entity_12,-0.6855884162192276,-inf,inf
user_9,default,entity_2,-0.885308032694007,-inf,inf
user_9,default,entity_14,0.5609584765999138,-inf,inf
user_9,default,entity_9,-0.023301797645628702,-inf,inf
user_9,default,entity_5,-0.7796727684945641,-inf,inf
user_9,default,entity_1,-2.172493689682739,-inf,inf
Loading

0 comments on commit 6acf0b5

Please sign in to comment.