cache.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. package main
  2. import (
  3. "errors"
  4. "log/slog"
  5. "reflect"
  6. "time"
  7. "github.com/ollama/ollama/llama"
  8. )
  9. type InputCache struct {
  10. // context window size (per slot)
  11. numCtx int
  12. // individual KV caches
  13. slots []InputCacheSlot
  14. // optimize cache eviction for multiple users
  15. multiUserCache bool
  16. lc *llama.Context
  17. }
  18. func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache bool) *InputCache {
  19. slots := make([]InputCacheSlot, numSlots)
  20. for i := range slots {
  21. slots[i] = InputCacheSlot{
  22. Id: i,
  23. Inputs: make([]input, 0),
  24. }
  25. }
  26. return &InputCache{
  27. numCtx: kvSize / numSlots,
  28. slots: slots,
  29. multiUserCache: multiUserCache,
  30. lc: lc,
  31. }
  32. }
  33. // Locking: Operations on InputCacheSlot (including finding one
  34. // through LoadCacheSlot) require a lock to be be held that serializes
  35. // these operations with each other and llama.Decode
  36. type InputCacheSlot struct {
  37. // Index in the KV cache
  38. Id int
  39. // Inputs that are stored in the KV cache
  40. Inputs []input
  41. // is this cache actively being processed as part of a sequence?
  42. InUse bool
  43. // last time this cache was used (as of start of processing)
  44. lastUsed time.Time
  45. }
  46. func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, int, error) {
  47. var slot *InputCacheSlot
  48. var numPast int
  49. var err error
  50. // In single-user scenarios, the longest cache slot works fine for getting good input
  51. // cache hit rates and it reuses the same VRAM over and over again, which is good for
  52. // GPU performance in situations where we miss the input cache.
  53. // For multiple users, the "best" cache slot produces better input cache hit rates
  54. // at the cost of worse performance when we miss the input cache (because it causes
  55. // GPU L2 cache misses due to spreading out accesses across VRAM).
  56. if !c.multiUserCache {
  57. slot, numPast, err = c.findLongestCacheSlot(prompt)
  58. } else {
  59. slot, numPast, err = c.findBestCacheSlot(prompt)
  60. }
  61. if err != nil {
  62. return nil, nil, 0, err
  63. }
  64. if !cachePrompt {
  65. numPast = 0
  66. }
  67. slot.InUse = true
  68. slot.lastUsed = time.Now()
  69. if numPast == len(prompt) {
  70. // Leave one input to sample so we can get a response
  71. numPast--
  72. }
  73. if !c.lc.KvCacheSeqRm(slot.Id, numPast, -1) {
  74. // Some models don't support partial erasure
  75. c.lc.KvCacheSeqRm(slot.Id, 0, -1)
  76. numPast = 0
  77. }
  78. slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
  79. "used", numPast, "remaining", len(prompt)-numPast)
  80. prompt = prompt[numPast:]
  81. slot.Inputs = slot.Inputs[:numPast]
  82. return slot, prompt, numPast, nil
  83. }
  84. func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
  85. longest := -1
  86. var longestSlot *InputCacheSlot
  87. for i, s := range c.slots {
  88. if s.InUse {
  89. continue
  90. }
  91. count := countCommonPrefix(s.Inputs, prompt)
  92. if count > longest {
  93. longest = count
  94. longestSlot = &c.slots[i]
  95. }
  96. }
  97. if longestSlot == nil {
  98. return nil, 0, errors.New("no available cache slots")
  99. }
  100. return longestSlot, longest, nil
  101. }
  102. func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
  103. oldest := time.Now()
  104. var oldestSlot *InputCacheSlot
  105. longest := -1
  106. var longestSlot *InputCacheSlot
  107. for i, s := range c.slots {
  108. count := countCommonPrefix(s.Inputs, prompt)
  109. if count > longest {
  110. longest = count
  111. longestSlot = &c.slots[i]
  112. }
  113. if s.lastUsed.Compare(oldest) < 0 && !s.InUse {
  114. oldest = s.lastUsed
  115. oldestSlot = &c.slots[i]
  116. }
  117. }
  118. if longest == len(longestSlot.Inputs) && !longestSlot.InUse {
  119. return longestSlot, longest, nil
  120. }
  121. if oldestSlot.InUse {
  122. return nil, 0, errors.New("no available cache slots")
  123. }
  124. if len(oldestSlot.Inputs) != 0 {
  125. slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs),
  126. "used", oldestSlot.lastUsed)
  127. }
  128. if longest > 0 && longestSlot != oldestSlot {
  129. slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
  130. len(longestSlot.Inputs))
  131. oldestSlot.Inputs = make([]input, longest)
  132. copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
  133. // This is only nil for unit tests
  134. if c.lc != nil {
  135. c.lc.KvCacheSeqRm(oldestSlot.Id, 0, -1)
  136. c.lc.KvCacheSeqCp(longestSlot.Id, oldestSlot.Id, 0, longest)
  137. }
  138. }
  139. return oldestSlot, longest, nil
  140. }
  141. func countCommonPrefix(a []input, b []input) int {
  142. var count int
  143. for i := range a {
  144. if i >= len(b) {
  145. break
  146. }
  147. if !reflect.DeepEqual(a[i], b[i]) {
  148. break
  149. }
  150. count++
  151. }
  152. return count
  153. }
  154. func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int, numDiscard int, numPast int) {
  155. // TODO (jessegross): KV cache removal can fail for certain types of models
  156. // server.cpp doesn't handle this, though we can be more graceful
  157. c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+numDiscard)
  158. c.lc.KvCacheSeqAdd(slot.Id, numKeep+numDiscard, numPast, -numDiscard)
  159. for i := numKeep + numDiscard; i < len(slot.Inputs); i++ {
  160. slot.Inputs[i-numDiscard] = slot.Inputs[i]
  161. }
  162. slot.Inputs = slot.Inputs[:len(slot.Inputs)-numDiscard]
  163. }