Browse Source

isLocal testing

Josh Yan 9 months ago
parent
commit
c507325288
4 changed files with 115 additions and 14 deletions
  1. 0 1
      auth/auth.go
  2. 0 1
      cmd/cmd.go
  3. 3 12
      server/routes.go
  4. 112 0
      server/routes_test.go

+ 0 - 1
auth/auth.go

@@ -24,7 +24,6 @@ func privateKey() (ssh.Signer, error) {
 		return nil, err
 	}
 
-	keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
 	keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
 	privateKeyFile, err := os.ReadFile(keyPath)
 	if os.IsNotExist(err) {

+ 0 - 1
cmd/cmd.go

@@ -343,7 +343,6 @@ func getLocalPath(ctx context.Context, digest string) (string, error) {
 	}
 
 	request.Header.Set("Authorization", authz)
-	request.Header.Set("Timestamp", time.Now().Format(time.RFC3339))
 	request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
 	request.Header.Set("X-Redirect-Create", "1")
 

+ 3 - 12
server/routes.go

@@ -942,7 +942,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
 		c.Status(http.StatusOK)
 		return
 	}
-	if c.GetHeader("X-Redirect-Create") == "1" && s.IsServerKeyPublicKey(c) {
+	if c.GetHeader("X-Redirect-Create") == "1" && s.isLocal(c) {
 		c.Header("LocalLocation", path)
 		c.Status(http.StatusTemporaryRedirect)
 		return
@@ -962,7 +962,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
 	c.Status(http.StatusCreated)
 }
 
-func (s *Server) IsServerKeyPublicKey(c *gin.Context) bool {
+func (s *Server) isLocal(c *gin.Context) bool {
 	if authz := c.GetHeader("Authorization"); authz != "" {
 		parts := strings.Split(authz, ":")
 		if len(parts) != 3 {
@@ -999,16 +999,7 @@ func (s *Server) IsServerKeyPublicKey(c *gin.Context) bool {
 			slog.Error(fmt.Sprintf("failed to get server public key: %v", err))
 			return false
 		}
-
-		timestamp, err := time.Parse(time.RFC3339, c.GetHeader("Timestamp"))
-		if err != nil {
-			return false
-		}
-
-		if time.Since(timestamp) > time.Minute {
-			return false
-		}
-
+		
 		if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
 			return true
 		}

+ 112 - 0
server/routes_test.go

@@ -3,20 +3,27 @@ package server
 import (
 	"bytes"
 	"context"
+	"crypto/ed25519"
+	"crypto/rand"
 	"encoding/binary"
 	"encoding/json"
+	"encoding/pem"
 	"fmt"
 	"io"
 	"math"
 	"net/http"
 	"net/http/httptest"
+	"net/url"
 	"os"
+	"path/filepath"
 	"sort"
 	"strings"
 	"testing"
 
+	"github.com/gin-gonic/gin"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
+	"golang.org/x/crypto/ssh"
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/envconfig"
@@ -527,3 +534,108 @@ func TestNormalize(t *testing.T) {
 		})
 	}
 }
+
+func TestIsLocalReal(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+	clientPubLoc := t.TempDir()
+	t.Setenv("HOME", clientPubLoc)
+
+	err := initializeKeypair()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	w := httptest.NewRecorder()
+    ctx, _ := gin.CreateTestContext(w)
+	ctx.Request = &http.Request{
+		Header: make(http.Header),
+	}
+
+	requestURL := url.URL{
+		Scheme: "http",
+		Host:   "localhost:8080",
+		Path:   "/api/blobs",
+	}
+	request := &http.Request{
+		Method: http.MethodPost,
+		URL:    &requestURL,
+	}
+	s := &Server{}
+	
+	authz, err := api.Authorization(ctx, request)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// Set client authorization header
+	ctx.Request.Header.Set("Authorization", authz)
+	if !s.isLocal(ctx) {
+		t.Fatal("Expected isLocal to return true")
+	}
+
+	t.Run("different server pubkey", func(t *testing.T) {
+		serverPubLoc := t.TempDir()
+		t.Setenv("HOME", serverPubLoc)
+		err := initializeKeypair()
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		if s.isLocal(ctx) {
+			t.Fatal("Expected isLocal to return false")
+		}
+	})
+
+	t.Run("invalid pubkey", func(t *testing.T) {
+		ctx.Request.Header.Set("Authorization", "sha-25616:invalid")
+		if s.isLocal(ctx) {
+			t.Fatal("Expected isLocal to return false")
+		}
+	})
+}
+
+func initializeKeypair() error {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return err
+	}
+
+	privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
+	pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
+
+	_, err = os.Stat(privKeyPath)
+	if os.IsNotExist(err) {
+		fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
+		cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
+		if err != nil {
+			return err
+		}
+
+		privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
+		if err != nil {
+			return err
+		}
+
+		if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
+			return fmt.Errorf("could not create directory %w", err)
+		}
+
+		if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
+			return err
+		}
+
+		sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
+		if err != nil {
+			return err
+		}
+
+		publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
+
+		if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
+			return err
+		}
+
+		fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
+	}
+	return nil
+}