Skip to content

Commit

Permalink
implements AutoCloseable now; TAGS_OBJECTIVE is now public; exiting t…
Browse files Browse the repository at this point in the history
…raining loop if training finished; adding class label to options again; skipping class attribute from categorical_features parameter; closing optional validation set now as well
  • Loading branch information
fracpete committed Feb 10, 2023
1 parent 038d2fb commit 6cd0b8f
Showing 1 changed file with 43 additions and 13 deletions.
56 changes: 43 additions & 13 deletions src/main/java/weka/classifiers/functions/LightGBM.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
*/
public class LightGBM
extends RandomizableClassifier
implements TechnicalInformationHandler {
implements TechnicalInformationHandler, AutoCloseable {

private static final long serialVersionUID = -6138516902729782286L;

Expand All @@ -82,7 +82,7 @@ public class LightGBM
public static final int OBJECTIVE_RANK_XENDCG = 15;

/** the available objectives. */
protected static final Tag[] TAGS_OBJECTIVE = {
public static final Tag[] TAGS_OBJECTIVE = {
new Tag(OBJECTIVE_REGRESSION, "REGRESSION", "Regression"),
new Tag(OBJECTIVE_REGRESSION_L1, "REGRESSION_L1", "Regression L1"),
new Tag(OBJECTIVE_HUBER, "HUBER", "Huber loss"),
Expand Down Expand Up @@ -279,16 +279,14 @@ public String[] getOptions() {
result.add("-I");
result.add("" + getNumIterations());

result.add("-V");
result.add("" + getValidationPercentage());
if (getValidationPercentage() > 0) {
result.add("-V");
result.add("" + getValidationPercentage());
}

if (getRandomizeBeforeSplit()) {
if (getRandomizeBeforeSplit())
result.add("-R");

result.add("-S");
result.add("" + getSeed());
}

return result.toArray(new String[0]);
}

Expand Down Expand Up @@ -497,6 +495,7 @@ public void buildClassifier(Instances data) throws Exception {
int i;
int size;
StringBuilder categorical;
boolean finished;

// can classifier handle the data?
getCapabilities().testWithFail(data);
Expand Down Expand Up @@ -527,6 +526,8 @@ public void buildClassifier(Instances data) throws Exception {
// categorical features
categorical = new StringBuilder();
for (i = 0; i < data.numAttributes(); i++) {
if (i == data.classIndex())
continue;
if (data.attribute(i).isNominal()) {
if (categorical.length() > 0)
categorical.append(",");
Expand All @@ -538,23 +539,37 @@ public void buildClassifier(Instances data) throws Exception {
lgbmVal = null;
if (val != null)
lgbmVal = LightGBMUtils.fromInstances(val, lgbmTrain);
m_ActualParameters = "objective=" + getObjective().getSelectedTag().getIDStr().toLowerCase();
m_ActualParameters = "objective=" + getObjective().getSelectedTag().getIDStr().toLowerCase()
+ " label=name:" + data.classAttribute().name();
if (categorical.length() > 0)
m_ActualParameters += " categorical_features=" + categorical.toString();
if (!m_Parameters.isEmpty())
m_ActualParameters += " " + m_Parameters;
if (getDebug())
System.out.println("Actual parameters: " + m_ActualParameters);

try {
m_Booster = LGBMBooster.create(lgbmTrain, m_ActualParameters);
if (lgbmVal != null)
m_Booster.addValidData(lgbmVal);
for (i = 0; i < m_NumIterations; i++)
m_Booster.updateOneIter();
// train
for (i = 0; i < m_NumIterations; i++) {
finished = m_Booster.updateOneIter();
if (finished) {
System.out.println("No more splits possible, stopping training at iteration " + (i+1) + " out of " + m_NumIterations);
break;
}
}
m_Model = m_Booster.saveModelToString(m_NumIterations - 1, m_NumIterations - 1, LGBMBooster.FeatureImportanceType.GAIN);
}
catch (Exception e) {
// m_Booster.close(); // TODO memory leak?
if (m_Booster != null)
m_Booster.close();
}
finally {
lgbmTrain.close();
if (lgbmVal != null)
lgbmVal.close();
}
}

Expand Down Expand Up @@ -619,6 +634,21 @@ public String toString() {
return result.toString();
}

/**
* Closes this resource, relinquishing any underlying resources.
* This method is invoked automatically on objects managed by the
* {@code try}-with-resources statement.
*
* @throws Exception if this resource cannot be closed
*/
@Override
public void close() throws Exception {
if (m_Booster != null) {
m_Booster.close();
m_Booster = null;
}
}

/**
* Main method.
*
Expand Down

0 comments on commit 6cd0b8f

Please sign in to comment.