embed_test.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. //go:build integration
  2. package integration
  3. import (
  4. "context"
  5. "testing"
  6. "time"
  7. "github.com/ollama/ollama/api"
  8. )
  9. func TestAllMiniLMEmbed(t *testing.T) {
  10. ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
  11. defer cancel()
  12. req := api.EmbedRequest{
  13. Model: "all-minilm",
  14. Input: "why is the sky blue?",
  15. }
  16. res, err := embedTestHelper(ctx, t, req)
  17. if err != nil {
  18. t.Fatalf("error: %v", err)
  19. }
  20. if len(res.Embeddings) != 1 {
  21. t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
  22. }
  23. if len(res.Embeddings[0]) != 384 {
  24. t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
  25. }
  26. if res.Embeddings[0][0] != 0.010071031 {
  27. t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0])
  28. }
  29. }
  30. func TestAllMiniLMBatchEmbed(t *testing.T) {
  31. ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
  32. defer cancel()
  33. req := api.EmbedRequest{
  34. Model: "all-minilm",
  35. Input: []string{"why is the sky blue?", "why is the grass green?"},
  36. }
  37. res, err := embedTestHelper(ctx, t, req)
  38. if err != nil {
  39. t.Fatalf("error: %v", err)
  40. }
  41. if len(res.Embeddings) != 2 {
  42. t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
  43. }
  44. if len(res.Embeddings[0]) != 384 {
  45. t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
  46. }
  47. if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 {
  48. t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
  49. }
  50. }
  51. func TestAllMiniLmEmbedTruncate(t *testing.T) {
  52. ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
  53. defer cancel()
  54. truncTrue, truncFalse := true, false
  55. type testReq struct {
  56. Name string
  57. Request api.EmbedRequest
  58. }
  59. reqs := []testReq{
  60. {
  61. Name: "Target Truncation",
  62. Request: api.EmbedRequest{
  63. Model: "all-minilm",
  64. Input: "why",
  65. },
  66. },
  67. {
  68. Name: "Default Truncate",
  69. Request: api.EmbedRequest{
  70. Model: "all-minilm",
  71. Input: "why is the sky blue?",
  72. Options: map[string]any{"num_ctx": 1},
  73. },
  74. },
  75. {
  76. Name: "Explicit Truncate",
  77. Request: api.EmbedRequest{
  78. Model: "all-minilm",
  79. Input: "why is the sky blue?",
  80. Truncate: &truncTrue,
  81. Options: map[string]any{"num_ctx": 1},
  82. },
  83. },
  84. }
  85. res := make(map[string]*api.EmbedResponse)
  86. for _, req := range reqs {
  87. response, err := embedTestHelper(ctx, t, req.Request)
  88. if err != nil {
  89. t.Fatalf("error: %v", err)
  90. }
  91. res[req.Name] = response
  92. }
  93. if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
  94. t.Fatal("expected default request to truncate correctly")
  95. }
  96. if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
  97. t.Fatal("expected default request and truncate true request to be the same")
  98. }
  99. // check that truncate set to false returns an error if context length is exceeded
  100. _, err := embedTestHelper(ctx, t, api.EmbedRequest{
  101. Model: "all-minilm",
  102. Input: "why is the sky blue?",
  103. Truncate: &truncFalse,
  104. Options: map[string]any{"num_ctx": 1},
  105. })
  106. if err == nil {
  107. t.Fatal("expected error, got nil")
  108. }
  109. }
  110. func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
  111. client, _, cleanup := InitServerConnection(ctx, t)
  112. defer cleanup()
  113. if err := PullIfMissing(ctx, client, req.Model); err != nil {
  114. t.Fatalf("failed to pull model %s: %v", req.Model, err)
  115. }
  116. response, err := client.Embed(ctx, &req)
  117. if err != nil {
  118. return nil, err
  119. }
  120. return response, nil
  121. }