cache.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. package cache
  2. import (
  3. "errors"
  4. "fmt"
  5. "log/slog"
  6. "math"
  7. "slices"
  8. "github.com/ollama/ollama/ml"
  9. )
  10. var ErrNotSupported = errors.New("model does not support operation")
  11. type Cache interface {
  12. // ** used by model implementations **
  13. // Returns an instance of the cache for layer 'i'
  14. Sub(i int) Cache
  15. // Returns the history of key and value tensors plus a mask
  16. //
  17. // The tensors are of shape embed dim, kv heads, batch size
  18. // The mask is of shape history size, batch size
  19. Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
  20. // Stores a batch of key and value in the cache
  21. //
  22. // The tensors must be of shape embed dim, kv heads, batch size
  23. Put(ctx ml.Context, key, value ml.Tensor)
  24. // ** cache management **
  25. // Closes the cache and frees resources associated with it
  26. Close()
  27. // Called before the start of the model's forward pass. For each
  28. // token in the coming batch, there must be a corresponding entry
  29. // in positions and seqs.
  30. StartForward(ctx ml.Context, positions []int32, seqs []int) error
  31. // Copies tokens in the range [0, len) from srcSeq to dstSeq
  32. CopyPrefix(srcSeq, dstSeq int, len int32)
  33. // Removes tokens in the range [beginIndex, endIndex) from seq. Set
  34. // endIndex to math.MaxInt32 to remove everything starting at beginIndex
  35. Remove(seq int, beginIndex, endIndex int32) error
  36. }
  37. type Causal struct {
  38. DType ml.DType
  39. Capacity int32
  40. // current forward pass
  41. curLayer int
  42. curLoc int
  43. curBatchSize int
  44. curMask ml.Tensor
  45. curCellRange cellRange
  46. // metadata
  47. cells []cacheCell
  48. cellRanges map[int]cellRange
  49. // cache data storage
  50. backend ml.Backend
  51. cacheCtx ml.Context
  52. keys, values []ml.Tensor
  53. }
  54. type seqCell struct {
  55. seq int
  56. pos int32
  57. }
  58. type cacheCell struct {
  59. sequences []seqCell
  60. }
  61. type cellRange struct {
  62. min int
  63. max int
  64. }
  65. func (cell cacheCell) findSeq(seq int) *seqCell {
  66. for i := range cell.sequences {
  67. if cell.sequences[i].seq == seq {
  68. return &cell.sequences[i]
  69. }
  70. }
  71. return nil
  72. }
  73. func NewCausalCache(backend ml.Backend, dtype ml.DType, capacity int32) Cache {
  74. return &Causal{
  75. Capacity: capacity,
  76. DType: dtype,
  77. cells: make([]cacheCell, capacity),
  78. cellRanges: make(map[int]cellRange),
  79. backend: backend,
  80. cacheCtx: backend.NewContext(),
  81. }
  82. }
  83. func (c *Causal) Close() {
  84. c.cacheCtx.Close()
  85. }
  86. var ErrKvCacheFull = errors.New("could not find a kv cache slot")
  87. func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
  88. if len(positions) != len(seqs) {
  89. return fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(positions), len(seqs))
  90. }
  91. c.curBatchSize = len(positions)
  92. if c.curBatchSize < 1 {
  93. return errors.New("batch size cannot be less than 1")
  94. }
  95. var err error
  96. c.curLoc, err = c.findStartLoc()
  97. if errors.Is(err, ErrKvCacheFull) {
  98. c.defrag()
  99. c.curLoc, err = c.findStartLoc()
  100. }
  101. if err != nil {
  102. return err
  103. }
  104. c.curCellRange = newRange()
  105. for i, pos := range positions {
  106. seq := seqs[i]
  107. c.cells[c.curLoc+i] = cacheCell{sequences: []seqCell{{seq: seq, pos: pos}}}
  108. ranges, ok := c.cellRanges[seq]
  109. if !ok {
  110. ranges = newRange()
  111. }
  112. if c.curLoc+i > ranges.max {
  113. ranges.max = c.curLoc + i
  114. }
  115. if ranges.max > c.curCellRange.max {
  116. c.curCellRange.max = ranges.max
  117. }
  118. if c.curLoc+i < ranges.min {
  119. ranges.min = c.curLoc + i
  120. }
  121. if ranges.min < c.curCellRange.min {
  122. c.curCellRange.min = ranges.min
  123. }
  124. c.cellRanges[seq] = ranges
  125. }
  126. c.curMask, err = c.buildMask(ctx, positions, seqs)
  127. return err
  128. }
  129. func newRange() cellRange {
  130. return cellRange{
  131. min: math.MaxInt,
  132. max: 0,
  133. }
  134. }
  135. func (c *Causal) findStartLoc() (int, error) {
  136. var start, count int
  137. for i := range c.cells {
  138. if len(c.cells[i].sequences) == 0 {
  139. count++
  140. if count >= c.curBatchSize {
  141. return start, nil
  142. }
  143. } else {
  144. start = i + 1
  145. count = 0
  146. }
  147. }
  148. return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
  149. }
  150. func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
  151. // TODO(jessegross): This makes a number of simplifications such as no padding,
  152. // which could be an issue for CUDA graphs and/or flash attention
  153. len := c.curCellRange.max - c.curCellRange.min + 1
  154. mask := make([]float32, c.curBatchSize*len)
  155. for i := range c.curBatchSize {
  156. for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
  157. cellSeq := c.cells[j].findSeq(seqs[i])
  158. if cellSeq == nil || cellSeq.pos > positions[i] {
  159. mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
  160. }
  161. }
  162. }
  163. return ctx.FromFloatSlice(mask, len, c.curBatchSize)
  164. }
  165. func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
  166. for _, obj := range objs {
  167. srcView := obj.View(ctx, int(obj.Stride(2))*src, int(obj.Dim(0)*obj.Dim(1))*len)
  168. dstView := obj.View(ctx, int(obj.Stride(2))*dst, int(obj.Dim(0)*obj.Dim(1))*len)
  169. ctx.Forward(srcView.Copy(ctx, dstView))
  170. }
  171. }
  172. func (c *Causal) defrag() {
  173. slog.Debug("defragmenting kv cache")
  174. // Defrag strategy:
  175. // - Search for empty holes at the beginning of the cache,
  176. // filling them with active data starting at the end
  177. // - If there are contiguous elements that need to be moved,
  178. // combine them into a single operation by holding new moves
  179. // until we see the next one is non-contiguous
  180. // - Fill up the context with the maximum number of operations it
  181. // can hold then compute that and continue with a new context
  182. //
  183. // We could try to optimize placement by grouping blocks from
  184. // the same sequences together but most likely the next forward
  185. // pass will disrupt this anyways, so the real world benefit
  186. // seems limited as this time.
  187. ctx := c.backend.NewContext()
  188. // For every move, 6 tensors are required per layer (2 views and a
  189. // copy for each of k and v). For efficiency, we try to group
  190. // multiple contiguous blocks into a single move. However, if we
  191. // exceed the maximum number of tensors then we need to compute
  192. // what we have and start a new batch.
  193. maxMoves := ctx.MaxTensors() / (6 * len(c.keys))
  194. moves := 0
  195. var pendingSrc, pendingDst, pendingLen int
  196. for dst := range c.cells {
  197. if len(c.cells[dst].sequences) == 0 {
  198. for src := len(c.cells) - 1; src > dst; src-- {
  199. if len(c.cells[src].sequences) != 0 {
  200. c.cells[dst] = c.cells[src]
  201. c.cells[src] = cacheCell{}
  202. if pendingLen > 0 {
  203. if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
  204. pendingSrc = src
  205. pendingLen++
  206. break
  207. } else {
  208. moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
  209. moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
  210. moves++
  211. }
  212. }
  213. pendingSrc = src
  214. pendingDst = dst
  215. pendingLen = 1
  216. break
  217. }
  218. }
  219. }
  220. if moves >= maxMoves {
  221. ctx.Compute(nil)
  222. ctx.Close()
  223. ctx = c.backend.NewContext()
  224. moves = 0
  225. }
  226. }
  227. if pendingLen > 0 {
  228. moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
  229. moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
  230. moves++
  231. }
  232. if moves > 0 {
  233. ctx.Compute(nil)
  234. }
  235. ctx.Close()
  236. for seq := range c.cellRanges {
  237. seqRange := newRange()
  238. for i, cell := range c.cells {
  239. if cell.findSeq(seq) != nil {
  240. if i < seqRange.min {
  241. seqRange.min = i
  242. }
  243. if i > seqRange.max {
  244. seqRange.max = i
  245. }
  246. }
  247. }
  248. c.cellRanges[seq] = seqRange
  249. }
  250. }
  251. func (c *Causal) Sub(i int) Cache {
  252. if i >= len(c.keys) {
  253. c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...)
  254. c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...)
  255. }
  256. c.curLayer = i
  257. return c
  258. }
  259. func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
  260. key := c.keys[c.curLayer]
  261. value := c.values[c.curLayer]
  262. key = key.View(ctx, int(key.Stride(2))*c.curCellRange.min,
  263. int(key.Dim(0)), int(key.Stride(1)),
  264. int(key.Dim(1)), int(key.Stride(2)),
  265. int(c.curMask.Dim(0)),
  266. )
  267. value = value.View(ctx, int(key.Stride(2))*c.curCellRange.min,
  268. int(value.Dim(0)), int(value.Stride(1)),
  269. int(value.Dim(1)), int(value.Stride(2)),
  270. int(c.curMask.Dim(0)),
  271. )
  272. return key, value, c.curMask
  273. }
  274. func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
  275. if c.curBatchSize != int(key.Dim(2)) {
  276. panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, int(key.Dim(2))))
  277. }
  278. if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
  279. c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int64(c.Capacity))
  280. c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int64(c.Capacity))
  281. }
  282. ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, int(key.Stride(2))*c.curLoc, int(key.Dim(0)*key.Dim(1)*key.Dim(2)))))
  283. ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, int(value.Stride(2))*c.curLoc, int(value.Dim(0)*value.Dim(1)*value.Dim(2)))))
  284. }
  285. func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
  286. seqRange := newRange()
  287. for i := range c.cells {
  288. srcCellSeq := c.cells[i].findSeq(srcSeq)
  289. dstCellSeq := c.cells[i].findSeq(dstSeq)
  290. if dstCellSeq != nil {
  291. c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s seqCell) bool { return s.seq == dstSeq })
  292. }
  293. if srcCellSeq != nil && srcCellSeq.pos < len {
  294. c.cells[i].sequences = append(c.cells[i].sequences, seqCell{seq: dstSeq, pos: srcCellSeq.pos})
  295. if i < seqRange.min {
  296. seqRange.min = i
  297. }
  298. if i > seqRange.max {
  299. seqRange.max = i
  300. }
  301. }
  302. }
  303. c.cellRanges[dstSeq] = seqRange
  304. }
  305. func (c *Causal) shift(seq int, beginIndex, offset int32) error {
  306. panic("Shift not yet implemented")
  307. }
  308. func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
  309. var offset int32
  310. if endIndex != math.MaxInt32 {
  311. offset = beginIndex - endIndex
  312. }
  313. seqRange := newRange()
  314. for i := range c.cells {
  315. cellSeq := c.cells[i].findSeq(seq)
  316. if cellSeq != nil {
  317. if cellSeq.pos >= beginIndex && cellSeq.pos < endIndex {
  318. c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s seqCell) bool { return s.seq == seq })
  319. } else {
  320. if cellSeq.pos >= endIndex {
  321. cellSeq.pos += offset
  322. }
  323. if i < seqRange.min {
  324. seqRange.min = i
  325. }
  326. if i > seqRange.max {
  327. seqRange.max = i
  328. }
  329. }
  330. }
  331. }
  332. if endIndex != math.MaxInt32 {
  333. err := c.shift(seq, endIndex, offset)
  334. if err != nil {
  335. return err
  336. }
  337. }
  338. c.cellRanges[seq] = seqRange
  339. return nil
  340. }