diff --git a/src/main/java/weka/classifiers/functions/LightGBM.java b/src/main/java/weka/classifiers/functions/LightGBM.java index 3a7aa4b..8c9626c 100644 --- a/src/main/java/weka/classifiers/functions/LightGBM.java +++ b/src/main/java/weka/classifiers/functions/LightGBM.java @@ -56,7 +56,7 @@ */ public class LightGBM extends RandomizableClassifier - implements TechnicalInformationHandler { + implements TechnicalInformationHandler, AutoCloseable { private static final long serialVersionUID = -6138516902729782286L; @@ -82,7 +82,7 @@ public class LightGBM public static final int OBJECTIVE_RANK_XENDCG = 15; /** the available objectives. */ - protected static final Tag[] TAGS_OBJECTIVE = { + public static final Tag[] TAGS_OBJECTIVE = { new Tag(OBJECTIVE_REGRESSION, "REGRESSION", "Regression"), new Tag(OBJECTIVE_REGRESSION_L1, "REGRESSION_L1", "Regression L1"), new Tag(OBJECTIVE_HUBER, "HUBER", "Huber loss"), @@ -279,16 +279,14 @@ public String[] getOptions() { result.add("-I"); result.add("" + getNumIterations()); - result.add("-V"); - result.add("" + getValidationPercentage()); + if (getValidationPercentage() > 0) { + result.add("-V"); + result.add("" + getValidationPercentage()); + } - if (getRandomizeBeforeSplit()) { + if (getRandomizeBeforeSplit()) result.add("-R"); - result.add("-S"); - result.add("" + getSeed()); - } - return result.toArray(new String[0]); } @@ -497,6 +495,7 @@ public void buildClassifier(Instances data) throws Exception { int i; int size; StringBuilder categorical; + boolean finished; // can classifier handle the data? getCapabilities().testWithFail(data); @@ -527,6 +526,8 @@ public void buildClassifier(Instances data) throws Exception { // categorical features categorical = new StringBuilder(); for (i = 0; i < data.numAttributes(); i++) { + if (i == data.classIndex()) + continue; if (data.attribute(i).isNominal()) { if (categorical.length() > 0) categorical.append(","); @@ -538,23 +539,37 @@ public void buildClassifier(Instances data) throws Exception { lgbmVal = null; if (val != null) lgbmVal = LightGBMUtils.fromInstances(val, lgbmTrain); - m_ActualParameters = "objective=" + getObjective().getSelectedTag().getIDStr().toLowerCase(); + m_ActualParameters = "objective=" + getObjective().getSelectedTag().getIDStr().toLowerCase() + + " label=name:" + data.classAttribute().name(); if (categorical.length() > 0) m_ActualParameters += " categorical_features=" + categorical.toString(); if (!m_Parameters.isEmpty()) m_ActualParameters += " " + m_Parameters; + if (getDebug()) + System.out.println("Actual parameters: " + m_ActualParameters); try { m_Booster = LGBMBooster.create(lgbmTrain, m_ActualParameters); if (lgbmVal != null) m_Booster.addValidData(lgbmVal); - for (i = 0; i < m_NumIterations; i++) - m_Booster.updateOneIter(); + // train + for (i = 0; i < m_NumIterations; i++) { + finished = m_Booster.updateOneIter(); + if (finished) { + System.out.println("No more splits possible, stopping training at iteration " + (i+1) + " out of " + m_NumIterations); + break; + } + } m_Model = m_Booster.saveModelToString(m_NumIterations - 1, m_NumIterations - 1, LGBMBooster.FeatureImportanceType.GAIN); } catch (Exception e) { - // m_Booster.close(); // TODO memory leak? + if (m_Booster != null) + m_Booster.close(); + } + finally { lgbmTrain.close(); + if (lgbmVal != null) + lgbmVal.close(); } } @@ -619,6 +634,21 @@ public String toString() { return result.toString(); } + /** + * Closes this resource, relinquishing any underlying resources. + * This method is invoked automatically on objects managed by the + * {@code try}-with-resources statement. + * + * @throws Exception if this resource cannot be closed + */ + @Override + public void close() throws Exception { + if (m_Booster != null) { + m_Booster.close(); + m_Booster = null; + } + } + /** * Main method. *