Skip to content

Commit

Permalink
Merge pull request #7 from ruivieira/RHOAIENG-7082
Browse files Browse the repository at this point in the history
RHOAIENG-7082: Add SHAP support for KServe explainer
  • Loading branch information
ruivieira authored Jun 25, 2024
2 parents 80882b1 + f454419 commit 65dd33e
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 62 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,20 @@ spec:

The explanation request will be identical to the LIME explainer case.

## Configuration

The following environment variables can be used in the `InferenceService` to customize the explainer:

| Name | Description | Default |
|--------------------------------------------------------------------------|--------------------------------------------------------------------|---------------|
| `EXPLAINER_TYPE` | `LIME` or `SHAP`, the explainer to use. | `LIME` |
| `LIME_SAMPLES` | The number of samples to use in LIME | `200` |
| `LIME_RETRIES` | Number of LIME retries | `2` |
| `LIME_WLR` | Use LIME Weighted Linear Regression, `true` or `false` | `true` |
| `LIME_NORMALIZE_WEIGHTS` | Whether LIME should normalize the weights, `true` or `false` | `true` |
| `EXPLAINER_SHAP_BACKGROUND_QUEUE` | The number of observations to keep in memory for SHAP's background | `10` |
| `EXPLAINER_SHAP_BACKGROUND_DIVERSITY` | The number of synthetic samples to generate for diversity | `10` |

## Contributing

To get started with contributing to this project:
Expand Down
8 changes: 7 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.kie.trustyai</groupId>
<artifactId>trustyai-kserve</artifactId>
<version>1.0-SNAPSHOT</version>
<version>0.2-SNAPSHOT</version>

<properties>
<compiler-plugin.version>3.11.0</compiler-plugin.version>
Expand Down Expand Up @@ -46,6 +46,12 @@
<artifactId>explainability-connectors</artifactId>
<version>${trustyai.version}</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-collections4</artifactId>
<version>4.4</version>
</dependency>


