Skip to content

Commit

Permalink
ADD: *args for positional argumenting
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentAuriau committed Jul 19, 2024
1 parent 4870c05 commit f4e9cda
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
28 changes: 22 additions & 6 deletions choice_learn/models/baseline_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ def compute_batch_utility(
np.random.uniform(size=(available_items_by_choice.shape), low=0.0, high=1.0)
).astype(np.float32)

def fit(**kwargs):
def fit(self, *args, **kwargs):
"""Make sure that nothing happens during .fit."""
_ = kwargs
_ = args
return {}

def _fit_with_lbfgs(**kwargs):
def _fit_with_lbfgs(self, *args, **kwargs):
"""Make sure that nothing happens during .fit."""
_ = kwargs
_ = args
return {}


Expand All @@ -80,17 +82,31 @@ def trainable_weights(self):
"""Return the weights."""
return self.weigths

def fit(self, choice_dataset, **kwargs):
"""Compute the choice frequency of each product and defines it as choice probabilities."""
def fit(self, choice_dataset, *args, **kwargs):
"""Compute the choice frequency of each product and defines it as choice probabilities.
Parameters
----------
choice_dataset : ChoiceDataset
Dataset to be used for fitting
"""
_ = kwargs
_ = args
choices = choice_dataset.choices
for i in range(choice_dataset.get_n_items()):
self.weights.append(tf.reduce_sum(tf.cast(choices == i, tf.float32)))
self.weights = tf.stack(self.weights) / len(choices)

def _fit_with_lbfgs(self, choice_dataset, **kwargs):
"""Compute the choice frequency of each product and defines it as choice probabilities."""
def _fit_with_lbfgs(self, choice_dataset, *args, **kwargs):
"""Compute the choice frequency of each product and defines it as choice probabilities.
Parameters
----------
choice_dataset : ChoiceDataset
Dataset to be used for fitting
"""
_ = kwargs
_ = args
choices = choice_dataset.choices
for i in range(choice_dataset.get_n_items()):
self.weights.append(tf.reduce_sum(tf.cast(choices == i, tf.float32)))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ select = [
"PTH",
"PD",
] # See: https://beta.ruff.rs/docs/rules/
ignore = ["D203", "D213", "ANN101", "ANN102", "ANN204", "ANN001", "ANN202", "ANN201", "ANN206", "ANN003", "PTH100", "PTH118", "PTH123"]
ignore = ["D203", "D213", "ANN101", "ANN102", "ANN204", "ANN001", "ANN002", "ANN202", "ANN201", "ANN206", "ANN003", "PTH100", "PTH118", "PTH123"]
line-length = 100
exclude = [
".bzr",
Expand Down

0 comments on commit f4e9cda

Please sign in to comment.