Skip to content

Commit

Permalink
Cleanup API (#8)
Browse files Browse the repository at this point in the history
Co-authored-by: Robby <h0rv@users.noreply.github.com>
  • Loading branch information
h0rv and h0rv authored May 20, 2024
1 parent a6bdacb commit cc529cf
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 111 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@

# Dependency directories (remove the comment below to include it)
# vendor/
.env
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Import:

```go
import (
wx "github.com/h0rv/go-watsonx/models"
wx "github.com/h0rv/go-watsonx/pkg/models"
)
```

Expand All @@ -24,16 +24,15 @@ import (
model, _ := wx.NewModel(
wx.WithIBMCloudAPIKey("YOUR IBM CLOUD API KEY"),
wx.WithWatsonxProjectID("YOUR WATSONX PROJECT ID"),
wx.WithModel(wx.LLAMA_2_70B_CHAT),
)

result, _ := model.GenerateText(
"Hi, who are you?",
"meta-llama/llama-3-70b-instruct",
"Hi, who are you?",
wx.WithTemperature(0.9),
wx.WithTopP(.5),
wx.WithTopK(10),
wx.WithMaxNewTokens(512),
wx.WithDecodingMethod(wx.Greedy),
)

println(result.Text)
Expand Down Expand Up @@ -67,4 +66,4 @@ git config --local core.hooksPath .githooks/
## Resources

- [watsonx Python SDK Docs](https://ibm.github.io/watson-machine-learning-sdk)
- [watsonx REST API Docs (Internal)](https://test.cloud.ibm.com/apidocs/watsonx-ai)
- [watsonx REST API Docs](https://cloud.ibm.com/apidocs/watsonx-ai)
60 changes: 0 additions & 60 deletions models/types.go

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,22 @@ import (
"os"
"testing"

wx "github.com/h0rv/go-watsonx/models"
wx "github.com/h0rv/go-watsonx/pkg/models"
)

func getModel(t *testing.T) *wx.Model {
apiKey := os.Getenv(wx.IBMCloudAPIKeyEnvVarName)
apiKey := os.Getenv(wx.WatsonxAPIKeyEnvVarName)
projectID := os.Getenv(wx.WatsonxProjectIDEnvVarName)
if apiKey == "" {
t.Fatal("No IBM Cloud API key provided")
t.Fatal("No watsonx API key provided")
}
if projectID == "" {
t.Fatal("No watsonx project ID provided")
}

model, err := wx.NewModel(
wx.WithIBMCloudAPIKey(apiKey),
wx.WithWatsonxAPIKey(apiKey),
wx.WithWatsonxProjectID(projectID),
wx.WithModel(wx.FLAN_UL2),
)
if err != nil {
t.Fatalf("Failed to create model for testing. Error: %v", err)
Expand All @@ -32,7 +31,10 @@ func getModel(t *testing.T) *wx.Model {
func TestEmptyPromptError(t *testing.T) {
model := getModel(t)

_, err := model.GenerateText("")
_, err := model.GenerateText(
"dumby model",
"",
)
if err == nil {
t.Fatalf("Expected error for an empty prompt, but got nil")
}
Expand All @@ -41,7 +43,11 @@ func TestEmptyPromptError(t *testing.T) {
func TestNilOptions(t *testing.T) {
model := getModel(t)

_, err := model.GenerateText("What day is it?", nil)
_, err := model.GenerateText(
"meta-llama/llama-3-70b-instruct",
"What day is it?",
nil,
)
if err != nil {
t.Fatalf("Expected no error for nil options, but got %v", err)
}
Expand All @@ -50,8 +56,10 @@ func TestNilOptions(t *testing.T) {
func TestValidPrompt(t *testing.T) {
model := getModel(t)

prompt := "Test prompt"
_, err := model.GenerateText(prompt)
_, err := model.GenerateText(
"meta-llama/llama-3-70b-instruct",
"Test prompt",
)
if err != nil {
t.Fatalf("Expected no error, but got an error: %v", err)
}
Expand All @@ -60,14 +68,13 @@ func TestValidPrompt(t *testing.T) {
func TestGenerateText(t *testing.T) {
model := getModel(t)

prompt := "Hi, who are you?"
result, err := model.GenerateText(
prompt,
"meta-llama/llama-3-70b-instruct",
"Hi, who are you?",
wx.WithTemperature(0.9),
wx.WithTopP(.5),
wx.WithTopK(10),
wx.WithMaxNewTokens(512),
wx.WithDecodingMethod(wx.Greedy),
)
if err != nil {
t.Fatalf("Expected no error, but got an error: %v", err)
Expand All @@ -80,9 +87,9 @@ func TestGenerateText(t *testing.T) {
func TestGenerateTextWithNilOptions(t *testing.T) {
model := getModel(t)

prompt := "Who are you?"
result, err := model.GenerateText(
prompt,
"meta-llama/llama-3-70b-instruct",
"Who are you?",
nil,
)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions models/generate.go → pkg/models/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type generateTextResponse struct {
}

// GenerateText generates completion text based on a given prompt and parameters
func (m *Model) GenerateText(prompt string, options ...GenerateOption) (GenerateTextResult, error) {
func (m *Model) GenerateText(model, prompt string, options ...GenerateOption) (GenerateTextResult, error) {
m.CheckAndRefreshToken()

if prompt == "" {
Expand All @@ -65,7 +65,7 @@ func (m *Model) GenerateText(prompt string, options ...GenerateOption) (Generate

payload := GenerateTextPayload{
ProjectID: m.projectID,
Model: m.modelType,
Model: model,
Prompt: prompt,
Parameters: opts,
}
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions models/iam.go → pkg/models/iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ type TokenResponse struct {
Expiration int64 `json:"expiration"`
}

func GenerateToken(client Doer, ibmCloudAPIKey IBMCloudAPIKey) (IAMToken, error) {
func GenerateToken(client Doer, watsonxApiKey WatsonxAPIKey) (IAMToken, error) {
values := url.Values{
"grant_type": {"urn:ibm:params:oauth:grant-type:apikey"},
"apikey": {ibmCloudAPIKey},
"apikey": {watsonxApiKey},
}

payload := strings.NewReader(values.Encode())
Expand Down
24 changes: 8 additions & 16 deletions models/model.go → pkg/models/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@ type Model struct {
region IBMCloudRegion
apiVersion string

ibmCloudAPIKey IBMCloudAPIKey
projectID WatsonxProjectID

modelType ModelTypes

token IAMToken
token IAMToken
apiKey WatsonxAPIKey
projectID WatsonxProjectID

httpClient Doer
}
Expand All @@ -44,12 +41,9 @@ func NewModel(options ...ModelOption) (*Model, error) {
region: opts.Region,
apiVersion: opts.APIVersion,

ibmCloudAPIKey: opts.ibmCloudAPIKey,
projectID: opts.projectID,

modelType: opts.Model,

// token: set below
apiKey: opts.watsonxAPIKey,
projectID: opts.projectID,

httpClient: &http.Client{},
}
Expand All @@ -72,7 +66,7 @@ func (m *Model) CheckAndRefreshToken() error {

// RefreshToken generates and sets the model with a new token
func (m *Model) RefreshToken() error {
token, err := GenerateToken(m.httpClient, m.ibmCloudAPIKey)
token, err := GenerateToken(m.httpClient, m.apiKey)
if err != nil {
return err
}
Expand All @@ -90,9 +84,7 @@ func defaulModelOptions() *ModelOptions {
Region: DefaultRegion,
APIVersion: DefaultAPIVersion,

ibmCloudAPIKey: os.Getenv(IBMCloudAPIKeyEnvVarName),
projectID: os.Getenv(WatsonxProjectIDEnvVarName),

Model: DefaultModelType,
watsonxAPIKey: os.Getenv(WatsonxAPIKeyEnvVarName),
projectID: os.Getenv(WatsonxProjectIDEnvVarName),
}
}
16 changes: 4 additions & 12 deletions models/model_option.go → pkg/models/model_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ type ModelOptions struct {
Region IBMCloudRegion
APIVersion string

ibmCloudAPIKey IBMCloudAPIKey
projectID WatsonxProjectID

Model ModelTypes
watsonxAPIKey WatsonxAPIKey
projectID WatsonxProjectID
}

func WithURL(url string) ModelOption {
Expand All @@ -31,9 +29,9 @@ func WithAPIVersion(apiVersion string) ModelOption {
}
}

func WithIBMCloudAPIKey(ibmCloudAPIKey IBMCloudAPIKey) ModelOption {
func WithWatsonxAPIKey(watsonxAPIKey WatsonxAPIKey) ModelOption {
return func(o *ModelOptions) {
o.ibmCloudAPIKey = ibmCloudAPIKey
o.watsonxAPIKey = watsonxAPIKey
}
}

Expand All @@ -42,9 +40,3 @@ func WithWatsonxProjectID(projectID WatsonxProjectID) ModelOption {
o.projectID = projectID
}
}

func WithModel(model ModelTypes) ModelOption {
return func(o *ModelOptions) {
o.Model = model
}
}
36 changes: 36 additions & 0 deletions pkg/models/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package models

import (
"net/http"
)

/*
* https://ibm.github.io/watson-machine-learning-sdk/model.html#ibm_watson_machine_learning.foundation_models.utils.enums.ModelTypes
*/

type (
WatsonxAPIKey = string
WatsonxProjectID = string
IBMCloudRegion = string
ModelType = string
)

const (
WatsonxAPIKeyEnvVarName = "WATSONX_API_KEY"
WatsonxProjectIDEnvVarName = "WATSONX_PROJECT_ID"

US_South IBMCloudRegion = "us-south"
Dallas IBMCloudRegion = US_South
EU_DE IBMCloudRegion = "eu-de"
Frankfurt IBMCloudRegion = EU_DE
JP_TOK IBMCloudRegion = "jp-tok"
Tokyo IBMCloudRegion = JP_TOK

DefaultRegion = US_South
BaseURLFormatStr = "%s.ml.cloud.ibm.com" // Need to call SPrintf on it with region
DefaultAPIVersion = "2024-05-20"
)

type Doer interface {
Do(req *http.Request) (*http.Response, error)
}

0 comments on commit cc529cf

Please sign in to comment.