|
@@ -4,10 +4,12 @@ import (
|
|
|
"bytes"
|
|
|
"cmp"
|
|
|
"context"
|
|
|
+ "encoding/base64"
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
+ "log"
|
|
|
"log/slog"
|
|
|
"net"
|
|
|
"net/http"
|
|
@@ -22,8 +24,10 @@ import (
|
|
|
|
|
|
"github.com/gin-contrib/cors"
|
|
|
"github.com/gin-gonic/gin"
|
|
|
+ "golang.org/x/crypto/ssh"
|
|
|
|
|
|
"github.com/ollama/ollama/api"
|
|
|
+ "github.com/ollama/ollama/auth"
|
|
|
"github.com/ollama/ollama/envconfig"
|
|
|
"github.com/ollama/ollama/gpu"
|
|
|
"github.com/ollama/ollama/llm"
|
|
@@ -783,7 +787,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- if c.GetHeader("X-Redirect-Create") == "1" {
|
|
|
+ if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
|
|
|
c.Header("LocalLocation", path)
|
|
|
c.Status(http.StatusTemporaryRedirect)
|
|
|
return
|
|
@@ -803,6 +807,74 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|
|
c.Status(http.StatusCreated)
|
|
|
}
|
|
|
|
|
|
+func (s *Server) IsLocal(c *gin.Context) bool {
|
|
|
+ if authz := c.GetHeader("Authorization"); authz != "" {
|
|
|
+ parts := strings.Split(authz, ":")
|
|
|
+ if len(parts) != 3 {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
|
|
|
+ if err != nil {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ // partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
|
|
|
+ partialRequestData, err := base64.StdEncoding.DecodeString(parts[1])
|
|
|
+ if err != nil {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ partialRequestDataParts := strings.Split(string(partialRequestData), ",")
|
|
|
+ if len(partialRequestDataParts) != 4 {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ /* timestamp, err := strconv.ParseInt(partialRequestDataParts[2], 10, 0)
|
|
|
+ if err != nil {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ t := time.Unix(timestamp, 0)
|
|
|
+ if time.Since(t) > 5*time.Minute || time.Until(t) > 5*time.Minute {
|
|
|
+ // token is invalid if timestamp +/- 5 minutes from current time
|
|
|
+ return false
|
|
|
+ } */
|
|
|
+
|
|
|
+ /* nonce := partialRequestDataParts[3]
|
|
|
+ if nonceCache.has(nonce) {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ nonceCache.add(nonce, 5*time.Minute) */
|
|
|
+
|
|
|
+ signature, err := base64.StdEncoding.DecodeString(parts[2])
|
|
|
+ if err != nil {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ serverPublicKey, err := auth.GetPublicKey()
|
|
|
+ if err != nil {
|
|
|
+ log.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ _, key, _ := bytes.Cut(bytes.TrimSpace(ssh.MarshalAuthorizedKey(serverPublicKey)), []byte(" "))
|
|
|
+ requestData := fmt.Sprintf("%s,%s", key, partialRequestData)
|
|
|
+
|
|
|
+ if err := clientPublicKey.Verify([]byte(requestData), &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
func isLocalIP(ip netip.Addr) bool {
|
|
|
if interfaces, err := net.Interfaces(); err == nil {
|
|
|
for _, iface := range interfaces {
|