Skip to content

Commit

Permalink
Add session to enable multi-turn conversation (deepspeedai#177)
Browse files Browse the repository at this point in the history
* add session to enable multi-turn conversation

* remove old code

* add example for chat applications

* update README and chat example

* update chat example
  • Loading branch information
tohtana authored May 10, 2023
1 parent a8f6569 commit 9527eb5
Show file tree
Hide file tree
Showing 10 changed files with 344 additions and 42 deletions.
67 changes: 67 additions & 0 deletions examples/local/chat/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Multi-turn Conversation Example for Chat Applications

MII can manage multi-turn conversations, enabling users to easily create their own chat applications.
The scripts in this folder provide a complete example of a multi-turn conversation scenario.

## Starting the server

Starting the server for your chat application requires nothing special.
Just make sure that the model supports `text-generation` and is trained for conversations.

The example script uses [AdamG012/chat-opt-1.3b-rlhf-actor-deepspeed](https://huggingface.co/AdamG012/chat-opt-1.3b-rlhf-actor-deepspeed), which was trained using [DeepSpeed-Chat](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/README.md).

```python
name = "AdamG012/chat-opt-1.3b-rlhf-actor-deepspeed"
...
mii.deploy(task='text-generation', model=name, deployment_name="chat_example_deployment")
```

## Running multi-turn conversation

The client create a *session* to make MII recognize the session of the conversation.
`create_session` creates a new session with a given name.

```python
# You can set a session name
session_id = "chat_example_session"
# You need to call `create_session` before you start a multi-turn conversation session
generator.create_session(session_id)
```

The session ID is given as a keyword argument as shown below.
Besides this, you can call `query` function as well as the normal usage of MII inference.
Note that the prompt may need to be designed according to your model.

```python
query_template = "Human: {}\n Assistant: "
print("# Start a conversation session. Type 'q' to exit.")

while True:
user_input = input("You: ")
if user_input == "q":
break
result = generator.query({"query": query_template.format(user_input)},
session_id=session_id,
max_new_tokens=128)
print(f"Bot: {result.response[0].replace('</s>', '')}")
```

Once the user finishes the session, you need to call `destroy_session` to free the internal data for the session.

```python
generator.destroy_session(session_id)
```

The following is an example conversation:
```bash
$ python chat-client-example.py
# Start a conversation session. Type 'q' to exit.
You: Can you tell me about deep learning?
Bot: Yes, it is a type of artificial intelligence that learns from data. It can process large amounts of data quickly and accurately, and it can develop sophisticated models to analyze data. Deep learning techniques are being applied to a wide variety of tasks, including image recognition, speech recognition, recommendation systems, and self-driving cars.
You: I want to try it.
Bot: Yes, it is a very powerful technology. It has the potential to revolutionize many fields, including artificial intelligence, data analysis, and machine learning.
You: Is it hard to learn?
Bot: Yes, it is a challenging field of study. It requires a lot of training and practice to develop effective models. It also requires a high level of computational power and data storage.
You: Where can I start?
Bot: Yes, it is a good idea to start with a basic model. It will help you develop the necessary skills and knowledge to progress to more advanced models..
```
29 changes: 29 additions & 0 deletions examples/local/chat/chat-client-example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import mii

# Run `chat-server-example.py` before running this script
generator = mii.mii_query_handle("chat_example_deployment")

# You can set a session name
session_id = "chat_example_session"
# You need to call `create_session` before you start a multi-turn conversation session
generator.create_session(session_id)

print("# Start a conversation session. Type 'q' to exit.")
query_template = "Human: {}\n Assistant: "
while True:
user_input = input("You: ")
if user_input == "q":
break

# A session ID is given as a keyword argument
result = generator.query({"query": query_template.format(user_input)},
session_id=session_id,
max_new_tokens=128)
print(f"Bot: {result.response[0].replace('</s>', '').strip()}")

# You need to destroy the session after finishing the conversation
generator.destroy_session(session_id)
16 changes: 16 additions & 0 deletions examples/local/chat/chat-server-example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import mii

mii_configs = {'tensor_parallel': 1}

# This checkpoint is create using DeepSpeed-Chat
# https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/README.md
name = "AdamG012/chat-opt-1.3b-rlhf-actor-deepspeed"

print(f"Deploying {name}...")

# Deploy as "text-generation" task
mii.deploy(task='text-generation', model=name, deployment_name="chat_example_deployment")
23 changes: 23 additions & 0 deletions mii/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ async def terminate_async(self):
def terminate(self):
self.asyncio_loop.run_until_complete(self.terminate_async())

async def create_session_async(self, session_id):
return await self.stub.CreateSession(
modelresponse_pb2.SessionID(session_id=session_id))

def create_session(self, session_id):
return self.asyncio_loop.run_until_complete(
self.create_session_async(session_id))

async def destroy_session_async(self, session_id):
await self.stub.DestroySession(modelresponse_pb2.SessionID(session_id=session_id)
)

def destroy_session(self, session_id):
self.asyncio_loop.run_until_complete(self.destroy_session_async(session_id))


class MIITensorParallelClient():
"""
Expand Down Expand Up @@ -133,6 +148,14 @@ def terminate(self):
for client in self.clients:
client.terminate()

def create_session(self, session_id):
for client in self.clients:
client.create_session(session_id)

def destroy_session(self, session_id):
for client in self.clients:
client.destroy_session(session_id)


def terminate_restful_gateway(deployment_name):
_, mii_configs = _get_deployment_info(deployment_name)
Expand Down
2 changes: 2 additions & 0 deletions mii/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ class ModelProvider(enum.Enum):
GRPC_MAX_MSG_SIZE = 2**27 # ~100MB

TERMINATE_METHOD = "Terminate"
CREATE_SESSION_METHOD = "CreateSession"
DESTROY_SESSION_METHOD = "DestroySession"

LB_MAX_WORKER_THREADS = 32

Expand Down
56 changes: 51 additions & 5 deletions mii/grpc_related/modelresponse_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
import threading
import time

from mii.constants import GRPC_MAX_MSG_SIZE, TERMINATE_METHOD, LB_MAX_WORKER_THREADS, SERVER_SHUTDOWN_TIMEOUT
from mii.constants import GRPC_MAX_MSG_SIZE, CREATE_SESSION_METHOD, DESTROY_SESSION_METHOD, TERMINATE_METHOD, LB_MAX_WORKER_THREADS, SERVER_SHUTDOWN_TIMEOUT
from mii.method_table import GRPC_METHOD_TABLE
from mii.client import create_channel
from mii.utils import get_task
from mii.utils import get_task, unpack_proto_query_kwargs


class ServiceBase(modelresponse_pb2_grpc.ModelResponseServicer):
Expand All @@ -42,6 +42,7 @@ def __init__(self, inference_pipeline):
super().__init__()
self.inference_pipeline = inference_pipeline
self.method_name_to_task = {m["method"]: t for t, m in GRPC_METHOD_TABLE.items()}
self.session_context = {}
self.lock = threading.Lock()

def _get_model_time(self, model, sum_times=False):
Expand All @@ -61,6 +62,18 @@ def _get_model_time(self, model, sum_times=False):
model_time = -1
return model_time

def CreateSession(self, request, context):
if request.session_id in self.session_context:
raise ValueError(f"session {request.session_id} already exists")
self.session_context[request.session_id] = None
return google_dot_protobuf_dot_empty__pb2.Empty()

def DestroySession(self, request, context):
if request.session_id not in self.session_context:
raise ValueError(f"session {request.session_id} does not exist")
del self.session_context[request.session_id]
return google_dot_protobuf_dot_empty__pb2.Empty()

def _run_inference(self, method_name, request_proto):
if method_name not in self.method_name_to_task:
raise ValueError(f"unknown method: {method_name}")
Expand All @@ -72,11 +85,23 @@ def _run_inference(self, method_name, request_proto):
conversions = GRPC_METHOD_TABLE[task]
args, kwargs = conversions["unpack_request_from_proto"](request_proto)

session_id = kwargs.pop("session_id", None)
if session_id and "preprocess_session" in GRPC_METHOD_TABLE[task]:
args, kwargs = GRPC_METHOD_TABLE[task]["preprocess_session"](session_id, self.session_context, args, kwargs)

start = time.time()
with self.lock:
response = self.inference_pipeline(*args, **kwargs)
end = time.time()

if session_id and "postprocess_session" in GRPC_METHOD_TABLE[task]:
response = GRPC_METHOD_TABLE[task]["postprocess_session"](
session_id,
self.session_context,
args,
kwargs,
response)

model_time = self._get_model_time(self.inference_pipeline.model,
sum_times=True) if hasattr(
self.inference_pipeline,
Expand Down Expand Up @@ -165,6 +190,7 @@ def __init__(self, task_name, replica_configs):
]
self.counter = AtomicCounter()
self.task = get_task(task_name)
self.replica_sessions = {}

# Start the asyncio loop in a separate thread
def run_asyncio_loop(loop):
Expand All @@ -191,9 +217,29 @@ def invoke_intercept_method(request_proto, context):
return next_handler.unary_unary(request_proto, context)

call_count = self.counter.get_and_increment()
ret = self.stubs[call_count % len(self.stubs)].invoke(
method_name,
request_proto)
replica_index = call_count % len(self.stubs)

if method_name == CREATE_SESSION_METHOD:
if request_proto.session_id in self.sessions:
raise ValueError(
f"session {request_proto.session_id} already exists")
self.replica_sessions[request_proto.session_id] = replica_index
self.stubs[replica_index].invoke(CREATE_SESSION_METHOD, request_proto)
return google_dot_protobuf_dot_empty__pb2.Empty()

if method_name == DESTROY_SESSION_METHOD:
replica_index = self.replica_sessions.pop(request_proto.session_id)
self.stubs[replica_index].invoke(DESTROY_SESSION_METHOD, request_proto)
return google_dot_protobuf_dot_empty__pb2.Empty()

kwargs = unpack_proto_query_kwargs(request_proto.query_kwargs)
if "session_id" in kwargs:
session_id = kwargs["session_id"]
if session_id not in self.replica_sessions:
raise ValueError(f"session not found")
replica_index = self.replica_sessions[session_id]

ret = self.stubs[replica_index].invoke(method_name, request_proto)
return ret

return grpc.unary_unary_rpc_method_handler(
Expand Down
6 changes: 6 additions & 0 deletions mii/grpc_related/proto/modelresponse.proto
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ package modelresponse;

service ModelResponse {
rpc Terminate (google.protobuf.Empty) returns (google.protobuf.Empty) {}
rpc CreateSession (SessionID) returns (google.protobuf.Empty) {}
rpc DestroySession (SessionID) returns (google.protobuf.Empty) {}
rpc GeneratorReply (MultiStringRequest) returns (MultiStringReply) {}
rpc ClassificationReply (SingleStringRequest) returns (SingleStringReply) {}
rpc QuestionAndAnswerReply(QARequest) returns (SingleStringReply) {}
Expand All @@ -43,6 +45,10 @@ message Value {
}
}

message SessionID {
string session_id = 1;
}

message SingleStringRequest {
string request = 1;
map<string,Value> query_kwargs = 2;
Expand Down
Loading

0 comments on commit 9527eb5

Please sign in to comment.