|
@@ -2,14 +2,13 @@ package server
|
|
|
|
|
|
import (
|
|
|
"encoding/json"
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"net/http"
|
|
|
"os"
|
|
|
"path"
|
|
|
"strconv"
|
|
|
-
|
|
|
- "github.com/jmorganca/ollama/api"
|
|
|
)
|
|
|
|
|
|
const directoryURL = "https://ollama.ai/api/models"
|
|
@@ -36,12 +35,12 @@ func (m *Model) FullName() string {
|
|
|
return path.Join(home, ".ollama", "models", m.Name+".bin")
|
|
|
}
|
|
|
|
|
|
-func pull(model string, progressCh chan<- api.PullProgress) error {
|
|
|
- remote, err := getRemote(model)
|
|
|
- if err != nil {
|
|
|
- return fmt.Errorf("failed to pull model: %w", err)
|
|
|
- }
|
|
|
- return saveModel(remote, progressCh)
|
|
|
+func (m *Model) TempFile() string {
|
|
|
+ fullName := m.FullName()
|
|
|
+ return path.Join(
|
|
|
+ path.Dir(fullName),
|
|
|
+ fmt.Sprintf(".%s.part", path.Base(fullName)),
|
|
|
+ )
|
|
|
}
|
|
|
|
|
|
func getRemote(model string) (*Model, error) {
|
|
@@ -68,7 +67,7 @@ func getRemote(model string) (*Model, error) {
|
|
|
return nil, fmt.Errorf("model not found in directory: %s", model)
|
|
|
}
|
|
|
|
|
|
-func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
|
|
+func saveModel(model *Model, fn func(total, completed int64)) error {
|
|
|
// this models cache directory is created by the server on startup
|
|
|
|
|
|
client := &http.Client{}
|
|
@@ -76,41 +75,45 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
|
|
if err != nil {
|
|
|
return fmt.Errorf("failed to download model: %w", err)
|
|
|
}
|
|
|
- // check for resume
|
|
|
- alreadyDownloaded := int64(0)
|
|
|
- fileInfo, err := os.Stat(model.FullName())
|
|
|
- if err != nil {
|
|
|
- if !os.IsNotExist(err) {
|
|
|
- return fmt.Errorf("failed to check resume model file: %w", err)
|
|
|
- }
|
|
|
- // file doesn't exist, create it now
|
|
|
- } else {
|
|
|
- alreadyDownloaded = fileInfo.Size()
|
|
|
- req.Header.Add("Range", fmt.Sprintf("bytes=%d-", alreadyDownloaded))
|
|
|
+
|
|
|
+ // check if completed file exists
|
|
|
+ fi, err := os.Stat(model.FullName())
|
|
|
+ switch {
|
|
|
+ case errors.Is(err, os.ErrNotExist):
|
|
|
+ // noop, file doesn't exist so create it
|
|
|
+ case err != nil:
|
|
|
+ return fmt.Errorf("stat: %w", err)
|
|
|
+ default:
|
|
|
+ fn(fi.Size(), fi.Size())
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ var size int64
|
|
|
+
|
|
|
+ // completed file doesn't exist, check partial file
|
|
|
+ fi, err = os.Stat(model.TempFile())
|
|
|
+ switch {
|
|
|
+ case errors.Is(err, os.ErrNotExist):
|
|
|
+ // noop, file doesn't exist so create it
|
|
|
+ case err != nil:
|
|
|
+ return fmt.Errorf("stat: %w", err)
|
|
|
+ default:
|
|
|
+ size = fi.Size()
|
|
|
}
|
|
|
|
|
|
+ req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size))
|
|
|
+
|
|
|
resp, err := client.Do(req)
|
|
|
if err != nil {
|
|
|
return fmt.Errorf("failed to download model: %w", err)
|
|
|
}
|
|
|
-
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
- if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
|
|
|
- // already downloaded
|
|
|
- progressCh <- api.PullProgress{
|
|
|
- Total: alreadyDownloaded,
|
|
|
- Completed: alreadyDownloaded,
|
|
|
- Percent: 100,
|
|
|
- }
|
|
|
- return nil
|
|
|
- }
|
|
|
-
|
|
|
- if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
|
|
|
+ if resp.StatusCode >= 400 {
|
|
|
return fmt.Errorf("failed to download model: %s", resp.Status)
|
|
|
}
|
|
|
|
|
|
- out, err := os.OpenFile(model.FullName(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
|
|
+ out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
|
|
if err != nil {
|
|
|
panic(err)
|
|
|
}
|
|
@@ -118,37 +121,23 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
|
|
|
|
|
totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
|
|
|
|
|
- buf := make([]byte, 1024)
|
|
|
- totalBytes := alreadyDownloaded
|
|
|
- totalSize += alreadyDownloaded
|
|
|
+ totalBytes := size
|
|
|
+ totalSize += size
|
|
|
|
|
|
for {
|
|
|
- n, err := resp.Body.Read(buf)
|
|
|
- if err != nil && err != io.EOF {
|
|
|
+ n, err := io.CopyN(out, resp.Body, 8192)
|
|
|
+ if err != nil && !errors.Is(err, io.EOF) {
|
|
|
return err
|
|
|
}
|
|
|
+
|
|
|
if n == 0 {
|
|
|
break
|
|
|
}
|
|
|
- if _, err := out.Write(buf[:n]); err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
-
|
|
|
- totalBytes += int64(n)
|
|
|
-
|
|
|
- // send progress updates
|
|
|
- progressCh <- api.PullProgress{
|
|
|
- Total: totalSize,
|
|
|
- Completed: totalBytes,
|
|
|
- Percent: float64(totalBytes) / float64(totalSize) * 100,
|
|
|
- }
|
|
|
- }
|
|
|
|
|
|
- progressCh <- api.PullProgress{
|
|
|
- Total: totalSize,
|
|
|
- Completed: totalSize,
|
|
|
- Percent: 100,
|
|
|
+ totalBytes += n
|
|
|
+ fn(totalSize, totalBytes)
|
|
|
}
|
|
|
|
|
|
- return nil
|
|
|
+ fn(totalSize, totalSize)
|
|
|
+ return os.Rename(model.TempFile(), model.FullName())
|
|
|
}
|