From f4e9cda91f04bf6cf30366a7c5949c8f6512c8f1 Mon Sep 17 00:00:00 2001 From: VincentAuriau Date: Fri, 19 Jul 2024 18:33:21 +0200 Subject: [PATCH] ADD: *args for positional argumenting --- choice_learn/models/baseline_models.py | 28 ++++++++++++++++++++------ pyproject.toml | 2 +- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/choice_learn/models/baseline_models.py b/choice_learn/models/baseline_models.py index 72473cae..ab2572df 100644 --- a/choice_learn/models/baseline_models.py +++ b/choice_learn/models/baseline_models.py @@ -53,14 +53,16 @@ def compute_batch_utility( np.random.uniform(size=(available_items_by_choice.shape), low=0.0, high=1.0) ).astype(np.float32) - def fit(**kwargs): + def fit(self, *args, **kwargs): """Make sure that nothing happens during .fit.""" _ = kwargs + _ = args return {} - def _fit_with_lbfgs(**kwargs): + def _fit_with_lbfgs(self, *args, **kwargs): """Make sure that nothing happens during .fit.""" _ = kwargs + _ = args return {} @@ -80,17 +82,31 @@ def trainable_weights(self): """Return the weights.""" return self.weigths - def fit(self, choice_dataset, **kwargs): - """Compute the choice frequency of each product and defines it as choice probabilities.""" + def fit(self, choice_dataset, *args, **kwargs): + """Compute the choice frequency of each product and defines it as choice probabilities. + + Parameters + ---------- + choice_dataset : ChoiceDataset + Dataset to be used for fitting + """ _ = kwargs + _ = args choices = choice_dataset.choices for i in range(choice_dataset.get_n_items()): self.weights.append(tf.reduce_sum(tf.cast(choices == i, tf.float32))) self.weights = tf.stack(self.weights) / len(choices) - def _fit_with_lbfgs(self, choice_dataset, **kwargs): - """Compute the choice frequency of each product and defines it as choice probabilities.""" + def _fit_with_lbfgs(self, choice_dataset, *args, **kwargs): + """Compute the choice frequency of each product and defines it as choice probabilities. + + Parameters + ---------- + choice_dataset : ChoiceDataset + Dataset to be used for fitting + """ _ = kwargs + _ = args choices = choice_dataset.choices for i in range(choice_dataset.get_n_items()): self.weights.append(tf.reduce_sum(tf.cast(choices == i, tf.float32))) diff --git a/pyproject.toml b/pyproject.toml index 14c7d596..4e7dcc74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ select = [ "PTH", "PD", ] # See: https://beta.ruff.rs/docs/rules/ -ignore = ["D203", "D213", "ANN101", "ANN102", "ANN204", "ANN001", "ANN202", "ANN201", "ANN206", "ANN003", "PTH100", "PTH118", "PTH123"] +ignore = ["D203", "D213", "ANN101", "ANN102", "ANN204", "ANN001", "ANN002", "ANN202", "ANN201", "ANN206", "ANN003", "PTH100", "PTH118", "PTH123"] line-length = 100 exclude = [ ".bzr",