From ce1fd65c3e2d8ca6e72642cae53403dfe949bd30 Mon Sep 17 00:00:00 2001 From: Kevin Hermawan <84965338+kevinhermawan@users.noreply.github.com> Date: Thu, 31 Oct 2024 06:48:11 +0700 Subject: [PATCH] refactor: improves error handling (#15) --- Playground/Playground/Views/ChatView.swift | 32 +++- .../Views/Subviews/CancelButton.swift | 27 ++++ README.md | 15 +- Sources/LLMChatOpenAI/ChatOptions.swift | 6 +- .../Documentation.docc/Documentation.md | 15 +- Sources/LLMChatOpenAI/LLMChatOpenAI.swift | 85 ++++++---- .../LLMChatOpenAI/LLMChatOpenAIError.swift | 32 ++-- .../ChatCompletionTests.swift | 153 +++++++++++++++--- .../Utils/URLProtocolMock.swift | 17 +- 9 files changed, 294 insertions(+), 88 deletions(-) create mode 100644 Playground/Playground/Views/Subviews/CancelButton.swift diff --git a/Playground/Playground/Views/ChatView.swift b/Playground/Playground/Views/ChatView.swift index c2c64cd..3514f70 100644 --- a/Playground/Playground/Views/ChatView.swift +++ b/Playground/Playground/Views/ChatView.swift @@ -20,6 +20,9 @@ struct ChatView: View { @State private var outputTokens: Int = 0 @State private var totalTokens: Int = 0 + @State private var isGenerating: Bool = false + @State private var generationTask: Task? + var body: some View { @Bindable var viewModelBindable = viewModel @@ -37,8 +40,11 @@ struct ChatView: View { } VStack { - SendButton(stream: viewModel.stream, onSend: onSend, onStream: onStream) - .disabled(viewModel.models.isEmpty) + if isGenerating { + CancelButton(onCancel: { generationTask?.cancel() }) + } else { + SendButton(stream: viewModel.stream, onSend: onSend, onStream: onStream) + } } } .toolbar { @@ -64,6 +70,8 @@ struct ChatView: View { private func onSend() { clear() + isGenerating = true + let messages = [ ChatMessage(role: .system, content: viewModel.systemPrompt), ChatMessage(role: .user, content: prompt) @@ -71,8 +79,13 @@ struct ChatView: View { let options = ChatOptions(temperature: viewModel.temperature) - Task { + generationTask = Task { do { + defer { + self.isGenerating = false + self.generationTask = nil + } + let completion = try await viewModel.chat.send(model: viewModel.selectedModel, messages: messages, options: options) if let content = completion.choices.first?.message.content { @@ -85,7 +98,7 @@ struct ChatView: View { self.totalTokens = usage.totalTokens } } catch { - print(String(describing: error)) + print(error) } } } @@ -93,6 +106,8 @@ struct ChatView: View { private func onStream() { clear() + isGenerating = true + let messages = [ ChatMessage(role: .system, content: viewModel.systemPrompt), ChatMessage(role: .user, content: prompt) @@ -100,8 +115,13 @@ struct ChatView: View { let options = ChatOptions(temperature: viewModel.temperature) - Task { + generationTask = Task { do { + defer { + self.isGenerating = false + self.generationTask = nil + } + for try await chunk in viewModel.chat.stream(model: viewModel.selectedModel, messages: messages, options: options) { if let content = chunk.choices.first?.delta.content { self.response += content @@ -114,7 +134,7 @@ struct ChatView: View { } } } catch { - print(String(describing: error)) + print(error) } } } diff --git a/Playground/Playground/Views/Subviews/CancelButton.swift b/Playground/Playground/Views/Subviews/CancelButton.swift new file mode 100644 index 0000000..ec05f7e --- /dev/null +++ b/Playground/Playground/Views/Subviews/CancelButton.swift @@ -0,0 +1,27 @@ +// +// CancelButton.swift +// Playground +// +// Created by Kevin Hermawan on 10/31/24. +// + +import SwiftUI + +struct CancelButton: View { + private let onCancel: () -> Void + + init(onCancel: @escaping () -> Void) { + self.onCancel = onCancel + } + + var body: some View { + Button(action: onCancel) { + Text("Cancel") + .padding(.vertical, 8) + .frame(maxWidth: .infinity) + } + .buttonStyle(.bordered) + .padding([.horizontal, .bottom]) + .padding(.top, 8) + } +} diff --git a/README.md b/README.md index 8c4d1cb..9ae7b81 100644 --- a/README.md +++ b/README.md @@ -258,10 +258,19 @@ do { case .networkError(let error): // Handle network-related errors (e.g., no internet connection) print("Network Error: \(error.localizedDescription)") - case .badServerResponse: - // Handle invalid server responses - print("Invalid response received from server") + case .decodingError(let error): + // Handle errors that occur when the response cannot be decoded + print("Decoding Error: \(error.localizedDescription)") + case .streamError: + // Handle errors that occur when streaming responses + print("Stream Error") + case .cancelled: + // Handle requests that are cancelled + print("Request was cancelled") } +} catch { + // Handle any other errors + print("An unexpected error occurred: \(error)") } ``` diff --git a/Sources/LLMChatOpenAI/ChatOptions.swift b/Sources/LLMChatOpenAI/ChatOptions.swift index b5390d0..ea8a580 100644 --- a/Sources/LLMChatOpenAI/ChatOptions.swift +++ b/Sources/LLMChatOpenAI/ChatOptions.swift @@ -189,7 +189,7 @@ public struct ChatOptions: Encodable, Sendable { } } - public struct Tool: Encodable { + public struct Tool: Encodable, Sendable { /// The type of the tool. Currently, only function is supported. public let type: String @@ -201,7 +201,7 @@ public struct ChatOptions: Encodable, Sendable { self.function = function } - public struct Function: Encodable { + public struct Function: Encodable, Sendable { /// The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. public let name: String @@ -240,7 +240,7 @@ public struct ChatOptions: Encodable, Sendable { } } - public enum ToolChoice: Encodable { + public enum ToolChoice: Encodable, Sendable { case none case auto case function(name: String) diff --git a/Sources/LLMChatOpenAI/Documentation.docc/Documentation.md b/Sources/LLMChatOpenAI/Documentation.docc/Documentation.md index 2368189..336ef85 100644 --- a/Sources/LLMChatOpenAI/Documentation.docc/Documentation.md +++ b/Sources/LLMChatOpenAI/Documentation.docc/Documentation.md @@ -229,10 +229,19 @@ do { case .networkError(let error): // Handle network-related errors (e.g., no internet connection) print("Network Error: \(error.localizedDescription)") - case .badServerResponse: - // Handle invalid server responses - print("Invalid response received from server") + case .decodingError(let error): + // Handle errors that occur when the response cannot be decoded + print("Decoding Error: \(error.localizedDescription)") + case .streamError: + // Handle errors that occur when streaming responses + print("Stream Error") + case .cancelled: + // Handle requests that are cancelled + print("Request was cancelled") } +} catch { + // Handle any other errors + print("An unexpected error occurred: \(error)") } ``` diff --git a/Sources/LLMChatOpenAI/LLMChatOpenAI.swift b/Sources/LLMChatOpenAI/LLMChatOpenAI.swift index 8cda3dd..fdb973b 100644 --- a/Sources/LLMChatOpenAI/LLMChatOpenAI.swift +++ b/Sources/LLMChatOpenAI/LLMChatOpenAI.swift @@ -139,15 +139,26 @@ private extension LLMChatOpenAI { let request = try createRequest(for: endpoint, with: body) let (data, response) = try await URLSession.shared.data(for: request) + guard let httpResponse = response as? HTTPURLResponse else { + throw LLMChatOpenAIError.serverError(response.description) + } + + // Check for API errors first, as they might come with 200 status if let errorResponse = try? JSONDecoder().decode(ChatCompletionError.self, from: data) { throw LLMChatOpenAIError.serverError(errorResponse.error.message) } - guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else { - throw LLMChatOpenAIError.badServerResponse + guard 200...299 ~= httpResponse.statusCode else { + throw LLMChatOpenAIError.serverError(response.description) } return try JSONDecoder().decode(ChatCompletion.self, from: data) + } catch is CancellationError { + throw LLMChatOpenAIError.cancelled + } catch let error as URLError where error.code == .cancelled { + throw LLMChatOpenAIError.cancelled + } catch let error as DecodingError { + throw LLMChatOpenAIError.decodingError(error) } catch let error as LLMChatOpenAIError { throw error } catch { @@ -157,46 +168,58 @@ private extension LLMChatOpenAI { func performStreamRequest(with body: RequestBody) -> AsyncThrowingStream { AsyncThrowingStream { continuation in - Task { - do { - let request = try createRequest(for: endpoint, with: body) - let (bytes, response) = try await URLSession.shared.bytes(for: request) - - guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else { - for try await line in bytes.lines { - if let data = line.data(using: .utf8), let errorResponse = try? JSONDecoder().decode(ChatCompletionError.self, from: data) { - throw LLMChatOpenAIError.serverError(errorResponse.error.message) - } - - break + let task = Task { + await withTaskCancellationHandler { + do { + let request = try createRequest(for: endpoint, with: body) + let (bytes, response) = try await URLSession.shared.bytes(for: request) + + guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else { + throw LLMChatOpenAIError.serverError(response.description) } - throw LLMChatOpenAIError.badServerResponse - } - - for try await line in bytes.lines { - if line.hasPrefix("data: ") { - let jsonString = line.dropFirst(6) + for try await line in bytes.lines { + try Task.checkCancellation() - if jsonString.trimmingCharacters(in: .whitespacesAndNewlines) == "[DONE]" { + guard line.hasPrefix("data: ") else { continue } + + let jsonString = line.dropFirst(6).trimmingCharacters(in: .whitespacesAndNewlines) + + if jsonString == "[DONE]" { break } - if let data = jsonString.data(using: .utf8) { - let chunk = try JSONDecoder().decode(ChatCompletionChunk.self, from: data) - - continuation.yield(chunk) + guard let data = jsonString.data(using: .utf8) else { continue } + + guard (try? JSONDecoder().decode(ChatCompletionError.self, from: data)) == nil else { + throw LLMChatOpenAIError.streamError } + + let chunk = try JSONDecoder().decode(ChatCompletionChunk.self, from: data) + + continuation.yield(chunk) } + + continuation.finish() + } catch is CancellationError { + continuation.finish(throwing: LLMChatOpenAIError.cancelled) + } catch let error as URLError where error.code == .cancelled { + continuation.finish(throwing: LLMChatOpenAIError.cancelled) + } catch let error as DecodingError { + continuation.finish(throwing: LLMChatOpenAIError.decodingError(error)) + } catch let error as LLMChatOpenAIError { + continuation.finish(throwing: error) + } catch { + continuation.finish(throwing: LLMChatOpenAIError.networkError(error)) } - - continuation.finish() - } catch let error as LLMChatOpenAIError { - continuation.finish(throwing: error) - } catch { - continuation.finish(throwing: LLMChatOpenAIError.networkError(error)) + } onCancel: { + continuation.finish(throwing: LLMChatOpenAIError.cancelled) } } + + continuation.onTermination = { @Sendable _ in + task.cancel() + } } } } diff --git a/Sources/LLMChatOpenAI/LLMChatOpenAIError.swift b/Sources/LLMChatOpenAI/LLMChatOpenAIError.swift index ae6740f..f4abef0 100644 --- a/Sources/LLMChatOpenAI/LLMChatOpenAIError.swift +++ b/Sources/LLMChatOpenAI/LLMChatOpenAIError.swift @@ -8,29 +8,25 @@ import Foundation /// An enum that represents errors from the chat completion request. -public enum LLMChatOpenAIError: LocalizedError, Sendable { - /// A case that represents a server-side error response. +public enum LLMChatOpenAIError: Error, Sendable { + /// An error that occurs during JSON decoding. /// - /// - Parameter message: The error message from the server. - case serverError(String) + /// - Parameter error: The underlying decoding error. + case decodingError(Error) - /// A case that represents a network-related error. + /// An error that occurs during network operations. /// /// - Parameter error: The underlying network error. case networkError(Error) - /// A case that represents an invalid server response. - case badServerResponse + /// An error returned by the server. + /// + /// - Parameter message: The error message received from the server. + case serverError(String) + + /// An error that occurs during stream processing. + case streamError - /// A localized message that describes the error. - public var errorDescription: String? { - switch self { - case .serverError(let error): - return error - case .networkError(let error): - return error.localizedDescription - case .badServerResponse: - return "Invalid response received from server" - } - } + /// An error that occurs when the request is cancelled. + case cancelled } diff --git a/Tests/LLMChatOpenAITests/ChatCompletionTests.swift b/Tests/LLMChatOpenAITests/ChatCompletionTests.swift index c5fa47f..38ed15d 100644 --- a/Tests/LLMChatOpenAITests/ChatCompletionTests.swift +++ b/Tests/LLMChatOpenAITests/ChatCompletionTests.swift @@ -23,6 +23,7 @@ final class ChatCompletionTests: XCTestCase { ] URLProtocol.registerClass(URLProtocolMock.self) + URLProtocolMock.reset() } override func tearDown() { @@ -175,7 +176,7 @@ extension ChatCompletionTests { URLProtocolMock.mockData = mockErrorResponse.data(using: .utf8) do { - _ = try await chat.send(model: "gpt-4", messages: messages) + _ = try await chat.send(model: "gpt-4o", messages: messages) XCTFail("Expected serverError to be thrown") } catch let error as LLMChatOpenAIError { @@ -196,7 +197,7 @@ extension ChatCompletionTests { ) do { - _ = try await chat.send(model: "gpt-4", messages: messages) + _ = try await chat.send(model: "gpt-4o", messages: messages) XCTFail("Expected networkError to be thrown") } catch let error as LLMChatOpenAIError { @@ -209,49 +210,163 @@ extension ChatCompletionTests { } } - func testStreamServerError() async throws { - let mockErrorResponse = """ - { - "error": { - "message": "Rate limit exceeded" + func testHTTPError() async throws { + URLProtocolMock.mockStatusCode = 429 + URLProtocolMock.mockData = "Rate limit exceeded".data(using: .utf8) + + do { + _ = try await chat.send(model: "gpt-4o", messages: messages) + + XCTFail("Expected serverError to be thrown") + } catch let error as LLMChatOpenAIError { + switch error { + case .serverError(let message): + XCTAssertTrue(message.contains("429")) + default: + XCTFail("Expected serverError but got \(error)") } } - """ + } + + func testDecodingError() async throws { + let invalidJSON = "{ invalid json }" + URLProtocolMock.mockData = invalidJSON.data(using: .utf8) + + do { + _ = try await chat.send(model: "gpt-4o", messages: messages) + + XCTFail("Expected decodingError to be thrown") + } catch let error as LLMChatOpenAIError { + switch error { + case .decodingError: + break + default: + XCTFail("Expected decodingError but got \(error)") + } + } + } + + func testCancellation() async throws { + let task = Task { + _ = try await chat.send(model: "gpt-4o", messages: messages) + } - URLProtocolMock.mockStreamData = [mockErrorResponse] + task.cancel() do { - for try await _ in chat.stream(model: "gpt-4", messages: messages) { - XCTFail("Expected serverError to be thrown") + _ = try await task.value + + XCTFail("Expected cancelled error to be thrown") + } catch let error as LLMChatOpenAIError { + switch error { + case .cancelled: + break + default: + XCTFail("Expected cancelled but got \(error)") + } + } + } +} + +// MARK: - Error Handling (Stream) +extension ChatCompletionTests { + func testStreamServerError() async throws { + URLProtocolMock.mockStreamData = ["data: {\"error\": {\"message\": \"Server error occurred\", \"type\": \"server_error\", \"code\": \"internal_error\"}}\n\n"] + + do { + for try await _ in chat.stream(model: "gpt-4o", messages: messages) { + XCTFail("Expected streamError to be thrown") } } catch let error as LLMChatOpenAIError { switch error { - case .serverError(let message): - XCTAssertEqual(message, "Rate limit exceeded") + case .streamError: + break default: - XCTFail("Expected serverError but got \(error)") + XCTFail("Expected streamError but got \(error)") } } } func testStreamNetworkError() async throws { - URLProtocolMock.mockError = NSError( + let networkError = NSError( domain: NSURLErrorDomain, - code: NSURLErrorNotConnectedToInternet, - userInfo: [NSLocalizedDescriptionKey: "The Internet connection appears to be offline."] + code: NSURLErrorNetworkConnectionLost, + userInfo: [NSLocalizedDescriptionKey: "The network connection was lost."] ) + URLProtocolMock.mockError = networkError + do { - for try await _ in chat.stream(model: "gpt-4", messages: messages) { + for try await _ in chat.stream(model: "gpt-4o", messages: messages) { XCTFail("Expected networkError to be thrown") } } catch let error as LLMChatOpenAIError { switch error { case .networkError(let underlyingError): - XCTAssertEqual((underlyingError as NSError).code, NSURLErrorNotConnectedToInternet) + XCTAssertEqual((underlyingError as NSError).code, NSURLErrorNetworkConnectionLost) default: XCTFail("Expected networkError but got \(error)") } } } + + func testStreamHTTPError() async throws { + URLProtocolMock.mockStatusCode = 503 + URLProtocolMock.mockStreamData = [""] + + do { + for try await _ in chat.stream(model: "gpt-4o", messages: messages) { + XCTFail("Expected serverError to be thrown") + } + } catch let error as LLMChatOpenAIError { + switch error { + case .serverError(let message): + XCTAssertTrue(message.contains("503")) + default: + XCTFail("Expected serverError but got \(error)") + } + } + } + + func testStreamDecodingError() async throws { + URLProtocolMock.mockStreamData = ["data: { invalid json }\n\n"] + + do { + for try await _ in chat.stream(model: "gpt-4o", messages: messages) { + XCTFail("Expected decodingError to be thrown") + } + } catch let error as LLMChatOpenAIError { + switch error { + case .decodingError: + break + default: + XCTFail("Expected decodingError but got \(error)") + } + } + } + + func testStreamCancellation() async throws { + URLProtocolMock.mockStreamData = Array(repeating: "data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1694268190,\"model\":\"gpt-3.5-turbo-0613\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"test\"},\"finish_reason\":null}]}\n\n", count: 1000) + + let expectation = XCTestExpectation(description: "Stream cancelled") + + let task = Task { + do { + for try await _ in chat.stream(model: "gpt-4o", messages: messages) { + try await Task.sleep(nanoseconds: 100_000_000) // 1 second + } + + XCTFail("Expected stream to be cancelled") + } catch is CancellationError { + expectation.fulfill() + } catch { + XCTFail("Expected CancellationError but got \(error)") + } + } + + try await Task.sleep(nanoseconds: 1_000_000_000) // 1 second + task.cancel() + + await fulfillment(of: [expectation], timeout: 5.0) + } } diff --git a/Tests/LLMChatOpenAITests/Utils/URLProtocolMock.swift b/Tests/LLMChatOpenAITests/Utils/URLProtocolMock.swift index 153bc95..990b75c 100644 --- a/Tests/LLMChatOpenAITests/Utils/URLProtocolMock.swift +++ b/Tests/LLMChatOpenAITests/Utils/URLProtocolMock.swift @@ -11,6 +11,7 @@ final class URLProtocolMock: URLProtocol { static var mockData: Data? static var mockStreamData: [String]? static var mockError: Error? + static var mockStatusCode: Int? override class func canInit(with request: URLRequest) -> Bool { return true @@ -28,19 +29,18 @@ final class URLProtocolMock: URLProtocol { return } - if let streamData = URLProtocolMock.mockStreamData { - let response = HTTPURLResponse(url: request.url!, statusCode: 200, httpVersion: nil, headerFields: ["Content-Type": "text/event-stream"])! + if let streamData = URLProtocolMock.mockStreamData, let url = request.url { + let response = HTTPURLResponse(url: url, statusCode: URLProtocolMock.mockStatusCode ?? 200, httpVersion: nil, headerFields: ["Content-Type": "text/event-stream"])! client.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) for line in streamData { client.urlProtocol(self, didLoad: Data(line.utf8)) } - } else if let data = URLProtocolMock.mockData { - let response = HTTPURLResponse(url: request.url!, statusCode: 200, httpVersion: nil, headerFields: nil)! + } else if let data = URLProtocolMock.mockData, let url = request.url { + let response = HTTPURLResponse(url: url, statusCode: URLProtocolMock.mockStatusCode ?? 200, httpVersion: nil, headerFields: nil)! client.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) client.urlProtocol(self, didLoad: data) } else { - client.urlProtocol(self, didFailWithError: NSError(domain: "MockURLProtocol", code: -1, userInfo: [NSLocalizedDescriptionKey: "No mock data available"])) return } @@ -48,4 +48,11 @@ final class URLProtocolMock: URLProtocol { } override func stopLoading() {} + + static func reset() { + mockData = nil + mockStreamData = nil + mockError = nil + mockStatusCode = nil + } }