Browse Source

Token auth (#314)

Patrick Devine 1 year ago
parent
commit
be989d89d1
3 changed files with 233 additions and 6 deletions
  1. 4 0
      api/types.go
  2. 164 0
      server/auth.go
  3. 65 6
      server/images.go

+ 4 - 0
api/types.go

@@ -98,6 +98,10 @@ type ListResponseModel struct {
 	Size       int       `json:"size"`
 	Size       int       `json:"size"`
 }
 }
 
 
+type TokenResponse struct {
+	Token string `json:"token"`
+}
+
 type GenerateResponse struct {
 type GenerateResponse struct {
 	Model     string    `json:"model"`
 	Model     string    `json:"model"`
 	CreatedAt time.Time `json:"created_at"`
 	CreatedAt time.Time `json:"created_at"`

+ 164 - 0
server/auth.go

@@ -0,0 +1,164 @@
+package server
+
+import (
+	"bytes"
+	"crypto/rand"
+	"crypto/sha256"
+	"encoding/base64"
+	"encoding/hex"
+	"encoding/json"
+	"fmt"
+	"io"
+	"io/ioutil"
+	"log"
+	"net/http"
+	"os"
+	"path"
+	"strings"
+	"time"
+
+	"golang.org/x/crypto/ssh"
+
+	"github.com/jmorganca/ollama/api"
+)
+
+type AuthRedirect struct {
+	Realm   string
+	Service string
+	Scope   string
+}
+
+type SignatureData struct {
+	Method string
+	Path   string
+	Data   []byte
+}
+
+func generateNonce(length int) (string, error) {
+	nonce := make([]byte, length)
+	_, err := rand.Read(nonce)
+	if err != nil {
+		return "", err
+	}
+	return base64.RawURLEncoding.EncodeToString(nonce), nil
+}
+
+func (r AuthRedirect) URL() (string, error) {
+	nonce, err := generateNonce(16)
+	if err != nil {
+		return "", err
+	}
+	return fmt.Sprintf("%s?service=%s&scope=%s&ts=%d&nonce=%s", r.Realm, r.Service, r.Scope, time.Now().Unix(), nonce), nil
+}
+
+func getAuthToken(redirData AuthRedirect, regOpts *RegistryOptions) (string, error) {
+	url, err := redirData.URL()
+	if err != nil {
+		return "", err
+	}
+
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return "", err
+	}
+
+	keyPath := path.Join(home, ".ollama/id_ed25519")
+
+	rawKey, err := ioutil.ReadFile(keyPath)
+	if err != nil {
+		log.Printf("Failed to load private key: %v", err)
+		return "", err
+	}
+
+	s := SignatureData{
+		Method: "GET",
+		Path:   url,
+		Data:   nil,
+	}
+
+	if !strings.HasPrefix(s.Path, "http") {
+		if regOpts.Insecure {
+			s.Path = "http://" + url
+		} else {
+			s.Path = "https://" + url
+		}
+	}
+
+	sig, err := s.Sign(rawKey)
+	if err != nil {
+		return "", err
+	}
+
+	headers := map[string]string{
+		"Authorization": sig,
+	}
+
+	resp, err := makeRequest("GET", url, headers, nil, regOpts)
+	if err != nil {
+		log.Printf("couldn't get token: %q", err)
+	}
+	defer resp.Body.Close()
+
+	if resp.StatusCode != http.StatusOK {
+		body, _ := io.ReadAll(resp.Body)
+		return "", fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
+	}
+
+	respBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return "", err
+	}
+
+	var tok api.TokenResponse
+	if err := json.Unmarshal(respBody, &tok); err != nil {
+		return "", err
+	}
+
+	return tok.Token, nil
+}
+
+// Bytes returns a byte slice of the data to sign for the request
+func (s SignatureData) Bytes() []byte {
+	// We first derive the content hash of the request body using:
+	//     base64(hex(sha256(request body)))
+
+	hash := sha256.Sum256(s.Data)
+	hashHex := make([]byte, hex.EncodedLen(len(hash)))
+	hex.Encode(hashHex, hash[:])
+	contentHash := base64.StdEncoding.EncodeToString(hashHex)
+
+	// We then put the entire request together in a serialize string using:
+	//       "<method>,<uri>,<content hash>"
+	// e.g.  "GET,http://localhost,OTdkZjM1O..."
+
+	return []byte(strings.Join([]string{s.Method, s.Path, contentHash}, ","))
+}
+
+// SignData takes a SignatureData object and signs it with a raw private key
+func (s SignatureData) Sign(rawKey []byte) (string, error) {
+	privateKey, err := ssh.ParseRawPrivateKey(rawKey)
+	if err != nil {
+		return "", err
+	}
+
+	signer, err := ssh.NewSignerFromKey(privateKey)
+	if err != nil {
+		return "", err
+	}
+
+	// get the pubkey, but remove the type
+	pubKey := ssh.MarshalAuthorizedKey(signer.PublicKey())
+	parts := bytes.Split(pubKey, []byte(" "))
+	if len(parts) < 2 {
+		return "", fmt.Errorf("malformed public key")
+	}
+
+	signedData, err := signer.Sign(nil, s.Bytes())
+	if err != nil {
+		return "", err
+	}
+
+	// signature is <pubkey>:<signature>
+	sig := fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob))
+	return sig, nil
+}

