Skip to content

Commit

Permalink
Add RunOptions to Servable interface.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 550921545
  • Loading branch information
cky9301 authored and tensorflow-copybara committed Jul 25, 2023
1 parent 3d8e8c3 commit 6a9d0fd
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 19 deletions.
2 changes: 2 additions & 0 deletions tensorflow_serving/servables/tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,8 @@ cc_library(
"//visibility:public",
],
deps = [
":predict_response_tensor_serialization_option",
":thread_pool_factory",
"//tensorflow_serving/apis:classification_cc_proto",
"//tensorflow_serving/apis:get_model_metadata_cc_proto",
"//tensorflow_serving/apis:inference_cc_proto",
Expand Down
17 changes: 11 additions & 6 deletions tensorflow_serving/servables/tensorflow/mock_servable.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_MOCK_SERVABLE_H_
#define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_MOCK_SERVABLE_H_

#include <gmock/gmock.h>
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "tensorflow_serving/servables/tensorflow/servable.h"
#include "tensorflow_serving/test_util/test_util.h"

namespace tensorflow {
namespace serving {
Expand All @@ -31,20 +31,25 @@ class MockServable : public Servable {
~MockServable() override = default;

MOCK_METHOD(absl::Status, Classify,
(const tensorflow::serving::ClassificationRequest& request,
(const tensorflow::serving::Servable::RunOptions& run_options,
const tensorflow::serving::ClassificationRequest& request,
tensorflow::serving::ClassificationResponse* response));
MOCK_METHOD(absl::Status, Regress,
(const tensorflow::serving::RegressionRequest& request,
(const tensorflow::serving::Servable::RunOptions& run_options,
const tensorflow::serving::RegressionRequest& request,
tensorflow::serving::RegressionResponse* response));
MOCK_METHOD(absl::Status, Predict,
(const tensorflow::serving::PredictRequest& request,
(const tensorflow::serving::Servable::RunOptions& run_options,
const tensorflow::serving::PredictRequest& request,
tensorflow::serving::PredictResponse* response));
MOCK_METHOD(absl::Status, PredictStreamed,
(const tensorflow::serving::PredictRequest& request,
(const tensorflow::serving::Servable::RunOptions& run_options,
const tensorflow::serving::PredictRequest& request,
absl::AnyInvocable<void(tensorflow::serving::PredictResponse)>
response_callback));
MOCK_METHOD(absl::Status, MultiInference,
(const tensorflow::serving::MultiInferenceRequest& request,
(const tensorflow::serving::Servable::RunOptions& run_options,
const tensorflow::serving::MultiInferenceRequest& request,
tensorflow::serving::MultiInferenceResponse* response));
MOCK_METHOD(absl::Status, GetModelMetadata,
(const tensorflow::serving::GetModelMetadataRequest& request,
Expand Down
41 changes: 31 additions & 10 deletions tensorflow_serving/servables/tensorflow/servable.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ limitations under the License.
#include "tensorflow_serving/apis/inference.pb.h"
#include "tensorflow_serving/apis/predict.pb.h"
#include "tensorflow_serving/apis/regression.pb.h"
#include "tensorflow_serving/servables/tensorflow/predict_response_tensor_serialization_option.h"
#include "tensorflow_serving/servables/tensorflow/thread_pool_factory.h"

namespace tensorflow {
namespace serving {
Expand All @@ -48,13 +50,27 @@ class Servable {
// Returns the version associated with this servable.
int64_t version() const { return version_; }

virtual absl::Status Classify(const ClassificationRequest& request,
// RunOptions group the configuration for individual inference executions.
// The per-request configuration (e.g. deadline) can be passed here.
struct RunOptions {
// Priority of the request. Some thread pool implementation will schedule
// ops based on the priority number. Larger number means higher
// priority.
int64_t priority = 1;
// The deadline for this request.
absl::Time deadline = absl::InfiniteFuture();
};

virtual absl::Status Classify(const RunOptions& run_options,
const ClassificationRequest& request,
ClassificationResponse* response) = 0;

virtual absl::Status Regress(const RegressionRequest& request,
virtual absl::Status Regress(const RunOptions& run_options,
const RegressionRequest& request,
RegressionResponse* response) = 0;

virtual absl::Status Predict(const PredictRequest& request,
virtual absl::Status Predict(const RunOptions& run_options,
const PredictRequest& request,
PredictResponse* response) = 0;

// Streamed version of `Predict`. Experimental API that is not yet part of the
Expand All @@ -67,10 +83,11 @@ class Servable {
// callback invocation to be delayed. The implementation guarantees that the
// callback is never called after the `PredictStreamed` method returns.
virtual absl::Status PredictStreamed(
const PredictRequest& request,
const RunOptions& run_options, const PredictRequest& request,
absl::AnyInvocable<void(PredictResponse)> response_callback) = 0;

virtual absl::Status MultiInference(const MultiInferenceRequest& request,
virtual absl::Status MultiInference(const RunOptions& run_options,
const MultiInferenceRequest& request,
MultiInferenceResponse* response) = 0;

virtual absl::Status GetModelMetadata(const GetModelMetadataRequest& request,
Expand All @@ -95,28 +112,32 @@ class EmptyServable : public Servable {
public:
EmptyServable();

absl::Status Classify(const ClassificationRequest& request,
absl::Status Classify(const RunOptions& run_options,
const ClassificationRequest& request,
ClassificationResponse* response) override {
return error_;
}

absl::Status Regress(const RegressionRequest& request,
absl::Status Regress(const RunOptions& run_options,
const RegressionRequest& request,
RegressionResponse* response) override {
return error_;
}

absl::Status Predict(const PredictRequest& request,
absl::Status Predict(const RunOptions& run_options,
const PredictRequest& request,
PredictResponse* response) override {
return error_;
}

absl::Status PredictStreamed(
const PredictRequest& request,
const RunOptions& run_options, const PredictRequest& request,
absl::AnyInvocable<void(PredictResponse)> response_callback) override {
return error_;
}

absl::Status MultiInference(const MultiInferenceRequest& request,
absl::Status MultiInference(const RunOptions& run_options,
const MultiInferenceRequest& request,
MultiInferenceResponse* response) override {
return error_;
}
Expand Down
7 changes: 4 additions & 3 deletions tensorflow_serving/servables/tensorflow/servable_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@ limitations under the License.

#include "tensorflow_serving/servables/tensorflow/servable.h"

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "tensorflow_serving/apis/predict.pb.h"
#include "tensorflow_serving/test_util/test_util.h"

namespace tensorflow {
namespace serving {
namespace {

TEST(EmptyServableTest, Predict) {
PredictResponse response;
EXPECT_EQ(EmptyServable().Predict(PredictRequest(), &response).code(),
EXPECT_EQ(EmptyServable()
.Predict(Servable::RunOptions(), PredictRequest(), &response)
.code(),
absl::StatusCode::kFailedPrecondition);
}

Expand Down

0 comments on commit 6a9d0fd

Please sign in to comment.