Skip to content

Commit

Permalink
Cherry pick 2 OnnxRuntime Server fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
csteegz authored and shahasad committed Oct 12, 2019
1 parent 7d5b089 commit 7963e77
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 19 deletions.
5 changes: 0 additions & 5 deletions onnxruntime/server/http/core/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,6 @@ http::status HttpSession::ExecuteUserFunction(HttpContext& context) {
context.client_request_id = context.request[util::MS_CLIENT_REQUEST_ID_HEADER].to_string();
}

if (path == "/score") {
// This is a shortcut since we have only one model instance currently.
// This code path will be removed once we start supporting multiple models or multiple versions of one model.
path = "/v1/models/default/versions/1:predict";
}

auto status = routes_.ParseUrl(context.request.method(), path, model_name, model_version, action, func);

Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/server/http/predict_request_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ void Predict(const std::string& name,
auto logger = env->GetLogger(context.request_id);
logger->info("Model Name: {}, Version: {}, Action: {}", name, version, action);

auto effective_name = name.empty() ? "default" : name;
auto effective_version = version.empty() ? "1" : version;

if (!context.client_request_id.empty()) {
logger->info("{}: [{}]", util::MS_CLIENT_REQUEST_ID_HEADER, context.client_request_id);
}
Expand All @@ -64,7 +67,7 @@ void Predict(const std::string& name,
// Run Prediction
Executor executor(env.get(), context.request_id);
PredictResponse predict_response{};
auto status = executor.Predict(name, version, predict_request, predict_response);
auto status = executor.Predict(effective_name, effective_version, predict_request, predict_response);
if (!status.ok()) {
GenerateErrorResponse(logger, GetHttpStatusCode((status)), status.error_message(), context);
return;
Expand Down
9 changes: 8 additions & 1 deletion onnxruntime/server/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,18 @@ int main(int argc, char* argv[]) {
});

app.RegisterPost(
R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))",
R"(/(?:v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))|(?:score()()()))",
[&env](const auto& name, const auto& version, const auto& action, auto& context) -> void {
server::Predict(name, version, action, context, env);
});

app.RegisterPost(
R"(/score()()())",
[&env](const auto& name, const auto& version, const auto& action, auto& context) -> void {
server::Predict(name, version, action, context, env);
}
);

app.Bind(boost_address, config.http_port)
.NumThreads(config.num_http_threads)
.Run();
Expand Down
44 changes: 44 additions & 0 deletions onnxruntime/test/server/integration_tests/function_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,50 @@ def test_single_model_shortcut(self):
for i in range(0, 10):
self.assertTrue(test_util.compare_floats(actual_data[i], expected_data[i]))

def test_single_version_shortcut(self):
input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.json')
output_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_output.json')

with open(input_data_file, 'r') as f:
request_payload = f.read()

with open(output_data_file, 'r') as f:
expected_response_json = f.read()
expected_response = json.loads(expected_response_json)

request_headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'x-ms-client-request-id': 'This~is~my~id'
}

url = "http://{0}:{1}/v1/models/{2}:predict".format(self.server_ip, self.server_port, 'default')
test_util.test_log(url)
r = requests.post(url, headers=request_headers, data=request_payload)
self.assertEqual(r.status_code, 200)
self.assertEqual(r.headers.get('Content-Type'), 'application/json')
self.assertTrue(r.headers.get('x-ms-request-id'))
self.assertEqual(r.headers.get('x-ms-client-request-id'), 'This~is~my~id')

actual_response = json.loads(r.content.decode('utf-8'))

# Note:
# The 'dims' field is defined as "repeated int64" in protobuf.
# When it is serialized to JSON, all int64/fixed64/uint64 numbers are converted to string
# Reference: https://developers.google.com/protocol-buffers/docs/proto3#json

self.assertTrue(actual_response['outputs'])
self.assertTrue(actual_response['outputs']['Plus214_Output_0'])
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dims'])
self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dims'], ['1', '10'])
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dataType'])
self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dataType'], 1)
self.assertTrue(actual_response['outputs']['Plus214_Output_0']['rawData'])
actual_data = test_util.decode_base64_string(actual_response['outputs']['Plus214_Output_0']['rawData'], '10f')
expected_data = test_util.decode_base64_string(expected_response['outputs']['Plus214_Output_0']['rawData'], '10f')

for i in range(0, 10):
self.assertTrue(test_util.compare_floats(actual_data[i], expected_data[i]))

