embedding_test.go 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. //go:build integration
  2. package integration
  3. import (
  4. "context"
  5. "net/http"
  6. "testing"
  7. "time"
  8. "github.com/ollama/ollama/api"
  9. )
  10. func TestAllMiniLMEmbedding(t *testing.T) {
  11. ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
  12. defer cancel()
  13. req := api.EmbeddingRequest{
  14. Model: "all-minilm",
  15. Prompt: "why is the sky blue?",
  16. Options: map[string]interface{}{
  17. "temperature": 0,
  18. "seed": 123,
  19. },
  20. }
  21. res := EmbeddingTestHelper(ctx, t, &http.Client{}, req)
  22. if len(res.Embedding) != 384 {
  23. t.Fatalf("Expected 384 floats to be returned, got %v", len(res.Embedding))
  24. }
  25. if res.Embedding[0] != 0.146763876080513 {
  26. t.Fatalf("Expected first embedding float to be 0.146763876080513, got %v", res.Embedding[0])
  27. }
  28. }
  29. func TestAllMiniLMEmbeddings(t *testing.T) {
  30. ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
  31. defer cancel()
  32. req := api.EmbeddingRequest{
  33. Model: "all-minilm",
  34. Prompts: []string{"why is the sky blue?", "why is the sky blue?"},
  35. Options: map[string]interface{}{
  36. "temperature": 0,
  37. "seed": 123,
  38. },
  39. }
  40. res := EmbeddingTestHelper(ctx, t, &http.Client{}, req)
  41. if len(res.Embeddings) != 2 {
  42. t.Fatal("Expected 2 embeddings to be returned")
  43. }
  44. if len(res.Embeddings[0]) != 384 {
  45. t.Fatalf("Expected first embedding to have 384 floats, got %v", len(res.Embeddings[0]))
  46. }
  47. if res.Embeddings[0][0] != 0.146763876080513 && res.Embeddings[1][0] != 0.146763876080513 {
  48. t.Fatalf("Expected first embedding floats to be 0.146763876080513, got %v, %v", res.Embeddings[0][0], res.Embeddings[1][0])
  49. }
  50. }