Skip to content

Latest commit

 

History

History
99 lines (69 loc) · 3.45 KB

File metadata and controls

99 lines (69 loc) · 3.45 KB

Bayesian Classifers for PU-learning

This module contains a custom implementation of Bayesian network for PU learning, written from scratch in Python 3, API inspired by SciKit-learn. This module implements multiple Bayesian classifiers:

PU learning is the setting where a learner only has access to positive examples and unlabeled data. The assumption is that the unlabeled data can contain both positive and negative examples.

Note: All algorithms make the "Select Completely At Random(SCAR)" labeling assumption. We consider "case-control" and "single training" sampling scenariors.

Current features

PU Generator

Generate Positive and Unlabeled data from fully labeled data set, either follows "case-control" or "single training".

import PUgenerator

p450_pu = PUgenerator()
p450_pu.fit(X,y,nl = 400,nu = 800, case_control = True) # "case-control"
p450_pu.prevalence_ # p(y=1)
p450_pu.plot_dist() # plot all feature distributions

p450_pu.X_1abeled_  # labeled set
p450_pu.X_Unlabeled_ # Unlabeled set
p450_pu.y_Unlabeled_ # Unlabeled target values(unknown in practice)

PU Bayesian classifiers

Demo for implementation of PNB and PSTAN, from training to prediction

import PNB
pnb = PNB()
pnb.fit(p450_pu.X_1abeled_,p450_pu.X_Unlabeled_, p450_pu.prevalence_) # model fitting
pnb.predict(p450_pu.X_Unlabeled_) # prediction
pnb.predict_proba(p450_pu.X_Unlabeled_) # proba prediction
import PSTAN
pstan = PSTAN()
pstan.fit(p450_pu.X_1abeled_,p450_pu.X_Unlabeled_, p450_pu.prevalence_,M) # model fitting
pstan.plot_tree_structure() # plot learned tree-structure
pstan.predict(p450_pu.X_Unlabeled_) # prediction
pstan.predict_proba(p450_pu.X_Unlabeled_) # proba prediction

Evaluation

Evaluation under multiple runs, take average.

import get_cv

Accuracy, CLL, Precision,Recall = get_cv(PNB,X,y,400,800,M)
print(np.mean(Accuracy))
print(np.mean(CLL))
print(np.mean(Precision))
print(np.mean(Recall))

Built With

  • Dropwizard - scikit-learn API
  • Maven - Dependency Management
  • ROME - Used to generate RSS Feeds

Versioning

We use SemVer for versioning. For the versions available, see the tags on this repository.

Authors

License

This project is licensed under the MIT License - see the LICENSE.md file for details

Acknowledgments

  • Hat tip to anyone whose code was used
  • Inspiration
  • etc