Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve: adds better error handling #5

Merged
merged 2 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Playground/Playground/ViewModels/AppViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import Foundation

@MainActor
@Observable
final class AppViewModel {
var cohereAPIKey: String
Expand Down
71 changes: 47 additions & 24 deletions Playground/Playground/Views/ModelListView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ import AIModelRetriever
struct ModelListView: View {
private let title: String
private let provider: AIProvider
private let retriever = AIModelRetriever()

@Environment(AppViewModel.self) private var viewModel

@State private var isFetching: Bool = false
@State private var fetchTask: Task<Void, Never>?
@State private var models: [AIModel] = []

init(title: String, provider: AIProvider) {
Expand All @@ -21,36 +25,55 @@ struct ModelListView: View {
}

var body: some View {
List(models) { model in
VStack(alignment: .leading) {
Text(model.id)
.font(.footnote)
.foregroundStyle(.secondary)

Text(model.name)
VStack {
if models.isEmpty, isFetching {
VStack(spacing: 16) {
ProgressView()

Button("Cancel") {
fetchTask?.cancel()
}
}
} else {
List(models) { model in
VStack(alignment: .leading) {
Text(model.id)
.font(.footnote)
.foregroundStyle(.secondary)

Text(model.name)
}
}
}
}
.navigationTitle(title)
.task {
let retriever = AIModelRetriever()
isFetching = true

do {
switch provider {
case .anthropic:
models = retriever.anthropic()
case .cohere:
models = try await retriever.cohere(apiKey: viewModel.cohereAPIKey)
case .google:
models = retriever.google()
case .ollama:
models = try await retriever.ollama()
case .openai:
models = try await retriever.openAI(apiKey: viewModel.openaiAPIKey)
case .groq:
models = try await retriever.openAI(apiKey: viewModel.groqAPIKey, endpoint: URL(string: "https://api.groq.com/openai/v1/models"))
fetchTask = Task {
do {
defer {
self.isFetching = false
self.fetchTask = nil
}

switch provider {
case .anthropic:
models = retriever.anthropic()
case .cohere:
models = try await retriever.cohere(apiKey: viewModel.cohereAPIKey)
case .google:
models = retriever.google()
case .ollama:
models = try await retriever.ollama()
case .openai:
models = try await retriever.openAI(apiKey: viewModel.openaiAPIKey)
case .groq:
models = try await retriever.openAI(apiKey: viewModel.groqAPIKey, endpoint: URL(string: "https://api.groq.com/openai/v1/models"))
}
} catch {
print(String(describing: error))
}
} catch {
print(String(describing: error))
}
}
}
Expand Down
29 changes: 27 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,39 @@ do {
}
```

## Donations
### Error Handling

`AIModelRetrieverError` provides structured error handling through the `AIModelRetrieverError` enum. This enum contains three cases that represent different types of errors you might encounter:

```swift
do {
let models = try await modelRetriever.openai(apiKey: "your-api-key")
} catch let error as LLMChatOpenAIError {
switch error {
case .serverError(let message):
// Handle server-side errors (e.g., invalid API key, rate limits)
print("Server Error: \(message)")
case .networkError(let error):
// Handle network-related errors (e.g., no internet connection)
print("Network Error: \(error.localizedDescription)")
case .badServerResponse:
// Handle invalid server responses
print("Invalid response received from server")
case .cancelled:
// Handle cancelled requests
print("Request cancelled")
}
}
```

## Support

If you find `AIModelRetriever` helpful and would like to support its development, consider making a donation. Your contribution helps maintain the project and develop new features.

- [GitHub Sponsors](https://github.com/sponsors/kevinhermawan)
- [Buy Me a Coffee](https://buymeacoffee.com/kevinhermawan)

Your support is greatly appreciated!
Your support is greatly appreciated! ❤️

## Contributing

Expand Down
76 changes: 62 additions & 14 deletions Sources/AIModelRetriever/AIModelRetriever.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,33 @@ public struct AIModelRetriever: Sendable {
/// Initializes a new instance of ``AIModelRetriever``.
public init() {}

private func performRequest<T: Decodable>(_ request: URLRequest) async throws -> T {
let (data, response) = try await URLSession.shared.data(for: request)

guard let httpResponse = response as? HTTPURLResponse else {
throw AIModelRetrieverError.badServerResponse
}

guard 200...299 ~= httpResponse.statusCode else {
throw AIModelRetrieverError.serverError(statusCode: httpResponse.statusCode, error: String(data: data, encoding: .utf8))
private func performRequest<T: Decodable, E: ProviderError>(_ request: URLRequest, errorType: E.Type) async throws -> T {
do {
let (data, response) = try await URLSession.shared.data(for: request)

if let errorResponse = try? JSONDecoder().decode(E.self, from: data) {
throw AIModelRetrieverError.serverError(errorResponse.errorMessage)
}

guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else {
throw AIModelRetrieverError.badServerResponse
}

let models = try JSONDecoder().decode(T.self, from: data)

return models
} catch let error as AIModelRetrieverError {
throw error
} catch let error as URLError {
switch error.code {
case .cancelled:
throw AIModelRetrieverError.cancelled
default:
throw AIModelRetrieverError.networkError(error)
}
} catch {
throw AIModelRetrieverError.networkError(error)
}

return try JSONDecoder().decode(T.self, from: data)
}

private func createRequest(for endpoint: URL, with headers: [String: String]? = nil) -> URLRequest {
Expand Down Expand Up @@ -74,7 +89,7 @@ public extension AIModelRetriever {
let allHeaders = ["Authorization": "Bearer \(apiKey)"]

let request = createRequest(for: defaultEndpoint, with: allHeaders)
let response: CohereResponse = try await performRequest(request)
let response: CohereResponse = try await performRequest(request, errorType: CohereError.self)

return response.models.map { AIModel(id: $0.name, name: $0.name) }
}
Expand All @@ -86,6 +101,12 @@ public extension AIModelRetriever {
private struct CohereModel: Decodable {
let name: String
}

private struct CohereError: ProviderError {
let message: String

var errorMessage: String { message }
}
}

// MARK: - Google
Expand Down Expand Up @@ -122,7 +143,7 @@ public extension AIModelRetriever {
guard let defaultEndpoint = URL(string: "http://localhost:11434/api/tags") else { return [] }

let request = createRequest(for: endpoint ?? defaultEndpoint, with: headers)
let response: OllamaResponse = try await performRequest(request)
let response: OllamaResponse = try await performRequest(request, errorType: OllamaError.self)

return response.models.map { AIModel(id: $0.model, name: $0.name) }
}
Expand All @@ -135,6 +156,16 @@ public extension AIModelRetriever {
let name: String
let model: String
}

private struct OllamaError: ProviderError {
let error: Error

struct Error: Decodable {
let message: String
}

var errorMessage: String { error.message }
}
}

// MARK: - OpenAI
Expand All @@ -156,7 +187,7 @@ public extension AIModelRetriever {
allHeaders["Authorization"] = "Bearer \(apiKey)"

let request = createRequest(for: endpoint ?? defaultEndpoint, with: allHeaders)
let response: OpenAIResponse = try await performRequest(request)
let response: OpenAIResponse = try await performRequest(request, errorType: OpenAIError.self)

return response.data.map { AIModel(id: $0.id, name: $0.id) }
}
Expand All @@ -168,4 +199,21 @@ public extension AIModelRetriever {
private struct OpenAIModel: Decodable {
let id: String
}

private struct OpenAIError: ProviderError {
let error: Error

struct Error: Decodable {
let message: String
}

var errorMessage: String { error.message }
}
}

// MARK: - Supporting Types
private extension AIModelRetriever {
protocol ProviderError: Decodable {
var errorMessage: String { get }
}
}
34 changes: 27 additions & 7 deletions Sources/AIModelRetriever/AIModelRetrieverError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,33 @@ import Foundation

/// An enum that represents errors that can occur during AI model retrieval.
public enum AIModelRetrieverError: Error, Sendable {
/// Indicates that the server response was not in the expected format.
case badServerResponse
/// A case that represents a server-side error response.
///
/// - Parameter message: The error message from the server.
case serverError(String)

/// Indicates that the server returned an error.
/// A case that represents a network-related error.
///
/// - Parameters:
/// - statusCode: The HTTP status code returned by the server.
/// - error: An optional string that contains additional error information provided by the server.
case serverError(statusCode: Int, error: String?)
/// - Parameter error: The underlying network error.
case networkError(Error)

/// A case that represents an invalid server response.
case badServerResponse

/// A case that represents a request has been canceled.
case cancelled

/// A localized message that describes the error.
public var errorDescription: String? {
switch self {
case .serverError(let error):
return error
case .networkError(let error):
return error.localizedDescription
case .badServerResponse:
return "Invalid response received from server"
case .cancelled:
return "Request was cancelled"
}
}
}
25 changes: 14 additions & 11 deletions Sources/AIModelRetriever/Documentation.docc/Documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,25 @@ do {

### Error Handling

The package uses ``AIModelRetrieverError`` to represent specific errors that may occur. You can catch and handle these errors as follows:
``AIModelRetrieverError`` provides structured error handling through the ``AIModelRetrieverError`` enum. This enum contains three cases that represent different types of errors you might encounter:

```swift
let apiKey = "your-openai-api-key"

do {
let models = try await modelRetriever.openai(apiKey: apiKey)
// Process models
} catch let error as AIModelRetrieverError {
let models = try await modelRetriever.openai(apiKey: "your-api-key")
} catch let error as LLMChatOpenAIError {
switch error {
case .serverError(let message):
// Handle server-side errors (e.g., invalid API key, rate limits)
print("Server Error: \(message)")
case .networkError(let error):
// Handle network-related errors (e.g., no internet connection)
print("Network Error: \(error.localizedDescription)")
case .badServerResponse:
print("Received an invalid response from the server")
case .serverError(let statusCode, let errorMessage):
print("Server error (status \(statusCode)): \(errorMessage ?? "No error message provided")")
// Handle invalid server responses
print("Invalid response received from server")
case .cancelled:
// Handle cancelled requests
print("Request cancelled")
}
} catch {
print("An unexpected error occurred: \(error)")
}
```
Loading