Skip to content

Commit

Permalink
classifyInstance now rounds the predictions for non-numeric class att…
Browse files Browse the repository at this point in the history
…ributes to get the index right
  • Loading branch information
fracpete committed Feb 10, 2023
1 parent 6cd0b8f commit 721d66c
Showing 1 changed file with 163 additions and 44 deletions.
207 changes: 163 additions & 44 deletions src/main/java/weka/classifiers/functions/LightGBM.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
Expand All @@ -43,14 +42,83 @@
import java.util.Vector;

/**
* <!-- globalinfo-start -->
* <!-- globalinfo-end -->
<!-- globalinfo-start -->
* LightGBM (https://github.com/microsoft/LightGBM) is a gradient boosting framework that uses tree based learning algorithms. It is designed to be distributed and efficient.<br>
* <br>
* Information on parameters:<br>
* https://lightgbm.readthedocs.io/en/v3.3.2/Parameters.html<br>
* The following parameters get filled in automatically:<br>
* - objective<br>
* - categorical_features<br>
* <br>
* For more information see:<br>
* <br>
* Ke, Guolin, Meng, Qi, Finley, Thomas, Wang, Taifeng, Chen, Wei, Ma, Weidong, Ye, Qiwei, Liu, Tie-Yan: LightGBM: A Highly Efficient Gradient Boosting Decision Tree. In: Advances in Neural Information Processing Systems, 3149-3157, 2017.
* <br><br>
<!-- globalinfo-end -->
*
* <!-- technical-bibtex-start -->
* <!-- technical-bibtex-end -->
<!-- technical-bibtex-start -->
* BibTeX:
* <pre>
* &#64;inproceedings{Ke2017,
* author = {Ke, Guolin and Meng, Qi and Finley, Thomas and Wang, Taifeng and Chen, Wei and Ma, Weidong and Ye, Qiwei and Liu, Tie-Yan},
* booktitle = {Advances in Neural Information Processing Systems},
* editor = {I. Guyon and U. Von Luxburg and S. Bengio and H. Wallach and R. Fergus and S. Vishwanathan and R. Garnett},
* pages = {3149-3157},
* publisher = {Curran Associates, Inc.},
* title = {LightGBM: A Highly Efficient Gradient Boosting Decision Tree},
* year = {2017},
* URL = {https://proceedings.neurips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf}
* }
* </pre>
* <br><br>
<!-- technical-bibtex-end -->
*
* <!-- options-start -->
* <!-- options-end -->
<!-- options-start -->
* Valid options are: <p>
*
* <pre> -O &lt;REGRESSION|REGRESSION_L1|HUBER|FAIR|POISSON|QUANTILE|MAPE|GAMMA|TWEEDIE|BINARY|MULTICLASS|MULTICLASSOVA|CROSSENTROPY|CROSSENTROPY_LAMBDA|LAMBDA_RANK|RANK_XENDCG&gt;
* The type of booster to use:
* REGRESSION = Regression
* REGRESSION_L1 = Regression L1
* HUBER = Huber loss
* FAIR = Fair loss
* POISSON = Poisson regression
* QUANTILE = Quantile regression
* MAPE = MAPE loss
* GAMMA = Gamma regression with log-link
* TWEEDIE = Tweedie regression with log-link
* BINARY = Binary log loss classification
* MULTICLASS = Multi-class (softmax)
* MULTICLASSOVA = Multi-class (one-vs-all)
* CROSSENTROPY = Cross-entropy
* CROSSENTROPY_LAMBDA = Cross-entropy Lambda
* LAMBDA_RANK = Lambda rank
* RANK_XENDCG = Rank Xendcg
* (default: REGRESSION)</pre>
*
* <pre> -P &lt;parameters&gt;
* The parameters for the booster (blank-separated key=value pairs).
* See: https://lightgbm.readthedocs.io/en/v3.3.2/Parameters.html
* (default: none)
* </pre>
*
* <pre> -I &lt;iterations&gt;
* The number of iterations to train for.
* (default: 1000)
* </pre>
*
* <pre> -V &lt;0-100&gt;
* The size of the validation set to split off from the training set.
* (default: 0.0)
* </pre>
*
* <pre> -R
* Turns on randomization before splitting off the validation set.
* (default: off)
* </pre>
*
<!-- options-end -->
*
* @author fracpete (fracpete at waikato dot ac dot nz)
*/
Expand Down Expand Up @@ -125,6 +193,9 @@ public class LightGBM
/** the built model. */
protected String m_Model = null;

