|
@@ -3,20 +3,27 @@ package server
|
|
|
import (
|
|
|
"bytes"
|
|
|
"context"
|
|
|
+ "crypto/ed25519"
|
|
|
+ "crypto/rand"
|
|
|
"encoding/binary"
|
|
|
"encoding/json"
|
|
|
+ "encoding/pem"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"math"
|
|
|
"net/http"
|
|
|
"net/http/httptest"
|
|
|
+ "net/url"
|
|
|
"os"
|
|
|
+ "path/filepath"
|
|
|
"sort"
|
|
|
"strings"
|
|
|
"testing"
|
|
|
|
|
|
+ "github.com/gin-gonic/gin"
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
"github.com/stretchr/testify/require"
|
|
|
+ "golang.org/x/crypto/ssh"
|
|
|
|
|
|
"github.com/ollama/ollama/api"
|
|
|
"github.com/ollama/ollama/envconfig"
|
|
@@ -527,3 +534,108 @@ func TestNormalize(t *testing.T) {
|
|
|
})
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+func TestIsLocalReal(t *testing.T) {
|
|
|
+ gin.SetMode(gin.TestMode)
|
|
|
+ clientPubLoc := t.TempDir()
|
|
|
+ t.Setenv("HOME", clientPubLoc)
|
|
|
+
|
|
|
+ err := initializeKeypair()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ w := httptest.NewRecorder()
|
|
|
+ ctx, _ := gin.CreateTestContext(w)
|
|
|
+ ctx.Request = &http.Request{
|
|
|
+ Header: make(http.Header),
|
|
|
+ }
|
|
|
+
|
|
|
+ requestURL := url.URL{
|
|
|
+ Scheme: "http",
|
|
|
+ Host: "localhost:8080",
|
|
|
+ Path: "/api/blobs",
|
|
|
+ }
|
|
|
+ request := &http.Request{
|
|
|
+ Method: http.MethodPost,
|
|
|
+ URL: &requestURL,
|
|
|
+ }
|
|
|
+ s := &Server{}
|
|
|
+
|
|
|
+ authz, err := api.Authorization(ctx, request)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Set client authorization header
|
|
|
+ ctx.Request.Header.Set("Authorization", authz)
|
|
|
+ if !s.isLocal(ctx) {
|
|
|
+ t.Fatal("Expected isLocal to return true")
|
|
|
+ }
|
|
|
+
|
|
|
+ t.Run("different server pubkey", func(t *testing.T) {
|
|
|
+ serverPubLoc := t.TempDir()
|
|
|
+ t.Setenv("HOME", serverPubLoc)
|
|
|
+ err := initializeKeypair()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if s.isLocal(ctx) {
|
|
|
+ t.Fatal("Expected isLocal to return false")
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ t.Run("invalid pubkey", func(t *testing.T) {
|
|
|
+ ctx.Request.Header.Set("Authorization", "sha-25616:invalid")
|
|
|
+ if s.isLocal(ctx) {
|
|
|
+ t.Fatal("Expected isLocal to return false")
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func initializeKeypair() error {
|
|
|
+ home, err := os.UserHomeDir()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
|
|
|
+ pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
|
|
|
+
|
|
|
+ _, err = os.Stat(privKeyPath)
|
|
|
+ if os.IsNotExist(err) {
|
|
|
+ fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
|
|
|
+ cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
|
|
|
+ return fmt.Errorf("could not create directory %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
|
|
|
+
|
|
|
+ if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|