Skip to content

Commit

Permalink
changes required for using prefsampling>=0.1.19
Browse files Browse the repository at this point in the history
thanks to @Simon-Rey
  • Loading branch information
martinlackner committed Sep 17, 2024
1 parent 469ce77 commit e3e4111
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 31 deletions.
4 changes: 2 additions & 2 deletions abcvoting/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ def write_profile_to_preflib_cat_file(filepath, profile):
Parameters
----------
filename : str
File name of the Preflib file.
filepath : str
File path of the Preflib file.
profile : abcvoting.preferences.Profile
Profile to be written.
Expand Down
82 changes: 54 additions & 28 deletions abcvoting/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,8 @@ def random_profile(num_voters, num_cand, prob_distribution):
>>> prob_distribution = {"id": "IC", "p": 0.5}
>>> profile = random_profile(num_voters=5, num_cand=5, prob_distribution=prob_distribution)
>>> print(profile)
profile with 5 voters and 5 candidates:
voter 0: {0, 1, 2, 3},
voter 1: {0, 2, 4},
voter 2: {1, 3},
voter 3: {0, 3, 4},
voter 4: {0}
>>> print(len(profile))
5
"""
if "id" not in prob_distribution:
raise KeyError('Probability distribution requires key "id".')
Expand Down Expand Up @@ -826,6 +821,7 @@ class PointProbabilityDistribution:
from abcvoting import generate
from abcvoting.generate import PointProbabilityDistribution
import matplotlib.pyplot as plt
import numpy as np
.. testcode::
Expand All @@ -846,6 +842,8 @@ class PointProbabilityDistribution:
from abcvoting.generate import PointProbabilityDistribution
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 4})
# distributions to generate points in 1- and 2-dimensional space
distributions = [
PointProbabilityDistribution("1d_interval", center_point=[0]),
Expand All @@ -857,7 +855,7 @@ class PointProbabilityDistribution:
PointProbabilityDistribution("2d_gaussian_disc", center_point=[6, 2], sigma=0.25),
]
fig, ax = plt.subplots(dpi=600, figsize=(7, 3))
fig, ax = plt.subplots(dpi=600, figsize=(14, 6))
points = []
for dist in distributions:
if dist.name.startswith("2d"):
Expand Down Expand Up @@ -903,48 +901,76 @@ def __init__(self, name, center_point=0.5, sigma=0.15, width=1.0):
self.center_point = np.array([center_point] * self.dimension, dtype=float)

def prefsampling_function(self):
# todo: the following lambda functions have a parameter num_dimensions, even though
# the dimension is specified by the name. FIX
if self.name == "1d_interval":
return lambda num_points, num_dimensions: point_samplers.cube(
num_points, 1, center_point=self.center_point, widths=self.width
return lambda num_points, num_dimensions=1, seed=None: point_samplers.cube(
num_points,
num_dimensions=1,
center_point=self.center_point,
widths=self.width,
seed=seed,
)
elif self.name == "1d_gaussian":
return lambda num_points, num_dimensions: point_samplers.gaussian(
num_points, 1, center_point=self.center_point, sigmas=self.sigma
return lambda num_points, num_dimensions=1, seed=None: point_samplers.gaussian(
num_points,
num_dimensions=1,
center_point=self.center_point,
sigmas=self.sigma,
seed=seed,
)
elif self.name == "1d_gaussian_interval":
return lambda num_points, num_dimensions: point_samplers.gaussian(
return lambda num_points, num_dimensions=1, seed=None: point_samplers.gaussian(
num_points,
1,
num_dimensions=1,
center_point=self.center_point,
sigmas=self.sigma,
widths=[self.width],
seed=seed,
)
elif self.name == "2d_disc":
return lambda num_points, num_dimensions: point_samplers.ball_uniform(
num_points, 2, center_point=self.center_point, widths=self.width
return lambda num_points, num_dimensions=2, seed=None: point_samplers.ball_uniform(
num_points,
num_dimensions=2,
center_point=self.center_point,
widths=self.width,
seed=seed,
)
elif self.name == "2d_square":
return lambda num_points, num_dimensions: point_samplers.cube(
num_points, 2, center_point=self.center_point, widths=self.width
return lambda num_points, num_dimensions=2, seed=None: point_samplers.cube(
num_points,
num_dimensions=2,
center_point=self.center_point,
widths=self.width,
seed=seed,
)
elif self.name == "2d_gaussian":
return lambda num_points, num_dimensions: point_samplers.gaussian(
num_points, 2, center_point=self.center_point, sigmas=self.sigma
return lambda num_points, num_dimensions=2, seed=None: point_samplers.gaussian(
num_points,
num_dimensions=2,
center_point=self.center_point,
sigmas=self.sigma,
seed=seed,
)
elif self.name == "2d_gaussian_disc":
return lambda num_points, num_dimensions: point_samplers.ball_resampling(
return lambda num_points, num_dimensions=2, seed=None: point_samplers.ball_resampling(
num_points,
2,
lambda: point_samplers.gaussian(
1, 2, center_point=self.center_point, sigmas=self.sigma
num_dimensions=2,
inner_sampler=lambda: point_samplers.gaussian(
1, num_dimensions=2, center_point=self.center_point, sigmas=self.sigma
)[0],
{},
inner_sampler_args={},
center_point=self.center_point,
width=self.width,
seed=seed,
)
elif self.name == "3d_cube":
return lambda num_points, num_dimensions: point_samplers.cube(
num_points, 3, center_point=self.center_point, widths=self.width
return lambda num_points, num_dimensions=3, seed=None: point_samplers.cube(
num_points,
num_dimensions=3,
center_point=self.center_point,
widths=self.width,
seed=seed,
)
else:
raise ValueError(f"unknown name of point distribution: {self.name}")
Expand All @@ -963,7 +989,7 @@ def random_point(prob_distribution):
-------
np.ndarray
"""
return prob_distribution.prefsampling_function()(1)[0]
return prob_distribution.prefsampling_function()(num_points=1)[0]


PROBABILITY_DISTRIBUTIONS = {
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def read_version():
"numpy>=1.21",
"gurobipy>=11.0",
"preflibtools>=2.0.12",
"prefsampling>=0.1.18",
"prefsampling>=0.1.19",
],
extras_require={
"dev": [
Expand Down

0 comments on commit e3e4111

Please sign in to comment.