浏览代码

validate the format of the digest when getting the model path (#4175)

Patrick Devine 1 年之前
父节点
当前提交
2a21363bb7
共有 2 个文件被更改,包括 84 次插入4 次删除
  1. 16 3
      server/modelpath.go
  2. 68 1
      server/modelpath_test.go

+ 16 - 3
server/modelpath.go

@@ -6,6 +6,7 @@ import (
 	"net/url"
 	"os"
 	"path/filepath"
+	"regexp"
 	"strings"
 )
 
@@ -25,9 +26,10 @@ const (
 )
 
 var (
-	ErrInvalidImageFormat = errors.New("invalid image format")
-	ErrInvalidProtocol    = errors.New("invalid protocol scheme")
-	ErrInsecureProtocol   = errors.New("insecure protocol http")
+	ErrInvalidImageFormat  = errors.New("invalid image format")
+	ErrInvalidProtocol     = errors.New("invalid protocol scheme")
+	ErrInsecureProtocol    = errors.New("insecure protocol http")
+	ErrInvalidDigestFormat = errors.New("invalid digest format")
 )
 
 func ParseModelPath(name string) ModelPath {
@@ -149,6 +151,17 @@ func GetBlobsPath(digest string) (string, error) {
 		return "", err
 	}
 
+	// only accept actual sha256 digests
+	pattern := "^sha256[:-][0-9a-fA-F]{64}$"
+	re := regexp.MustCompile(pattern)
+	if err != nil {
+		return "", err
+	}
+
+	if digest != "" && !re.MatchString(digest) {
+		return "", ErrInvalidDigestFormat
+	}
+
 	digest = strings.ReplaceAll(digest, ":", "-")
 	path := filepath.Join(dir, "blobs", digest)
 	dirPath := filepath.Dir(path)

+ 68 - 1
server/modelpath_test.go

@@ -1,6 +1,73 @@
 package server
 
-import "testing"
+import (
+	"os"
+	"path/filepath"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestGetBlobsPath(t *testing.T) {
+	// GetBlobsPath expects an actual directory to exist
+	dir, err := os.MkdirTemp("", "ollama-test")
+	assert.Nil(t, err)
+	defer os.RemoveAll(dir)
+
+	tests := []struct {
+		name     string
+		digest   string
+		expected string
+		err      error
+	}{
+		{
+			"empty digest",
+			"",
+			filepath.Join(dir, "blobs"),
+			nil,
+		},
+		{
+			"valid with colon",
+			"sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
+			filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
+			nil,
+		},
+		{
+			"valid with dash",
+			"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
+			filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
+			nil,
+		},
+		{
+			"digest too short",
+			"sha256-45640291",
+			"",
+			ErrInvalidDigestFormat,
+		},
+		{
+			"digest too long",
+			"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa",
+			"",
+			ErrInvalidDigestFormat,
+		},
+		{
+			"digest invalid chars",
+			"../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a",
+			"",
+			ErrInvalidDigestFormat,
+		},
+	}
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			t.Setenv("OLLAMA_MODELS", dir)
+
+			got, err := GetBlobsPath(tc.digest)
+
+			assert.ErrorIs(t, tc.err, err, tc.name)
+			assert.Equal(t, tc.expected, got, tc.name)
+		})
+	}
+}
 
 func TestParseModelPath(t *testing.T) {
 	tests := []struct {