浏览代码

isLocal firstdraft

Josh Yan 10 月之前
父节点
当前提交
10ea0987e9
共有 5 个文件被更改,包括 134 次插入31 次删除
  1. 31 0
      api/client.go
  2. 18 28
      auth/auth.go
  3. 9 1
      cmd/cmd.go
  4. 3 1
      server/images.go
  5. 73 1
      server/routes.go

+ 31 - 0
api/client.go

@@ -17,14 +17,20 @@ import (
 	"bufio"
 	"bytes"
 	"context"
+	"encoding/base64"
 	"encoding/json"
 	"fmt"
 	"io"
 	"net"
 	"net/http"
 	"net/url"
+	"os"
+	"path/filepath"
 	"runtime"
+	"strings"
+	"time"
 
+	"github.com/ollama/ollama/auth"
 	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/version"
@@ -394,3 +400,28 @@ func (c *Client) IsLocal() bool {
 
 	return false
 }
+
+func Authorization(ctx context.Context, request *http.Request) (string, error) {
+
+	data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
+
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return "", err
+	}
+
+	knownHostsFile, err := os.OpenFile(filepath.Join(home, ".ollama", "known_hosts"), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0600)
+	if err != nil {
+		return "", err
+	}
+	defer knownHostsFile.Close()
+
+	token, err := auth.Sign(ctx, data)
+	if err != nil {
+		return "", err
+	}
+
+	// interleave request data into the token
+	key, sig, _ := strings.Cut(token, ":")
+	return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
+}

+ 18 - 28
auth/auth.go

@@ -10,42 +10,37 @@ import (
 	"log/slog"
 	"os"
 	"path/filepath"
-	"strings"
 
 	"golang.org/x/crypto/ssh"
 )
 
 const defaultPrivateKey = "id_ed25519"
 
-func keyPath() (string, error) {
+func keyPath() (ssh.Signer, error) {
 	home, err := os.UserHomeDir()
 	if err != nil {
-		return "", err
-	}
-
-	return filepath.Join(home, ".ollama", defaultPrivateKey), nil
-}
-
-func GetPublicKey() (string, error) {
-	keyPath, err := keyPath()
-	if err != nil {
-		return "", err
+		return nil, err
 	}
 
+	keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
 	privateKeyFile, err := os.ReadFile(keyPath)
 	if err != nil {
 		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
-		return "", err
+		return nil, err
 	}
 
-	privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
+	return ssh.ParsePrivateKey(privateKeyFile)
+}
+
+func GetPublicKey() (ssh.PublicKey, error) {
+	privateKey, err := keyPath()
+	// if privateKey, try public key directly
+
 	if err != nil {
-		return "", err
+		return nil, err
 	}
 
-	publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
-
-	return strings.TrimSpace(string(publicKey)), nil
+	return privateKey.PublicKey(), nil
 }
 
 func NewNonce(r io.Reader, length int) (string, error) {
@@ -58,25 +53,20 @@ func NewNonce(r io.Reader, length int) (string, error) {
 }
 
 func Sign(ctx context.Context, bts []byte) (string, error) {
-	keyPath, err := keyPath()
+	privateKey, err := keyPath()
 	if err != nil {
 		return "", err
 	}
 
-	privateKeyFile, err := os.ReadFile(keyPath)
+	// get the pubkey, but remove the type
+	publicKey, err := GetPublicKey()
 	if err != nil {
-		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
 		return "", err
 	}
 
-	privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
-	if err != nil {
-		return "", err
-	}
+	publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
 
-	// get the pubkey, but remove the type
-	publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
-	parts := bytes.Split(publicKey, []byte(" "))
+	parts := bytes.Split(publicKeyBytes, []byte(" "))
 	if len(parts) < 2 {
 		return "", fmt.Errorf("malformed public key")
 	}

+ 9 - 1
cmd/cmd.go

@@ -354,6 +354,12 @@ func getLocalPath(ctx context.Context, digest string) (string, error) {
 		return "", err
 	}
 
+	authz, err := api.Authorization(ctx, request)
+	if err != nil {
+		return "", err
+	}
+
+	request.Header.Set("Authorization", authz)
 	request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
 	request.Header.Set("X-Redirect-Create", "1")
 
@@ -499,11 +505,13 @@ func errFromUnknownKey(unknownKeyErr error) error {
 	if len(matches) > 0 {
 		serverPubKey := matches[0]
 
-		localPubKey, err := auth.GetPublicKey()
+		publicKey, err := auth.GetPublicKey()
 		if err != nil {
 			return unknownKeyErr
 		}
 
+		localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey)))
+
 		if runtime.GOOS == "linux" && serverPubKey != localPubKey {
 			// try the ollama service public key
 			svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")

+ 3 - 1
server/images.go

@@ -32,6 +32,7 @@ import (
 	"github.com/ollama/ollama/types/errtypes"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
+	"golang.org/x/crypto/ssh"
 )
 
 var errCapabilityCompletion = errors.New("completion")
@@ -1064,11 +1065,12 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
 	if anonymous {
 		// no user is associated with the public key, and the request requires non-anonymous access
 		pubKey, nestedErr := auth.GetPublicKey()
+		localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey)))
 		if nestedErr != nil {
 			slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
 			return nil, errUnauthorized
 		}
-		return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
+		return nil, &errtypes.UnknownOllamaKey{Key: localPubKey}
 	}
 	// user is associated with the public key, but is not authorized to make the request
 	return nil, errUnauthorized

