Josh Yan 9 months ago
parent
commit
ad36d4ff1b
3 changed files with 105 additions and 17 deletions
  1. 5 1
      api/client.go
  2. 83 16
      cmd/cmd.go
  3. 17 0
      server/routes.go

+ 5 - 1
api/client.go

@@ -367,7 +367,11 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
 
 
 // CreateBlob creates a blob from a file on the server. digest is the
 // CreateBlob creates a blob from a file on the server. digest is the
 // expected SHA256 digest of the file, and r represents the file.
 // expected SHA256 digest of the file, and r represents the file.
-func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
+func (c *Client) CreateBlob(ctx context.Context, digest string, local bool, r io.Reader) error {
+	headers := make(http.Header)
+	if local {
+		headers.Set("X-Redirect-Create", "1")
+	}
 	return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
 	return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
 }
 }
 
 

+ 83 - 16
cmd/cmd.go

@@ -5,6 +5,7 @@ import (
 	"bytes"
 	"bytes"
 	"context"
 	"context"
 	"crypto/sha256"
 	"crypto/sha256"
+	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
@@ -12,6 +13,7 @@ import (
 	"math"
 	"math"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
+	"net/url"
 	"os"
 	"os"
 	"os/signal"
 	"os/signal"
 	"path/filepath"
 	"path/filepath"
@@ -286,8 +288,9 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
 
 
 	// Resolve server to IP
 	// Resolve server to IP
 	// Check if server is local
 	// Check if server is local
-	if client.IsLocal() {
-		config, err := client.ServerConfig(cmd.Context())
+	/* if client.IsLocal() {
+		digest = strings.ReplaceAll(digest, ":", "-")
+		config, err := client.HeadBlob(cmd.Context(), digest)
 		if err != nil {
 		if err != nil {
 			return "", err
 			return "", err
 		}
 		}
@@ -295,42 +298,106 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
 		modelDir := config.ModelDir
 		modelDir := config.ModelDir
 
 
 		// Get blob destination
 		// Get blob destination
-		digest = strings.ReplaceAll(digest, ":", "-")
+
 		dest := filepath.Join(modelDir, "blobs", digest)
 		dest := filepath.Join(modelDir, "blobs", digest)
 
 
 		err = createBlobLocal(path, dest)
 		err = createBlobLocal(path, dest)
 		if err == nil {
 		if err == nil {
 			return digest, nil
 			return digest, nil
 		}
 		}
+	} */
+	if client.IsLocal() {
+		config, err := getLocalPath(cmd.Context(), digest)
+		if err != nil {
+			return "", err
+		}
+
+		if config == nil {
+			fmt.Println("config is nil")
+			return digest, nil
+		}
+
+		fmt.Println("HI")
+		dest := config.ModelDir
+		fmt.Println("dest is ", dest)
+		err = createBlobLocal(path, dest)
+		if err == nil {
+			fmt.Println("createlocalblob succeed")
+			return digest, nil
+		}
+		fmt.Println("err is ", err)
+		fmt.Println("createlocalblob faileds")
 	}
 	}
 
 
-	if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
+	fmt.Println("DEFAULT")
+	if err = client.CreateBlob(cmd.Context(), digest, false, bin); err != nil {
 		return "", err
 		return "", err
 	}
 	}
 	return digest, nil
 	return digest, nil
 }
 }
 
 
+func getLocalPath(ctx context.Context, digest string) (*api.ServerConfig, error) {
+	ollamaHost := envconfig.Host
+
+	client := http.DefaultClient
+	base := &url.URL{
+		Scheme: ollamaHost.Scheme,
+		Host:   net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
+	}
+
+	var reqBody io.Reader
+	var respData api.ServerConfig
+	data, err := json.Marshal(digest)
+	if err != nil {
+		return nil, err
+	}
+
+	reqBody = bytes.NewReader(data)
+	path := fmt.Sprintf("/api/blobs/%s", digest)
+	requestURL := base.JoinPath(path)
+	request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody)
+	if err != nil {
+		return nil, err
+	}
+
+	request.Header.Set("Content-Type", "application/json")
+	request.Header.Set("Accept", "application/json")
+	request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
+	request.Header.Set("X-Redirect-Create", "1")
+
+	fmt.Println("request", request)
+	resp, err := client.Do(request)
+	if err != nil {
+		return nil, err
+	}
+	defer resp.Body.Close()
+	fmt.Println("made it here")
+	fmt.Println("resp", resp)
+
+	if resp.StatusCode == http.StatusTemporaryRedirect {
+		fmt.Println("redirect")
+		if err := json.Unmarshal([]byte(resp.Header.Get("loc")), &respData); err != nil {
+			fmt.Println("error unmarshalling response data")
+			return nil, err
+		}
+	}
+
+	fmt.Println("!!!!!!!!!!")
+	fmt.Println(respData)
+	return &respData, nil
+}
+
 func createBlobLocal(path string, dest string) error {
 func createBlobLocal(path string, dest string) error {
 	// This function should be called if the server is local
 	// This function should be called if the server is local
 	// It should find the model directory, copy the blob over, and return the digest
 	// It should find the model directory, copy the blob over, and return the digest
 	dirPath := filepath.Dir(dest)
 	dirPath := filepath.Dir(dest)
+	fmt.Println("dirpath is ", dirPath)
 
 
 	if err := os.MkdirAll(dirPath, 0o755); err != nil {
 	if err := os.MkdirAll(dirPath, 0o755); err != nil {
+		fmt.Println("failed to create directory")
 		return err
 		return err
 	}
 	}
 
 
-	// Check blob exists
-	_, err := os.Stat(dest)
-	switch {
-	case errors.Is(err, os.ErrNotExist):
-		// noop
-	case err != nil:
-		return err
-	default:
-		// blob already exists
-		return nil
-	}
-
 	// Copy blob over
 	// Copy blob over
 	sourceFile, err := os.Open(path)
 	sourceFile, err := os.Open(path)
 	if err != nil {
 	if err != nil {

+ 17 - 0
server/routes.go

@@ -940,6 +940,23 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
 		c.Status(http.StatusOK)
 		c.Status(http.StatusOK)
 		return
 		return
 	}
 	}
+	fmt.Println("HEIAHOEIHFOAHAEFHAO")
+	fmt.Println(c.GetHeader("X-Redirect-Create"))
+	if c.GetHeader("X-Redirect-Create") == "1" {
+		response := api.ServerConfig{ModelDir: path}
+		fmt.Println("Hit redirect")
+		resp, err := json.Marshal(response)
+		fmt.Println("marshalled response")
+		if err != nil {
+			c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+
+		c.Header("loc", string(resp))
+		fmt.Println("!!!!!!!!!", string(resp))
+		c.Status(http.StatusTemporaryRedirect)
+		return
+	}
 
 
 	layer, err := NewLayer(c.Request.Body, "")
 	layer, err := NewLayer(c.Request.Body, "")
 	if err != nil {
 	if err != nil {