diff --git a/.gitignore b/.gitignore index 66fd13c..238d5cf 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ # Dependency directories (remove the comment below to include it) # vendor/ +.env diff --git a/README.md b/README.md index cd23396..17bdd8b 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Import: ```go import ( - wx "github.com/h0rv/go-watsonx/models" + wx "github.com/h0rv/go-watsonx/pkg/models" ) ``` @@ -24,16 +24,15 @@ import ( model, _ := wx.NewModel( wx.WithIBMCloudAPIKey("YOUR IBM CLOUD API KEY"), wx.WithWatsonxProjectID("YOUR WATSONX PROJECT ID"), - wx.WithModel(wx.LLAMA_2_70B_CHAT), ) result, _ := model.GenerateText( - "Hi, who are you?", + "meta-llama/llama-3-70b-instruct", + "Hi, who are you?", wx.WithTemperature(0.9), wx.WithTopP(.5), wx.WithTopK(10), wx.WithMaxNewTokens(512), - wx.WithDecodingMethod(wx.Greedy), ) println(result.Text) @@ -67,4 +66,4 @@ git config --local core.hooksPath .githooks/ ## Resources - [watsonx Python SDK Docs](https://ibm.github.io/watson-machine-learning-sdk) -- [watsonx REST API Docs (Internal)](https://test.cloud.ibm.com/apidocs/watsonx-ai) +- [watsonx REST API Docs](https://cloud.ibm.com/apidocs/watsonx-ai) diff --git a/models/types.go b/models/types.go deleted file mode 100644 index 35784a5..0000000 --- a/models/types.go +++ /dev/null @@ -1,60 +0,0 @@ -package models - -import ( - "net/http" -) - -/* - * https://ibm.github.io/watson-machine-learning-sdk/model.html#ibm_watson_machine_learning.foundation_models.utils.enums.ModelTypes - */ - -type ( - IBMCloudAPIKey = string - WatsonxProjectID = string - IBMCloudRegion = string - - ModelTypes = string - DecodingMethods = string -) - -const ( - IBMCloudAPIKeyEnvVarName = "IBMCLOUD_API_KEY" - WatsonxProjectIDEnvVarName = "WATSONX_PROJECT_ID" - - US_South IBMCloudRegion = "us-south" - Dallas IBMCloudRegion = US_South - EU_DE IBMCloudRegion = "eu-de" - Frankfurt IBMCloudRegion = EU_DE - JP_TOK IBMCloudRegion = "jp-tok" - Tokyo IBMCloudRegion = JP_TOK - - DefaultRegion = US_South - BaseURLFormatStr = "%s.ml.cloud.ibm.com" // Need to call SPrintf on it with region - DefaultAPIVersion = "2023-05-02" - - // https://ibm.github.io/watson-machine-learning-sdk/_modules/ibm_watson_machine_learning/foundation_models/utils/enums.html#ModelTypes - FLAN_T5_XXL ModelTypes = "google/flan-t5-xxl" - FLAN_UL2 ModelTypes = "google/flan-ul2" - MT0_XXL ModelTypes = "bigscience/mt0-xxl" - GPT_NEOX ModelTypes = "eleutherai/gpt-neox-20b" - MPT_7B_INSTRUCT2 ModelTypes = "ibm/mpt-7b-instruct2" - STARCODER ModelTypes = "bigcode/starcoder" - LLAMA_2_70B_CHAT ModelTypes = "meta-llama/llama-2-70b-chat" - LLAMA_2_13B_CHAT ModelTypes = "meta-llama/llama-2-13b-chat" - GRANITE_13B_INSTRUCT ModelTypes = "ibm/granite-13b-instruct-v1" - GRANITE_13B_CHAT ModelTypes = "ibm/granite-13b-chat-v1" - FLAN_T5_XL ModelTypes = "google/flan-t5-xl" - GRANITE_13B_CHAT_V2 ModelTypes = "ibm/granite-13b-chat-v2" - GRANITE_13B_INSTRUCT_V2 ModelTypes = "ibm/granite-13b-instruct-v2" - - // https://ibm.github.io/watson-machine-learning-sdk/_modules/ibm_watson_machine_learning/foundation_models/utils/enums.html#DecodingMethods - Sample DecodingMethods = "sample" - Greedy DecodingMethods = "greedy" - - DefaultModelType = FLAN_T5_XL - DefaultDecodingMethod = Greedy -) - -type Doer interface { - Do(req *http.Request) (*http.Response, error) -} diff --git a/models/test/generate_test.go b/pkg/internal/tests/models/generate_test.go similarity index 73% rename from models/test/generate_test.go rename to pkg/internal/tests/models/generate_test.go index 4f0713a..a21c62f 100644 --- a/models/test/generate_test.go +++ b/pkg/internal/tests/models/generate_test.go @@ -4,23 +4,22 @@ import ( "os" "testing" - wx "github.com/h0rv/go-watsonx/models" + wx "github.com/h0rv/go-watsonx/pkg/models" ) func getModel(t *testing.T) *wx.Model { - apiKey := os.Getenv(wx.IBMCloudAPIKeyEnvVarName) + apiKey := os.Getenv(wx.WatsonxAPIKeyEnvVarName) projectID := os.Getenv(wx.WatsonxProjectIDEnvVarName) if apiKey == "" { - t.Fatal("No IBM Cloud API key provided") + t.Fatal("No watsonx API key provided") } if projectID == "" { t.Fatal("No watsonx project ID provided") } model, err := wx.NewModel( - wx.WithIBMCloudAPIKey(apiKey), + wx.WithWatsonxAPIKey(apiKey), wx.WithWatsonxProjectID(projectID), - wx.WithModel(wx.FLAN_UL2), ) if err != nil { t.Fatalf("Failed to create model for testing. Error: %v", err) @@ -32,7 +31,10 @@ func getModel(t *testing.T) *wx.Model { func TestEmptyPromptError(t *testing.T) { model := getModel(t) - _, err := model.GenerateText("") + _, err := model.GenerateText( + "dumby model", + "", + ) if err == nil { t.Fatalf("Expected error for an empty prompt, but got nil") } @@ -41,7 +43,11 @@ func TestEmptyPromptError(t *testing.T) { func TestNilOptions(t *testing.T) { model := getModel(t) - _, err := model.GenerateText("What day is it?", nil) + _, err := model.GenerateText( + "meta-llama/llama-3-70b-instruct", + "What day is it?", + nil, + ) if err != nil { t.Fatalf("Expected no error for nil options, but got %v", err) } @@ -50,8 +56,10 @@ func TestNilOptions(t *testing.T) { func TestValidPrompt(t *testing.T) { model := getModel(t) - prompt := "Test prompt" - _, err := model.GenerateText(prompt) + _, err := model.GenerateText( + "meta-llama/llama-3-70b-instruct", + "Test prompt", + ) if err != nil { t.Fatalf("Expected no error, but got an error: %v", err) } @@ -60,14 +68,13 @@ func TestValidPrompt(t *testing.T) { func TestGenerateText(t *testing.T) { model := getModel(t) - prompt := "Hi, who are you?" result, err := model.GenerateText( - prompt, + "meta-llama/llama-3-70b-instruct", + "Hi, who are you?", wx.WithTemperature(0.9), wx.WithTopP(.5), wx.WithTopK(10), wx.WithMaxNewTokens(512), - wx.WithDecodingMethod(wx.Greedy), ) if err != nil { t.Fatalf("Expected no error, but got an error: %v", err) @@ -80,9 +87,9 @@ func TestGenerateText(t *testing.T) { func TestGenerateTextWithNilOptions(t *testing.T) { model := getModel(t) - prompt := "Who are you?" result, err := model.GenerateText( - prompt, + "meta-llama/llama-3-70b-instruct", + "Who are you?", nil, ) if err != nil { diff --git a/models/generate.go b/pkg/models/generate.go similarity index 96% rename from models/generate.go rename to pkg/models/generate.go index 213e94a..5fc43fd 100644 --- a/models/generate.go +++ b/pkg/models/generate.go @@ -49,7 +49,7 @@ type generateTextResponse struct { } // GenerateText generates completion text based on a given prompt and parameters -func (m *Model) GenerateText(prompt string, options ...GenerateOption) (GenerateTextResult, error) { +func (m *Model) GenerateText(model, prompt string, options ...GenerateOption) (GenerateTextResult, error) { m.CheckAndRefreshToken() if prompt == "" { @@ -65,7 +65,7 @@ func (m *Model) GenerateText(prompt string, options ...GenerateOption) (Generate payload := GenerateTextPayload{ ProjectID: m.projectID, - Model: m.modelType, + Model: model, Prompt: prompt, Parameters: opts, } diff --git a/models/generate_option.go b/pkg/models/generate_option.go similarity index 100% rename from models/generate_option.go rename to pkg/models/generate_option.go diff --git a/models/iam.go b/pkg/models/iam.go similarity index 90% rename from models/iam.go rename to pkg/models/iam.go index 7422f98..06f5787 100644 --- a/models/iam.go +++ b/pkg/models/iam.go @@ -23,10 +23,10 @@ type TokenResponse struct { Expiration int64 `json:"expiration"` } -func GenerateToken(client Doer, ibmCloudAPIKey IBMCloudAPIKey) (IAMToken, error) { +func GenerateToken(client Doer, watsonxApiKey WatsonxAPIKey) (IAMToken, error) { values := url.Values{ "grant_type": {"urn:ibm:params:oauth:grant-type:apikey"}, - "apikey": {ibmCloudAPIKey}, + "apikey": {watsonxApiKey}, } payload := strings.NewReader(values.Encode()) diff --git a/models/model.go b/pkg/models/model.go similarity index 78% rename from models/model.go rename to pkg/models/model.go index f7b7b36..f61aeee 100644 --- a/models/model.go +++ b/pkg/models/model.go @@ -15,12 +15,9 @@ type Model struct { region IBMCloudRegion apiVersion string - ibmCloudAPIKey IBMCloudAPIKey - projectID WatsonxProjectID - - modelType ModelTypes - - token IAMToken + token IAMToken + apiKey WatsonxAPIKey + projectID WatsonxProjectID httpClient Doer } @@ -44,12 +41,9 @@ func NewModel(options ...ModelOption) (*Model, error) { region: opts.Region, apiVersion: opts.APIVersion, - ibmCloudAPIKey: opts.ibmCloudAPIKey, - projectID: opts.projectID, - - modelType: opts.Model, - // token: set below + apiKey: opts.watsonxAPIKey, + projectID: opts.projectID, httpClient: &http.Client{}, } @@ -72,7 +66,7 @@ func (m *Model) CheckAndRefreshToken() error { // RefreshToken generates and sets the model with a new token func (m *Model) RefreshToken() error { - token, err := GenerateToken(m.httpClient, m.ibmCloudAPIKey) + token, err := GenerateToken(m.httpClient, m.apiKey) if err != nil { return err } @@ -90,9 +84,7 @@ func defaulModelOptions() *ModelOptions { Region: DefaultRegion, APIVersion: DefaultAPIVersion, - ibmCloudAPIKey: os.Getenv(IBMCloudAPIKeyEnvVarName), - projectID: os.Getenv(WatsonxProjectIDEnvVarName), - - Model: DefaultModelType, + watsonxAPIKey: os.Getenv(WatsonxAPIKeyEnvVarName), + projectID: os.Getenv(WatsonxProjectIDEnvVarName), } } diff --git a/models/model_option.go b/pkg/models/model_option.go similarity index 68% rename from models/model_option.go rename to pkg/models/model_option.go index 6547711..af77e14 100644 --- a/models/model_option.go +++ b/pkg/models/model_option.go @@ -7,10 +7,8 @@ type ModelOptions struct { Region IBMCloudRegion APIVersion string - ibmCloudAPIKey IBMCloudAPIKey - projectID WatsonxProjectID - - Model ModelTypes + watsonxAPIKey WatsonxAPIKey + projectID WatsonxProjectID } func WithURL(url string) ModelOption { @@ -31,9 +29,9 @@ func WithAPIVersion(apiVersion string) ModelOption { } } -func WithIBMCloudAPIKey(ibmCloudAPIKey IBMCloudAPIKey) ModelOption { +func WithWatsonxAPIKey(watsonxAPIKey WatsonxAPIKey) ModelOption { return func(o *ModelOptions) { - o.ibmCloudAPIKey = ibmCloudAPIKey + o.watsonxAPIKey = watsonxAPIKey } } @@ -42,9 +40,3 @@ func WithWatsonxProjectID(projectID WatsonxProjectID) ModelOption { o.projectID = projectID } } - -func WithModel(model ModelTypes) ModelOption { - return func(o *ModelOptions) { - o.Model = model - } -} diff --git a/pkg/models/types.go b/pkg/models/types.go new file mode 100644 index 0000000..2bc0e0e --- /dev/null +++ b/pkg/models/types.go @@ -0,0 +1,36 @@ +package models + +import ( + "net/http" +) + +/* + * https://ibm.github.io/watson-machine-learning-sdk/model.html#ibm_watson_machine_learning.foundation_models.utils.enums.ModelTypes + */ + +type ( + WatsonxAPIKey = string + WatsonxProjectID = string + IBMCloudRegion = string + ModelType = string +) + +const ( + WatsonxAPIKeyEnvVarName = "WATSONX_API_KEY" + WatsonxProjectIDEnvVarName = "WATSONX_PROJECT_ID" + + US_South IBMCloudRegion = "us-south" + Dallas IBMCloudRegion = US_South + EU_DE IBMCloudRegion = "eu-de" + Frankfurt IBMCloudRegion = EU_DE + JP_TOK IBMCloudRegion = "jp-tok" + Tokyo IBMCloudRegion = JP_TOK + + DefaultRegion = US_South + BaseURLFormatStr = "%s.ml.cloud.ibm.com" // Need to call SPrintf on it with region + DefaultAPIVersion = "2024-05-20" +) + +type Doer interface { + Do(req *http.Request) (*http.Response, error) +}