Josh Yan 9 months ago
parent
commit
76b4dfcc9e
2 changed files with 95 additions and 76 deletions
  1. 92 25
      auth/auth.go
  2. 3 51
      cmd/cmd.go

+ 92 - 25
auth/auth.go

@@ -3,49 +3,75 @@ package auth
 import (
 	"bytes"
 	"context"
+	"crypto/ed25519"
 	"crypto/rand"
 	"encoding/base64"
+	"encoding/pem"
 	"fmt"
 	"io"
 	"log/slog"
 	"os"
 	"path/filepath"
-	"strings"
 
 	"golang.org/x/crypto/ssh"
 )
 
 const defaultPrivateKey = "id_ed25519"
 
-func keyPath() (string, error) {
+func privateKey() (ssh.Signer, error) {
 	home, err := os.UserHomeDir()
 	if err != nil {
-		return "", err
+		return nil, err
+	}
+
+	keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
+	privateKeyFile, err := os.ReadFile(keyPath)
+	if err != nil {
+		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
+		return nil, err
 	}
 
-	return filepath.Join(home, ".ollama", defaultPrivateKey), nil
+	return ssh.ParsePrivateKey(privateKeyFile)
 }
 
-func GetPublicKey() (string, error) {
-	keyPath, err := keyPath()
-	if err != nil {
-		return "", err
+func GetPublicKey() (ssh.PublicKey, error) {
+	// try to read pubkey first
+	pubkey, err := readPubkey()
+	if err == nil {
+		return pubkey, nil
 	}
 
-	privateKeyFile, err := os.ReadFile(keyPath)
+	privateKey, err := privateKey()
+	if err == nil {
+		return privateKey.PublicKey(), nil
+	}
+
+	err = initializeKeypair()
 	if err != nil {
-		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
-		return "", err
+		return nil, err
 	}
 
-	privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
+	return readPubkey()
+}
+
+func readPubkey() (ssh.PublicKey, error) {
+	home, err := os.UserHomeDir()
 	if err != nil {
-		return "", err
+		return nil, err
+	}
+
+	pubkeyPath := filepath.Join(home, ".ollama", defaultPrivateKey+".pub")
+	_, err = os.Stat(pubkeyPath)
+	if os.IsNotExist(err) {
+		return nil, fmt.Errorf("public key not found")
 	}
 
-	publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
+	pubKeyFile, err := os.ReadFile(pubkeyPath)
+	if err != nil {
+		return nil, fmt.Errorf("failed to read public key: %w", err)
+	}
 
-	return strings.TrimSpace(string(publicKey)), nil
+	return ssh.ParsePublicKey(pubKeyFile)
 }
 
 func NewNonce(r io.Reader, length int) (string, error) {
@@ -58,25 +84,20 @@ func NewNonce(r io.Reader, length int) (string, error) {
 }
 
 func Sign(ctx context.Context, bts []byte) (string, error) {
-	keyPath, err := keyPath()
+	privateKey, err := privateKey()
 	if err != nil {
 		return "", err
 	}
 
-	privateKeyFile, err := os.ReadFile(keyPath)
+	// get the pubkey, but remove the type
+	publicKey, err := GetPublicKey()
 	if err != nil {
-		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
 		return "", err
 	}
 
-	privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
-	if err != nil {
-		return "", err
-	}
+	publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
 
-	// get the pubkey, but remove the type
-	publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
-	parts := bytes.Split(publicKey, []byte(" "))
+	parts := bytes.Split(publicKeyBytes, []byte(" "))
 	if len(parts) < 2 {
 		return "", fmt.Errorf("malformed public key")
 	}
@@ -89,3 +110,49 @@ func Sign(ctx context.Context, bts []byte) (string, error) {
 	// signature is <pubkey>:<signature>
 	return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
 }
+
+func initializeKeypair() error {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return err
+	}
+
+	privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
+	pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
+
+	_, err = os.Stat(privKeyPath)
+	if os.IsNotExist(err) {
+		fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
+		cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
+		if err != nil {
+			return err
+		}
+
+		privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
+		if err != nil {
+			return err
+		}
+
+		if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
+			return fmt.Errorf("could not create directory %w", err)
+		}
+
+		if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
+			return err
+		}
+
+		sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
+		if err != nil {
+			return err
+		}
+
+		publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
+
+		if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
+			return err
+		}
+
+		fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
+	}
+	return nil
+}

+ 3 - 51
cmd/cmd.go

@@ -4,10 +4,7 @@ import (
 	"archive/zip"
 	"bytes"
 	"context"
-	"crypto/ed25519"
-	"crypto/rand"
 	"crypto/sha256"
-	"encoding/pem"
 	"errors"
 	"fmt"
 	"io"
@@ -379,11 +376,12 @@ func errFromUnknownKey(unknownKeyErr error) error {
 	if len(matches) > 0 {
 		serverPubKey := matches[0]
 
-		localPubKey, err := auth.GetPublicKey()
+		publicKey, err := auth.GetPublicKey()
 		if err != nil {
 			return unknownKeyErr
 		}
 
+		localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey)))
 		if runtime.GOOS == "linux" && serverPubKey != localPubKey {
 			// try the ollama service public key
 			svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
@@ -1072,7 +1070,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
 }
 
 func RunServer(cmd *cobra.Command, _ []string) error {
-	if err := initializeKeypair(); err != nil {
+	if _, err := auth.GetPublicKey(); err != nil {
 		return err
 	}
 
@@ -1089,52 +1087,6 @@ func RunServer(cmd *cobra.Command, _ []string) error {
 	return err
 }
 
-func initializeKeypair() error {
-	home, err := os.UserHomeDir()
-	if err != nil {
-		return err
-	}
-
-	privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
-	pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
-
-	_, err = os.Stat(privKeyPath)
-	if os.IsNotExist(err) {
-		fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
-		cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
-		if err != nil {
-			return err
-		}
-
-		privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
-		if err != nil {
-			return err
-		}
-
-		if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
-			return fmt.Errorf("could not create directory %w", err)
-		}
-
-		if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
-			return err
-		}
-
-		sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
-		if err != nil {
-			return err
-		}
-
-		publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
-
-		if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
-			return err
-		}
-
-		fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
-	}
-	return nil
-}
-
 func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
 	client, err := api.ClientFromEnvironment()
 	if err != nil {