Browse Source

use int64 consistently

Michael Yang 1 năm trước cách đây
mục cha
commit
f40b3de758
7 tập tin đã thay đổi với 59 bổ sung59 xóa
  1. 3 3
      api/types.go
  2. 11 11
      cmd/cmd.go
  3. 3 3
      llm/llama.go
  4. 14 14
      server/download.go
  5. 8 8
      server/images.go
  6. 3 3
      server/modelpath_test.go
  7. 17 17
      server/upload.go

+ 3 - 3
api/types.go

@@ -88,8 +88,8 @@ type PullRequest struct {
 type ProgressResponse struct {
 	Status    string `json:"status"`
 	Digest    string `json:"digest,omitempty"`
-	Total     int    `json:"total,omitempty"`
-	Completed int    `json:"completed,omitempty"`
+	Total     int64  `json:"total,omitempty"`
+	Completed int64  `json:"completed,omitempty"`
 }
 
 type PushRequest struct {
@@ -106,7 +106,7 @@ type ListResponse struct {
 type ModelResponse struct {
 	Name       string    `json:"name"`
 	ModifiedAt time.Time `json:"modified_at"`
-	Size       int       `json:"size"`
+	Size       int64     `json:"size"`
 	Digest     string    `json:"digest"`
 }
 

+ 11 - 11
cmd/cmd.go

@@ -78,18 +78,18 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 			currentDigest = resp.Digest
 			switch {
 			case strings.Contains(resp.Status, "embeddings"):
-				bar = progressbar.Default(int64(resp.Total), resp.Status)
-				bar.Set(resp.Completed)
+				bar = progressbar.Default(resp.Total, resp.Status)
+				bar.Set64(resp.Completed)
 			default:
 				// pulling
 				bar = progressbar.DefaultBytes(
-					int64(resp.Total),
+					resp.Total,
 					resp.Status,
 				)
-				bar.Set(resp.Completed)
+				bar.Set64(resp.Completed)
 			}
 		} else if resp.Digest == currentDigest && resp.Digest != "" {
-			bar.Set(resp.Completed)
+			bar.Set64(resp.Completed)
 		} else {
 			currentDigest = ""
 			if spinner != nil {
@@ -160,13 +160,13 @@ func PushHandler(cmd *cobra.Command, args []string) error {
 		if resp.Digest != currentDigest && resp.Digest != "" {
 			currentDigest = resp.Digest
 			bar = progressbar.DefaultBytes(
-				int64(resp.Total),
+				resp.Total,
 				fmt.Sprintf("pushing %s...", resp.Digest[7:19]),
 			)
 
-			bar.Set(resp.Completed)
+			bar.Set64(resp.Completed)
 		} else if resp.Digest == currentDigest && resp.Digest != "" {
-			bar.Set(resp.Completed)
+			bar.Set64(resp.Completed)
 		} else {
 			currentDigest = ""
 			fmt.Println(resp.Status)
@@ -349,13 +349,13 @@ func pull(model string, insecure bool) error {
 		if resp.Digest != currentDigest && resp.Digest != "" {
 			currentDigest = resp.Digest
 			bar = progressbar.DefaultBytes(
-				int64(resp.Total),
+				resp.Total,
 				fmt.Sprintf("pulling %s...", resp.Digest[7:19]),
 			)
 
-			bar.Set(resp.Completed)
+			bar.Set64(resp.Completed)
 		} else if resp.Digest == currentDigest && resp.Digest != "" {
-			bar.Set(resp.Completed)
+			bar.Set64(resp.Completed)
 		} else {
 			currentDigest = ""
 			fmt.Println(resp.Status)

+ 3 - 3
llm/llama.go

@@ -187,7 +187,7 @@ type llama struct {
 var errNoGPU = errors.New("nvidia-smi command failed")
 
 // CheckVRAM returns the available VRAM in MiB on Linux machines with NVIDIA GPUs
-func CheckVRAM() (int, error) {
+func CheckVRAM() (int64, error) {
 	cmd := exec.Command("nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits")
 	var stdout bytes.Buffer
 	cmd.Stdout = &stdout
@@ -196,11 +196,11 @@ func CheckVRAM() (int, error) {
 		return 0, errNoGPU
 	}
 
-	var total int
+	var total int64
 	scanner := bufio.NewScanner(&stdout)
 	for scanner.Scan() {
 		line := scanner.Text()
-		vram, err := strconv.Atoi(line)
+		vram, err := strconv.ParseInt(strings.TrimSpace(line), 10, 64)
 		if err != nil {
 			return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
 		}

+ 14 - 14
server/download.go

@@ -46,8 +46,8 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
 		// we already have the file, so return
 		opts.fn(api.ProgressResponse{
 			Digest:    opts.digest,
-			Total:     int(fi.Size()),
-			Completed: int(fi.Size()),
+			Total:     fi.Size(),
+			Completed: fi.Size(),
 		})
 
 		return nil
@@ -93,8 +93,8 @@ func monitorDownload(ctx context.Context, opts downloadOpts, f *FileDownload) er
 					// successful download while monitoring
 					opts.fn(api.ProgressResponse{
 						Digest:    f.Digest,
-						Total:     int(fi.Size()),
-						Completed: int(fi.Size()),
+						Total:     fi.Size(),
+						Completed: fi.Size(),
 					})
 					return true, false, nil
 				}
@@ -109,8 +109,8 @@ func monitorDownload(ctx context.Context, opts downloadOpts, f *FileDownload) er
 			opts.fn(api.ProgressResponse{
 				Status:    fmt.Sprintf("downloading %s", f.Digest),
 				Digest:    f.Digest,
-				Total:     int(f.Total),
-				Completed: int(f.Completed),
+				Total:     f.Total,
+				Completed: f.Completed,
 			})
 			return false, false, nil
 		}()
@@ -129,8 +129,8 @@ func monitorDownload(ctx context.Context, opts downloadOpts, f *FileDownload) er
 }
 
 var (
-	chunkSize   = 1024 * 1024 // 1 MiB in bytes
-	errDownload = fmt.Errorf("download failed")
+	chunkSize   int64 = 1024 * 1024 // 1 MiB in bytes
+	errDownload       = fmt.Errorf("download failed")
 )
 
 // doDownload downloads a blob from the registry and stores it in the blobs directory
@@ -147,7 +147,7 @@ func doDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error {
 	default:
 		size = fi.Size()
 		// Ensure the size is divisible by the chunk size by removing excess bytes
-		size -= size % int64(chunkSize)
+		size -= size % chunkSize
 
 		err := os.Truncate(f.FilePath+"-partial", size)
 		if err != nil {
@@ -200,8 +200,8 @@ outerLoop:
 			opts.fn(api.ProgressResponse{
 				Status:    fmt.Sprintf("downloading %s", f.Digest),
 				Digest:    f.Digest,
-				Total:     int(f.Total),
-				Completed: int(f.Completed),
+				Total:     f.Total,
+				Completed: f.Completed,
 			})
 
 			if f.Completed >= f.Total {
@@ -213,8 +213,8 @@ outerLoop:
 					opts.fn(api.ProgressResponse{
 						Status:    fmt.Sprintf("error renaming file: %v", err),
 						Digest:    f.Digest,
-						Total:     int(f.Total),
-						Completed: int(f.Completed),
+						Total:     f.Total,
+						Completed: f.Completed,
 					})
 					return err
 				}
@@ -223,7 +223,7 @@ outerLoop:
 			}
 		}
 
-		n, err := io.CopyN(out, resp.Body, int64(chunkSize))
+		n, err := io.CopyN(out, resp.Body, chunkSize)
 		if err != nil && !errors.Is(err, io.EOF) {
 			return fmt.Errorf("%w: %w", errDownload, err)
 		}

+ 8 - 8
server/images.go

@@ -103,7 +103,7 @@ type ManifestV2 struct {
 type Layer struct {
 	MediaType string `json:"mediaType"`
 	Digest    string `json:"digest"`
-	Size      int    `json:"size"`
+	Size      int64  `json:"size"`
 	From      string `json:"from,omitempty"`
 }
 
@@ -129,11 +129,11 @@ type RootFS struct {
 	DiffIDs []string `json:"diff_ids"`
 }
 
-func (m *ManifestV2) GetTotalSize() int {
-	var total int
+func (m *ManifestV2) GetTotalSize() (total int64) {
 	for _, layer := range m.Layers {
 		total += layer.Size
 	}
+
 	total += m.Config.Size
 	return total
 }
@@ -649,8 +649,8 @@ func embeddingLayers(workDir string, e EmbeddingParams) ([]*LayerReader, error)
 					e.fn(api.ProgressResponse{
 						Status:    fmt.Sprintf("creating embeddings for file %s", filePath),
 						Digest:    fileDigest,
-						Total:     len(data) - 1,
-						Completed: i,
+						Total:     int64(len(data) - 1),
+						Completed: int64(i),
 					})
 					if len(existing[d]) > 0 {
 						// already have an embedding for this line
@@ -675,7 +675,7 @@ func embeddingLayers(workDir string, e EmbeddingParams) ([]*LayerReader, error)
 					Layer: Layer{
 						MediaType: "application/vnd.ollama.image.embed",
 						Digest:    digest,
-						Size:      r.Len(),
+						Size:      r.Size(),
 					},
 					Reader: r,
 				}
@@ -1356,14 +1356,14 @@ func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
 }
 
 // GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
-func GetSHA256Digest(r io.Reader) (string, int) {
+func GetSHA256Digest(r io.Reader) (string, int64) {
 	h := sha256.New()
 	n, err := io.Copy(h, r)
 	if err != nil {
 		log.Fatal(err)
 	}
 
-	return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
+	return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
 }
 
 // Function to check if a blob already exists in the Docker registry

+ 3 - 3
server/modelpath_test.go

@@ -4,9 +4,9 @@ import "testing"
 
 func TestParseModelPath(t *testing.T) {
 	tests := []struct {
-		name    string
-		arg    string
-		want    ModelPath
+		name string
+		arg  string
+		want ModelPath
 	}{
 		{
 			"full path https",

+ 17 - 17
server/upload.go

@@ -15,8 +15,8 @@ import (
 )
 
 const (
-	redirectChunkSize = 1024 * 1024 * 1024
-	regularChunkSize  = 95 * 1024 * 1024
+	redirectChunkSize int64 = 1024 * 1024 * 1024
+	regularChunkSize  int64 = 95 * 1024 * 1024
 )
 
 func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, int64, error) {
@@ -48,7 +48,7 @@ func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *Regis
 		return nil, 0, err
 	}
 
-	return locationURL, int64(chunkSize), nil
+	return locationURL, chunkSize, nil
 }
 
 func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSize int64, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
@@ -73,10 +73,10 @@ func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSiz
 		fn:     fn,
 	}
 
-	for offset := int64(0); offset < int64(layer.Size); {
-		chunk := int64(layer.Size) - offset
-		if chunk > int64(chunkSize) {
-			chunk = int64(chunkSize)
+	for offset := int64(0); offset < layer.Size; {
+		chunk := layer.Size - offset
+		if chunk > chunkSize {
+			chunk = chunkSize
 		}
 
 		resp, err := uploadBlobChunk(ctx, http.MethodPatch, requestURL, f, offset, chunk, regOpts, &pw)
@@ -85,7 +85,7 @@ func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSiz
 				Status:    fmt.Sprintf("error uploading chunk: %v", err),
 				Digest:    layer.Digest,
 				Total:     layer.Size,
-				Completed: int(offset),
+				Completed: offset,
 			})
 
 			return err
@@ -127,7 +127,7 @@ func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSiz
 }
 
 func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r io.ReaderAt, offset, limit int64, opts *RegistryOptions, pw *ProgressWriter) (*http.Response, error) {
-	sectionReader := io.NewSectionReader(r, int64(offset), limit)
+	sectionReader := io.NewSectionReader(r, offset, limit)
 
 	headers := make(http.Header)
 	headers.Set("Content-Type", "application/octet-stream")
@@ -152,7 +152,7 @@ func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r
 				return nil, err
 			}
 
-			pw.completed = int(offset)
+			pw.completed = offset
 			if _, err := uploadBlobChunk(ctx, http.MethodPut, location, r, offset, limit, nil, pw); err != nil {
 				// retry
 				log.Printf("retrying redirected upload: %v", err)
@@ -170,7 +170,7 @@ func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r
 
 			opts.Token = token
 
-			pw.completed = int(offset)
+			pw.completed = offset
 			sectionReader = io.NewSectionReader(r, offset, limit)
 			continue
 		case resp.StatusCode >= http.StatusBadRequest:
@@ -187,19 +187,19 @@ func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r
 type ProgressWriter struct {
 	status    string
 	digest    string
-	bucket    int
-	completed int
-	total     int
+	bucket    int64
+	completed int64
+	total     int64
 	fn        func(api.ProgressResponse)
 }
 
 func (pw *ProgressWriter) Write(b []byte) (int, error) {
 	n := len(b)
-	pw.bucket += n
-	pw.completed += n
+	pw.bucket += int64(n)
 
 	// throttle status updates to not spam the client
-	if pw.bucket >= 1024*1024 || pw.completed >= pw.total {
+	if pw.bucket >= 1024*1024 || pw.completed+pw.bucket >= pw.total {
+		pw.completed += pw.bucket
 		pw.fn(api.ProgressResponse{
 			Status:    pw.status,
 			Digest:    pw.digest,