Skip to content

Commit

Permalink
refactor MAP/MPE and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
damianoazzolini committed Oct 30, 2024
1 parent 8f6a322 commit 78eb204
Show file tree
Hide file tree
Showing 11 changed files with 122 additions and 111 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ jobs:

- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade build
pip install clingo
pip install numpy
pip install pytest
Expand All @@ -32,19 +34,17 @@ jobs:
- name: Test with pytest
run: |
cd test && pytest
- name: Login to Docker Hub
uses: docker/login-action@v2
- name: Docker login
uses: docker/login-action@v3
with:
username: damianodamianodamiano
password: ${{ secrets.DOCKER_HUB_TOKEN }}

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2

- name: Build and push
uses: docker/build-push-action@v4
- name: Docker build and push
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
push: true
tags: damianodamianodamiano/pasta:latest
17 changes: 17 additions & 0 deletions examples/map/win_mpe_disj.lp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
win ; nwin :- red, green.
win :- blue, yellow.

map 0.4::red.
map 0.9::green.
map 0.1::blue.
map 0.6::yellow.

% $ pastasolver win_mpe.lp --map --query="win" --upper
% MPE: 0.1944
% MPE states: 1
% State 0: ['red', 'green', 'not blue', 'yellow']

% $ pastasolver win_mpe.lp --map --query="win"
% MPE: 0.032400000000000005
% MPE states: 1
% State 0: ['not red', 'green', 'blue', 'yellow']
5 changes: 3 additions & 2 deletions pastasolver/models_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,9 +693,10 @@ def get_highest_prob_and_w_id_map(
'''
max_prob : float = 0.0
w_id_list : 'list[str]' = []

for el, w in current_worlds_dict.items():
if w.model_query_count > 0 and (w.model_not_query_count == 0 if lower else True):
if (lower and w.model_query_count > 0 and w.model_not_query_count == 0) or (not lower and w.model_query_count > 0):
# print("ok")
if w.prob == max_prob:
max_prob = w.prob
w_id_list.append(el)
Expand Down
43 changes: 22 additions & 21 deletions pastasolver/pasta_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ def setup_interface(self, from_string : str = "", approx : bool = False) -> None

asp_program = self.parser.get_asp_program(self.lpmln)

if not self.consider_lower_prob and self.query != "":
asp_program.append(f":- not {self.query}.")
# if not self.consider_lower_prob and self.query != "":
# asp_program.append(f":- not {self.query}.")

self.interface = AspInterface(
self.parser.probabilistic_facts,
Expand Down Expand Up @@ -593,6 +593,7 @@ def map_inference(self, from_string : str = "") -> 'tuple[float,list[list[str]]]
print_error_and_exit("Specify at least one map fact.")
if len(self.parser.map_id_list) == len(self.interface.prob_facts_dict) and not self.consider_lower_prob and not self.stop_if_inconsistent and not self.normalize_prob:
print_warning("Brave (upper) MPE can be solved in a faster way using the --solver flag.")
# self.consider_lower_prob = True
self.interface.compute_probabilities()
max_prob, map_state = self.interface.model_handler.get_map_solution(
self.parser.map_id_list, self.consider_lower_prob)
Expand Down Expand Up @@ -688,7 +689,7 @@ def main():
args.minimal = False
args.stop_if_inconsistent = False
args.normalize = False
if ((args.minimal and args.stop_if_inconsistent) or args.upper) and (not args.dtn and not args.dt and not args.dtopt):
if ((args.minimal and args.stop_if_inconsistent) or args.upper) and (not args.dtn and not args.dt and not args.dtopt and not args.map):
print_warning("The program is assumed to be consistent.")
args.stop_if_inconsistent = False
if args.stop_if_inconsistent:
Expand All @@ -700,23 +701,23 @@ def main():
print_warning("The lower utility may be greater than the upper utility for some strategies.")


pasta_solver = Pasta(args.filename,
args.query,
args.evidence,
args.verbose,
args.pedantic,
args.samples,
not args.upper,
args.minimal,
args.normalize,
args.stop_if_inconsistent,
args.one,
args.xor,
100,
args.dtn,
args.lpmln,
args.processes,
args.aspmc
pasta_solver = Pasta(filename=args.filename,
query=args.query,
evidence=args.evidence,
verbose=args.verbose,
pedantic=args.pedantic,
samples=args.samples,
consider_lower_prob=not args.upper,
minimal=args.minimal,
normalize_prob=args.normalize,
stop_if_inconsistent=args.stop_if_inconsistent,
one=args.one,
xor=args.xor,
k=100,
naive_dt=args.dtn,
lpmln=args.lpmln,
processes=args.processes,
aspmc=args.aspmc
)

if args.convert:
Expand Down Expand Up @@ -751,7 +752,7 @@ def main():
elif args.pl:
pasta_solver.parameter_learning()
elif args.map:
if args.upper or args.solver:
if args.solver:
pasta_solver.for_asp_solver = True
max_p, atoms_list_res = pasta_solver.upper_mpe_inference()
else:
Expand Down
6 changes: 3 additions & 3 deletions test/test_abduction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# import utils_for_tests
from .utils_for_tests import almostEqual, check_if_lists_equal
from .utils_for_tests import almost_equal, check_if_lists_equal

from pastasolver.pasta_solver import Pasta

Expand All @@ -16,8 +16,8 @@ def wrap_test_abduction(
pasta_solver = Pasta(filename, query, evidence)
lp, up, abd = pasta_solver.abduction()

assert almostEqual(lp, expected_lp, 5), test_name + ": wrong lower probability"
assert almostEqual(up, expected_up, 5), test_name + ": wrong upper probability"
assert almost_equal(lp, expected_lp), test_name + ": wrong lower probability"
assert almost_equal(up, expected_up), test_name + ": wrong upper probability"
assert check_if_lists_equal(abd, expected_abd), f"{test_name}: wrong abduction. Found {abd} expected {expected_abd}"

def test_bird_4_abd_prob():
Expand Down
6 changes: 3 additions & 3 deletions test/test_approximate_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pastasolver.pasta_solver import Pasta

from .utils_for_tests import ArgumentsTest, almostEqual
from .utils_for_tests import ArgumentsTest, almost_equal


@pytest.mark.parametrize("parameters", [
Expand Down Expand Up @@ -32,5 +32,5 @@ def test_approximate_inference(parameters : ArgumentsTest):

lp, up = pasta_solver.approximate_solve(args)

assert almostEqual(lp, parameters.expected_lp), f"{parameters.test_name}: wrong lower probability - E: {parameters.expected_lp}, F: {lp}"
assert almostEqual(up, parameters.expected_up), f"{parameters.test_name}: wrong upper probability - E: {parameters.expected_up}, F: {up}"
assert almost_equal(lp, parameters.expected_lp), f"{parameters.test_name}: wrong lower probability - E: {parameters.expected_lp}, F: {lp}"
assert almost_equal(up, parameters.expected_up), f"{parameters.test_name}: wrong upper probability - E: {parameters.expected_up}, F: {up}"
6 changes: 3 additions & 3 deletions test/test_exact_inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from .utils_for_tests import almostEqual
from .utils_for_tests import almost_equal

from pastasolver.pasta_solver import Pasta

Expand Down Expand Up @@ -34,8 +34,8 @@ def test_exact_inference(
pasta_solver = Pasta(filename, query, evidence, normalize_prob = normalize)
lp, up = pasta_solver.inference()

assert almostEqual(lp,expected_lp,5), test_name + ": wrong lower probability"
assert almostEqual(up,expected_up,5), test_name + ": wrong upper probability"
assert almost_equal(lp,expected_lp), f"{test_name}: wrong lower probability - E: {expected_lp}, F: {lp}"
assert almost_equal(up,expected_up), f"{test_name}: wrong upper probability - E: {expected_up}, F: {up}"


def test_certain_fact_a1_exit():
Expand Down
6 changes: 3 additions & 3 deletions test/test_hybrid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# import unittest
import pytest
from .utils_for_tests import ArgumentsTest, almostEqual
from .utils_for_tests import ArgumentsTest, almost_equal

from pastasolver.pasta_solver import Pasta

Expand All @@ -16,5 +16,5 @@ def test_hybrid(parameters : ArgumentsTest):
pasta_solver = Pasta(parameters.filename, parameters.query, parameters.evidence, normalize_prob=parameters.normalize)
lp, up = pasta_solver.inference()

assert almostEqual(lp, parameters.expected_lp), f"{parameters.test_name}: wrong lower probability - E: {parameters.expected_lp}, F: {lp}"
assert almostEqual(up, parameters.expected_up), f"{parameters.test_name}: wrong upper probability - E: {parameters.expected_up}, F: {up}"
assert almost_equal(lp, parameters.expected_lp), f"{parameters.test_name}: wrong lower probability - E: {parameters.expected_lp}, F: {lp}"
assert almost_equal(up, parameters.expected_up), f"{parameters.test_name}: wrong upper probability - E: {parameters.expected_up}, F: {up}"
50 changes: 25 additions & 25 deletions test/test_lifted.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .utils_for_tests import almostEqual
from .utils_for_tests import almost_equal

import pastasolver.lifted.lifted as lft

Expand All @@ -11,69 +11,69 @@
def test_cxy_ax_bxy_multiple_bi_1_lower():
test_name = "test_cxy_ax_bxy_multiple_bi_1_lower"
lp, up, _, _ = lft.cxy_ax_bxy_multiple_bi(0.4, [1, 3], lower=40)
assert almostEqual(lp, 0.0576, 5), test_name + ": wrong lower probability"
assert almostEqual(up, 0.16, 5), test_name + ": wrong upper probability"
assert almost_equal(lp, 0.0576), test_name + ": wrong lower probability"
assert almost_equal(up, 0.16), test_name + ": wrong upper probability"
def test_cxy_ax_bxy_multiple_bi_1_upper():
test_name = "test_cxy_ax_bxy_multiple_bi_1_upper"
lp, up, _, _ = lft.cxy_ax_bxy_multiple_bi(0.4, [1, 3], lower=0, upper=70)
assert almostEqual(lp, 0, 5), test_name + ": wrong lower probability"
assert almostEqual(up, 0.10240, 5), test_name + ": wrong upper probability"
assert almost_equal(lp, 0), test_name + ": wrong lower probability"
assert almost_equal(up, 0.10240), test_name + ": wrong upper probability"

def test_cxy_ax_bxy_multiple_bi_2_lower():
test_name = "test_cxy_ax_bxy_multiple_bi_2_lower"
lp, up, _, _ = lft.cxy_ax_bxy_multiple_bi(0.4, [1, 1, 2, 2], lower=40)
assert almostEqual(lp, 0.071424, 6), test_name + ": wrong lower probability"
assert almostEqual(up, 0.16, 5), test_name + ": wrong upper probability"
assert almost_equal(lp, 0.071424), test_name + ": wrong lower probability"
assert almost_equal(up, 0.16), test_name + ": wrong upper probability"
def test_cxy_ax_bxy_multiple_bi_2_upper():
test_name = "test_cxy_ax_bxy_multiple_bi_2_upper"
lp, up, _, _ = lft.cxy_ax_bxy_multiple_bi(0.4, [1, 1, 2, 2], lower=0, upper=70)
assert almostEqual(lp, 0, 6), test_name + ": wrong lower probability"
assert almostEqual(up, 0.088576, 5), test_name + ": wrong upper probability"
assert almost_equal(lp, 0), test_name + ": wrong lower probability"
assert almost_equal(up, 0.088576), test_name + ": wrong upper probability"


def test_cxy_ax_bxy_multiple_bi_3_lower():
test_name = "test_cxy_ax_bxy_multiple_bi_3_lower"
lp, up, _, _ = lft.cxy_ax_bxy_multiple_bi(0.4, [1, 1, 1, 2, 2, 1], lower=40)
assert almostEqual(lp, 0.059996160, 6), test_name + ": wrong lower probability"
assert almostEqual(up, 0.16, 5), test_name + ": wrong upper probability"
assert almost_equal(lp, 0.059996160), test_name + ": wrong lower probability"
assert almost_equal(up, 0.16), test_name + ": wrong upper probability"
def test_cxy_ax_bxy_multiple_bi_3_upper():
test_name = "test_cxy_ax_bxy_multiple_bi_3_upper"
lp, up, _, _ = lft.cxy_ax_bxy_multiple_bi(0.4, [1, 1, 1, 2, 2, 1], lower=0, upper=70)
assert almostEqual(lp, 0, 6), test_name + ": wrong lower probability"
assert almostEqual(up, 0.10000384, 8), test_name + ": wrong upper probability"
assert almost_equal(lp, 0), test_name + ": wrong lower probability"
assert almost_equal(up, 0.10000384), test_name + ": wrong upper probability"


def test_cxy_ax_bxy_multiple_bi_4_lower():
test_name = "test_cxy_ax_bxy_multiple_bi_4_lower"
lp, up, _, _ = lft.cxy_ax_bxy_multiple_bi(0.4, [1, 1, 3, 2], lower=40)
assert almostEqual(lp, 0.0428544, 6), test_name + ": wrong lower probability"
assert almostEqual(up, 0.16, 5), test_name + ": wrong upper probability"
assert almost_equal(lp, 0.0428544), test_name + ": wrong lower probability"
assert almost_equal(up, 0.16), test_name + ": wrong upper probability"
def test_cxy_ax_bxy_multiple_bi_4_upper():
test_name = "test_cxy_ax_bxy_multiple_bi_4_upper"
lp, up, _, _ = lft.cxy_ax_bxy_multiple_bi(0.4, [1, 1, 3, 2], lower=0, upper=70)
assert almostEqual(lp, 0, 6), test_name + ": wrong lower probability"
assert almostEqual(up, 0.1171456, 6), test_name + ": wrong upper probability"
assert almost_equal(lp, 0), test_name + ": wrong lower probability"
assert almost_equal(up, 0.1171456), test_name + ": wrong upper probability"


def test_cxy_ax_bxy_multiple_bi_5_lower():
test_name = "test_cxy_ax_bxy_multiple_bi_5_lower"
lp, up, _, _ = lft.cxy_ax_bxy_multiple_bi(0.4, [1, 1, 1, 3, 2, 1], lower=40)
assert almostEqual(lp, 0.035997696, 6), test_name + ": wrong lower probability"
assert almostEqual(up, 0.16, 5), test_name + ": wrong upper probability"
assert almost_equal(lp, 0.035997696), test_name + ": wrong lower probability"
assert almost_equal(up, 0.16), test_name + ": wrong upper probability"
def test_cxy_ax_bxy_multiple_bi_5_upper():
test_name = "test_cxy_ax_bxy_multiple_bi_5_upper"
lp, up, _, _ = lft.cxy_ax_bxy_multiple_bi(0.4, [1, 1, 1, 3, 2, 1], lower=0, upper=70)
assert almostEqual(lp, 0, 6), test_name + ": wrong lower probability"
assert almostEqual(up, 0.1240023, 6), test_name + ": wrong upper probability"
assert almost_equal(lp, 0), test_name + ": wrong lower probability"
assert almost_equal(up, 0.1240023), test_name + ": wrong upper probability"


def test_cxy_ax_bxy_multiple_bi_6_lower():
test_name = "test_cxy_ax_bxy_multiple_bi_6_lower"
lp, up, _, _ = lft.cxy_ax_bxy_multiple_bi(0.4, [1,1,1,1,1,3,2,2,1,1], lower=10)
assert almostEqual(lp, 0.02249712, 6), test_name + ": wrong lower probability"
assert almostEqual(up, 0.16, 5), test_name + ": wrong upper probability"
assert almost_equal(lp, 0.02249712), test_name + ": wrong lower probability"
assert almost_equal(up, 0.16), test_name + ": wrong upper probability"
def test_cxy_ax_bxy_multiple_bi_6_upper():
test_name = "test_cxy_ax_bxy_multiple_bi_6_upper"
lp, up, _, _ = lft.cxy_ax_bxy_multiple_bi(0.4, [1,1,1,1,1,3,2,2,1,1], lower=0, upper=80)
assert almostEqual(lp, 0, 6), test_name + ": wrong lower probability"
assert almostEqual(up, 0.137502879, 5), test_name + ": wrong upper probability"
assert almost_equal(lp, 0), test_name + ": wrong lower probability"
assert almost_equal(up, 0.137502879), test_name + ": wrong upper probability"
Loading

0 comments on commit 78eb204

Please sign in to comment.