cache.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. package ollamarunner
  2. import (
  3. "errors"
  4. "fmt"
  5. "log/slog"
  6. "math"
  7. "time"
  8. "github.com/ollama/ollama/kvcache"
  9. "github.com/ollama/ollama/ml"
  10. "github.com/ollama/ollama/model"
  11. )
  12. type InputCache struct {
  13. // context window size (per slot)
  14. numCtx int32
  15. // does the cache store data or do we need to always send the full input?
  16. // note that when enabled is false the underlying cache may either be nil
  17. // or a non-nil dummy that doesn't actually store anything
  18. enabled bool
  19. // individual KV caches
  20. slots []InputCacheSlot
  21. // optimize cache eviction for multiple users
  22. multiUserCache bool
  23. cache kvcache.Cache
  24. }
  25. func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) {
  26. if kvSize/int32(numSlots) < 1 {
  27. return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
  28. }
  29. slots := make([]InputCacheSlot, numSlots)
  30. for i := range slots {
  31. slots[i] = InputCacheSlot{Id: i}
  32. }
  33. cache := model.Config().Cache
  34. if cache != nil {
  35. cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize)
  36. }
  37. return &InputCache{
  38. numCtx: kvSize / int32(numSlots),
  39. enabled: cache != nil,
  40. slots: slots,
  41. multiUserCache: multiUserCache,
  42. cache: cache,
  43. }, nil
  44. }
  45. func kvCacheTypeFromStr(s string) ml.DType {
  46. switch s {
  47. case "q8_0":
  48. panic("kv cache quantization not yet implemented")
  49. case "q4_0":
  50. panic("kv cache quantization not yet implemented")
  51. default:
  52. return ml.DTypeF16
  53. }
  54. }
  55. func (c *InputCache) Close() {
  56. c.cache.Close()
  57. }
  58. // Locking: Operations on InputCacheSlot (including finding one
  59. // through LoadCacheSlot) require a lock to be be held that serializes
  60. // these operations with each other and processBatch
  61. type InputCacheSlot struct {
  62. // Index in the KV cache
  63. Id int
  64. // Inputs that are stored in the KV cache
  65. Inputs []model.Input
  66. // is this cache actively being processed as part of a sequence?
  67. InUse bool
  68. // last time this cache was used (as of start of processing)
  69. lastUsed time.Time
  70. }
  71. func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*InputCacheSlot, []model.Input, error) {
  72. var slot *InputCacheSlot
  73. var numPast int32
  74. var err error
  75. // In single-user scenarios, the longest cache slot works fine for getting good input
  76. // cache hit rates and it keeps the footprint of the cache small, which improves throughput.
  77. // For multiple users, the "best" cache slot produces better input cache hit rates
  78. // at the cost of worse performance when we miss the input cache.
  79. if !c.multiUserCache {
  80. slot, numPast, err = c.findLongestCacheSlot(prompt)
  81. } else {
  82. slot, numPast, err = c.findBestCacheSlot(prompt)
  83. }
  84. if err != nil {
  85. return nil, nil, err
  86. }
  87. if !cachePrompt {
  88. numPast = 0
  89. }
  90. slot.InUse = true
  91. slot.lastUsed = time.Now()
  92. if numPast == int32(len(prompt)) {
  93. // Leave one input to sample so we can get a response
  94. numPast--
  95. }
  96. if c.cache != nil {
  97. err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)
  98. if err != nil {
  99. // Some models don't support partial erasure
  100. err = c.cache.Remove(slot.Id, 0, math.MaxInt32)
  101. if err != nil {
  102. return nil, nil, err
  103. }
  104. numPast = 0
  105. }
  106. }
  107. slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
  108. "used", numPast, "remaining", int32(len(prompt))-numPast)
  109. prompt = prompt[numPast:]
  110. slot.Inputs = slot.Inputs[:numPast]
  111. return slot, prompt, nil
  112. }
  113. func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
  114. longest := int32(-1)
  115. var longestSlot *InputCacheSlot
  116. for i, s := range c.slots {
  117. if s.InUse {
  118. continue
  119. }
  120. count := countCommonPrefix(s.Inputs, prompt)
  121. if count > longest {
  122. longest = count
  123. longestSlot = &c.slots[i]
  124. }
  125. }
  126. if longestSlot == nil {
  127. return nil, 0, errors.New("no available cache slots")
  128. }
  129. return longestSlot, longest, nil
  130. }
  131. func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
  132. oldest := time.Now()
  133. var oldestSlot *InputCacheSlot
  134. longest := int32(-1)
  135. var longestSlot *InputCacheSlot
  136. for i, s := range c.slots {
  137. count := countCommonPrefix(s.Inputs, prompt)
  138. if count > longest {
  139. longest = count
  140. longestSlot = &c.slots[i]
  141. }
  142. if s.lastUsed.Compare(oldest) < 0 && !s.InUse {
  143. oldest = s.lastUsed
  144. oldestSlot = &c.slots[i]
  145. }
  146. }
  147. if longest == int32(len(longestSlot.Inputs)) && !longestSlot.InUse {
  148. return longestSlot, longest, nil
  149. }
  150. if oldestSlot.InUse {
  151. return nil, 0, errors.New("no available cache slots")
  152. }
  153. if len(oldestSlot.Inputs) != 0 {
  154. slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs),
  155. "used", oldestSlot.lastUsed)
  156. }
  157. if longest > 0 && longestSlot != oldestSlot {
  158. slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
  159. len(longestSlot.Inputs))
  160. oldestSlot.Inputs = make([]model.Input, longest)
  161. copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
  162. if c.cache != nil {
  163. c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
  164. }
  165. }
  166. return oldestSlot, longest, nil
  167. }
  168. func countCommonPrefix(a []model.Input, b []model.Input) int32 {
  169. var count int32
  170. for i := range a {
  171. if i >= len(b) {
  172. break
  173. }
  174. if a[i].Token != b[i].Token || a[i].MultimodalHash != b[i].MultimodalHash {
  175. break
  176. }
  177. count++
  178. }
  179. return count
  180. }
  181. func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
  182. targetFree := (c.numCtx - numKeep) / 2
  183. targetFree = max(targetFree, 1)
  184. currentFree := c.numCtx - inputLen
  185. discard := targetFree - currentFree
  186. if discard < 0 {
  187. discard = 0
  188. }
  189. return discard
  190. }
  191. // Frees up space in the KV cache by deleting the oldest half of history and shifting
  192. // the newest half into that space (saving numKeep inputs at the beginning).
  193. //
  194. // Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
  195. func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
  196. if numKeep >= c.numCtx {
  197. return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
  198. }
  199. inputLen := int32(len(slot.Inputs))
  200. discard := c.ShiftDiscard(inputLen, numKeep)
  201. if discard <= 0 {
  202. return nil
  203. }
  204. slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
  205. "keep", numKeep, "discard", discard)
  206. // TODO (jessegross): KV cache removal can fail for certain types of models
  207. if c.cache != nil {
  208. err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
  209. if err != nil {
  210. return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
  211. }
  212. }
  213. for i := numKeep + discard; i < inputLen; i++ {
  214. slot.Inputs[i-discard] = slot.Inputs[i]
  215. }
  216. slot.Inputs = slot.Inputs[:inputLen-discard]
  217. return nil
  218. }