Skip to content

Commit

Permalink
feat(models): pull models from urls (#2750)
Browse files Browse the repository at this point in the history
* feat(models): pull models from urls

When using `run` now we can point directly to hf models via URL, for
instance:

```bash
local-ai run
huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf
```

Will pull the gguf model and place it in the models folder - of course
this depends on the fact that the gguf file should be automatically
detected by our guesser mechanism in order to this to make effective.

Similarly now galleries can refer to single files in the API requests.

This also changes the download code and `yaml` files now are treated in
the same way, so now config files are saved with the appropriate name
(and not hashed anymore).

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Adapt tests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
  • Loading branch information
mudler authored Jul 11, 2024
1 parent b60acab commit b6b8ab6
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 9 deletions.
48 changes: 41 additions & 7 deletions pkg/startup/model_preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package startup
import (
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -77,19 +78,35 @@ func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath

log.Info().Msgf("[startup] installed model from OCI repository: %s", ociName)
case downloader.LooksLikeURL(url):
log.Debug().Msgf("[startup] resolved model to download: %s", url)
log.Debug().Msgf("[startup] downloading %s", url)

// Extract filename from URL
fileName, e := filenameFromUrl(url)
if e != nil || fileName == "" {
fileName = utils.MD5(url)
if strings.HasSuffix(url, ".yaml") || strings.HasSuffix(url, ".yml") {
fileName = fileName + ".yaml"
}
log.Warn().Err(e).Str("url", url).Msg("error extracting filename from URL")
//err = errors.Join(err, e)
//continue
}

// md5 of model name
md5Name := utils.MD5(url)
modelPath := filepath.Join(modelPath, fileName)

if e := utils.VerifyPath(fileName, modelPath); e != nil {
log.Error().Err(e).Str("filepath", modelPath).Msg("error verifying path")
err = errors.Join(err, e)
continue
}

// check if file exists
if _, e := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(e, os.ErrNotExist) {
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
e := downloader.DownloadFile(url, modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) {
if _, e := os.Stat(modelPath); errors.Is(e, os.ErrNotExist) {
e := downloader.DownloadFile(url, modelPath, "", 0, 0, func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
})
if e != nil {
log.Error().Err(e).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model")
log.Error().Err(e).Str("url", url).Str("filepath", modelPath).Msg("error downloading model")
err = errors.Join(err, e)
}
}
Expand Down Expand Up @@ -150,3 +167,20 @@ func installModel(galleries []config.Gallery, modelName, modelPath string, downl

return nil, true
}

func filenameFromUrl(urlstr string) (string, error) {
// strip anything after @
if strings.Contains(urlstr, "@") {
urlstr = strings.Split(urlstr, "@")[0]
}

u, err := url.Parse(urlstr)
if err != nil {
return "", fmt.Errorf("error due to parsing url: %w", err)
}
x, err := url.QueryUnescape(u.EscapedPath())
if err != nil {
return "", fmt.Errorf("error due to escaping: %w", err)
}
return filepath.Base(x), nil
}
18 changes: 16 additions & 2 deletions pkg/startup/model_preload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ var _ = Describe("Preload test", func() {
tmpdir, err := os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
libraryURL := "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml"
fileName := fmt.Sprintf("%s.yaml", "1701d57f28d47552516c2b6ecc3cc719")
fileName := fmt.Sprintf("%s.yaml", "phi-2")

InstallModels([]config.Gallery{}, libraryURL, tmpdir, true, nil, "phi-2")

Expand All @@ -36,7 +36,7 @@ var _ = Describe("Preload test", func() {
tmpdir, err := os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml"
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
fileName := fmt.Sprintf("%s.yaml", "phi-2")

InstallModels([]config.Gallery{}, "", tmpdir, true, nil, url)

Expand Down Expand Up @@ -79,5 +79,19 @@ var _ = Describe("Preload test", func() {

Expect(string(content)).To(ContainSubstring("name: mistral-openorca"))
})
It("downloads from urls", func() {
tmpdir, err := os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
url := "huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
fileName := fmt.Sprintf("%s.gguf", "tinyllama-1.1b-chat-v0.3.Q2_K")

err = InstallModels([]config.Gallery{}, "", tmpdir, false, nil, url)
Expect(err).ToNot(HaveOccurred())

resultFile := filepath.Join(tmpdir, fileName)

_, err = os.Stat(resultFile)
Expect(err).ToNot(HaveOccurred())
})
})
})

0 comments on commit b6b8ab6

Please sign in to comment.