+ 73 - 1
server/routes.go

@@ -4,10 +4,12 @@ import (
 	"bytes"
 	"cmp"
 	"context"
+	"encoding/base64"
 	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
+	"log"
 	"log/slog"
 	"net"
 	"net/http"
@@ -22,8 +24,10 @@ import (
 
 	"github.com/gin-contrib/cors"
 	"github.com/gin-gonic/gin"
+	"golang.org/x/crypto/ssh"
 
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/auth"
 	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/gpu"
 	"github.com/ollama/ollama/llm"
@@ -783,7 +787,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
 		return
 	}
 
-	if c.GetHeader("X-Redirect-Create") == "1" {
+	if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
 		c.Header("LocalLocation", path)
 		c.Status(http.StatusTemporaryRedirect)
 		return
@@ -803,6 +807,74 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
 	c.Status(http.StatusCreated)
 }
 
+func (s *Server) IsLocal(c *gin.Context) bool {
+	if authz := c.GetHeader("Authorization"); authz != "" {
+		parts := strings.Split(authz, ":")
+		if len(parts) != 3 {
+			return false
+		}
+
+		clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
+		if err != nil {
+			return false
+		}
+
+		// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
+		partialRequestData, err := base64.StdEncoding.DecodeString(parts[1])
+		if err != nil {
+			return false
+		}
+
+		partialRequestDataParts := strings.Split(string(partialRequestData), ",")
+		if len(partialRequestDataParts) != 4 {
+			return false
+		}
+
+		/* timestamp, err := strconv.ParseInt(partialRequestDataParts[2], 10, 0)
+		if err != nil {
+			return false
+		}
+
+		t := time.Unix(timestamp, 0)
+		if time.Since(t) > 5*time.Minute || time.Until(t) > 5*time.Minute {
+			// token is invalid if timestamp +/- 5 minutes from current time
+			return false
+		} */
+
+		/* nonce := partialRequestDataParts[3]
+		if nonceCache.has(nonce) {
+			return false
+		}
+		nonceCache.add(nonce, 5*time.Minute) */
+
+		signature, err := base64.StdEncoding.DecodeString(parts[2])
+		if err != nil {
+			return false
+		}
+
+		serverPublicKey, err := auth.GetPublicKey()
+		if err != nil {
+			log.Fatal(err)
+		}
+
+		_, key, _ := bytes.Cut(bytes.TrimSpace(ssh.MarshalAuthorizedKey(serverPublicKey)), []byte(" "))
+		requestData := fmt.Sprintf("%s,%s", key, partialRequestData)
+
+		if err := clientPublicKey.Verify([]byte(requestData), &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
+			return false
+		}
+
+		if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
+			return true
+		}
+
+		c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
+		return false
+	}
+
+	return false
+}
+
 func isLocalIP(ip netip.Addr) bool {
 	if interfaces, err := net.Interfaces(); err == nil {
 		for _, iface := range interfaces {