cache.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. package ollamarunner
  2. import (
  3. "errors"
  4. "fmt"
  5. "log/slog"
  6. "math"
  7. "time"
  8. "github.com/ollama/ollama/kvcache"
  9. "github.com/ollama/ollama/ml"
  10. "github.com/ollama/ollama/model"
  11. "github.com/ollama/ollama/model/input"
  12. )
  13. type InputCache struct {
  14. // context window size (per slot)
  15. numCtx int32
  16. // does the cache store data or do we need to always send the full input?
  17. // note that when enabled is false the underlying cache may either be nil
  18. // or a non-nil dummy that doesn't actually store anything
  19. enabled bool
  20. // individual KV caches
  21. slots []InputCacheSlot
  22. // optimize cache eviction for multiple users
  23. multiUserCache bool
  24. cache kvcache.Cache
  25. }
  26. func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) {
  27. if kvSize/int32(numSlots) < 1 {
  28. return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
  29. }
  30. slots := make([]InputCacheSlot, numSlots)
  31. for i := range slots {
  32. slots[i] = InputCacheSlot{Id: i}
  33. }
  34. cache := model.Config().Cache
  35. if cache != nil {
  36. cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize)
  37. }
  38. return &InputCache{
  39. numCtx: kvSize / int32(numSlots),
  40. enabled: cache != nil,
  41. slots: slots,
  42. multiUserCache: multiUserCache,
  43. cache: cache,
  44. }, nil
  45. }
  46. func kvCacheTypeFromStr(s string) ml.DType {
  47. switch s {
  48. case "q8_0":
  49. return ml.DTypeQ80
  50. case "q4_0":
  51. return ml.DTypeQ40
  52. default:
  53. return ml.DTypeF16
  54. }
  55. }
  56. func (c *InputCache) Close() {
  57. c.cache.Close()
  58. }
  59. // Locking: Operations on InputCacheSlot (including finding one
  60. // through LoadCacheSlot) require a lock to be be held that serializes
  61. // these operations with each other and processBatch
  62. type InputCacheSlot struct {
  63. // Index in the KV cache
  64. Id int
  65. // Inputs that are stored in the KV cache
  66. Inputs []input.Input
  67. // is this cache actively being processed as part of a sequence?
  68. InUse bool
  69. // last time this cache was used (as of start of processing)
  70. lastUsed time.Time
  71. }
  72. func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
  73. var slot *InputCacheSlot
  74. var numPast int32
  75. var err error
  76. // In single-user scenarios, the longest cache slot works fine for getting good input
  77. // cache hit rates and it keeps the footprint of the cache small, which improves throughput.
  78. // For multiple users, the "best" cache slot produces better input cache hit rates
  79. // at the cost of worse performance when we miss the input cache.
  80. if !c.multiUserCache {
  81. slot, numPast, err = c.findLongestCacheSlot(prompt)
  82. } else {
  83. slot, numPast, err = c.findBestCacheSlot(prompt)
  84. }
  85. if err != nil {
  86. return nil, nil, err
  87. }
  88. if !cachePrompt {
  89. numPast = 0
  90. }
  91. slot.InUse = true
  92. slot.lastUsed = time.Now()
  93. if numPast == int32(len(prompt)) {
  94. // Leave one input to sample so we can get a response
  95. numPast--
  96. }
  97. if c.cache != nil {
  98. err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)
  99. if err != nil {
  100. // Some models don't support partial erasure
  101. err = c.cache.Remove(slot.Id, 0, math.MaxInt32)
  102. if err != nil {
  103. return nil, nil, err
  104. }
  105. numPast = 0
  106. }
  107. }
  108. slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
  109. "used", numPast, "remaining", int32(len(prompt))-numPast)
  110. prompt = prompt[numPast:]
  111. slot.Inputs = slot.Inputs[:numPast]
  112. return slot, prompt, nil
  113. }
  114. func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
  115. longest := int32(-1)
  116. var longestSlot *InputCacheSlot
  117. for i, s := range c.slots {
  118. if s.InUse {
  119. continue
  120. }
  121. count := countCommonPrefix(s.Inputs, prompt)
  122. if count > longest {
  123. longest = count
  124. longestSlot = &c.slots[i]
  125. }
  126. }
  127. if longestSlot == nil {
  128. return nil, 0, errors.New("no available cache slots")
  129. }
  130. return longestSlot, longest, nil
  131. }
  132. func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
  133. oldest := time.Now()
  134. var oldestSlot *InputCacheSlot
  135. longest := int32(-1)
  136. var longestSlot *InputCacheSlot
  137. for i, s := range c.slots {
  138. count := countCommonPrefix(s.Inputs, prompt)
  139. if count > longest {
  140. longest = count
  141. longestSlot = &c.slots[i]
  142. }
  143. if s.lastUsed.Compare(oldest) < 0 && !s.InUse {
  144. oldest = s.lastUsed
  145. oldestSlot = &c.slots[i]
  146. }
  147. }
  148. if longest == int32(len(longestSlot.Inputs)) && !longestSlot.InUse {
  149. return longestSlot, longest, nil
  150. }
  151. if oldestSlot.InUse {
  152. return nil, 0, errors.New("no available cache slots")
  153. }
  154. if len(oldestSlot.Inputs) != 0 {
  155. slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs),
  156. "used", oldestSlot.lastUsed)
  157. }
  158. if longest > 0 && longestSlot != oldestSlot {
  159. slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
  160. len(longestSlot.Inputs))
  161. oldestSlot.Inputs = make([]input.Input, longest)
  162. copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
  163. if c.cache != nil {
  164. c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
  165. }
  166. }
  167. return oldestSlot, longest, nil
  168. }
  169. func countCommonPrefix(a []input.Input, b []input.Input) int32 {
  170. var count int32
  171. for i := range a {
  172. if i >= len(b) {
  173. break
  174. }
  175. if a[i].Token != b[i].Token || a[i].MultimodalHash != b[i].MultimodalHash {
  176. break
  177. }
  178. count++
  179. }
  180. return count
  181. }
  182. func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
  183. targetFree := (c.numCtx - numKeep) / 2
  184. targetFree = max(targetFree, 1)
  185. currentFree := c.numCtx - inputLen
  186. discard := targetFree - currentFree
  187. if discard < 0 {
  188. discard = 0
  189. }
  190. return discard
  191. }
  192. // Frees up space in the KV cache by deleting the oldest half of history and shifting
  193. // the newest half into that space (saving numKeep inputs at the beginning).
  194. //
  195. // Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
  196. func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
  197. if numKeep >= c.numCtx {
  198. return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
  199. }
  200. inputLen := int32(len(slot.Inputs))
  201. discard := c.ShiftDiscard(inputLen, numKeep)
  202. if discard <= 0 {
  203. return nil
  204. }
  205. slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
  206. "keep", numKeep, "discard", discard)
  207. // TODO (jessegross): KV cache removal can fail for certain types of models
  208. if c.cache != nil {
  209. err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
  210. if err != nil {
  211. return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
  212. }
  213. }
  214. for i := numKeep + discard; i < inputLen; i++ {
  215. slot.Inputs[i-discard] = slot.Inputs[i]
  216. }
  217. slot.Inputs = slot.Inputs[:inputLen-discard]
  218. return nil
  219. }