auth.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. package server
  2. import (
  3. "context"
  4. "crypto/rand"
  5. "crypto/sha256"
  6. "encoding/base64"
  7. "encoding/hex"
  8. "encoding/json"
  9. "fmt"
  10. "io"
  11. "net/http"
  12. "net/url"
  13. "strconv"
  14. "strings"
  15. "time"
  16. "github.com/ollama/ollama/api"
  17. "github.com/ollama/ollama/auth"
  18. )
  19. type registryChallenge struct {
  20. Realm string
  21. Service string
  22. Scope string
  23. Timestamp time.Time
  24. }
  25. func (r registryChallenge) URL() (*url.URL, error) {
  26. redirectURL, err := url.Parse(r.Realm)
  27. if err != nil {
  28. return nil, err
  29. }
  30. values := redirectURL.Query()
  31. values.Add("service", r.Service)
  32. for _, s := range strings.Split(r.Scope, " ") {
  33. values.Add("scope", s)
  34. }
  35. values.Add("ts", strconv.FormatInt(r.Timestamp.Unix(), 10))
  36. nonce, err := auth.NewNonce(rand.Reader, 16)
  37. if err != nil {
  38. return nil, err
  39. }
  40. values.Add("nonce", nonce)
  41. redirectURL.RawQuery = values.Encode()
  42. return redirectURL, nil
  43. }
  44. func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
  45. redirectURL, err := challenge.URL()
  46. if err != nil {
  47. return "", err
  48. }
  49. sha256sum := sha256.Sum256(nil)
  50. data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
  51. headers := make(http.Header)
  52. signature, err := auth.Sign(ctx, data)
  53. if err != nil {
  54. return "", err
  55. }
  56. headers.Add("Authorization", signature)
  57. response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
  58. if err != nil {
  59. return "", err
  60. }
  61. defer response.Body.Close()
  62. body, err := io.ReadAll(response.Body)
  63. if err != nil {
  64. return "", fmt.Errorf("%d: %v", response.StatusCode, err)
  65. }
  66. if response.StatusCode >= http.StatusBadRequest {
  67. if len(body) > 0 {
  68. return "", fmt.Errorf("%d: %s", response.StatusCode, body)
  69. } else {
  70. return "", fmt.Errorf("%d", response.StatusCode)
  71. }
  72. }
  73. var token api.TokenResponse
  74. if err := json.Unmarshal(body, &token); err != nil {
  75. return "", err
  76. }
  77. return token.Token, nil
  78. }