Skip to content

Commit

Permalink
refactor: improves error handling (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinhermawan authored Oct 30, 2024
1 parent bda4433 commit ce1fd65
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 88 deletions.
32 changes: 26 additions & 6 deletions Playground/Playground/Views/ChatView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ struct ChatView: View {
@State private var outputTokens: Int = 0
@State private var totalTokens: Int = 0

@State private var isGenerating: Bool = false
@State private var generationTask: Task<Void, Never>?

var body: some View {
@Bindable var viewModelBindable = viewModel

Expand All @@ -37,8 +40,11 @@ struct ChatView: View {
}

VStack {
SendButton(stream: viewModel.stream, onSend: onSend, onStream: onStream)
.disabled(viewModel.models.isEmpty)
if isGenerating {
CancelButton(onCancel: { generationTask?.cancel() })
} else {
SendButton(stream: viewModel.stream, onSend: onSend, onStream: onStream)
}
}
}
.toolbar {
Expand All @@ -64,15 +70,22 @@ struct ChatView: View {
private func onSend() {
clear()

isGenerating = true

let messages = [
ChatMessage(role: .system, content: viewModel.systemPrompt),
ChatMessage(role: .user, content: prompt)
]

let options = ChatOptions(temperature: viewModel.temperature)

Task {
generationTask = Task {
do {
defer {
self.isGenerating = false
self.generationTask = nil
}

let completion = try await viewModel.chat.send(model: viewModel.selectedModel, messages: messages, options: options)

if let content = completion.choices.first?.message.content {
Expand All @@ -85,23 +98,30 @@ struct ChatView: View {
self.totalTokens = usage.totalTokens
}
} catch {
print(String(describing: error))
print(error)
}
}
}

private func onStream() {
clear()

isGenerating = true

let messages = [
ChatMessage(role: .system, content: viewModel.systemPrompt),
ChatMessage(role: .user, content: prompt)
]

let options = ChatOptions(temperature: viewModel.temperature)

Task {
generationTask = Task {
do {
defer {
self.isGenerating = false
self.generationTask = nil
}

for try await chunk in viewModel.chat.stream(model: viewModel.selectedModel, messages: messages, options: options) {
if let content = chunk.choices.first?.delta.content {
self.response += content
Expand All @@ -114,7 +134,7 @@ struct ChatView: View {
}
}
} catch {
print(String(describing: error))
print(error)
}
}
}
Expand Down
27 changes: 27 additions & 0 deletions Playground/Playground/Views/Subviews/CancelButton.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//
// CancelButton.swift
// Playground
//
// Created by Kevin Hermawan on 10/31/24.
//

import SwiftUI

struct CancelButton: View {
private let onCancel: () -> Void

init(onCancel: @escaping () -> Void) {
self.onCancel = onCancel
}

var body: some View {
Button(action: onCancel) {
Text("Cancel")
.padding(.vertical, 8)
.frame(maxWidth: .infinity)
}
.buttonStyle(.bordered)
.padding([.horizontal, .bottom])
.padding(.top, 8)
}
}
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,19 @@ do {
case .networkError(let error):
// Handle network-related errors (e.g., no internet connection)
print("Network Error: \(error.localizedDescription)")
case .badServerResponse:
// Handle invalid server responses
print("Invalid response received from server")
case .decodingError(let error):
// Handle errors that occur when the response cannot be decoded
print("Decoding Error: \(error.localizedDescription)")
case .streamError:
// Handle errors that occur when streaming responses
print("Stream Error")
case .cancelled:
// Handle requests that are cancelled
print("Request was cancelled")
}
} catch {
// Handle any other errors
print("An unexpected error occurred: \(error)")
}
```

Expand Down
6 changes: 3 additions & 3 deletions Sources/LLMChatOpenAI/ChatOptions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ public struct ChatOptions: Encodable, Sendable {
}
}

public struct Tool: Encodable {
public struct Tool: Encodable, Sendable {
/// The type of the tool. Currently, only function is supported.
public let type: String

Expand All @@ -201,7 +201,7 @@ public struct ChatOptions: Encodable, Sendable {
self.function = function
}

public struct Function: Encodable {
public struct Function: Encodable, Sendable {
/// The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64.
public let name: String

Expand Down Expand Up @@ -240,7 +240,7 @@ public struct ChatOptions: Encodable, Sendable {
}
}

public enum ToolChoice: Encodable {
public enum ToolChoice: Encodable, Sendable {
case none
case auto
case function(name: String)
Expand Down
15 changes: 12 additions & 3 deletions Sources/LLMChatOpenAI/Documentation.docc/Documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,19 @@ do {
case .networkError(let error):
// Handle network-related errors (e.g., no internet connection)
print("Network Error: \(error.localizedDescription)")
case .badServerResponse:
// Handle invalid server responses
print("Invalid response received from server")
case .decodingError(let error):
// Handle errors that occur when the response cannot be decoded
print("Decoding Error: \(error.localizedDescription)")
case .streamError:
// Handle errors that occur when streaming responses
print("Stream Error")
case .cancelled:
// Handle requests that are cancelled
print("Request was cancelled")
}
} catch {
// Handle any other errors
print("An unexpected error occurred: \(error)")
}
```

Expand Down
85 changes: 54 additions & 31 deletions Sources/LLMChatOpenAI/LLMChatOpenAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,26 @@ private extension LLMChatOpenAI {
let request = try createRequest(for: endpoint, with: body)
let (data, response) = try await URLSession.shared.data(for: request)

guard let httpResponse = response as? HTTPURLResponse else {
throw LLMChatOpenAIError.serverError(response.description)
}

// Check for API errors first, as they might come with 200 status
if let errorResponse = try? JSONDecoder().decode(ChatCompletionError.self, from: data) {
throw LLMChatOpenAIError.serverError(errorResponse.error.message)
}

guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else {
throw LLMChatOpenAIError.badServerResponse
guard 200...299 ~= httpResponse.statusCode else {
throw LLMChatOpenAIError.serverError(response.description)
}

return try JSONDecoder().decode(ChatCompletion.self, from: data)
} catch is CancellationError {
throw LLMChatOpenAIError.cancelled
} catch let error as URLError where error.code == .cancelled {
throw LLMChatOpenAIError.cancelled
} catch let error as DecodingError {
throw LLMChatOpenAIError.decodingError(error)
} catch let error as LLMChatOpenAIError {
throw error
} catch {
Expand All @@ -157,46 +168,58 @@ private extension LLMChatOpenAI {

func performStreamRequest(with body: RequestBody) -> AsyncThrowingStream<ChatCompletionChunk, Error> {
AsyncThrowingStream { continuation in
Task {
do {
let request = try createRequest(for: endpoint, with: body)
let (bytes, response) = try await URLSession.shared.bytes(for: request)

guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else {
for try await line in bytes.lines {
if let data = line.data(using: .utf8), let errorResponse = try? JSONDecoder().decode(ChatCompletionError.self, from: data) {
throw LLMChatOpenAIError.serverError(errorResponse.error.message)
}

break
let task = Task {
await withTaskCancellationHandler {
do {
let request = try createRequest(for: endpoint, with: body)
let (bytes, response) = try await URLSession.shared.bytes(for: request)

guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else {
throw LLMChatOpenAIError.serverError(response.description)
}

throw LLMChatOpenAIError.badServerResponse
}

for try await line in bytes.lines {
if line.hasPrefix("data: ") {
let jsonString = line.dropFirst(6)
for try await line in bytes.lines {
try Task.checkCancellation()

if jsonString.trimmingCharacters(in: .whitespacesAndNewlines) == "[DONE]" {
guard line.hasPrefix("data: ") else { continue }

let jsonString = line.dropFirst(6).trimmingCharacters(in: .whitespacesAndNewlines)

if jsonString == "[DONE]" {
break
}

if let data = jsonString.data(using: .utf8) {
let chunk = try JSONDecoder().decode(ChatCompletionChunk.self, from: data)
continuation.yield(chunk)
guard let data = jsonString.data(using: .utf8) else { continue }

guard (try? JSONDecoder().decode(ChatCompletionError.self, from: data)) == nil else {
throw LLMChatOpenAIError.streamError
}

let chunk = try JSONDecoder().decode(ChatCompletionChunk.self, from: data)

continuation.yield(chunk)
}

continuation.finish()
} catch is CancellationError {
continuation.finish(throwing: LLMChatOpenAIError.cancelled)
} catch let error as URLError where error.code == .cancelled {
continuation.finish(throwing: LLMChatOpenAIError.cancelled)
} catch let error as DecodingError {
continuation.finish(throwing: LLMChatOpenAIError.decodingError(error))
} catch let error as LLMChatOpenAIError {
continuation.finish(throwing: error)
} catch {
continuation.finish(throwing: LLMChatOpenAIError.networkError(error))
}

continuation.finish()
} catch let error as LLMChatOpenAIError {
continuation.finish(throwing: error)
} catch {
continuation.finish(throwing: LLMChatOpenAIError.networkError(error))
} onCancel: {
continuation.finish(throwing: LLMChatOpenAIError.cancelled)
}
}

continuation.onTermination = { @Sendable _ in
task.cancel()
}
}
}
}
Expand Down
32 changes: 14 additions & 18 deletions Sources/LLMChatOpenAI/LLMChatOpenAIError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,25 @@
import Foundation

/// An enum that represents errors from the chat completion request.
public enum LLMChatOpenAIError: LocalizedError, Sendable {
/// A case that represents a server-side error response.
public enum LLMChatOpenAIError: Error, Sendable {
/// An error that occurs during JSON decoding.
///
/// - Parameter message: The error message from the server.
case serverError(String)
/// - Parameter error: The underlying decoding error.
case decodingError(Error)

/// A case that represents a network-related error.
/// An error that occurs during network operations.
///
/// - Parameter error: The underlying network error.
case networkError(Error)

/// A case that represents an invalid server response.
case badServerResponse
/// An error returned by the server.
///
/// - Parameter message: The error message received from the server.
case serverError(String)

/// An error that occurs during stream processing.
case streamError

/// A localized message that describes the error.
public var errorDescription: String? {
switch self {
case .serverError(let error):
return error
case .networkError(let error):
return error.localizedDescription
case .badServerResponse:
return "Invalid response received from server"
}
}
/// An error that occurs when the request is cancelled.
case cancelled
}
Loading

0 comments on commit ce1fd65

Please sign in to comment.