<dependency>
<groupId>io.quarkus</groupId>
Expand Down
19 changes: 12 additions & 7 deletions src/main/java/org/kie/trustyai/CommandLineArgs.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ public class CommandLineArgs {

@CommandLine.Option(names = "--http_port", description = "The HTTP port of the predictor")
private int httpPort;
@CommandLine.Option(names = "--predictor_protocol", defaultValue = "v1", description = "The predictor protocol version (v1 or v2)")
private String predictorProtocol;

public String getPredictorProtocol() {
return predictorProtocol;
}

public String getPredictorHost() {
return predictorHost;
Expand All @@ -27,14 +33,13 @@ public int getHttpPort() {
return httpPort;
}

public String getV1HTTPPredictorURI() {

return "http://" +
predictorHost +
"/v1/models/" +
modelName +
":predict";
public String getV1HTTPPredictorURI(String modelName) {
return "http://" + predictorHost + "/v1/models/" + modelName + ":predict";
}

public String getV2HTTPPredictorURI(String modelName) {

return "http://" + predictorHost + "/v2/models/" + modelName + "/infer";
}

}
17 changes: 8 additions & 9 deletions src/main/java/org/kie/trustyai/ConfigCommand.java
Original file line number Diff line number Diff line change
@@ -1,42 +1,41 @@
package org.kie.trustyai;

import io.quarkus.logging.Log;
import io.quarkus.runtime.Quarkus;
import io.quarkus.runtime.QuarkusApplication;
import io.quarkus.runtime.annotations.QuarkusMain;
import jakarta.inject.Inject;
import picocli.CommandLine;
import org.jboss.logging.Logger;

import java.util.Arrays;

@QuarkusMain
public class ConfigCommand implements QuarkusApplication {

private static final Logger LOGGER = Logger.getLogger(ConfigCommand.class.getName());

@Inject
CommandLineArgs cmdArgs;

@Override
public int run(String... args) {
LOGGER.debug("Starting application...");
Log.info("Starting application...");
final CommandLine commandLine = new CommandLine(cmdArgs);

Log.debug("Using command-line arguments: " + Arrays.toString(args));
try {
commandLine.parseArgs(args);
if (commandLine.isUsageHelpRequested()) {
commandLine.usage(System.out);
return 0;
}


LOGGER.debug("Configuration loaded successfully.");
Log.info("Configuration loaded successfully.");
} catch (CommandLine.ParameterException e) {
LOGGER.error("Error parsing command line: " + e.getMessage());
Log.error("Error parsing command line: " + e.getMessage());
commandLine.usage(System.err);
return 1;
}

Quarkus.waitForExit(); // Wait for Quarkus shutdown events
LOGGER.debug("Quarkus is waiting for exit...");
Log.info("Quarkus is waiting for exit...");
return 0;
}
}
39 changes: 20 additions & 19 deletions src/main/java/org/kie/trustyai/ConfigService.java
Original file line number Diff line number Diff line change
@@ -1,55 +1,56 @@
package org.kie.trustyai;

import io.quarkus.logging.Log;
import jakarta.annotation.PostConstruct;
import jakarta.enterprise.context.ApplicationScoped;
import org.eclipse.microprofile.config.inject.ConfigProperty;
import org.jboss.logging.Logger;


@ApplicationScoped
public class ConfigService {

private static final Logger LOGGER = Logger.getLogger(ConfigService.class.getName());



@ConfigProperty(name = "explainer.type", defaultValue = "LIME")
ExplainerType explainerType;
@ConfigProperty(name = "lime.samples", defaultValue = "200")
int limeSamples;
@ConfigProperty(name = "lime.retries", defaultValue = "2")
int limeRetries;
@ConfigProperty(name = "lime.wlr", defaultValue = "true")
boolean limeWLR;
@ConfigProperty(name = "lime.normalize.weights", defaultValue = "true")
boolean limeNormalizeWeights;
@ConfigProperty(name = "explainer.shap.background.queue", defaultValue = "10")
int queueSize;
@ConfigProperty(name = "explainer.shap.background.diversity", defaultValue = "10")
int diversitySize;

public int getLimeSamples() {
return limeSamples;
}

@ConfigProperty(name = "lime.samples", defaultValue = "200")
int limeSamples;

public int getLimeRetries() {
return limeRetries;
}

@ConfigProperty(name = "lime.retries", defaultValue = "2")
int limeRetries;

public boolean getLimeWLR() {
return limeWLR;
}

@ConfigProperty(name = "lime.wlr", defaultValue = "true")
boolean limeWLR;


public boolean getLimeNormalizeWeights() {
return limeNormalizeWeights;
}

@ConfigProperty(name = "lime.normalize.weights", defaultValue = "true")
boolean limeNormalizeWeights;
public int getQueueSize() {
return queueSize;
}

public int getDiversitySize() {
return diversitySize;
}

@PostConstruct
private void validateConfig() {
if (explainerType == null) {
LOGGER.error("Unknown explainer type configured. Falling back to LIME.");
Log.error("Unknown explainer type configured. Falling back to LIME.");
explainerType = ExplainerType.LIME;
}
}
Expand Down
23 changes: 18 additions & 5 deletions src/main/java/org/kie/trustyai/ExplainerFactory.java
Original file line number Diff line number Diff line change
@@ -1,34 +1,47 @@
package org.kie.trustyai;

import java.util.List;

import io.quarkus.logging.Log;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import org.kie.trustyai.explainability.local.LocalExplainer;
import org.kie.trustyai.explainability.local.lime.LimeConfig;
import org.kie.trustyai.explainability.local.lime.LimeExplainer;
import org.kie.trustyai.explainability.local.shap.ShapConfig;
import org.kie.trustyai.explainability.local.shap.ShapKernelExplainer;
import org.kie.trustyai.explainability.model.*;

import java.util.List;
import org.kie.trustyai.explainability.model.PredictionInput;
import org.kie.trustyai.explainability.model.SaliencyResults;

@Singleton
public class ExplainerFactory {

@Inject
ConfigService configService;

public LocalExplainer<SaliencyResults> getExplainer(ExplainerType type, List<PredictionInput> background) throws IllegalArgumentException {
@Inject
StreamingGeneratorManager streamingGeneratorManager;

public LocalExplainer<SaliencyResults> getExplainer(ExplainerType type) throws IllegalArgumentException {
return switch (type) {
case LIME -> {
final LimeConfig limeConfig = new LimeConfig()
.withNormalizeWeights(configService.getLimeNormalizeWeights())
.withSamples(configService.getLimeSamples())
.withRetries(configService.getLimeRetries())
.withUseWLRLinearModel(configService.getLimeWLR());
Log.info("Instating LIME explainer");
yield new LimeExplainer(limeConfig);
}
case SHAP -> {
ShapConfig shapConfig = ShapConfig.builder().withLink(ShapConfig.LinkType.IDENTITY).withBackground(background).build();
final int backgroundSize = configService.getQueueSize() + configService.getDiversitySize();
Log.debug("Requesting " + backgroundSize + " background samples from SHAP's streaming generator");
final List<PredictionInput> background = streamingGeneratorManager.getStreamingGenerator().generate(backgroundSize);
Log.debug("The background has a size of " + background.size());
final ShapConfig shapConfig = ShapConfig.builder().withRegularizer(5)
.withLink(ShapConfig.LinkType.IDENTITY)
.withBackground(background).build();
Log.info("Instantiating SHAP explainer");
yield new ShapKernelExplainer(shapConfig);
}
default -> throw new IllegalArgumentException("Unsupported explainer type: " + type);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
package org.kie.trustyai;

import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutionException;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.quarkus.logging.Log;
import jakarta.enterprise.inject.Default;
import jakarta.inject.Inject;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import org.jboss.logging.Logger;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealVector;
import org.kie.trustyai.connectors.kserve.v1.KServeV1HTTPPredictionProvider;
import org.kie.trustyai.connectors.kserve.v1.KServeV1RequestPayload;
import org.kie.trustyai.explainability.local.LocalExplainer;
import org.kie.trustyai.explainability.model.*;

import jakarta.inject.Inject;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.kie.trustyai.explainability.model.Prediction;
import org.kie.trustyai.explainability.model.PredictionInput;
import org.kie.trustyai.explainability.model.PredictionOutput;
import org.kie.trustyai.explainability.model.PredictionProvider;
import org.kie.trustyai.explainability.model.SaliencyResults;
import org.kie.trustyai.explainability.model.SimplePrediction;
import org.kie.trustyai.payloads.SaliencyExplanationResponse;

import java.util.List;
import java.util.concurrent.ExecutionException;

@Path("/v1/models/{modelName}:explain")
public class ExplainerEndpoint {

private static final Logger LOGGER = Logger.getLogger(ExplainerEndpoint.class.getName());
public class ExplainerV1Endpoint {

@Inject
ObjectMapper objectMapper;
Expand All @@ -38,37 +43,54 @@ public class ExplainerEndpoint {
@Inject
ExplainerFactory explainerFactory;

@Inject
StreamingGeneratorManager streamingGeneratorManager;

@POST
@Consumes(MediaType.APPLICATION_JSON)
public Response explainIncome(@PathParam("modelName") String modelName, KServeV1RequestPayload data)
public Response explain(@PathParam("modelName") String modelName, KServeV1RequestPayload data)
throws ExecutionException, InterruptedException {
final String predictorURI = cmdArgs.getV1HTTPPredictorURI();

LOGGER.debug("Using explainer type [" + configService.getExplainerType() + "]");
LOGGER.debug("Using predictor URI [" + predictorURI + "]");
Log.info("Using explainer type [" + configService.getExplainerType() + "]");
Log.info("Using V1 HTTP protocol");
final String predictorURI = cmdArgs.getV1HTTPPredictorURI(modelName);
final PredictionProvider provider = new KServeV1HTTPPredictionProvider(null, null, predictorURI, 1);
Log.info("Using predictor URI [" + predictorURI + "]");

final PredictionProvider provider = new KServeV1HTTPPredictionProvider(null, null, predictorURI);
final List<PredictionInput> input = data.toPredictionInputs();
final PredictionOutput output = provider.predictAsync(input).get().get(0);
final Prediction prediction = new SimplePrediction(input.get(0), output);
final int dimensions = input.get(0).getFeatures().size();

if (configService.getExplainerType() == ExplainerType.SHAP) {
if (Objects.isNull(streamingGeneratorManager.getStreamingGenerator())) {
Log.info("Initializing SHAP's Streaming Background Generator with dimension " + dimensions);
streamingGeneratorManager.initialize(dimensions);
}
final double[] numericData = new double[dimensions];
for (int i = 0; i < dimensions; i++) {
numericData[i] = input.get(0).getFeatures().get(i).getValue().asNumber();
}
final RealVector vectorData = new ArrayRealVector(numericData);
streamingGeneratorManager.getStreamingGenerator().update(vectorData);
}

final ExplainerType explainerType = configService.getExplainerType();

try {
final LocalExplainer<SaliencyResults> explainer = explainerFactory.getExplainer(explainerType, input);

final LocalExplainer<SaliencyResults> explainer = explainerFactory.getExplainer(explainerType);
Log.info("Sending explaining request to " + predictorURI);
final SaliencyResults results = explainer.explainAsync(prediction, provider).get();
final SaliencyExplanationResponse response = SaliencyExplanationResponse.fromSaliencyResults(results);

try {
String resultsJson = objectMapper.writeValueAsString(results);
return Response.ok(response, MediaType.APPLICATION_JSON).build();
} catch (Exception e) {
return Response.serverError().entity("Error serializing SaliencyResults to JSON: " + e.getMessage())
.build();
}
} catch (IllegalArgumentException e) {
return Response.serverError().entity("Explainer type not supported: " + explainerType).build();
return Response.serverError().entity("Error: " + e.getMessage()).build();
}

}
Expand Down
29 changes: 29 additions & 0 deletions src/main/java/org/kie/trustyai/StreamingGeneratorManager.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package org.kie.trustyai;

import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import org.kie.trustyai.explainability.local.shap.background.StreamingGenerator;
import org.kie.trustyai.statistics.MultivariateOnlineEstimator;
import org.kie.trustyai.statistics.distributions.gaussian.MultivariateGaussianParameters;
import org.kie.trustyai.statistics.estimators.WelfordOnlineEstimator;

@Singleton
public class StreamingGeneratorManager {

@Inject
ConfigService configService;

private StreamingGenerator streamingGenerator = null;

public synchronized void initialize(int dimensions) {
if (streamingGenerator == null && configService.getExplainerType() == ExplainerType.SHAP) {
final MultivariateOnlineEstimator<MultivariateGaussianParameters> estimator = new WelfordOnlineEstimator(dimensions);
streamingGenerator = new StreamingGenerator(dimensions, configService.getQueueSize(), configService.getDiversitySize(), estimator);
}
}

public StreamingGenerator getStreamingGenerator() {
return streamingGenerator;
}

}
Loading

0 comments on commit 65dd33e

Please sign in to comment.