Parcourir la source

Merge pull request #260 from jmorganca/embed-ggml-metal

override ggml-metal if the file is different
Michael Yang il y a 1 an
Parent
commit
cc509a994e
1 fichiers modifiés avec 35 ajouts et 8 suppressions
  1. 35 8
      llama/llama_darwin.go

+ 35 - 8
llama/llama_darwin.go

@@ -1,6 +1,8 @@
 package llama
 
 import (
+	"bytes"
+	"crypto/sha256"
 	"errors"
 	"io"
 	"log"
@@ -27,26 +29,51 @@ func initBackend() error {
 	}
 
 	metal := filepath.Join(filepath.Dir(exec), "ggml-metal.metal")
-	if _, err := os.Stat(metal); err != nil {
-		if !errors.Is(err, os.ErrNotExist) {
+	fi, err := os.Stat(metal)
+	if err != nil && !errors.Is(err, os.ErrNotExist) {
+		return err
+	}
+
+	if fi != nil {
+		actual, err := os.Open(metal)
+		if err != nil {
 			return err
 		}
 
-		dst, err := os.Create(filepath.Join(filepath.Dir(exec), "ggml-metal.metal"))
-		if err != nil {
+		actualSum := sha256.New()
+		if _, err := io.Copy(actualSum, actual); err != nil {
 			return err
 		}
-		defer dst.Close()
 
-		src, err := fs.Open("ggml-metal.metal")
+		expect, err := fs.Open("ggml-metal.metal")
 		if err != nil {
 			return err
 		}
-		defer src.Close()
 
-		if _, err := io.Copy(dst, src); err != nil {
+		expectSum := sha256.New()
+		if _, err := io.Copy(expectSum, expect); err != nil {
 			return err
 		}
+
+		if bytes.Equal(actualSum.Sum(nil), expectSum.Sum(nil)) {
+			return nil
+		}
+	}
+
+	dst, err := os.Create(filepath.Join(filepath.Dir(exec), "ggml-metal.metal"))
+	if err != nil {
+		return err
+	}
+	defer dst.Close()
+
+	src, err := fs.Open("ggml-metal.metal")
+	if err != nil {
+		return err
+	}
+	defer src.Close()
+
+	if _, err := io.Copy(dst, src); err != nil {
+		return err
 	}
 
 	return nil