Browse Source

Merge pull request #760 from jmorganca/mxyng/more-downloads

Mxyng/more downloads
Michael Yang 1 year ago
parent
commit
788637918a
5 changed files with 70 additions and 33 deletions
  1. 1 1
      api/client.go
  2. 16 0
      format/bytes.go
  3. 1 1
      llm/llama.go
  4. 18 18
      llm/llm.go
  5. 34 13
      server/download.go

+ 1 - 1
api/client.go

@@ -127,7 +127,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
 	return nil
 }
 
-const maxBufferSize = 512 * 1024 // 512KB
+const maxBufferSize = 512 * 1000 // 512KB
 
 func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
 	var buf *bytes.Buffer

+ 16 - 0
format/bytes.go

@@ -0,0 +1,16 @@
+package format
+
+import "fmt"
+
+func HumanBytes(b int64) string {
+	switch {
+	case b > 1000*1000*1000:
+		return fmt.Sprintf("%d GB", b/1000/1000/1000)
+	case b > 1000*1000:
+		return fmt.Sprintf("%d MB", b/1000/1000)
+	case b > 1000:
+		return fmt.Sprintf("%d KB", b/1000)
+	default:
+		return fmt.Sprintf("%d B", b)
+	}
+}

+ 1 - 1
llm/llama.go

@@ -454,7 +454,7 @@ type PredictRequest struct {
 	Stop             []string `json:"stop,omitempty"`
 }
 
-const maxBufferSize = 512 * 1024 // 512KB
+const maxBufferSize = 512 * 1000 // 512KB
 
 func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
 	prevConvo, err := llm.Decode(ctx, prevContext)

+ 18 - 18
llm/llm.go

