|
@@ -9,15 +9,14 @@ import (
|
|
|
"os"
|
|
|
"path"
|
|
|
"strconv"
|
|
|
+
|
|
|
+ "github.com/jmorganca/ollama/api"
|
|
|
)
|
|
|
|
|
|
// const directoryURL = "https://ollama.ai/api/models"
|
|
|
+// TODO
|
|
|
const directoryURL = "https://raw.githubusercontent.com/jmorganca/ollama/go/models.json"
|
|
|
|
|
|
-type directoryCtxKey string
|
|
|
-
|
|
|
-var dirCtx directoryCtxKey = "directory"
|
|
|
-
|
|
|
type Model struct {
|
|
|
Name string `json:"name"`
|
|
|
DisplayName string `json:"display_name"`
|
|
@@ -31,7 +30,7 @@ type Model struct {
|
|
|
License string `json:"license"`
|
|
|
}
|
|
|
|
|
|
-func pull(model string, progressCh chan<- string) error {
|
|
|
+func pull(model string, progressCh chan<- api.PullProgress) error {
|
|
|
remote, err := getRemote(model)
|
|
|
if err != nil {
|
|
|
return fmt.Errorf("failed to pull model: %w", err)
|
|
@@ -64,7 +63,7 @@ func getRemote(model string) (*Model, error) {
|
|
|
return nil, fmt.Errorf("model not found in directory: %s", model)
|
|
|
}
|
|
|
|
|
|
-func saveModel(model *Model, progressCh chan<- string) error {
|
|
|
+func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
|
|
// this models cache directory is created by the server on startup
|
|
|
home, err := os.UserHomeDir()
|
|
|
if err != nil {
|
|
@@ -130,11 +129,18 @@ func saveModel(model *Model, progressCh chan<- string) error {
|
|
|
totalBytes += n
|
|
|
|
|
|
// send progress updates
|
|
|
- progressCh <- fmt.Sprintf("Downloaded %d out of %d bytes (%.2f%%)", totalBytes, totalSize, float64(totalBytes)/float64(totalSize)*100)
|
|
|
+ progressCh <- api.PullProgress{
|
|
|
+ Total: totalSize,
|
|
|
+ Completed: totalBytes,
|
|
|
+ Percent: float64(totalBytes) / float64(totalSize) * 100,
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- // send completion message
|
|
|
- progressCh <- "Download complete!"
|
|
|
+ progressCh <- api.PullProgress{
|
|
|
+ Total: totalSize,
|
|
|
+ Completed: totalSize,
|
|
|
+ Percent: 100,
|
|
|
+ }
|
|
|
|
|
|
return nil
|
|
|
}
|