|
@@ -3,49 +3,75 @@ package auth
|
|
|
import (
|
|
|
"bytes"
|
|
|
"context"
|
|
|
+ "crypto/ed25519"
|
|
|
"crypto/rand"
|
|
|
"encoding/base64"
|
|
|
+ "encoding/pem"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"log/slog"
|
|
|
"os"
|
|
|
"path/filepath"
|
|
|
- "strings"
|
|
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
|
)
|
|
|
|
|
|
const defaultPrivateKey = "id_ed25519"
|
|
|
|
|
|
-func keyPath() (string, error) {
|
|
|
+func privateKey() (ssh.Signer, error) {
|
|
|
home, err := os.UserHomeDir()
|
|
|
if err != nil {
|
|
|
- return "", err
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
|
|
+ privateKeyFile, err := os.ReadFile(keyPath)
|
|
|
+ if err != nil {
|
|
|
+ slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
- return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
|
|
+ return ssh.ParsePrivateKey(privateKeyFile)
|
|
|
}
|
|
|
|
|
|
-func GetPublicKey() (string, error) {
|
|
|
- keyPath, err := keyPath()
|
|
|
- if err != nil {
|
|
|
- return "", err
|
|
|
+func GetPublicKey() (ssh.PublicKey, error) {
|
|
|
+ // try to read pubkey first
|
|
|
+ pubkey, err := readPubkey()
|
|
|
+ if err == nil {
|
|
|
+ return pubkey, nil
|
|
|
}
|
|
|
|
|
|
- privateKeyFile, err := os.ReadFile(keyPath)
|
|
|
+ privateKey, err := privateKey()
|
|
|
+ if err == nil {
|
|
|
+ return privateKey.PublicKey(), nil
|
|
|
+ }
|
|
|
+
|
|
|
+ err = initializeKeypair()
|
|
|
if err != nil {
|
|
|
- slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
|
|
- return "", err
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
- privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
|
|
+ return readPubkey()
|
|
|
+}
|
|
|
+
|
|
|
+func readPubkey() (ssh.PublicKey, error) {
|
|
|
+ home, err := os.UserHomeDir()
|
|
|
if err != nil {
|
|
|
- return "", err
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ pubkeyPath := filepath.Join(home, ".ollama", defaultPrivateKey+".pub")
|
|
|
+ _, err = os.Stat(pubkeyPath)
|
|
|
+ if os.IsNotExist(err) {
|
|
|
+ return nil, fmt.Errorf("public key not found")
|
|
|
}
|
|
|
|
|
|
- publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
|
|
+ pubKeyFile, err := os.ReadFile(pubkeyPath)
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("failed to read public key: %w", err)
|
|
|
+ }
|
|
|
|
|
|
- return strings.TrimSpace(string(publicKey)), nil
|
|
|
+ return ssh.ParsePublicKey(pubKeyFile)
|
|
|
}
|
|
|
|
|
|
func NewNonce(r io.Reader, length int) (string, error) {
|
|
@@ -58,25 +84,20 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
|
|
}
|
|
|
|
|
|
func Sign(ctx context.Context, bts []byte) (string, error) {
|
|
|
- keyPath, err := keyPath()
|
|
|
+ privateKey, err := privateKey()
|
|
|
if err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
|
|
|
- privateKeyFile, err := os.ReadFile(keyPath)
|
|
|
+ // get the pubkey, but remove the type
|
|
|
+ publicKey, err := GetPublicKey()
|
|
|
if err != nil {
|
|
|
- slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
|
|
return "", err
|
|
|
}
|
|
|
|
|
|
- privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
|
|
- if err != nil {
|
|
|
- return "", err
|
|
|
- }
|
|
|
+ publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
|
|
|
|
|
|
- // get the pubkey, but remove the type
|
|
|
- publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
|
|
- parts := bytes.Split(publicKey, []byte(" "))
|
|
|
+ parts := bytes.Split(publicKeyBytes, []byte(" "))
|
|
|
if len(parts) < 2 {
|
|
|
return "", fmt.Errorf("malformed public key")
|
|
|
}
|
|
@@ -89,3 +110,49 @@ func Sign(ctx context.Context, bts []byte) (string, error) {
|
|
|
// signature is <pubkey>:<signature>
|
|
|
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
|
|
|
}
|
|
|
+
|
|
|
+func initializeKeypair() error {
|
|
|
+ home, err := os.UserHomeDir()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
|
|
|
+ pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
|
|
|
+
|
|
|
+ _, err = os.Stat(privKeyPath)
|
|
|
+ if os.IsNotExist(err) {
|
|
|
+ fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
|
|
|
+ cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
|
|
|
+ return fmt.Errorf("could not create directory %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
|
|
|
+
|
|
|
+ if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|