Преглед на файлове

server: validate local path on safetensor create (#9379)

More validation during the safetensor creation process.
Properly handle relative paths (like ./model.safetensors) while rejecting absolute paths
Add comprehensive test coverage for various paths
No functionality changes for valid inputs - existing workflows remain unaffected
Leverages Go 1.24's new os.Root functionality for secure containment
Bruce MacDonald преди 2 месеца
родител
ревизия
bebb6823c0
променени са 2 файла, в които са добавени 131 реда и са изтрити 1 реда
  1. 25 1
      server/create.go
  2. 106 0
      server/create_test.go

+ 25 - 1
server/create.go

@@ -8,6 +8,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"io/fs"
 	"log/slog"
 	"net/http"
 	"os"
@@ -34,6 +35,7 @@ var (
 	errOnlyGGUFSupported       = errors.New("supplied file was not in GGUF format")
 	errUnknownType             = errors.New("unknown type")
 	errNeitherFromOrFiles      = errors.New("neither 'from' or 'files' was specified")
+	errFilePath                = errors.New("file path must be relative")
 )
 
 func (s *Server) CreateHandler(c *gin.Context) {
@@ -46,6 +48,13 @@ func (s *Server) CreateHandler(c *gin.Context) {
 		return
 	}
 
+	for v := range r.Files {
+		if !fs.ValidPath(v) {
+			c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
+			return
+		}
+	}
+
 	name := model.ParseName(cmp.Or(r.Model, r.Name))
 	if !name.IsValid() {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
@@ -104,7 +113,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
 		if r.Adapters != nil {
 			adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
 			if err != nil {
-				for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType} {
+				for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} {
 					if errors.Is(err, badReq) {
 						ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
 						return
@@ -221,8 +230,22 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
 		return nil, err
 	}
 	defer os.RemoveAll(tmpDir)
+	// Set up a root to validate paths
+	root, err := os.OpenRoot(tmpDir)
+	if err != nil {
+		return nil, err
+	}
+	defer root.Close()
 
 	for fp, digest := range files {
+		if !fs.ValidPath(fp) {
+			return nil, fmt.Errorf("%w: %s", errFilePath, fp)
+		}
+		if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) {
+			// Path is likely outside the root
+			return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
+		}
+
 		blobPath, err := GetBlobsPath(digest)
 		if err != nil {
 			return nil, err
@@ -270,6 +293,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
 	if err != nil {
 		return nil, err
 	}
+	defer bin.Close()
 
 	f, _, err := ggml.Decode(bin, 0)
 	if err != nil {

+ 106 - 0
server/create_test.go

@@ -0,0 +1,106 @@
+package server
+
+import (
+	"bytes"
+	"encoding/binary"
+	"errors"
+	"os"
+	"path/filepath"
+	"strings"
+	"testing"
+
+	"github.com/ollama/ollama/api"
+)
+
+func TestConvertFromSafetensors(t *testing.T) {
+	t.Setenv("OLLAMA_MODELS", t.TempDir())
+
+	// Helper function to create a new layer and return its digest
+	makeTemp := func(content string) string {
+		l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
+		if err != nil {
+			t.Fatalf("Failed to create layer: %v", err)
+		}
+		return l.Digest
+	}
+
+	// Create a safetensors compatible file with empty JSON content
+	var buf bytes.Buffer
+	headerSize := int64(len("{}"))
+	binary.Write(&buf, binary.LittleEndian, headerSize)
+	buf.WriteString("{}")
+
+	model := makeTemp(buf.String())
+	config := makeTemp(`{
+		"architectures": ["LlamaForCausalLM"], 
+		"vocab_size": 32000
+	}`)
+	tokenizer := makeTemp(`{
+		"version": "1.0",
+		"truncation": null,
+		"padding": null,
+		"added_tokens": [
+			{
+				"id": 0,
+				"content": "<|endoftext|>",
+				"single_word": false,
+				"lstrip": false,
+				"rstrip": false,
+				"normalized": false,
+				"special": true
+			}
+		]
+	}`)
+
+	tests := []struct {
+		name     string
+		filePath string
+		wantErr  error
+	}{
+		// Invalid
+		{
+			name:     "InvalidRelativePathShallow",
+			filePath: filepath.Join("..", "file.safetensors"),
+			wantErr:  errFilePath,
+		},
+		{
+			name:     "InvalidRelativePathDeep",
+			filePath: filepath.Join("..", "..", "..", "..", "..", "..", "data", "file.txt"),
+			wantErr:  errFilePath,
+		},
+		{
+			name:     "InvalidNestedPath",
+			filePath: filepath.Join("dir", "..", "..", "..", "..", "..", "other.safetensors"),
+			wantErr:  errFilePath,
+		},
+		{
+			name:     "AbsolutePathOutsideRoot",
+			filePath: filepath.Join(os.TempDir(), "model.safetensors"),
+			wantErr:  errFilePath, // Should fail since it's outside tmpDir
+		},
+		{
+			name:     "ValidRelativePath",
+			filePath: "model.safetensors",
+			wantErr:  nil,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			// Create the minimum required file map for convertFromSafetensors
+			files := map[string]string{
+				tt.filePath:      model,
+				"config.json":    config,
+				"tokenizer.json": tokenizer,
+			}
+
+			_, err := convertFromSafetensors(files, nil, false, func(resp api.ProgressResponse) {})
+
+			if (tt.wantErr == nil && err != nil) ||
+				(tt.wantErr != nil && err == nil) ||
+				(tt.wantErr != nil && !errors.Is(err, tt.wantErr)) {
+				t.Errorf("convertFromSafetensors() error = %v, wantErr %v", err, tt.wantErr)
+			}
+		})
+	}
+}