@@ -60,33 +60,33 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
 	totalResidentMemory := memory.TotalMemory()
 	switch ggml.ModelType() {
 	case "3B", "7B":
-		if ggml.FileType() == "F16" && totalResidentMemory < 16*1024*1024 {
-			return nil, fmt.Errorf("F16 model requires at least 16GB of memory")
-		} else if totalResidentMemory < 8*1024*1024 {
-			return nil, fmt.Errorf("model requires at least 8GB of memory")
+		if ggml.FileType() == "F16" && totalResidentMemory < 16*1000*1000 {
+			return nil, fmt.Errorf("F16 model requires at least 16 GB of memory")
+		} else if totalResidentMemory < 8*1000*1000 {
+			return nil, fmt.Errorf("model requires at least 8 GB of memory")
 		}
 	case "13B":
-		if ggml.FileType() == "F16" && totalResidentMemory < 32*1024*1024 {
-			return nil, fmt.Errorf("F16 model requires at least 32GB of memory")
-		} else if totalResidentMemory < 16*1024*1024 {
-			return nil, fmt.Errorf("model requires at least 16GB of memory")
+		if ggml.FileType() == "F16" && totalResidentMemory < 32*1000*1000 {
+			return nil, fmt.Errorf("F16 model requires at least 32 GB of memory")
+		} else if totalResidentMemory < 16*1000*1000 {
+			return nil, fmt.Errorf("model requires at least 16 GB of memory")
 		}
 	case "30B", "34B", "40B":
-		if ggml.FileType() == "F16" && totalResidentMemory < 64*1024*1024 {
-			return nil, fmt.Errorf("F16 model requires at least 64GB of memory")
-		} else if totalResidentMemory < 32*1024*1024 {
-			return nil, fmt.Errorf("model requires at least 32GB of memory")
+		if ggml.FileType() == "F16" && totalResidentMemory < 64*1000*1000 {
+			return nil, fmt.Errorf("F16 model requires at least 64 GB of memory")
+		} else if totalResidentMemory < 32*1000*1000 {
+			return nil, fmt.Errorf("model requires at least 32 GB of memory")
 		}
 	case "65B", "70B":
-		if ggml.FileType() == "F16" && totalResidentMemory < 128*1024*1024 {
-			return nil, fmt.Errorf("F16 model requires at least 128GB of memory")
-		} else if totalResidentMemory < 64*1024*1024 {
-			return nil, fmt.Errorf("model requires at least 64GB of memory")
+		if ggml.FileType() == "F16" && totalResidentMemory < 128*1000*1000 {
+			return nil, fmt.Errorf("F16 model requires at least 128 GB of memory")
+		} else if totalResidentMemory < 64*1000*1000 {
+			return nil, fmt.Errorf("model requires at least 64 GB of memory")
 		}
 	case "180B":
-		if ggml.FileType() == "F16" && totalResidentMemory < 512*1024*1024 {
+		if ggml.FileType() == "F16" && totalResidentMemory < 512*1000*1000 {
 			return nil, fmt.Errorf("F16 model requires at least 512GB of memory")
-		} else if totalResidentMemory < 128*1024*1024 {
+		} else if totalResidentMemory < 128*1000*1000 {
 			return nil, fmt.Errorf("model requires at least 128GB of memory")
 		}
 	}

+ 34 - 13
server/download.go

@@ -20,6 +20,7 @@ import (
 	"golang.org/x/sync/errgroup"
 
 	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/format"
 )
 
 var blobDownloadManager sync.Map
@@ -34,6 +35,9 @@ type blobDownload struct {
 	Parts []*blobDownloadPart
 
 	context.CancelFunc
+
+	done       bool
+	err        error
 	references atomic.Int32
 }
 
@@ -46,6 +50,12 @@ type blobDownloadPart struct {
 	*blobDownload `json:"-"`
 }
 
+const (
+	numDownloadParts          = 64
+	minDownloadPartSize int64 = 32 * 1000 * 1000
+	maxDownloadPartSize int64 = 256 * 1000 * 1000
+)
+
 func (p *blobDownloadPart) Name() string {
 	return strings.Join([]string{
 		p.blobDownload.Name, "partial", strconv.Itoa(p.N),
@@ -91,9 +101,15 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
 
 		b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
 
-		var offset int64
-		var size int64 = 64 * 1024 * 1024
+		var size = b.Total / numDownloadParts
+		switch {
+		case size < minDownloadPartSize:
+			size = minDownloadPartSize
+		case size > maxDownloadPartSize:
+			size = maxDownloadPartSize
+		}
 
+		var offset int64
 		for offset < b.Total {
 			if offset+size > b.Total {
 				size = b.Total - offset
@@ -107,11 +123,15 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
 		}
 	}
 
-	log.Printf("downloading %s in %d part(s)", b.Digest[7:19], len(b.Parts))
+	log.Printf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size))
 	return nil
 }
 
-func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) (err error) {
+func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) {
+	b.err = b.run(ctx, requestURL, opts)
+}
+
+func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
 	defer blobDownloadManager.Delete(b.Digest)
 
 	ctx, b.CancelFunc = context.WithCancel(ctx)
@@ -124,9 +144,8 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
 
 	file.Truncate(b.Total)
 
-	g, ctx := errgroup.WithContext(ctx)
-	// TODO(mxyng): download concurrency should be configurable
-	g.SetLimit(64)
+	g, _ := errgroup.WithContext(ctx)
+	g.SetLimit(numDownloadParts)
 	for i := range b.Parts {
 		part := b.Parts[i]
 		if part.Completed == part.Size {
@@ -168,7 +187,12 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
 		}
 	}
 
-	return os.Rename(file.Name(), b.Name)
+	if err := os.Rename(file.Name(), b.Name); err != nil {
+		return err
+	}
+
+	b.done = true
+	return nil
 }
 
 func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
@@ -267,11 +291,8 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
 			Completed: b.Completed.Load(),
 		})
 
-		if b.Completed.Load() >= b.Total {
-			// wait for the file to get renamed
-			if _, err := os.Stat(b.Name); err == nil {
-				return nil
-			}
+		if b.done || b.err != nil {
+			return b.err
 		}
 	}
 }