cache.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. package llamarunner
  2. import (
  3. "errors"
  4. "fmt"
  5. "log/slog"
  6. "reflect"
  7. "time"
  8. "github.com/ollama/ollama/llama"
  9. )
  10. type InputCache struct {
  11. // context window size (per slot)
  12. numCtx int
  13. // individual KV caches
  14. slots []InputCacheSlot
  15. // optimize cache eviction for multiple users
  16. multiUserCache bool
  17. lc *llama.Context
  18. }
  19. func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache bool) (*InputCache, error) {
  20. if kvSize/numSlots < 1 {
  21. return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
  22. }
  23. slots := make([]InputCacheSlot, numSlots)
  24. for i := range slots {
  25. slots[i] = InputCacheSlot{
  26. Id: i,
  27. Inputs: make([]input, 0),
  28. }
  29. }
  30. return &InputCache{
  31. numCtx: kvSize / numSlots,
  32. slots: slots,
  33. multiUserCache: multiUserCache,
  34. lc: lc,
  35. }, nil
  36. }
  37. // Locking: Operations on InputCacheSlot (including finding one
  38. // through LoadCacheSlot) require a lock to be be held that serializes
  39. // these operations with each other and llama.Decode
  40. type InputCacheSlot struct {
  41. // Index in the KV cache
  42. Id int
  43. // Inputs that are stored in the KV cache
  44. Inputs []input
  45. // is this cache actively being processed as part of a sequence?
  46. InUse bool
  47. // last time this cache was used (as of start of processing)
  48. lastUsed time.Time
  49. }
  50. func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
  51. var slot *InputCacheSlot
  52. var numPast int
  53. var err error
  54. // In single-user scenarios, the longest cache slot works fine for getting good input
  55. // cache hit rates and it reuses the same VRAM over and over again, which is good for
  56. // GPU performance in situations where we miss the input cache.
  57. // For multiple users, the "best" cache slot produces better input cache hit rates
  58. // at the cost of worse performance when we miss the input cache (because it causes
  59. // GPU L2 cache misses due to spreading out accesses across VRAM).
  60. if !c.multiUserCache {
  61. slot, numPast, err = c.findLongestCacheSlot(prompt)
  62. } else {
  63. slot, numPast, err = c.findBestCacheSlot(prompt)
  64. }
  65. if err != nil {
  66. return nil, nil, err
  67. }
  68. if !cachePrompt {
  69. numPast = 0
  70. }
  71. slot.InUse = true
  72. slot.lastUsed = time.Now()
  73. if numPast == len(prompt) {
  74. // Leave one input to sample so we can get a response
  75. numPast--
  76. }
  77. if !c.lc.KvCacheSeqRm(slot.Id, numPast, -1) {
  78. // Some models don't support partial erasure
  79. c.lc.KvCacheSeqRm(slot.Id, 0, -1)
  80. numPast = 0
  81. }
  82. slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
  83. "used", numPast, "remaining", len(prompt)-numPast)
  84. prompt = prompt[numPast:]
  85. slot.Inputs = slot.Inputs[:numPast]
  86. return slot, prompt, nil
  87. }
  88. func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
  89. longest := -1
  90. var longestSlot *InputCacheSlot
  91. for i, s := range c.slots {
  92. if s.InUse {
  93. continue
  94. }
  95. count := countCommonPrefix(s.Inputs, prompt)
  96. if count > longest {
  97. longest = count
  98. longestSlot = &c.slots[i]
  99. }
  100. }
  101. if longestSlot == nil {
  102. return nil, 0, errors.New("no available cache slots")
  103. }
  104. return longestSlot, longest, nil
  105. }
  106. func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
  107. oldest := time.Now()
  108. var oldestSlot *InputCacheSlot
  109. longest := -1
  110. var longestSlot *InputCacheSlot
  111. for i, s := range c.slots {
  112. count := countCommonPrefix(s.Inputs, prompt)
  113. if count > longest {
  114. longest = count
  115. longestSlot = &c.slots[i]
  116. }
  117. if s.lastUsed.Compare(oldest) < 0 && !s.InUse {
  118. oldest = s.lastUsed
  119. oldestSlot = &c.slots[i]
  120. }
  121. }
  122. if longest == len(longestSlot.Inputs) && !longestSlot.InUse {
  123. return longestSlot, longest, nil
  124. }
  125. if oldestSlot.InUse {
  126. return nil, 0, errors.New("no available cache slots")
  127. }
  128. if len(oldestSlot.Inputs) != 0 {
  129. slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs),
  130. "used", oldestSlot.lastUsed)
  131. }
  132. if longest > 0 && longestSlot != oldestSlot {
  133. slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
  134. len(longestSlot.Inputs))
  135. oldestSlot.Inputs = make([]input, longest)
  136. copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
  137. // This is only nil for unit tests
  138. if c.lc != nil {
  139. c.lc.KvCacheSeqRm(oldestSlot.Id, 0, -1)
  140. c.lc.KvCacheSeqCp(longestSlot.Id, oldestSlot.Id, 0, longest)
  141. }
  142. }
  143. return oldestSlot, longest, nil
  144. }
  145. func countCommonPrefix(a []input, b []input) int {
  146. var count int
  147. for i := range a {
  148. if i >= len(b) {
  149. break
  150. }
  151. if !reflect.DeepEqual(a[i], b[i]) {
  152. break
  153. }
  154. count++
  155. }
  156. return count
  157. }
  158. func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
  159. targetFree := (c.numCtx - numKeep) / 2
  160. targetFree = max(targetFree, 1)
  161. currentFree := c.numCtx - inputLen
  162. discard := targetFree - currentFree
  163. if discard < 0 {
  164. discard = 0
  165. }
  166. return discard
  167. }
  168. // Frees up space in the KV cache by deleting the oldest half of history and shifting
  169. // the newest half into that space (saving numKeep inputs at the beginning).
  170. //
  171. // Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
  172. func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
  173. if numKeep >= c.numCtx {
  174. return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
  175. }
  176. discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
  177. if discard <= 0 {
  178. return nil
  179. }
  180. slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
  181. "keep", numKeep, "discard", discard)
  182. // TODO (jessegross): KV cache removal can fail for certain types of models
  183. if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
  184. return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
  185. }
  186. c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
  187. for i := numKeep + discard; i < len(slot.Inputs); i++ {
  188. slot.Inputs[i-discard] = slot.Inputs[i]
  189. }
  190. slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
  191. return nil
  192. }