From d20695fcd4301526cf273cb217db2c5f7b8fb2b4 Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Wed, 17 Jul 2024 17:26:02 -0700 Subject: [PATCH] adds answer session interface --- .../oramacloud-client/answer-session.swift | 222 ++++++++++++++++++ Sources/oramacloud-client/client.swift | 4 +- 2 files changed, 224 insertions(+), 2 deletions(-) create mode 100644 Sources/oramacloud-client/answer-session.swift diff --git a/Sources/oramacloud-client/answer-session.swift b/Sources/oramacloud-client/answer-session.swift new file mode 100644 index 0000000..913898d --- /dev/null +++ b/Sources/oramacloud-client/answer-session.swift @@ -0,0 +1,222 @@ +import Foundation + +@available(macOS 13.0, *) +struct AnswerParams { + let initialMessages: [Message] + let inferenceType: InferenceType + let oramaClient: OramaClient + let userContext: UserSpecs + let events: Events? + + enum InferenceType: Encodable, Decodable { + case documentation + } + + enum RelatedFormat: Encodable, Decodable { + case question + case query + } + + enum Role: Encodable, Decodable { + case user + case assistant + } + + struct Message: Encodable, Decodable { + let role: Role + var content: String + } + + enum UserSpecs: Encodable, Decodable { + case string + case JSObject + } + + struct Related: Encodable, Decodable { + var howMany: Int? = 3 + var format: RelatedFormat? = .question + } + + struct AskParams: Encodable, Decodable { + let query: String + let userData: UserSpecs? + let related: Related? + } + + struct Events { + var onMessageChange: (([Message]) -> Void)? + var onMessageLoading: ((Bool) -> Void)? + var onAnswerAborted: ((Bool) -> Void)? + var onSourceChange: ((SearchResults) -> Void)? + var onQueryTranslated: ((ClientSearchParams) -> Void)? + var onRelatedQueries: (([String]) -> Void)? + } +} + +@available(macOS 13.0, *) +class AnswerSession { + + struct SSEMessage: Codable { + let type: String + let message: String + } + + private let endpointBaseURL = "https://answer.api.orama.com" + private var abortController: Task? + private var endpoint: String + private let userContext: AnswerParams.UserSpecs + private let events: AnswerParams.Events? + private let searchEndpoint: String + private var messages: [AnswerParams.Message] + private var inferenceType: AnswerParams.InferenceType + + init(params: AnswerParams) { + self.userContext = params.userContext + self.events = params.events + self.messages = params.initialMessages + self.inferenceType = params.inferenceType + self.endpoint = "\(self.endpointBaseURL)/v1/answer?api-key=\(params.oramaClient.apiKey)" + self.searchEndpoint = params.oramaClient.endpoint + } + + func fetchAnswer(params: AnswerParams.AskParams) async throws -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + self.abortController = Task { + do { + guard let oramaAnswerEndpointURL = URL(string: self.endpoint) else { + throw URLError(.badURL) + } + + var request = URLRequest(url: oramaAnswerEndpointURL) + + request.httpMethod = "POST" + request.setValue("application/x-www-form-urlencoded", forHTTPHeaderField: "Content-Type") + request.httpBody = try self.buildRequestBody(params: params) + + let (responseStream, response) = try await URLSession.shared.bytes(for: request) + + guard let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode == 200 else { + throw NSError(domain: "NetworkError", code: 0, userInfo: [NSLocalizedDescriptionKey: "Invalid response"]) + } + + self.events?.onMessageLoading?(true) + self.addNewEmptyAssistantMessage() + + guard var lastMessage = self.messages.last else { + throw NSError(domain: "MessageError", code: 0, userInfo: [NSLocalizedDescriptionKey: "No messages available"]) + } + + var buffer = "" + for try await byte in responseStream { + if Task.isCancelled { break } + + buffer += String(bytes: [byte], encoding: .utf8) ?? "" + + while let endIndex = buffer.firstIndex(of: "\n") { + let rawMessage = String(buffer[...self, from: sourcesData) { + self.events?.onSourceChange?(sources) + } + case "query-translated": + if let queryData = parsedMessage.message.data(using: .utf8), + let query = try? JSONDecoder().decode(ClientSearchParams.self, from: queryData) { + self.events?.onQueryTranslated?(query) + } + case "related-queries": + if let queriesData = parsedMessage.message.data(using: .utf8), + let queries = try? JSONDecoder().decode([String].self, from: queriesData) { + self.events?.onRelatedQueries?(queries) + } + case "text": + lastMessage.content += parsedMessage.message + self.events?.onMessageChange?(self.messages) + continuation.yield(lastMessage.content) + default: + break + } + } + } + } + continuation.finish() + + } catch { + if error is CancellationError { + self.events?.onAnswerAborted?(true) + } else { + continuation.finish(throwing: error) + } + } + + self.events?.onMessageLoading?(false) + } + } + } + + private func addNewEmptyAssistantMessage() { + self.messages.append(AnswerParams.Message(role: .assistant, content: "")) + self.events?.onMessageChange?(self.messages) + } + + private func buildRequestBody(params: AnswerParams.AskParams) throws -> Data { + var components = URLComponents() + components.queryItems = [ + URLQueryItem(name: "type", value: "documentation"), // @todo: remove hardcoded value here + URLQueryItem(name: "messages", value: encodeJSON(self.messages)), + URLQueryItem(name: "query", value: params.query ?? ""), + // URLQueryItem(name: "conversationId", value: self.conversationID), + // URLQueryItem(name: "userId", value: self.userID), + URLQueryItem(name: "endpoint", value: self.searchEndpoint), + URLQueryItem(name: "searchParams", value: encodeJSON(params)) + ] + + if params.userData != nil { + components.queryItems?.append(URLQueryItem(name: "userData", value: encodeJSON(params.userData))) + } + + if params.related != nil { + components.queryItems?.append(URLQueryItem(name: "related", value: encodeJSON(params.related))) + } + + guard let encodedQuery = components.percentEncodedQuery?.data(using: .utf8) else { + throw URLError(.badURL) + } + + return encodedQuery + } + + private func parseSSE(_ rawMessage: String) -> (event: String, data: String)? { + let lines = rawMessage.split(separator: "\n") + var event = "" + var data = "" + + for line in lines { + if line.starts(with: "event:") { + event = String(line.dropFirst(6).trimmingCharacters(in: .whitespaces)) + } else if line.starts(with: "data:") { + data = String(line.dropFirst(5).trimmingCharacters(in: .whitespaces)) + } + } + + return (event, data) + } + + private func encodeJSON(_ value: T) -> String? { + let encoder = JSONEncoder() + do { + let data = try encoder.encode(value) + return String(data: data, encoding: .utf8) + } catch { + print("Failed to encode JSON: \(error)") + return nil + } + } +} diff --git a/Sources/oramacloud-client/client.swift b/Sources/oramacloud-client/client.swift index 2043dee..8f9e828 100644 --- a/Sources/oramacloud-client/client.swift +++ b/Sources/oramacloud-client/client.swift @@ -2,9 +2,9 @@ import Foundation @available(macOS 13.0, *) final class OramaClient { + let apiKey: String + let endpoint: String private let id: String - private let apiKey: String - private let endpoint: String private let debouncer = Debouncer() private var searchRequestCounter: Int = 0