+ 65 - 6
server/images.go

@@ -28,6 +28,7 @@ type RegistryOptions struct {
 	Insecure bool
 	Insecure bool
 	Username string
 	Username string
 	Password string
 	Password string
+	Token    string
 }
 }
 
 
 type Model struct {
 type Model struct {
@@ -1129,18 +1130,30 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader,
 		}
 		}
 	}
 	}
 
 
-	req, err := http.NewRequest(method, url, body)
+	// make a copy of the body in case we need to try the call to makeRequest again
+	var buf bytes.Buffer
+	if body != nil {
+		_, err := io.Copy(&buf, body)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	bodyCopy := bytes.NewReader(buf.Bytes())
+
+	req, err := http.NewRequest(method, url, bodyCopy)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	for k, v := range headers {
-		req.Header.Set(k, v)
+	if regOpts.Token != "" {
+		req.Header.Set("Authorization", "Bearer "+regOpts.Token)
+	} else if regOpts.Username != "" && regOpts.Password != "" {
+		req.SetBasicAuth(regOpts.Username, regOpts.Password)
 	}
 	}
 
 
-	// TODO: better auth
-	if regOpts.Username != "" && regOpts.Password != "" {
-		req.SetBasicAuth(regOpts.Username, regOpts.Password)
+	for k, v := range headers {
+		req.Header.Set(k, v)
 	}
 	}
 
 
 	client := &http.Client{
 	client := &http.Client{
@@ -1157,9 +1170,55 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader,
 		return nil, err
 		return nil, err
 	}
 	}
 
 
+	// if the request is unauthenticated, try to authenticate and make the request again
+	if resp.StatusCode == http.StatusUnauthorized {
+		auth := resp.Header.Get("Www-Authenticate")
+		authRedir := ParseAuthRedirectString(string(auth))
+		token, err := getAuthToken(authRedir, regOpts)
+		if err != nil {
+			return nil, err
+		}
+		regOpts.Token = token
+		bodyCopy = bytes.NewReader(buf.Bytes())
+		return makeRequest(method, url, headers, bodyCopy, regOpts)
+	}
+
 	return resp, nil
 	return resp, nil
 }
 }
 
 
+func getValue(header, key string) string {
+	startIdx := strings.Index(header, key+"=")
+	if startIdx == -1 {
+		return ""
+	}
+
+	// Move the index to the starting quote after the key.
+	startIdx += len(key) + 2
+	endIdx := startIdx
+
+	for endIdx < len(header) {
+		if header[endIdx] == '"' {
+			if endIdx+1 < len(header) && header[endIdx+1] != ',' { // If the next character isn't a comma, continue
+				endIdx++
+				continue
+			}
+			break
+		}
+		endIdx++
+	}
+	return header[startIdx:endIdx]
+}
+
+func ParseAuthRedirectString(authStr string) AuthRedirect {
+	authStr = strings.TrimPrefix(authStr, "Bearer ")
+
+	return AuthRedirect{
+		Realm:   getValue(authStr, "realm"),
+		Service: getValue(authStr, "service"),
+		Scope:   getValue(authStr, "scope"),
+	}
+}
+
 var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")
 var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")
 
 
 func verifyBlob(digest string) error {
 func verifyBlob(digest string) error {