Skip to content

Commit

Permalink
feat: adds 'regenerateLast' feature to answer session
Browse files Browse the repository at this point in the history
  • Loading branch information
micheleriva committed Jul 24, 2024
1 parent 58692e5 commit 5fcfa73
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
33 changes: 33 additions & 0 deletions Sources/oramacloud-client/answer-session.swift
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ class AnswerSession<Doc: Codable> {
let message: String
}

enum MaybeStream {
case stream(AsyncThrowingStream<String, Error>)
case string(String)
case error(Error)
}

private let eventEmitter = EventEmitter()
private let endpointBaseURL = "https://answer.api.orama.com"
private var abortController: Task<Void, Error>?
Expand All @@ -98,6 +104,7 @@ class AnswerSession<Doc: Codable> {
private var messages: [AnswerParams<Doc>.Message]
private var inferenceType: AnswerParams<Doc>.InferenceType
private var state: [AnswerParams<Doc>.Interaction<Doc>] = []
private var lastInteractionParams: AnswerParams<Doc>.AskParams?

init(params: AnswerParams<Doc>) {
userContext = params.userContext ?? .string
Expand Down Expand Up @@ -142,9 +149,35 @@ class AnswerSession<Doc: Codable> {
return response
}

public func regenerateLast(stream: Bool = true) async throws -> MaybeStream {
if state.isEmpty || self.messages.isEmpty {
throw OramaClientError.runtimeError("No messages to regenerate")
}

let isLastMessageAssistant = self.messages.last?.role == .assistant

if !isLastMessageAssistant {
throw OramaClientError.runtimeError("Last message is not an assistant message")
}

if self.lastInteractionParams == nil {
throw OramaClientError.runtimeError("Cannot find last interaction params")
}

self.messages.removeLast()
self.state.removeLast()

if stream {
return .stream(try await fetchAnswer(params: self.lastInteractionParams!))
} else {
return .string(try await ask(params: self.lastInteractionParams!))
}
}

private func fetchAnswer(params: AnswerParams<Doc>.AskParams) async throws -> AsyncThrowingStream<String, Error> {
AsyncThrowingStream { continuation in
let interactionId = Cuid.generateId()
self.lastInteractionParams = params
self.abortController = Task {
do {
self.state.append(AnswerParams<Doc>.Interaction(
Expand Down
6 changes: 6 additions & 0 deletions Sources/oramacloud-client/types.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,9 @@ struct SearchRequestPayload: Encodable {
return jsonObject
}
}

// ======================== ERRORS ========================

enum OramaClientError: Error {
case runtimeError(String)
}
27 changes: 27 additions & 0 deletions Tests/oramacloud-clientTests/oramacloud_clientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,31 @@ final class oramacloud_clientTests: XCTestCase {
wait(for: [expectation], timeout: 60.0)
}

func testE2ERegenerateLastAnswer() async throws {
let expectation = XCTestExpectation(description: "Can correctly regenerate the last answer")
var state: [AnswerParams<E2EDoc>.Interaction<E2EDoc>] = []

do {
_ = answerSession.on(event: .stateChange) {
state = $0 as! [AnswerParams<E2EDoc>.Interaction<E2EDoc>]
}

let _ = try await answerSession.ask(params: AnswerParams.AskParams(query: "german", userData: nil, related: nil))
let _ = try await answerSession.ask(params: AnswerParams.AskParams(query: "labrador", userData: nil, related: nil))
let _ = try await answerSession.regenerateLast(stream: false)

expectation.fulfill()
XCTAssertEqual(state.count, 2, "Should contain 2 interactions")
XCTAssertEqual(state.last!.query, "labrador", "Second query should be 'labrador'")

await fulfillment(of: [expectation], timeout: 120.0)

} catch {
XCTFail("Test failed with error: \(error)")
expectation.fulfill()
}
}

func testE2EIndexManager() throws {
struct DocumentStruct: Codable {
let breed: String
Expand Down Expand Up @@ -181,5 +206,7 @@ extension oramacloud_clientTests {
("testE2EAnswerSession", testE2EAnswerSession),
("testOnMessageLoading", testOnMessageLoading),
("testAsyncE2EAnswerSession", testAsyncE2EAnswerSession),
("testE2ERegenerateLastAnswer", testE2ERegenerateLastAnswer),
("testE2EIndexManager", testE2EIndexManager)
]
}

0 comments on commit 5fcfa73

Please sign in to comment.