|
@@ -0,0 +1,106 @@
|
|
|
+package server
|
|
|
+
|
|
|
+import (
|
|
|
+ "bytes"
|
|
|
+ "encoding/binary"
|
|
|
+ "errors"
|
|
|
+ "os"
|
|
|
+ "path/filepath"
|
|
|
+ "strings"
|
|
|
+ "testing"
|
|
|
+
|
|
|
+ "github.com/ollama/ollama/api"
|
|
|
+)
|
|
|
+
|
|
|
+func TestConvertFromSafetensors(t *testing.T) {
|
|
|
+ t.Setenv("OLLAMA_MODELS", t.TempDir())
|
|
|
+
|
|
|
+ // Helper function to create a new layer and return its digest
|
|
|
+ makeTemp := func(content string) string {
|
|
|
+ l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("Failed to create layer: %v", err)
|
|
|
+ }
|
|
|
+ return l.Digest
|
|
|
+ }
|
|
|
+
|
|
|
+ // Create a safetensors compatible file with empty JSON content
|
|
|
+ var buf bytes.Buffer
|
|
|
+ headerSize := int64(len("{}"))
|
|
|
+ binary.Write(&buf, binary.LittleEndian, headerSize)
|
|
|
+ buf.WriteString("{}")
|
|
|
+
|
|
|
+ model := makeTemp(buf.String())
|
|
|
+ config := makeTemp(`{
|
|
|
+ "architectures": ["LlamaForCausalLM"],
|
|
|
+ "vocab_size": 32000
|
|
|
+ }`)
|
|
|
+ tokenizer := makeTemp(`{
|
|
|
+ "version": "1.0",
|
|
|
+ "truncation": null,
|
|
|
+ "padding": null,
|
|
|
+ "added_tokens": [
|
|
|
+ {
|
|
|
+ "id": 0,
|
|
|
+ "content": "<|endoftext|>",
|
|
|
+ "single_word": false,
|
|
|
+ "lstrip": false,
|
|
|
+ "rstrip": false,
|
|
|
+ "normalized": false,
|
|
|
+ "special": true
|
|
|
+ }
|
|
|
+ ]
|
|
|
+ }`)
|
|
|
+
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ filePath string
|
|
|
+ wantErr error
|
|
|
+ }{
|
|
|
+ // Invalid
|
|
|
+ {
|
|
|
+ name: "InvalidRelativePathShallow",
|
|
|
+ filePath: filepath.Join("..", "file.safetensors"),
|
|
|
+ wantErr: errFilePath,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "InvalidRelativePathDeep",
|
|
|
+ filePath: filepath.Join("..", "..", "..", "..", "..", "..", "data", "file.txt"),
|
|
|
+ wantErr: errFilePath,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "InvalidNestedPath",
|
|
|
+ filePath: filepath.Join("dir", "..", "..", "..", "..", "..", "other.safetensors"),
|
|
|
+ wantErr: errFilePath,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "AbsolutePathOutsideRoot",
|
|
|
+ filePath: filepath.Join(os.TempDir(), "model.safetensors"),
|
|
|
+ wantErr: errFilePath, // Should fail since it's outside tmpDir
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "ValidRelativePath",
|
|
|
+ filePath: "model.safetensors",
|
|
|
+ wantErr: nil,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, tt := range tests {
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ // Create the minimum required file map for convertFromSafetensors
|
|
|
+ files := map[string]string{
|
|
|
+ tt.filePath: model,
|
|
|
+ "config.json": config,
|
|
|
+ "tokenizer.json": tokenizer,
|
|
|
+ }
|
|
|
+
|
|
|
+ _, err := convertFromSafetensors(files, nil, false, func(resp api.ProgressResponse) {})
|
|
|
+
|
|
|
+ if (tt.wantErr == nil && err != nil) ||
|
|
|
+ (tt.wantErr != nil && err == nil) ||
|
|
|
+ (tt.wantErr != nil && !errors.Is(err, tt.wantErr)) {
|
|
|
+ t.Errorf("convertFromSafetensors() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|