Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validation to log-regression benchmark #436

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ import org.renaissance.Benchmark
import org.renaissance.Benchmark._
import org.renaissance.BenchmarkContext
import org.renaissance.BenchmarkResult
import org.renaissance.BenchmarkResult.Validators
import org.renaissance.BenchmarkResult.Assert
import org.renaissance.License
import org.renaissance.apache.spark.ResourceUtil.duplicateLinesFromUrl

import java.nio.file.Files
import java.nio.file.Path

@Name("log-regression")
Expand All @@ -34,28 +33,53 @@ import java.nio.file.Path
defaultValue = "20",
summary = "Maximum number of iterations of the logistic regression algorithm."
)
@Configuration(name = "test", settings = Array("copy_count = 5"))
@Parameter(name = "expected_coefficient_sum", defaultValue = "-0.0653998570980114")
@Parameter(name = "expected_coefficient_sum_squares", defaultValue = "9.401759355004592E-5")
@Parameter(name = "expected_intercept_value", defaultValue = "2.287050116462375")
@Parameter(name = "expected_intercept_count", defaultValue = "1")
@Parameter(name = "expected_class_count", defaultValue = "2")
@Configuration(
name = "test",
settings = Array(
"copy_count = 5",
"expected_coefficient_sum = -0.06538768469885561",
"expected_coefficient_sum_squares = 9.395555567324299E-5",
"expected_intercept_value = 2.286718680950285",
"expected_class_count = 2"
)
)
@Configuration(name = "jmh")
final class LogRegression extends Benchmark with SparkUtil {

// Utility class for validation.

private case class ModelSummary(
coefficientSum: Double,
coefficientSumSquares: Double,
coefficientCount: Int,
interceptValue: Double,
interceptCount: Int,
classCount: Int
)

// TODO: Consolidate benchmark parameters across the suite.
// See: https://github.com/renaissance-benchmarks/renaissance/issues/27

private val inputResource = "/sample_libsvm_data.txt"

private val inputFeatureCount = 692

private var maxIterationsParam: Int = _

private val lrRegularizationParam = 0.1

private val lrElasticNetMixingParam = 0.0

private val lrConvergenceToleranceParam = 0.0

private var inputDataFrame: DataFrame = _
private var lrMaxIterationsParam: Int = _

private var expectedModelSummary: ModelSummary = _

private var outputLogisticRegression: LogisticRegressionModel = _
private var inputDataFrame: DataFrame = _

private def loadData(inputFile: Path, featureCount: Int) = {
sparkSession.read
Expand All @@ -67,7 +91,17 @@ final class LogRegression extends Benchmark with SparkUtil {
override def setUpBeforeAll(bc: BenchmarkContext): Unit = {
setUpSparkContext(bc)

maxIterationsParam = bc.parameter("max_iterations").toPositiveInteger
lrMaxIterationsParam = bc.parameter("max_iterations").toPositiveInteger

// Validation parameters.
expectedModelSummary = ModelSummary(
bc.parameter("expected_coefficient_sum").toDouble,
bc.parameter("expected_coefficient_sum_squares").toDouble,
inputFeatureCount,
bc.parameter("expected_intercept_value").toDouble,
bc.parameter("expected_intercept_count").toPositiveInteger,
bc.parameter("expected_class_count").toPositiveInteger
)

val inputFile = duplicateLinesFromUrl(
getClass.getResource(inputResource),
Expand All @@ -79,43 +113,81 @@ final class LogRegression extends Benchmark with SparkUtil {
}

override def run(bc: BenchmarkContext): BenchmarkResult = {
val lor = new LogisticRegression()
val logRegression = new LogisticRegression()
.setElasticNetParam(lrElasticNetMixingParam)
.setRegParam(lrRegularizationParam)
.setTol(lrConvergenceToleranceParam)
.setMaxIter(maxIterationsParam)

outputLogisticRegression = lor.fit(inputDataFrame)

// TODO: add more in-depth validation
Validators.compound(
Validators.simple("class count", 2, outputLogisticRegression.numClasses),
Validators.simple(
"feature count",
inputFeatureCount,
outputLogisticRegression.numFeatures
)
.setMaxIter(lrMaxIterationsParam)

val logRegressionModel = logRegression.fit(inputDataFrame)
() => validate(logRegressionModel)
}

private def validate(model: LogisticRegressionModel): Unit = {
//
// Validation currently supports only binary classification which returns
// a single intercept value. If multinomial logistic regression is needed,
// the validation needs to be updated to support multiple intercept values.
//
val actualModelSummary = summarizeModel(model)
validateSummary(
expectedModelSummary,
actualModelSummary,
coefficientSumTolerance = 0.1e-14,
coefficientSumSquaresTolerance = 0.1e-17,
interceptTolerance = 0.1e-13
)
}

override def tearDownAfterAll(bc: BenchmarkContext): Unit = {
if (dumpResultsBeforeTearDown && outputLogisticRegression != null) {
val outputFile = bc.scratchDirectory().resolve("output.txt")
dumpResult(outputLogisticRegression, outputFile)
}
private def summarizeModel(model: LogisticRegressionModel): ModelSummary = {
val coefficients = model.coefficients.toArray

tearDownSparkContext()
ModelSummary(
coefficients.sum,
coefficients.map(num => num * num).sum,
coefficients.length,
model.interceptVector(0),
model.interceptVector.size,
model.numClasses
)
}

private def dumpResult(lrm: LogisticRegressionModel, outputFile: Path) = {
val output = new StringBuilder
output.append(s"num features: ${lrm.numFeatures}\n")
output.append(s"num classes: ${lrm.numClasses}\n")
output.append(s"intercepts: ${lrm.interceptVector.toString}\n")
output.append(s"coefficients: ${lrm.coefficients.toString}\n")
private def validateSummary(
expected: ModelSummary,
actual: ModelSummary,
coefficientSumTolerance: Double,
coefficientSumSquaresTolerance: Double,
interceptTolerance: Double
): Unit = {
Assert.assertEquals(
expected.coefficientSum,
actual.coefficientSum,
coefficientSumTolerance,
"coefficients sum"
)

Assert.assertEquals(
expected.coefficientSumSquares,
actual.coefficientSumSquares,
coefficientSumSquaresTolerance,
"coefficients sum of squares"
)

Assert.assertEquals(expected.coefficientCount, actual.coefficientCount, "coefficient count")

// Files.writeString() is only available from Java 11.
Files.write(outputFile, output.toString.getBytes)
Assert.assertEquals(
expected.interceptValue,
actual.interceptValue,
interceptTolerance,
"intercept value"
)

Assert.assertEquals(expected.interceptCount, actual.interceptCount, "intercept count")

Assert.assertEquals(expected.classCount, actual.classCount, "class count")
}

override def tearDownAfterAll(bc: BenchmarkContext): Unit = {
tearDownSparkContext()
}
}