Skip to content

Commit

Permalink
fixes events
Browse files Browse the repository at this point in the history
  • Loading branch information
micheleriva committed Jul 20, 2024
1 parent 02fe7aa commit ee2b1b2
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 64 deletions.
53 changes: 21 additions & 32 deletions Sources/oramacloud-client/answer-session.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ struct AnswerParams<Doc: Encodable & Decodable> {
var onStateChange: (([Interaction<Doc>]) -> Void)?
}

enum Event {
enum Event: String {
case messageChange
case messageLoading
case answerAborted
Expand All @@ -86,6 +86,7 @@ class AnswerSession<Doc: Encodable & Decodable> {
let message: String
}

private let eventEmitter = EventEmitter()
private let endpointBaseURL = "https://answer.api.orama.com"
private var abortController: Task<Void, Error>?
private var endpoint: String
Expand All @@ -108,25 +109,9 @@ class AnswerSession<Doc: Encodable & Decodable> {
}

public func on(event: AnswerParams<Doc>.Event, callback: @escaping AnswerParams<Doc>.Events.Callback) -> AnswerSession<Doc> {
switch event {
case .messageChange:
events?.onMessageChange = { callback($0) }
case .messageLoading:
events?.onMessageLoading = { callback($0) }
case .answerAborted:
events?.onAnswerAborted = { callback($0) }
case .sourceChange:
events?.onSourceChange = { callback($0) }
case .queryTranslated:
events?.onQueryTranslated = { callback($0) }
case .relatedQueries:
events?.onRelatedQueries = { callback($0) }
case .newInteractionStarted:
events?.onNewInteractionStarted = { callback($0) }
case .stateChange:
events?.onStateChange = { callback($0) }
eventEmitter.on(event.rawValue) { data in
callback(data)
}

return self
}

Expand Down Expand Up @@ -177,7 +162,7 @@ class AnswerSession<Doc: Encodable & Decodable> {
throw URLError(.badServerResponse)
}

self.events?.onMessageLoading?(true)
self.eventEmitter.emit("messageLoading", data: true)
self.addNewEmptyAssistantMessage()

var buffer = ""
Expand All @@ -199,29 +184,30 @@ class AnswerSession<Doc: Encodable & Decodable> {
let sources = try? JSONDecoder().decode(SearchResults<Doc>.self, from: sourcesData)
{
self.state[currentInteractionIndex].sources = sources
self.events?.onSourceChange?(sources)
self.events?.onStateChange?(self.state)
self.eventEmitter.emit("onSourceChange", data: sources)
self.eventEmitter.emit("onStateChange", data: self.state)
}
case "query-translated":
if let queryData = parsedMessage.message.data(using: .utf8),
let query = try? JSONDecoder().decode(ClientSearchParams.self, from: queryData)
{
self.state[currentInteractionIndex].translatedQuery = query
self.events?.onQueryTranslated?(query)
self.events?.onStateChange?(self.state)
self.eventEmitter.emit("queryTranslated", data: query)
self.eventEmitter.emit("stateChange", data: self.state)
}
case "related-queries":
if let queriesData = parsedMessage.message.data(using: .utf8),
let queries = try? JSONDecoder().decode([String].self, from: queriesData)
{
self.state[currentInteractionIndex].relatedQueries = queries
self.events?.onRelatedQueries?(queries)
self.events?.onStateChange?(self.state)
self.eventEmitter.emit("relatedQueries", data: queries)
self.eventEmitter.emit("stateChange", data: self.state)
}
case "text":
self.state[currentInteractionIndex].response += parsedMessage.message
self.events?.onMessageChange?(self.messages)
self.events?.onStateChange?(self.state)
self.eventEmitter.emit("messageChange", data: self.messages)
self.eventEmitter.emit("stateChange", data: self.state)

continuation.yield(self.state[currentInteractionIndex].response)
default:
break
Expand All @@ -236,8 +222,11 @@ class AnswerSession<Doc: Encodable & Decodable> {
let index = self.state.firstIndex(where: { $0.interactionId == interactionId })!
self.state[index].loading = false
self.state[index].aborted = true
self.events?.onAnswerAborted?(true)
self.events?.onStateChange?(self.state)

self.eventEmitter.emit("messageLoading", data: false)
self.eventEmitter.emit("onAnswerAborted", data: true)
self.eventEmitter.emit("stateChange", data: self.state)

continuation.finish()
} else {
continuation.finish(throwing: error)
Expand All @@ -246,8 +235,8 @@ class AnswerSession<Doc: Encodable & Decodable> {

let index = self.state.firstIndex(where: { $0.interactionId == interactionId })!
self.state[index].loading = false
self.events?.onStateChange?(self.state)
self.events?.onMessageLoading?(false)
self.eventEmitter.emit("messageLoading", data: false)
self.eventEmitter.emit("stateChange", data: self.state)
continuation.finish()
}
}
Expand Down
23 changes: 23 additions & 0 deletions Sources/oramacloud-client/utils.swift
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Foundation

@available(macOS 13.0, iOS 16.0, watchOS 9.0, tvOS 16.0, *)
class Debouncer {
private var task: Task<Void, Never>?
Expand All @@ -16,3 +18,24 @@ class Debouncer {
}
}
}

class EventEmitter {
typealias EventCallback = (Any) -> Void

private var eventCallbacks: [String: [EventCallback]] = [:]

func on(_ eventName: String, callback: @escaping EventCallback) {
if eventCallbacks[eventName] == nil {
eventCallbacks[eventName] = []
}
eventCallbacks[eventName]?.append(callback)
fflush(stdout)
}

func emit(_ eventName: String, data: Any) {
fflush(stdout)
eventCallbacks[eventName]?.forEach { callback in
callback(data)
}
}
}
96 changes: 64 additions & 32 deletions Tests/oramacloud-clientTests/oramacloud_clientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,62 +8,94 @@ struct E2ETest1Document: Encodable & Decodable {
let e2eEndpoint = "https://cloud.orama.run/v1/indexes/e2e-index-client-rv4bdd"
let e2eApiKey = "eaXWAKLxn05lefXAfB3wAhuTq3VaXGqx"

func testE2EAnswerSession() async throws {
@available(macOS 12.0, *)
final class oramacloud_clientTests: XCTestCase {
struct E2EDoc: Codable {
let breed: String
}

let clientParams = OramaClientParams(endpoint: e2eEndpoint, apiKey: e2eApiKey)
let orama = OramaClient(params: clientParams)
let answerSessionParams = AnswerParams<E2EDoc>(
initialMessages: [],
inferenceType: .documentation,
oramaClient: orama,
userContext: nil,
events: nil
)

let answerSession = AnswerSession(params: answerSessionParams)

let askParams = AnswerParams<E2EDoc>.AskParams(query: "german", userData: nil, related: nil)

do {
let response = try await answerSession.ask(params: askParams)
XCTAssertFalse(response.isEmpty, "Response should not be empty")
} catch {
XCTFail("AnswerSession failed with error: \(error)")
var oramaClient: OramaClient!
var answerSession: AnswerSession<E2EDoc>!

override func setUp() {
super.setUp()
let clientParams = OramaClientParams(endpoint: e2eEndpoint, apiKey: e2eApiKey)
oramaClient = OramaClient(params: clientParams)

let answerParams = AnswerParams<E2EDoc>(
initialMessages: [],
inferenceType: .documentation,
oramaClient: oramaClient,
userContext: nil,
events: nil
)
answerSession = AnswerSession(params: answerParams)
}
}

@available(macOS 12.0, *)
final class oramacloud_clientTests: XCTestCase {
func testE2ESearch() async throws {
let expectation = XCTestExpectation(description: "Async search completes")
let clientParams = OramaClientParams(endpoint: e2eEndpoint, apiKey: e2eApiKey)
let orama = OramaClient(params: clientParams)

let params = ClientSearchParams.builder(term: "German", mode: .fulltext)
.limit(10)
.build()

Task {
do {
let searchResults: SearchResults<E2ETest1Document> = try await orama.search(query: params)
let params = ClientSearchParams.builder(term: "German", mode: .fulltext)
.limit(10)
.offset(0)
.build()
let searchResults: SearchResults<E2ETest1Document> = try await oramaClient.search(query: params)

XCTAssertGreaterThan(searchResults.count, 0)
XCTAssertNotNil(searchResults.elapsed.raw)
XCTAssertNotNil(searchResults.elapsed.formatted)
XCTAssertGreaterThan(searchResults.hits.count, 0)
expectation.fulfill()
} catch {
print("Search failed with error: \(error)")
debugPrint("Search failed with error: \(error)")
fflush(stdout)
XCTFail("Search failed with error: \(error)")
expectation.fulfill()
}
}

wait(for: [expectation], timeout: 10.0)
await fulfillment(of: [expectation], timeout: 10.0)
}

func testE2EAnswerSession() async throws {
let answerSessionParams = AnswerParams<E2EDoc>(
initialMessages: [],
inferenceType: .documentation,
oramaClient: oramaClient,
userContext: nil,
events: nil
)

let answerSession = AnswerSession(params: answerSessionParams)

let askParams = AnswerParams<E2EDoc>.AskParams(query: "german", userData: nil, related: nil)

do {
let response = try await answerSession.ask(params: askParams)
XCTAssertFalse(response.isEmpty, "Response should not be empty")
} catch {
XCTFail("AnswerSession failed with error: \(error)")
}
}

func testOnMessageLoading() async throws {
let expectation = XCTestExpectation(description: "Message loading event called")
var events: [Bool] = []

_ = answerSession.on(event: .messageLoading) { isLoading in
events.append(isLoading as! Bool)
if events.count == 2 { // Expecting two events: true and false
expectation.fulfill()
}
}

let _ = try await answerSession.ask(params: AnswerParams.AskParams(query: "german", userData: nil, related: nil))

await fulfillment(of: [expectation], timeout: 10.0)

XCTAssertEqual(events, [true, false], "Expected two message loading events: true followed by false")
}

func testAsyncE2EAnswerSession() throws {
Expand Down

0 comments on commit ee2b1b2

Please sign in to comment.