浏览代码

client create modelfile

Michael Yang 1 年之前
父节点
当前提交
1552cee59f
共有 4 个文件被更改,包括 147 次插入14 次删除
  1. 26 1
      api/client.go
  2. 4 0
      api/types.go
  3. 60 13
      cmd/cmd.go
  4. 57 0
      server/routes.go

+ 26 - 1
api/client.go

@@ -5,6 +5,7 @@ import (
 	"bytes"
 	"context"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"net"
@@ -95,11 +96,19 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
 	var reqBody io.Reader
 	var data []byte
 	var err error
-	if reqData != nil {
+
+	switch reqData := reqData.(type) {
+	case io.Reader:
+		// reqData is already an io.Reader
+		reqBody = reqData
+	case nil:
+		// noop
+	default:
 		data, err = json.Marshal(reqData)
 		if err != nil {
 			return err
 		}
+
 		reqBody = bytes.NewReader(data)
 	}
 
@@ -287,3 +296,19 @@ func (c *Client) Heartbeat(ctx context.Context) error {
 	}
 	return nil
 }
+
+func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) (string, error) {
+	var response CreateBlobResponse
+	if err := c.do(ctx, http.MethodGet, fmt.Sprintf("/api/blobs/%s/path", digest), nil, &response); err != nil {
+		var statusError StatusError
+		if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound {
+			return "", err
+		}
+
+		if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, &response); err != nil {
+			return "", err
+		}
+	}
+
+	return response.Path, nil
+}

+ 4 - 0
api/types.go

@@ -105,6 +105,10 @@ type CreateRequest struct {
 	Stream    *bool  `json:"stream,omitempty"`
 }
 
+type CreateBlobResponse struct {
+	Path string `json:"path"`
+}
+
 type DeleteRequest struct {
 	Name string `json:"name"`
 }

+ 60 - 13
cmd/cmd.go

@@ -1,9 +1,11 @@
 package cmd
 
 import (
+	"bytes"
 	"context"
 	"crypto/ed25519"
 	"crypto/rand"
+	"crypto/sha256"
 	"encoding/pem"
 	"errors"
 	"fmt"
@@ -27,6 +29,7 @@ import (
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/format"
+	"github.com/jmorganca/ollama/parser"
 	"github.com/jmorganca/ollama/progressbar"
 	"github.com/jmorganca/ollama/readline"
 	"github.com/jmorganca/ollama/server"
@@ -45,17 +48,65 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 		return err
 	}
 
-	var spinner *Spinner
+	modelfile, err := os.ReadFile(filename)
+	if err != nil {
+		return err
+	}
+
+	spinner := NewSpinner("transferring context")
+	go spinner.Spin(100 * time.Millisecond)
+
+	commands, err := parser.Parse(bytes.NewReader(modelfile))
+	if err != nil {
+		return err
+	}
+
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return err
+	}
+
+	for _, c := range commands {
+		switch c.Name {
+		case "model", "adapter":
+			path := c.Args
+			if path == "~" {
+				path = home
+			} else if strings.HasPrefix(path, "~/") {
+				path = filepath.Join(home, path[2:])
+			}
+
+			bin, err := os.Open(path)
+			if errors.Is(err, os.ErrNotExist) && c.Name == "model" {
+				// value might be a model reference and not a real file
+			} else if err != nil {
+				return err
+			}
+			defer bin.Close()
+
+			hash := sha256.New()
+			if _, err := io.Copy(hash, bin); err != nil {
+				return err
+			}
+			bin.Seek(0, io.SeekStart)
+
+			digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
+			path, err = client.CreateBlob(cmd.Context(), digest, bin)
+			if err != nil {
+				return err
+			}
+
+			modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte(path))
+		}
+	}
 
 	var currentDigest string
 	var bar *progressbar.ProgressBar
 
-	request := api.CreateRequest{Name: args[0], Path: filename}
+	request := api.CreateRequest{Name: args[0], Path: filename, Modelfile: string(modelfile)}
 	fn := func(resp api.ProgressResponse) error {
 		if resp.Digest != currentDigest && resp.Digest != "" {
-			if spinner != nil {
-				spinner.Stop()
-			}
+			spinner.Stop()
 			currentDigest = resp.Digest
 			// pulling
 			bar = progressbar.DefaultBytes(
@@ -67,9 +118,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 			bar.Set64(resp.Completed)
 		} else {
 			currentDigest = ""
-			if spinner != nil {
-				spinner.Stop()
-			}
+			spinner.Stop()
 			spinner = NewSpinner(resp.Status)
 			go spinner.Spin(100 * time.Millisecond)
 		}
@@ -81,11 +130,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 		return err
 	}
 
-	if spinner != nil {
-		spinner.Stop()
-		if spinner.description != "success" {
-			return errors.New("unexpected end to create model")
-		}
+	spinner.Stop()
+	if spinner.description != "success" {
+		return errors.New("unexpected end to create model")
 	}
 
 	return nil

+ 57 - 0
server/routes.go

@@ -2,6 +2,7 @@ package server
 
 import (
 	"context"
+	"crypto/sha256"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -649,6 +650,60 @@ func CopyModelHandler(c *gin.Context) {
 	}
 }
 
+func GetBlobHandler(c *gin.Context) {
+	path, err := GetBlobsPath(c.Param("digest"))
+	if err != nil {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	if _, err := os.Stat(path); err != nil {
+		c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
+		return
+	}
+
+	c.JSON(http.StatusOK, api.CreateBlobResponse{Path: path})
+}
+
+func CreateBlobHandler(c *gin.Context) {
+	hash := sha256.New()
+	temp, err := os.CreateTemp("", c.Param("digest"))
+	if err != nil {
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+	defer temp.Close()
+	defer os.Remove(temp.Name())
+
+	if _, err := io.Copy(temp, io.TeeReader(c.Request.Body, hash)); err != nil {
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	if fmt.Sprintf("sha256:%x", hash.Sum(nil)) != c.Param("digest") {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "digest does not match body"})
+		return
+	}
+
+	if err := temp.Close(); err != nil {
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	targetPath, err := GetBlobsPath(c.Param("digest"))
+	if err != nil {
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	if err := os.Rename(temp.Name(), targetPath); err != nil {
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	c.JSON(http.StatusOK, api.CreateBlobResponse{Path: targetPath})
+}
+
 var defaultAllowOrigins = []string{
 	"localhost",
 	"127.0.0.1",
@@ -708,6 +763,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
 	r.POST("/api/copy", CopyModelHandler)
 	r.DELETE("/api/delete", DeleteModelHandler)
 	r.POST("/api/show", ShowModelHandler)
+	r.POST("/api/blobs/:digest", CreateBlobHandler)
 
 	for _, method := range []string{http.MethodGet, http.MethodHead} {
 		r.Handle(method, "/", func(c *gin.Context) {
@@ -715,6 +771,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
 		})
 
 		r.Handle(method, "/api/tags", ListModelsHandler)
+		r.Handle(method, "/api/blobs/:digest/path", GetBlobHandler)
 	}
 
 	log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)