auth.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. package auth
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/ed25519"
  6. "crypto/rand"
  7. "encoding/base64"
  8. "encoding/pem"
  9. "fmt"
  10. "io"
  11. "log/slog"
  12. "os"
  13. "path/filepath"
  14. "golang.org/x/crypto/ssh"
  15. )
  16. const defaultPrivateKey = "id_ed25519"
  17. func privateKey() (ssh.Signer, error) {
  18. home, err := os.UserHomeDir()
  19. if err != nil {
  20. return nil, err
  21. }
  22. keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
  23. privateKeyFile, err := os.ReadFile(keyPath)
  24. if os.IsNotExist(err) {
  25. err := initializeKeypair()
  26. if err != nil {
  27. return nil, err
  28. }
  29. return privateKey()
  30. } else if err != nil {
  31. slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
  32. return nil, err
  33. return nil, err
  34. }
  35. return ssh.ParsePrivateKey(privateKeyFile)
  36. }
  37. func GetPublicKey() (ssh.PublicKey, error) {
  38. privateKey, err := keyPath()
  39. // if privateKey, try public key directly
  40. return ssh.ParsePrivateKey(privateKeyFile)
  41. }
  42. func GetPublicKey() (ssh.PublicKey, error) {
  43. // try to read pubkey first
  44. home, err := os.UserHomeDir()
  45. if err != nil {
  46. return nil, err
  47. }
  48. pubkeyPath := filepath.Join(home, ".ollama", defaultPrivateKey+".pub")
  49. pubKeyFile, err := os.ReadFile(pubkeyPath)
  50. if os.IsNotExist(err) {
  51. // try from privateKey
  52. privateKey, err := privateKey()
  53. if err != nil {
  54. return nil, fmt.Errorf("failed to read public key: %w", err)
  55. }
  56. return privateKey.PublicKey(), nil
  57. } else if err != nil {
  58. return nil, fmt.Errorf("failed to read public key: %w", err)
  59. }
  60. pubKey, _, _, _, err := ssh.ParseAuthorizedKey(pubKeyFile)
  61. return pubKey, err
  62. }
  63. func NewNonce(r io.Reader, length int) (string, error) {
  64. nonce := make([]byte, length)
  65. if _, err := io.ReadFull(r, nonce); err != nil {
  66. return "", err
  67. }
  68. return base64.RawURLEncoding.EncodeToString(nonce), nil
  69. }
  70. func Sign(ctx context.Context, bts []byte) (string, error) {
  71. privateKey, err := privateKey()
  72. if err != nil {
  73. return "", err
  74. }
  75. // get the pubkey, but remove the type
  76. publicKey, err := GetPublicKey()
  77. if err != nil {
  78. return "", err
  79. }
  80. publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
  81. parts := bytes.Split(publicKeyBytes, []byte(" "))
  82. if len(parts) < 2 {
  83. return "", fmt.Errorf("malformed public key")
  84. }
  85. signedData, err := privateKey.Sign(rand.Reader, bts)
  86. if err != nil {
  87. return "", err
  88. }
  89. // signature is <pubkey>:<signature>
  90. return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
  91. }
  92. func initializeKeypair() error {
  93. home, err := os.UserHomeDir()
  94. if err != nil {
  95. return err
  96. }
  97. privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
  98. pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
  99. _, err = os.Stat(privKeyPath)
  100. if os.IsNotExist(err) {
  101. fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
  102. cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
  103. if err != nil {
  104. return err
  105. }
  106. privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
  107. if err != nil {
  108. return err
  109. }
  110. if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
  111. return fmt.Errorf("could not create directory %w", err)
  112. }
  113. if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
  114. return err
  115. }
  116. sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
  117. if err != nil {
  118. return err
  119. }
  120. publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
  121. if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
  122. return err
  123. }
  124. fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
  125. }
  126. return nil
  127. }