auth.go 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. package auth
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/rand"
  6. "encoding/base64"
  7. "fmt"
  8. "io"
  9. "log/slog"
  10. "os"
  11. "path/filepath"
  12. "golang.org/x/crypto/ssh"
  13. )
  14. const defaultPrivateKey = "id_ed25519"
  15. func keyPath() (ssh.Signer, error) {
  16. home, err := os.UserHomeDir()
  17. if err != nil {
  18. return nil, err
  19. }
  20. keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
  21. privateKeyFile, err := os.ReadFile(keyPath)
  22. if err != nil {
  23. slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
  24. return nil, err
  25. }
  26. return ssh.ParsePrivateKey(privateKeyFile)
  27. }
  28. func GetPublicKey() (ssh.PublicKey, error) {
  29. privateKey, err := keyPath()
  30. // if privateKey, try public key directly
  31. if err != nil {
  32. return nil, err
  33. }
  34. return privateKey.PublicKey(), nil
  35. }
  36. func NewNonce(r io.Reader, length int) (string, error) {
  37. nonce := make([]byte, length)
  38. if _, err := io.ReadFull(r, nonce); err != nil {
  39. return "", err
  40. }
  41. return base64.RawURLEncoding.EncodeToString(nonce), nil
  42. }
  43. func Sign(ctx context.Context, bts []byte) (string, error) {
  44. privateKey, err := keyPath()
  45. if err != nil {
  46. return "", err
  47. }
  48. // get the pubkey, but remove the type
  49. publicKey, err := GetPublicKey()
  50. if err != nil {
  51. return "", err
  52. }
  53. publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
  54. parts := bytes.Split(publicKeyBytes, []byte(" "))
  55. if len(parts) < 2 {
  56. return "", fmt.Errorf("malformed public key")
  57. }
  58. signedData, err := privateKey.Sign(rand.Reader, bts)
  59. if err != nil {
  60. return "", err
  61. }
  62. // signature is <pubkey>:<signature>
  63. return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
  64. }