123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- package auth
- import (
- "bytes"
- "context"
- "crypto/ed25519"
- "crypto/rand"
- "encoding/base64"
- "encoding/pem"
- "fmt"
- "io"
- "log/slog"
- "os"
- "path/filepath"
- "golang.org/x/crypto/ssh"
- )
- const defaultPrivateKey = "id_ed25519"
- func privateKey() (ssh.Signer, error) {
- home, err := os.UserHomeDir()
- if err != nil {
- return nil, err
- }
- keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
- privateKeyFile, err := os.ReadFile(keyPath)
- if os.IsNotExist(err) {
- err := initializeKeypair()
- if err != nil {
- return nil, err
- }
- return privateKey()
- } else if err != nil {
- slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
- return nil, err
- }
- return ssh.ParsePrivateKey(privateKeyFile)
- }
- func GetPublicKey() (ssh.PublicKey, error) {
- // try to read pubkey first
- home, err := os.UserHomeDir()
- if err != nil {
- return nil, err
- }
- pubkeyPath := filepath.Join(home, ".ollama", defaultPrivateKey+".pub")
- pubKeyFile, err := os.ReadFile(pubkeyPath)
- if os.IsNotExist(err) {
- // try from privateKey
- privateKey, err := privateKey()
- if err != nil {
- return nil, fmt.Errorf("failed to read public key: %w", err)
- }
- return privateKey.PublicKey(), nil
- } else if err != nil {
- return nil, fmt.Errorf("failed to read public key: %w", err)
- }
- pubKey, _, _, _, err := ssh.ParseAuthorizedKey(pubKeyFile)
- return pubKey, err
- }
- func NewNonce(r io.Reader, length int) (string, error) {
- nonce := make([]byte, length)
- if _, err := io.ReadFull(r, nonce); err != nil {
- return "", err
- }
- return base64.RawURLEncoding.EncodeToString(nonce), nil
- }
- func Sign(ctx context.Context, bts []byte) (string, error) {
- privateKey, err := privateKey()
- if err != nil {
- return "", err
- }
- // get the pubkey, but remove the type
- publicKey, err := GetPublicKey()
- if err != nil {
- return "", err
- }
- publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
- parts := bytes.Split(publicKeyBytes, []byte(" "))
- if len(parts) < 2 {
- return "", fmt.Errorf("malformed public key")
- }
- signedData, err := privateKey.Sign(rand.Reader, bts)
- if err != nil {
- return "", err
- }
- // signature is <pubkey>:<signature>
- return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
- }
- func initializeKeypair() error {
- home, err := os.UserHomeDir()
- if err != nil {
- return err
- }
- privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
- pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
- _, err = os.Stat(privKeyPath)
- if os.IsNotExist(err) {
- fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
- cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
- if err != nil {
- return err
- }
- privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
- if err != nil {
- return err
- }
- if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
- return fmt.Errorf("could not create directory %w", err)
- }
- if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
- return err
- }
- sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
- if err != nil {
- return err
- }
- publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
- if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
- return err
- }
- fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
- }
- return nil
- }
|