cache.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. package newrunner
  2. import (
  3. "errors"
  4. "fmt"
  5. "log/slog"
  6. "math"
  7. "reflect"
  8. "time"
  9. "github.com/ollama/ollama/cache"
  10. "github.com/ollama/ollama/ml"
  11. )
  12. type InputCache struct {
  13. // context window size (per slot)
  14. numCtx int32
  15. // individual KV caches
  16. slots []InputCacheSlot
  17. // optimize cache eviction for multiple users
  18. multiUserCache bool
  19. cache cache.Cache
  20. }
  21. func NewInputCache(backend ml.Backend, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) {
  22. if kvSize/int32(numSlots) < 1 {
  23. return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
  24. }
  25. slots := make([]InputCacheSlot, numSlots)
  26. for i := range slots {
  27. slots[i] = InputCacheSlot{
  28. Id: i,
  29. Inputs: make([]input, 0),
  30. }
  31. }
  32. return &InputCache{
  33. numCtx: kvSize / int32(numSlots),
  34. slots: slots,
  35. multiUserCache: multiUserCache,
  36. cache: cache.NewCausalCache(backend, kvCacheTypeFromStr(kvCacheType), kvSize),
  37. }, nil
  38. }
  39. func kvCacheTypeFromStr(s string) ml.DType {
  40. switch s {
  41. case "q8_0":
  42. panic("kv cache quantization not yet implemented")
  43. case "q4_0":
  44. panic("kv cache quantization not yet implemented")
  45. default:
  46. return ml.DTypeF32
  47. }
  48. }
  49. // Locking: Operations on InputCacheSlot (including finding one
  50. // through LoadCacheSlot) require a lock to be be held that serializes
  51. // these operations with each other and processBatch
  52. type InputCacheSlot struct {
  53. // Index in the KV cache
  54. Id int
  55. // Inputs that are stored in the KV cache
  56. Inputs []input
  57. // is this cache actively being processed as part of a sequence?
  58. InUse bool
  59. // last time this cache was used (as of start of processing)
  60. lastUsed time.Time
  61. }
  62. func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
  63. var slot *InputCacheSlot
  64. var numPast int32
  65. var err error
  66. // In single-user scenarios, the longest cache slot works fine for getting good input
  67. // cache hit rates and it keeps the footprint of the cache small, which improves throughput.
  68. // For multiple users, the "best" cache slot produces better input cache hit rates
  69. // at the cost of worse performance when we miss the input cache.
  70. if !c.multiUserCache {
  71. slot, numPast, err = c.findLongestCacheSlot(prompt)
  72. } else {
  73. slot, numPast, err = c.findBestCacheSlot(prompt)
  74. }
  75. if err != nil {
  76. return nil, nil, err
  77. }
  78. if !cachePrompt {
  79. numPast = 0
  80. }
  81. slot.InUse = true
  82. slot.lastUsed = time.Now()
  83. if numPast == int32(len(prompt)) {
  84. // Leave one input to sample so we can get a response
  85. numPast--
  86. }
  87. err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)
  88. if err != nil {
  89. // Some models don't support partial erasure
  90. err = c.cache.Remove(slot.Id, 0, math.MaxInt32)
  91. if err != nil {
  92. return nil, nil, err
  93. }
  94. numPast = 0
  95. }
  96. slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
  97. "used", numPast, "remaining", int32(len(prompt))-numPast)
  98. prompt = prompt[numPast:]
  99. slot.Inputs = slot.Inputs[:numPast]
  100. return slot, prompt, nil
  101. }
  102. func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
  103. longest := int32(-1)
  104. var longestSlot *InputCacheSlot
  105. for i, s := range c.slots {
  106. if s.InUse {
  107. continue
  108. }
  109. count := countCommonPrefix(s.Inputs, prompt)
  110. if count > longest {
  111. longest = count
  112. longestSlot = &c.slots[i]
  113. }
  114. }
  115. if longestSlot == nil {
  116. return nil, 0, errors.New("no available cache slots")
  117. }
  118. return longestSlot, longest, nil
  119. }
  120. func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
  121. oldest := time.Now()
  122. var oldestSlot *InputCacheSlot
  123. longest := int32(-1)
  124. var longestSlot *InputCacheSlot
  125. for i, s := range c.slots {
  126. count := countCommonPrefix(s.Inputs, prompt)
  127. if count > longest {
  128. longest = count
  129. longestSlot = &c.slots[i]
  130. }
  131. if s.lastUsed.Compare(oldest) < 0 && !s.InUse {
  132. oldest = s.lastUsed
  133. oldestSlot = &c.slots[i]
  134. }
  135. }
  136. if longest == int32(len(longestSlot.Inputs)) && !longestSlot.InUse {
  137. return longestSlot, longest, nil
  138. }
  139. if oldestSlot.InUse {
  140. return nil, 0, errors.New("no available cache slots")
  141. }
  142. if len(oldestSlot.Inputs) != 0 {
  143. slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs),
  144. "used", oldestSlot.lastUsed)
  145. }
  146. if longest > 0 && longestSlot != oldestSlot {
  147. slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
  148. len(longestSlot.Inputs))
  149. oldestSlot.Inputs = make([]input, longest)
  150. copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
  151. // This is only nil for unit tests
  152. if c.cache != nil {
  153. c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
  154. }
  155. }
  156. return oldestSlot, longest, nil
  157. }
  158. func countCommonPrefix(a []input, b []input) int32 {
  159. var count int32
  160. for i := range a {
  161. if i >= len(b) {
  162. break
  163. }
  164. if !reflect.DeepEqual(a[i], b[i]) {
  165. break
  166. }
  167. count++
  168. }
  169. return count
  170. }
  171. func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
  172. targetFree := (c.numCtx - numKeep) / 2
  173. targetFree = max(targetFree, 1)
  174. currentFree := c.numCtx - inputLen
  175. discard := targetFree - currentFree
  176. if discard < 0 {
  177. discard = 0
  178. }
  179. return discard
  180. }
  181. // Frees up space in the KV cache by deleting the oldest half of history and shifting
  182. // the newest half into that space (saving numKeep inputs at the beginning).
  183. //
  184. // Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
  185. func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
  186. if numKeep >= c.numCtx {
  187. return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
  188. }
  189. inputLen := int32(len(slot.Inputs))
  190. discard := c.ShiftDiscard(inputLen, numKeep)
  191. if discard <= 0 {
  192. return nil
  193. }
  194. slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
  195. "keep", numKeep, "discard", discard)
  196. // TODO (jessegross): KV cache removal can fail for certain types of models
  197. err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
  198. if err != nil {
  199. return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
  200. }
  201. for i := numKeep + discard; i < inputLen; i++ {
  202. slot.Inputs[i-discard] = slot.Inputs[i]
  203. }
  204. slot.Inputs = slot.Inputs[:inputLen-discard]
  205. return nil
  206. }