image.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. package llamarunner
  2. import (
  3. "errors"
  4. "fmt"
  5. "hash/maphash"
  6. "log/slog"
  7. "slices"
  8. "sync"
  9. "time"
  10. "github.com/ollama/ollama/llama"
  11. )
  12. const imageCacheSize = 4
  13. type ImageContext struct {
  14. // mu is required to be held when generating embeddings or accessing the cache
  15. mu sync.Mutex
  16. clip *llama.ClipContext
  17. mllama *llama.MllamaContext
  18. // cache of images to embeddings
  19. images []imageCache
  20. imageHash maphash.Hash
  21. }
  22. func NewImageContext(llamaContext *llama.Context, modelPath string) (*ImageContext, error) {
  23. arch, err := llama.GetModelArch(modelPath)
  24. if err != nil {
  25. return nil, fmt.Errorf("unable to determine vision architecture: %w (%s)", err, modelPath)
  26. }
  27. var c ImageContext
  28. if arch == "clip" {
  29. c.clip, err = llama.NewClipContext(llamaContext, modelPath)
  30. } else if arch == "mllama" {
  31. c.mllama, err = llama.NewMllamaContext(llamaContext, modelPath)
  32. } else {
  33. return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
  34. }
  35. if err != nil {
  36. return nil, err
  37. }
  38. c.images = make([]imageCache, imageCacheSize)
  39. return &c, nil
  40. }
  41. func (c *ImageContext) Free(modelPath string) {
  42. if c == nil {
  43. return
  44. }
  45. if c.clip != nil {
  46. c.clip.Free()
  47. }
  48. if c.mllama != nil {
  49. c.mllama.Free()
  50. }
  51. }
  52. func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspectRatioId int) ([][]float32, error) {
  53. if c == nil {
  54. return nil, nil
  55. }
  56. if len(data) <= 0 {
  57. return nil, errors.New("received zero length image")
  58. }
  59. hash := c.hashImage(data)
  60. c.mu.Lock()
  61. defer c.mu.Unlock()
  62. embed, err := c.findImage(hash)
  63. if err != nil {
  64. if c.mllama != nil {
  65. embed, err = c.mllama.NewEmbed(llamaContext, data, aspectRatioId)
  66. if err != nil {
  67. return nil, err
  68. }
  69. } else if c.clip != nil {
  70. embed, err = c.clip.NewEmbed(llamaContext, data)
  71. if err != nil {
  72. return nil, err
  73. }
  74. } else {
  75. return nil, errors.New("received image but vision model not loaded")
  76. }
  77. c.addImage(hash, embed)
  78. }
  79. return embed, nil
  80. }
  81. func (c *ImageContext) BatchSize(configuredBatchSize int) int {
  82. // If images are not supported, we don't need to allocate embedding batches
  83. if c == nil {
  84. return 0
  85. }
  86. // Mllama maps an image to 1 embedding token (llava creates many tokens)
  87. // and doesn't support more than a single image per request.
  88. // The embeddings are large (100 MB), so allocating a big batch can fail
  89. // on some systems
  90. if c.mllama != nil {
  91. return 1
  92. }
  93. return configuredBatchSize
  94. }
  95. func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
  96. if c != nil && c.mllama != nil {
  97. return c.mllama.EmbedSize(llamaContext)
  98. } else {
  99. return llamaContext.Model().NEmbd()
  100. }
  101. }
  102. func (c *ImageContext) NeedCrossAttention(inputs ...input) bool {
  103. if c == nil || c.mllama == nil {
  104. return false
  105. }
  106. return slices.ContainsFunc(inputs, func(input input) bool {
  107. return input.embed != nil
  108. })
  109. }
  110. type imageCache struct {
  111. key uint64
  112. val [][]float32
  113. lastUsed time.Time
  114. }
  115. func (c *ImageContext) hashImage(image []byte) uint64 {
  116. c.imageHash.Reset()
  117. _, _ = c.imageHash.Write(image)
  118. return c.imageHash.Sum64()
  119. }
  120. var errImageNotFound = errors.New("image not found in cache")
  121. func (c *ImageContext) findImage(hash uint64) ([][]float32, error) {
  122. for i := range c.images {
  123. if c.images[i].key == hash {
  124. slog.Debug("loading image embeddings from cache", "entry", i)
  125. c.images[i].lastUsed = time.Now()
  126. return c.images[i].val, nil
  127. }
  128. }
  129. return nil, errImageNotFound
  130. }
  131. func (c *ImageContext) addImage(hash uint64, embed [][]float32) {
  132. best := time.Now()
  133. var bestImage int
  134. for i := range c.images {
  135. if c.images[i].key == hash {
  136. bestImage = i
  137. break
  138. }
  139. if c.images[i].lastUsed.Compare(best) < 0 {
  140. best = c.images[i].lastUsed
  141. bestImage = i
  142. }
  143. }
  144. slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
  145. c.images[bestImage].key = hash
  146. c.images[bestImage].val = embed
  147. c.images[bestImage].lastUsed = time.Now()
  148. }