store.go 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. package vector
  2. import (
  3. "container/heap"
  4. "sort"
  5. "gonum.org/v1/gonum/mat"
  6. )
  7. type Embedding struct {
  8. Vector []float64 // the embedding vector
  9. Data string // the data represted by the embedding
  10. }
  11. type EmbeddingSimilarity struct {
  12. Embedding Embedding // the embedding that was used to calculate the similarity
  13. Similarity float64 // the similarity between the embedding and the query
  14. }
  15. type Heap []EmbeddingSimilarity
  16. func (h Heap) Len() int { return len(h) }
  17. func (h Heap) Less(i, j int) bool { return h[i].Similarity < h[j].Similarity }
  18. func (h Heap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
  19. func (h *Heap) Push(e any) {
  20. *h = append(*h, e.(EmbeddingSimilarity))
  21. }
  22. func (h *Heap) Pop() interface{} {
  23. old := *h
  24. n := len(old)
  25. x := old[n-1]
  26. *h = old[0 : n-1]
  27. return x
  28. }
  29. // cosineSimilarity is a measure that calculates the cosine of the angle between two vectors.
  30. // This value will range from -1 to 1, where 1 means the vectors are identical.
  31. func cosineSimilarity(vec1, vec2 *mat.VecDense) float64 {
  32. dotProduct := mat.Dot(vec1, vec2)
  33. norms := mat.Norm(vec1, 2) * mat.Norm(vec2, 2)
  34. if norms == 0 {
  35. return 0
  36. }
  37. return dotProduct / norms
  38. }
  39. func TopK(k int, query *mat.VecDense, embeddings []Embedding) []EmbeddingSimilarity {
  40. h := &Heap{}
  41. heap.Init(h)
  42. for _, emb := range embeddings {
  43. similarity := cosineSimilarity(query, mat.NewVecDense(len(emb.Vector), emb.Vector))
  44. heap.Push(h, EmbeddingSimilarity{Embedding: emb, Similarity: similarity})
  45. if h.Len() > k {
  46. heap.Pop(h)
  47. }
  48. }
  49. topK := make([]EmbeddingSimilarity, 0, h.Len())
  50. for h.Len() > 0 {
  51. topK = append(topK, heap.Pop(h).(EmbeddingSimilarity))
  52. }
  53. sort.Slice(topK, func(i, j int) bool {
  54. return topK[i].Similarity > topK[j].Similarity
  55. })
  56. return topK
  57. }