class HttpProtobufPayloadTests(unittest.TestCase):
server_ip = '127.0.0.1'
Expand Down
34 changes: 22 additions & 12 deletions onnxruntime/test/server/unit_tests/http_routes_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace onnxruntime {
namespace server {
namespace test {

static const std::string predict_regex = R"(/(?:v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict)))";
using test_data = std::tuple<http::verb, std::string, std::string, std::string, std::string, http::status>;

void do_something(const std::string& name, const std::string& version,
Expand All @@ -20,7 +21,6 @@ void do_something(const std::string& name, const std::string& version,
void run_route(const std::string& pattern, http::verb method, const std::vector<test_data>& data, bool does_validate_data);

TEST(HttpRouteTests, RegisterTest) {
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";
Routes routes;
EXPECT_TRUE(routes.RegisterController(http::verb::post, predict_regex, do_something));

Expand All @@ -29,20 +29,22 @@ TEST(HttpRouteTests, RegisterTest) {
}

TEST(HttpRouteTests, PostRouteTest) {
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";

std::vector<test_data> actions{
std::make_tuple(http::verb::post, "/v1/models/abc/versions/23:predict", "abc", "23", "predict", http::status::ok),
std::make_tuple(http::verb::post, "/v1/models/abc:predict", "abc", "", "predict", http::status::ok),
std::make_tuple(http::verb::post, "/v1/models/models/versions/45:predict", "models", "45", "predict", http::status::ok),
std::make_tuple(http::verb::post, "/v1/models/??$$%%@@$^^/versions/45:predict", "??$$%%@@$^^", "45", "predict", http::status::ok),
std::make_tuple(http::verb::post, "/v1/models/versions/versions/45:predict", "versions", "45", "predict", http::status::ok)};
std::make_tuple(http::verb::post, "/v1/models/versions/versions/45:predict", "versions", "45", "predict", http::status::ok),
std::make_tuple(http::verb::post, "/v1/models/versions:predict", "versions", "", "predict", http::status::ok),
std::make_tuple(http::verb::post, "/v1/models/default:predict", "default", "", "predict", http::status::ok)
};


run_route(predict_regex, http::verb::post, actions, true);
}

TEST(HttpRouteTests, PostRouteInvalidURLTest) {
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";

std::vector<test_data> actions{
std::make_tuple(http::verb::post, "", "", "", "", http::status::not_found),
Expand All @@ -56,27 +58,35 @@ TEST(HttpRouteTests, PostRouteInvalidURLTest) {
std::make_tuple(http::verb::post, "/models/abc/versions/2:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/versions/2:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/foo/versions/:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/foo/versions:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "v1/models/foo/versions/12:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/abc/versions/23:foo", "", "", "", http::status::not_found)};
std::make_tuple(http::verb::post, "/v1/models/abc/versions/23:foo", "", "", "", http::status::not_found)
};

run_route(predict_regex, http::verb::post, actions, false);
}

// These tests are because we currently only support POST and GET
// Some HTTP methods should be removed from test data if we support more (e.g. PUT)
TEST(HttpRouteTests, PostRouteInvalidMethodTest) {
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";

std::vector<test_data> actions{
std::make_tuple(http::verb::get, "/v1/models/abc/versions/23:predict", "abc", "23", "predict", http::status::method_not_allowed),
std::make_tuple(http::verb::put, "/v1/models", "", "", "", http::status::method_not_allowed),
std::make_tuple(http::verb::delete_, "/v1/models", "", "", "", http::status::method_not_allowed),
std::make_tuple(http::verb::head, "/v1/models", "", "", "", http::status::method_not_allowed)};
std::make_tuple(http::verb::head, "/v1/models", "", "", "", http::status::method_not_allowed)
};

run_route(predict_regex, http::verb::post, actions, false);
}

TEST(HttpRouteTests, PostRouteSpecialMethodTest){
std::vector<test_data> actions{
std::make_tuple(http::verb::post, "/score", "", "", "", http::status::ok)
};

run_route(R"(/score()()())", http::verb::post, actions, true);
}

void run_route(const std::string& pattern, http::verb method, const std::vector<test_data>& data, bool does_validate_data) {
Routes routes;
EXPECT_TRUE(routes.RegisterController(method, pattern, do_something));
Expand All @@ -95,11 +105,11 @@ void run_route(const std::string& pattern, http::verb method, const std::vector<
http::status expected_status;

std::tie(test_method, url_string, expected_name, expected_version, expected_action, expected_status) = i;
EXPECT_EQ(expected_status, routes.ParseUrl(test_method, url_string, name, version, action, fn));
EXPECT_EQ(expected_status, routes.ParseUrl(test_method, url_string, name, version, action, fn)) << "On route " << url_string;
if (does_validate_data) {
EXPECT_EQ(name, expected_name);
EXPECT_EQ(version, expected_version);
EXPECT_EQ(action, expected_action);
EXPECT_EQ(name, expected_name) << "On route " << url_string;
EXPECT_EQ(version, expected_version) << "On route " << url_string;
EXPECT_EQ(action, expected_action) << "On route " << url_string;
}
}
}
Expand Down

0 comments on commit 7963e77

Please sign in to comment.