소스 검색

isLocal firstdraft

Josh Yan 10 달 전
부모
커밋
154b59c0b6
4개의 변경된 파일119개의 추가작업 그리고 1개의 파일을 삭제
  1. 31 0
      api/client.go
  2. 9 0
      auth/auth.go
  3. 6 0
      cmd/cmd.go
  4. 73 1
      server/routes.go

+ 31 - 0
api/client.go

@@ -17,14 +17,20 @@ import (
 	"bufio"
 	"bytes"
 	"context"
+	"encoding/base64"
 	"encoding/json"
 	"fmt"
 	"io"
 	"net"
 	"net/http"
 	"net/url"
+	"os"
+	"path/filepath"
 	"runtime"
+	"strings"
+	"time"
 
+	"github.com/ollama/ollama/auth"
 	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/version"
@@ -403,3 +409,28 @@ func (c *Client) IsLocal() bool {
 
 	return false
 }
+
+func Authorization(ctx context.Context, request *http.Request) (string, error) {
+
+	data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
+
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return "", err
+	}
+
+	knownHostsFile, err := os.OpenFile(filepath.Join(home, ".ollama", "known_hosts"), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0600)
+	if err != nil {
+		return "", err
+	}
+	defer knownHostsFile.Close()
+
+	token, err := auth.Sign(ctx, data)
+	if err != nil {
+		return "", err
+	}
+
+	// interleave request data into the token
+	key, sig, _ := strings.Cut(token, ":")
+	return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
+}

+ 9 - 0
auth/auth.go

@@ -24,6 +24,7 @@ func privateKey() (ssh.Signer, error) {
 		return nil, err
 	}
 
+	keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
 	keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
 	privateKeyFile, err := os.ReadFile(keyPath)
 	if os.IsNotExist(err) {
@@ -36,11 +37,19 @@ func privateKey() (ssh.Signer, error) {
 	} else if err != nil {
 		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
 		return nil, err
+		return nil, err
 	}
 
 	return ssh.ParsePrivateKey(privateKeyFile)
 }
 
+func GetPublicKey() (ssh.PublicKey, error) {
+	privateKey, err := keyPath()
+	// if privateKey, try public key directly
+
+	return ssh.ParsePrivateKey(privateKeyFile)
+}
+
 func GetPublicKey() (ssh.PublicKey, error) {
 	// try to read pubkey first
 	home, err := os.UserHomeDir()

+ 6 - 0
cmd/cmd.go

@@ -351,6 +351,12 @@ func getLocalPath(ctx context.Context, digest string) (string, error) {
 		return "", err
 	}
 
+	authz, err := api.Authorization(ctx, request)
+	if err != nil {
+		return "", err
+	}
+
+	request.Header.Set("Authorization", authz)
 	request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
 	request.Header.Set("X-Redirect-Create", "1")
 

+ 73 - 1
server/routes.go

@@ -4,10 +4,12 @@ import (
 	"bytes"
 	"cmp"
 	"context"
+	"encoding/base64"
 	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
+	"log"
 	"log/slog"
 	"math"
 	"net"
@@ -23,8 +25,10 @@ import (
 
 	"github.com/gin-contrib/cors"
 	"github.com/gin-gonic/gin"
+	"golang.org/x/crypto/ssh"
 
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/auth"
 	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/gpu"
 	"github.com/ollama/ollama/llm"
@@ -941,7 +945,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
 		return
 	}
 
-	if c.GetHeader("X-Redirect-Create") == "1" {
+	if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
 		c.Header("LocalLocation", path)
 		c.Status(http.StatusTemporaryRedirect)
 		return
@@ -961,6 +965,74 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
 	c.Status(http.StatusCreated)
 }
 
+func (s *Server) IsLocal(c *gin.Context) bool {
+	if authz := c.GetHeader("Authorization"); authz != "" {
+		parts := strings.Split(authz, ":")
+		if len(parts) != 3 {
+			return false
+		}
+
+		clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
+		if err != nil {
+			return false
+		}
+
+		// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
+		partialRequestData, err := base64.StdEncoding.DecodeString(parts[1])
+		if err != nil {
+			return false
+		}
+
+		partialRequestDataParts := strings.Split(string(partialRequestData), ",")
+		if len(partialRequestDataParts) != 4 {
+			return false
+		}
+
+		/* timestamp, err := strconv.ParseInt(partialRequestDataParts[2], 10, 0)
+		if err != nil {
+			return false
+		}
+
+		t := time.Unix(timestamp, 0)
+		if time.Since(t) > 5*time.Minute || time.Until(t) > 5*time.Minute {
+			// token is invalid if timestamp +/- 5 minutes from current time
+			return false
+		} */
+
+		/* nonce := partialRequestDataParts[3]
+		if nonceCache.has(nonce) {
+			return false
+		}
+		nonceCache.add(nonce, 5*time.Minute) */
+
+		signature, err := base64.StdEncoding.DecodeString(parts[2])
+		if err != nil {
+			return false
+		}
+
+		serverPublicKey, err := auth.GetPublicKey()
+		if err != nil {
+			log.Fatal(err)
+		}
+
+		_, key, _ := bytes.Cut(bytes.TrimSpace(ssh.MarshalAuthorizedKey(serverPublicKey)), []byte(" "))
+		requestData := fmt.Sprintf("%s,%s", key, partialRequestData)
+
+		if err := clientPublicKey.Verify([]byte(requestData), &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
+			return false
+		}
+
+		if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
+			return true
+		}
+
+		c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
+		return false
+	}
+
+	return false
+}
+
 func isLocalIP(ip netip.Addr) bool {
 	if interfaces, err := net.Interfaces(); err == nil {
 		for _, iface := range interfaces {