Skip to content

ustunb/reachml

Repository files navigation

reachml

python License: MIT arXiv CI

reachml is a library for recourse verification.

Background

Recourse is the ability of a decision subject to change the prediction of a machine learning model through actions on their features. Recourse verification aims to tell if a decision subject is assigned a prediction that is fixed.

Installation

You can install the library as follows:

pip install "git+https://github.com/ustunb/reachml#egg=reachml[cplex]"

Many of the functions in reach-ml will require CPLEX to run properly. The command above will install CPLEX Community Edition. The community edition has a strict limit on the number of constraints it can support. To avoid these, you will want install reachml without the cplex option, and download and install the full version of IBM CPLEX following these instructions.

Quickstart

The following example shows how to specify actionability constraints using ActionSet and to build a database of ReachableSet for each point.

import pandas as pd
from reachml import ActionSet, ReachableSet, ReachableDatabase
from reachml.constraints import OneHotEncoding, DirectionalLinkage

# feature matrix with 3 points
X = pd.DataFrame(
    {
        "age": [32, 19, 52],
        "marital_status": [1, 0, 0],
        "years_since_last_default": [5, 0, 21],
        "job_type_a": [0, 1, 1], # categorical feature with one-hot encoding
        "job_type_b": [1, 0, 0],
        "job_type_c": [0, 0, 0],
    }
)

# Create an action set
action_set = ActionSet(X)

# `ActionSet` infers the type and bounds on each feature from `X`. To see them:
print(action_set)

## print(action_set) should return the following output
##+---+--------------------------+--------+------------+----+----+----------------+---------+---------+
##|   | name                     |  type  | actionable | lb | ub | step_direction | step_ub | step_lb |
##+---+--------------------------+--------+------------+----+----+----------------+---------+---------+
##| 0 | age                      | <int>  |   False    | 19 | 52 |              0 |         |         |
##| 1 | marital_status           | <bool> |   False    | 0  | 1  |              0 |         |         |
##| 2 | years_since_last_default | <int>  |    True    | 0  | 21 |              1 |         |         |
##| 3 | job_type_a               | <bool> |    True    | 0  | 1  |              0 |         |         |
##| 4 | job_type_b               | <bool> |    True    | 0  | 1  |              0 |         |         |
##| 5 | job_type_c               | <bool> |    True    | 0  | 1  |              0 |         |         |
##+---+--------------------------+--------+------------+----+----+----------------+---------+---------+

# Specify constraints on individual features
action_set[["age", "marital_status"]].actionable = False # these features cannot or should not change
action_set["years_since_last_default"].ub = 100 # set maximum value of feature to 100
action_set["years_since_last_default"].step_direction = 1 # actions can only increase value
action_set["years_since_last_default"].step_ub = 1 # limit actions to changes value by 1

# Specify constraint to maintain one hot-encoding on `job_type`
action_set.constraints.add(
    constraint=OneHotEncoding(names=["job_type_a", "job_type_b", "job_type_c"])
)

# Specify deterministic causal relationships
# if `years_since_last_default` increases, then `age` must increase commensurately
# This will force `age` to change even though it is not immediately actionable
action_set.constraints.add(
    constraint=DirectionalLinkage(
        names=["years_since_last_default", "age"], scales=[1, 1]
    )
)

# Check that `ActionSet` is consistent with observed data
# For example, if features must obey one-hot encoding, this should be the case for X
assert action_set.validate(X)

# Build a database of reachable sets for all points
db = ReachableSetDatabase(action_set, path="reachable_db.h5") #database stored in file `./reachable_db.h5`
db.generate(data, overwrite=True)

# Pull reachable set for first point in dataset
x = data.iloc[0]
reachable_set = db[x]
print(reachable_set) # should return the following output:
##    age  marital_status  years_since_last_default  job_type_a  job_type_b  job_type_c
## 0  32.0             1.0                       5.0         0.0         1.0         0.0
## 1  32.0             1.0                       5.0         0.0         0.0         1.0
## 2  32.0             1.0                       5.0         1.0         0.0         0.0
## 3  33.0             1.0                       6.0         0.0         0.0         1.0
## 4  33.0             1.0                       6.0         0.0         1.0         0.0
## 5  33.0             1.0                       6.0         1.0         0.0         0.0

Given a classifier clf with a predict method, you can test if a point has recourse as np.any(clf.predict(reachable_set.X))

For more examples, check out this script which sets up the action set for the FICO dataset.

Resources and Citation

For more about recourse verification, check out our paper ICLR 2024 spotlight paper: Prediction without Preclusion

If you use this library in your research, we would appreciate a citation:

@inproceedings{kothari2024prediction,
    title={Prediction without Preclusion: Recourse Verification with Reachable Sets},
    author={Avni Kothari and Bogdan Kulynych and Tsui-Wei Weng and Berk Ustun},
    booktitle={The Twelfth International Conference on Learning Representations},
    year={2024},
    url={https://openreview.net/forum?id=SCQfYpdoGE}
}

The code for the paper is available under research/iclr2024.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •