create_test.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. package server
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "errors"
  6. "os"
  7. "path/filepath"
  8. "strings"
  9. "testing"
  10. "github.com/ollama/ollama/api"
  11. )
  12. func TestConvertFromSafetensors(t *testing.T) {
  13. t.Setenv("OLLAMA_MODELS", t.TempDir())
  14. // Helper function to create a new layer and return its digest
  15. makeTemp := func(content string) string {
  16. l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
  17. if err != nil {
  18. t.Fatalf("Failed to create layer: %v", err)
  19. }
  20. return l.Digest
  21. }
  22. // Create a safetensors compatible file with empty JSON content
  23. var buf bytes.Buffer
  24. headerSize := int64(len("{}"))
  25. binary.Write(&buf, binary.LittleEndian, headerSize)
  26. buf.WriteString("{}")
  27. model := makeTemp(buf.String())
  28. config := makeTemp(`{
  29. "architectures": ["LlamaForCausalLM"],
  30. "vocab_size": 32000
  31. }`)
  32. tokenizer := makeTemp(`{
  33. "version": "1.0",
  34. "truncation": null,
  35. "padding": null,
  36. "added_tokens": [
  37. {
  38. "id": 0,
  39. "content": "<|endoftext|>",
  40. "single_word": false,
  41. "lstrip": false,
  42. "rstrip": false,
  43. "normalized": false,
  44. "special": true
  45. }
  46. ]
  47. }`)
  48. tests := []struct {
  49. name string
  50. filePath string
  51. wantErr error
  52. }{
  53. // Invalid
  54. {
  55. name: "InvalidRelativePathShallow",
  56. filePath: filepath.Join("..", "file.safetensors"),
  57. wantErr: errFilePath,
  58. },
  59. {
  60. name: "InvalidRelativePathDeep",
  61. filePath: filepath.Join("..", "..", "..", "..", "..", "..", "data", "file.txt"),
  62. wantErr: errFilePath,
  63. },
  64. {
  65. name: "InvalidNestedPath",
  66. filePath: filepath.Join("dir", "..", "..", "..", "..", "..", "other.safetensors"),
  67. wantErr: errFilePath,
  68. },
  69. {
  70. name: "AbsolutePathOutsideRoot",
  71. filePath: filepath.Join(os.TempDir(), "model.safetensors"),
  72. wantErr: errFilePath, // Should fail since it's outside tmpDir
  73. },
  74. {
  75. name: "ValidRelativePath",
  76. filePath: "model.safetensors",
  77. wantErr: nil,
  78. },
  79. }
  80. for _, tt := range tests {
  81. t.Run(tt.name, func(t *testing.T) {
  82. // Create the minimum required file map for convertFromSafetensors
  83. files := map[string]string{
  84. tt.filePath: model,
  85. "config.json": config,
  86. "tokenizer.json": tokenizer,
  87. }
  88. _, err := convertFromSafetensors(files, nil, false, func(resp api.ProgressResponse) {})
  89. if (tt.wantErr == nil && err != nil) ||
  90. (tt.wantErr != nil && err == nil) ||
  91. (tt.wantErr != nil && !errors.Is(err, tt.wantErr)) {
  92. t.Errorf("convertFromSafetensors() error = %v, wantErr %v", err, tt.wantErr)
  93. }
  94. })
  95. }
  96. }