cache.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. package main
  2. import (
  3. "errors"
  4. "hash/maphash"
  5. "log/slog"
  6. "reflect"
  7. "time"
  8. "github.com/ollama/ollama/llama"
  9. )
  10. type InputCache struct {
  11. // context window size (per slot)
  12. numCtx int
  13. // individual KV caches
  14. slots []InputCacheSlot
  15. // optimize cache eviction for multiple users
  16. multiUserCache bool
  17. // cache of images to embeddings
  18. images []imageCache
  19. imageHash maphash.Hash
  20. lc *llama.Context
  21. }
  22. func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache bool) *InputCache {
  23. slots := make([]InputCacheSlot, numSlots)
  24. for i := range slots {
  25. slots[i] = InputCacheSlot{
  26. Id: i,
  27. Inputs: make([]input, 0),
  28. }
  29. }
  30. return &InputCache{
  31. numCtx: kvSize / numSlots,
  32. slots: slots,
  33. multiUserCache: multiUserCache,
  34. images: make([]imageCache, numSlots),
  35. lc: lc,
  36. }
  37. }
  38. // Locking: Operations on InputCacheSlot (including finding one
  39. // through LoadCacheSlot) require a lock to be be held that serializes
  40. // these operations with each other and llama.Decode
  41. type InputCacheSlot struct {
  42. // Index in the KV cache
  43. Id int
  44. // Inputs that are stored in the KV cache
  45. Inputs []input
  46. // is this cache actively being processed as part of a sequence?
  47. InUse bool
  48. // last time this cache was used (as of start of processing)
  49. lastUsed time.Time
  50. }
  51. func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, int, error) {
  52. var slot *InputCacheSlot
  53. var numPast int
  54. var err error
  55. // In single-user scenarios, the longest cache slot works fine for getting good input
  56. // cache hit rates and it reuses the same VRAM over and over again, which is good for
  57. // GPU performance in situations where we miss the input cache.
  58. // For multiple users, the "best" cache slot produces better input cache hit rates
  59. // at the cost of worse performance when we miss the input cache (because it causes
  60. // GPU L2 cache misses due to spreading out accesses across VRAM).
  61. if !c.multiUserCache {
  62. slot, numPast, err = c.findLongestCacheSlot(prompt)
  63. } else {
  64. slot, numPast, err = c.findBestCacheSlot(prompt)
  65. }
  66. if err != nil {
  67. return nil, nil, 0, err
  68. }
  69. if !cachePrompt {
  70. numPast = 0
  71. }
  72. slot.InUse = true
  73. slot.lastUsed = time.Now()
  74. if numPast == len(prompt) {
  75. // Leave one input to sample so we can get a response
  76. numPast--
  77. }
  78. if !c.lc.KvCacheSeqRm(slot.Id, numPast, -1) {
  79. // Some models don't support partial erasure
  80. c.lc.KvCacheSeqRm(slot.Id, 0, -1)
  81. numPast = 0
  82. }
  83. slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
  84. "used", numPast, "remaining", len(prompt)-numPast)
  85. prompt = prompt[numPast:]
  86. slot.Inputs = slot.Inputs[:numPast]
  87. return slot, prompt, numPast, nil
  88. }
  89. func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
  90. longest := -1
  91. var longestSlot *InputCacheSlot
  92. for i, s := range c.slots {
  93. if s.InUse {
  94. continue
  95. }
  96. count := countCommonPrefix(s.Inputs, prompt)
  97. if count > longest {
  98. longest = count
  99. longestSlot = &c.slots[i]
  100. }
  101. }
  102. if longestSlot == nil {
  103. return nil, 0, errors.New("no available cache slots")
  104. }
  105. return longestSlot, longest, nil
  106. }
  107. func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
  108. oldest := time.Now()
  109. var oldestSlot *InputCacheSlot
  110. longest := -1
  111. var longestSlot *InputCacheSlot
  112. for i, s := range c.slots {
  113. count := countCommonPrefix(s.Inputs, prompt)
  114. if count > longest {
  115. longest = count
  116. longestSlot = &c.slots[i]
  117. }
  118. if s.lastUsed.Compare(oldest) < 0 && !s.InUse {
  119. oldest = s.lastUsed
  120. oldestSlot = &c.slots[i]
  121. }
  122. }
  123. if longest == len(longestSlot.Inputs) && !longestSlot.InUse {
  124. return longestSlot, longest, nil
  125. }
  126. if oldestSlot.InUse {
  127. return nil, 0, errors.New("no available cache slots")
  128. }
  129. if len(oldestSlot.Inputs) != 0 {
  130. slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs),
  131. "used", oldestSlot.lastUsed)
  132. }
  133. if longest > 0 && longestSlot != oldestSlot {
  134. slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
  135. len(longestSlot.Inputs))
  136. oldestSlot.Inputs = make([]input, longest)
  137. copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
  138. // This is only nil for unit tests
  139. if c.lc != nil {
  140. c.lc.KvCacheSeqRm(oldestSlot.Id, 0, -1)
  141. c.lc.KvCacheSeqCp(longestSlot.Id, oldestSlot.Id, 0, longest)
  142. }
  143. }
  144. return oldestSlot, longest, nil
  145. }
  146. func countCommonPrefix(a []input, b []input) int {
  147. var count int
  148. for i := range a {
  149. if i >= len(b) {
  150. break
  151. }
  152. if !reflect.DeepEqual(a[i], b[i]) {
  153. break
  154. }
  155. count++
  156. }
  157. return count
  158. }
  159. func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int, numDiscard int, numPast int) {
  160. // TODO (jessegross): KV cache removal can fail for certain types of models
  161. // server.cpp doesn't handle this, though we can be more graceful
  162. c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+numDiscard)
  163. c.lc.KvCacheSeqAdd(slot.Id, numKeep+numDiscard, numPast, -numDiscard)
  164. for i := numKeep + numDiscard; i < len(slot.Inputs); i++ {
  165. slot.Inputs[i-numDiscard] = slot.Inputs[i]
  166. }
  167. slot.Inputs = slot.Inputs[:len(slot.Inputs)-numDiscard]
  168. }
  169. // Locking: Lookup and store operations on imageCache require a lock
  170. // to be held that serializes these with each other. Hash does not
  171. // require a lock nor they need to be serialized with InputCacheSlot.
  172. type imageCache struct {
  173. key uint64
  174. val [][]float32
  175. lastUsed time.Time
  176. }
  177. func (c *InputCache) HashImage(image []byte) uint64 {
  178. c.imageHash.Reset()
  179. _, _ = c.imageHash.Write(image)
  180. return c.imageHash.Sum64()
  181. }
  182. var ErrImageNotFound = errors.New("image not found in cache")
  183. func (c *InputCache) FindImage(hash uint64) ([][]float32, error) {
  184. for i := range c.images {
  185. if c.images[i].key == hash {
  186. slog.Debug("loading image embeddings from cache", "entry", i)
  187. c.images[i].lastUsed = time.Now()
  188. return c.images[i].val, nil
  189. }
  190. }
  191. return nil, ErrImageNotFound
  192. }
  193. func (c *InputCache) AddImage(hash uint64, embed [][]float32) {
  194. best := time.Now()
  195. var bestImage int
  196. for i := range c.images {
  197. if c.images[i].key == hash {
  198. bestImage = i
  199. break
  200. }
  201. if c.images[i].lastUsed.Compare(best) < 0 {
  202. best = c.images[i].lastUsed
  203. bestImage = i
  204. }
  205. }
  206. slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
  207. c.images[bestImage].key = hash
  208. c.images[bestImage].val = embed
  209. c.images[bestImage].lastUsed = time.Now()
  210. }