embed_test.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. //go:build integration
  2. package integration
  3. import (
  4. "context"
  5. "testing"
  6. "time"
  7. "github.com/ollama/ollama/api"
  8. "github.com/stretchr/testify/require"
  9. )
  10. func TestAllMiniLMEmbed(t *testing.T) {
  11. ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
  12. defer cancel()
  13. req := api.EmbedRequest{
  14. Model: "all-minilm",
  15. Input: "why is the sky blue?",
  16. }
  17. res, err := embedTestHelper(ctx, t, req)
  18. if err != nil {
  19. t.Fatalf("error: %v", err)
  20. }
  21. if len(res.Embeddings) != 1 {
  22. t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
  23. }
  24. if len(res.Embeddings[0]) != 384 {
  25. t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
  26. }
  27. if res.Embeddings[0][0] != 0.010071029 {
  28. t.Fatalf("expected 0.010071029, got %f", res.Embeddings[0][0])
  29. }
  30. }
  31. func TestAllMiniLMBatchEmbed(t *testing.T) {
  32. ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
  33. defer cancel()
  34. req := api.EmbedRequest{
  35. Model: "all-minilm",
  36. Input: []string{"why is the sky blue?", "why is the grass green?"},
  37. }
  38. res, err := embedTestHelper(ctx, t, req)
  39. if err != nil {
  40. t.Fatalf("error: %v", err)
  41. }
  42. if len(res.Embeddings) != 2 {
  43. t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
  44. }
  45. if len(res.Embeddings[0]) != 384 {
  46. t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
  47. }
  48. if res.Embeddings[0][0] != 0.010071029 || res.Embeddings[1][0] != -0.0098027075 {
  49. t.Fatalf("expected 0.010071029 and -0.0098027075, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
  50. }
  51. }
  52. func TestAllMiniLmEmbedTruncate(t *testing.T) {
  53. ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
  54. defer cancel()
  55. truncTrue, truncFalse := true, false
  56. type testReq struct {
  57. Name string
  58. Request api.EmbedRequest
  59. }
  60. reqs := []testReq{
  61. {
  62. Name: "Target Truncation",
  63. Request: api.EmbedRequest{
  64. Model: "all-minilm",
  65. Input: "why",
  66. },
  67. },
  68. {
  69. Name: "Default Truncate",
  70. Request: api.EmbedRequest{
  71. Model: "all-minilm",
  72. Input: "why is the sky blue?",
  73. Options: map[string]any{"num_ctx": 1},
  74. },
  75. },
  76. {
  77. Name: "Explicit Truncate",
  78. Request: api.EmbedRequest{
  79. Model: "all-minilm",
  80. Input: "why is the sky blue?",
  81. Truncate: &truncTrue,
  82. Options: map[string]any{"num_ctx": 1},
  83. },
  84. },
  85. }
  86. res := make(map[string]*api.EmbedResponse)
  87. for _, req := range reqs {
  88. response, err := embedTestHelper(ctx, t, req.Request)
  89. if err != nil {
  90. t.Fatalf("error: %v", err)
  91. }
  92. res[req.Name] = response
  93. }
  94. if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
  95. t.Fatalf("expected default request to truncate correctly")
  96. }
  97. if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
  98. t.Fatalf("expected default request and truncate true request to be the same")
  99. }
  100. // check that truncate set to false returns an error if context length is exceeded
  101. _, err := embedTestHelper(ctx, t, api.EmbedRequest{
  102. Model: "all-minilm",
  103. Input: "why is the sky blue?",
  104. Truncate: &truncFalse,
  105. Options: map[string]any{"num_ctx": 1},
  106. })
  107. if err == nil {
  108. t.Fatalf("expected error, got nil")
  109. }
  110. }
  111. func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
  112. client, _, cleanup := InitServerConnection(ctx, t)
  113. defer cleanup()
  114. require.NoError(t, PullIfMissing(ctx, client, req.Model))
  115. response, err := client.Embed(ctx, &req)
  116. if err != nil {
  117. return nil, err
  118. }
  119. return response, nil
  120. }