/** whether the class is numeric. */
protected boolean m_NumericClass;

/**
* Returns a string describing this clusterer
*
Expand Down Expand Up @@ -186,34 +257,34 @@ public Enumeration listOptions() {
for (i = 0; i < TAGS_OBJECTIVE.length; i++) {
tag = new SelectedTag(TAGS_OBJECTIVE[i].getID(), TAGS_OBJECTIVE);
desc += "\t" + tag.getSelectedTag().getIDStr()
+ " = " + tag.getSelectedTag().getReadable()
+ "\n";
+ " = " + tag.getSelectedTag().getReadable()
+ "\n";
}
result.addElement(new Option(
"\tThe type of booster to use:\n"
+ desc
+ "\t(default: " + new SelectedTag(OBJECTIVE_REGRESSION, TAGS_OBJECTIVE) + ")",
+ desc
+ "\t(default: " + new SelectedTag(OBJECTIVE_REGRESSION, TAGS_OBJECTIVE) + ")",
"O", 1, "-O " + Tag.toOptionList(TAGS_OBJECTIVE)));

result.addElement(new Option(
"\tThe parameters for the booster (blank-separated key=value pairs).\n"
+ "\tSee: " + PARAMETERS_URL + "\n"
+ "\t(default: none)\n",
+ "\tSee: " + PARAMETERS_URL + "\n"
+ "\t(default: none)\n",
"P", 1, "-P <parameters>"));

result.addElement(new Option(
"\tThe number of iterations to train for.\n"
+ "\t(default: 1000)\n",
+ "\t(default: 1000)\n",
"I", 1, "-I <iterations>"));

result.addElement(new Option(
"\tThe size of the validation set to split off from the training set.\n"
+ "\t(default: 0.0)\n",
+ "\t(default: 0.0)\n",
"V", 1, "-V <0-100>"));

result.addElement(new Option(
"\tTurns on randomization before splitting off the validation set.\n"
+ "\t(default: off)\n",
+ "\t(default: off)\n",
"R", 0, "-R"));

return result.elements();
Expand All @@ -223,6 +294,49 @@ public Enumeration listOptions() {
* Parses the options. <p/>
*
<!-- options-start -->
* Valid options are: <p>
*
* <pre> -O &lt;REGRESSION|REGRESSION_L1|HUBER|FAIR|POISSON|QUANTILE|MAPE|GAMMA|TWEEDIE|BINARY|MULTICLASS|MULTICLASSOVA|CROSSENTROPY|CROSSENTROPY_LAMBDA|LAMBDA_RANK|RANK_XENDCG&gt;
* The type of booster to use:
* REGRESSION = Regression
* REGRESSION_L1 = Regression L1
* HUBER = Huber loss
* FAIR = Fair loss
* POISSON = Poisson regression
* QUANTILE = Quantile regression
* MAPE = MAPE loss
* GAMMA = Gamma regression with log-link
* TWEEDIE = Tweedie regression with log-link
* BINARY = Binary log loss classification
* MULTICLASS = Multi-class (softmax)
* MULTICLASSOVA = Multi-class (one-vs-all)
* CROSSENTROPY = Cross-entropy
* CROSSENTROPY_LAMBDA = Cross-entropy Lambda
* LAMBDA_RANK = Lambda rank
* RANK_XENDCG = Rank Xendcg
* (default: REGRESSION)</pre>
*
* <pre> -P &lt;parameters&gt;
* The parameters for the booster (blank-separated key=value pairs).
* See: https://lightgbm.readthedocs.io/en/v3.3.2/Parameters.html
* (default: none)
* </pre>
*
* <pre> -I &lt;iterations&gt;
* The number of iterations to train for.
* (default: 1000)
* </pre>
*
* <pre> -V &lt;0-100&gt;
* The size of the validation set to split off from the training set.
* (default: 0.0)
* </pre>
*
* <pre> -R
* Turns on randomization before splitting off the validation set.
* (default: off)
* </pre>
*
<!-- options-end -->
*
* @param options the options to parse
Expand Down Expand Up @@ -455,19 +569,19 @@ public Capabilities getCapabilities() {
result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
switch (m_Objective) {
case OBJECTIVE_BINARY:
result.enable(Capabilities.Capability.BINARY_CLASS);
result.disable(Capabilities.Capability.UNARY_CLASS);
break;
result.enable(Capabilities.Capability.BINARY_CLASS);
result.disable(Capabilities.Capability.UNARY_CLASS);
break;

case OBJECTIVE_MULTICLASS:
case OBJECTIVE_MULTICLASSOVA:
result.enable(Capabilities.Capability.NOMINAL_CLASS);
result.disable(Capabilities.Capability.BINARY_CLASS);
result.disable(Capabilities.Capability.UNARY_CLASS);
break;
result.enable(Capabilities.Capability.NOMINAL_CLASS);
result.disable(Capabilities.Capability.BINARY_CLASS);
result.disable(Capabilities.Capability.UNARY_CLASS);
break;

default:
result.enable(Capabilities.Capability.NUMERIC_CLASS);
result.enable(Capabilities.Capability.NUMERIC_CLASS);
}

// other
Expand Down Expand Up @@ -504,23 +618,25 @@ public void buildClassifier(Instances data) throws Exception {
data = new Instances(data);
data.deleteWithMissingClass();

m_NumericClass = data.classAttribute().isNumeric();

// validation set?
train = data;
val = null;
if (m_ValidationPercentage > 0) {
if (m_RandomizeBeforeSplit)
data.randomize(new Random(m_Seed));
data.randomize(new Random(m_Seed));
size = (int) Math.round(data.size() * m_ValidationPercentage / 100);
train = new Instances(data, data.numInstances() - size);
val = new Instances(data, size);
for (i = 0; i < data.numInstances(); i++) {
if (i < data.numInstances() - size)
train.add((Instance) data.instance(i).copy());
else
val.add((Instance) data.instance(i).copy());
if (i < data.numInstances() - size)
train.add((Instance) data.instance(i).copy());
else
val.add((Instance) data.instance(i).copy());
}
if (getDebug())
System.out.println("train size: " + train.numInstances() + ", validation size: " + val.numInstances());
System.out.println("train size: " + train.numInstances() + ", validation size: " + val.numInstances());
}

// categorical features
Expand All @@ -529,9 +645,9 @@ public void buildClassifier(Instances data) throws Exception {
if (i == data.classIndex())
continue;
if (data.attribute(i).isNominal()) {
if (categorical.length() > 0)
categorical.append(",");
categorical.append(i);
if (categorical.length() > 0)
categorical.append(",");
categorical.append(i);
}
}

Expand All @@ -551,25 +667,25 @@ public void buildClassifier(Instances data) throws Exception {
try {
m_Booster = LGBMBooster.create(lgbmTrain, m_ActualParameters);
if (lgbmVal != null)
m_Booster.addValidData(lgbmVal);
m_Booster.addValidData(lgbmVal);
// train
for (i = 0; i < m_NumIterations; i++) {
finished = m_Booster.updateOneIter();
if (finished) {
System.out.println("No more splits possible, stopping training at iteration " + (i+1) + " out of " + m_NumIterations);
break;
}
finished = m_Booster.updateOneIter();
if (finished) {
System.out.println("No more splits possible, stopping training at iteration " + (i+1) + " out of " + m_NumIterations);
break;
}
}
m_Model = m_Booster.saveModelToString(m_NumIterations - 1, m_NumIterations - 1, LGBMBooster.FeatureImportanceType.GAIN);
m_Model = m_Booster.saveModelToString(0, 0, LGBMBooster.FeatureImportanceType.GAIN);
}
catch (Exception e) {
if (m_Booster != null)
m_Booster.close();
m_Booster.close();
}
finally {
lgbmTrain.close();
if (lgbmVal != null)
lgbmVal.close();
lgbmVal.close();
}
}

Expand All @@ -581,9 +697,9 @@ public void buildClassifier(Instances data) throws Exception {
protected void initBooster() throws Exception {
if (m_Booster == null) {
if (m_Model != null)
m_Booster = LGBMBooster.loadModelFromString(m_Model);
m_Booster = LGBMBooster.loadModelFromString(m_Model);
else
throw new IllegalStateException("No model trained?");
throw new IllegalStateException("No model trained?");
}
}

Expand All @@ -606,6 +722,9 @@ public double classifyInstance(Instance instance) throws Exception {
values = LightGBMUtils.fromInstance(instance);
result = m_Booster.predictForMatSingleRow(values, PredictionType.C_API_PREDICT_NORMAL);

if (!m_NumericClass)
result = Math.round(result);

return result;
}

Expand Down

0 comments on commit 721d66c

Please sign in to comment.