浏览代码

Merge pull request #461 from jmorganca/mxyng/fix-inherit-params

fix inherit params
Michael Yang 1 年之前
父节点
当前提交
d1c2558f7e
共有 3 个文件被更改,包括 33 次插入8 次删除
  1. 1 0
      go.mod
  2. 2 0
      go.sum
  3. 30 8
      server/images.go

+ 1 - 0
go.mod

@@ -39,6 +39,7 @@ require (
 	github.com/ugorji/go/codec v1.2.11 // indirect
 	github.com/ugorji/go/codec v1.2.11 // indirect
 	golang.org/x/arch v0.3.0 // indirect
 	golang.org/x/arch v0.3.0 // indirect
 	golang.org/x/crypto v0.10.0
 	golang.org/x/crypto v0.10.0
+	golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
 	golang.org/x/net v0.10.0 // indirect
 	golang.org/x/net v0.10.0 // indirect
 	golang.org/x/sys v0.11.0 // indirect
 	golang.org/x/sys v0.11.0 // indirect
 	golang.org/x/term v0.10.0
 	golang.org/x/term v0.10.0

+ 2 - 0
go.sum

@@ -121,6 +121,8 @@ golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5y
 golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM=
 golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM=
 golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
 golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
 golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
 golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
+golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
+golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
 golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
 golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
 golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
 golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=

+ 30 - 8
server/images.go

@@ -22,6 +22,8 @@ import (
 	"strings"
 	"strings"
 	"text/template"
 	"text/template"
 
 
+	"golang.org/x/exp/slices"
+
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/llm"
 	"github.com/jmorganca/ollama/llm"
 	"github.com/jmorganca/ollama/parser"
 	"github.com/jmorganca/ollama/parser"
@@ -274,6 +276,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 
 
 	var layers []*LayerReader
 	var layers []*LayerReader
 	params := make(map[string][]string)
 	params := make(map[string][]string)
+	var sourceParams map[string]any
 	embed := EmbeddingParams{fn: fn}
 	embed := EmbeddingParams{fn: fn}
 	for _, c := range commands {
 	for _, c := range commands {
 		log.Printf("[%s] - %s\n", c.Name, c.Args)
 		log.Printf("[%s] - %s\n", c.Name, c.Args)
@@ -357,6 +360,23 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 				config.FileType = source.FileType
 				config.FileType = source.FileType
 
 
 				for _, l := range mf.Layers {
 				for _, l := range mf.Layers {
+					if l.MediaType == "application/vnd.ollama.image.params" {
+						sourceParamsBlobPath, err := GetBlobsPath(l.Digest)
+						if err != nil {
+							return err
+						}
+
+						sourceParamsBlob, err := os.Open(sourceParamsBlobPath)
+						if err != nil {
+							return err
+						}
+						defer sourceParamsBlob.Close()
+
+						if err := json.NewDecoder(sourceParamsBlob).Decode(&sourceParams); err != nil {
+							return err
+						}
+					}
+
 					newLayer, err := GetLayerWithBufferFromLayer(l)
 					newLayer, err := GetLayerWithBufferFromLayer(l)
 					if err != nil {
 					if err != nil {
 						return err
 						return err
@@ -427,12 +447,19 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 	// Create a single layer for the parameters
 	// Create a single layer for the parameters
 	if len(params) > 0 {
 	if len(params) > 0 {
 		fn(api.ProgressResponse{Status: "creating parameter layer"})
 		fn(api.ProgressResponse{Status: "creating parameter layer"})
+
 		layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
 		layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
 		formattedParams, err := formatParams(params)
 		formattedParams, err := formatParams(params)
 		if err != nil {
 		if err != nil {
 			return fmt.Errorf("couldn't create params json: %v", err)
 			return fmt.Errorf("couldn't create params json: %v", err)
 		}
 		}
 
 
+		for k, v := range sourceParams {
+			if _, ok := formattedParams[k]; !ok {
+				formattedParams[k] = v
+			}
+		}
+
 		bts, err := json.Marshal(formattedParams)
 		bts, err := json.Marshal(formattedParams)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
@@ -630,14 +657,9 @@ func existingFileEmbeddings(digest string) (map[string][]float64, error) {
 }
 }
 
 
 func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
 func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
-	j := 0
-	for _, l := range layers {
-		if l.MediaType != mediaType {
-			layers[j] = l
-			j++
-		}
-	}
-	return layers[:j]
+	return slices.DeleteFunc(layers, func(layer *LayerReader) bool {
+		return layer.MediaType == mediaType
+	})
 }
 }
 
 
 func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
 func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {