causal.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683
  1. package kvcache
  2. import (
  3. "errors"
  4. "fmt"
  5. "log/slog"
  6. "math"
  7. "slices"
  8. "github.com/ollama/ollama/ml"
  9. "github.com/ollama/ollama/model/input"
  10. )
  11. type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
  12. // Causal cache stores K and V tensors according to their position in the
  13. // sequence. Returns the history and a mask for attending to past tokens
  14. //
  15. // The tensors are of shape embed dim, kv heads, batch size
  16. // The mask is of shape history size, batch size
  17. type Causal struct {
  18. DType ml.DType
  19. windowSize int32
  20. opts CausalOptions
  21. // config controls mostly backend-specific optimizations
  22. config *ml.CacheConfig
  23. // ** current forward pass **
  24. // the active layer for Get and Put
  25. curLayer int
  26. // starting location for data storage for this batch
  27. curLoc int
  28. // size of the current batch
  29. curBatchSize int
  30. // mask of the cache as used by this batch
  31. curMask ml.Tensor
  32. // locations in the cache that are needed for this batch
  33. curCellRange cellRange
  34. // curSequences is the sequences corresponding to this pass's entries in the cache
  35. curSequences []int
  36. // curPositions is the positions corresponding to this pass's entries in the cache
  37. curPositions []int32
  38. // ** cache metadata **
  39. // for each possible location in the cache, stores the position and set of sequences
  40. // that reference the data there
  41. cells []cacheCell
  42. // maps from sequence to the range of locations where it is stored in the cache
  43. cellRanges map[int]cellRange
  44. // ** cache data storage **
  45. shiftFn shiftFn
  46. backend ml.Backend
  47. ctxs map[int]ml.Context
  48. keys, values map[int]ml.Tensor
  49. }
  50. type cacheCell struct {
  51. pos int32
  52. sequences []int
  53. }
  54. type cellRange struct {
  55. min int
  56. max int
  57. }
  58. func NewCausalCache(shift shiftFn) *Causal {
  59. return &Causal{
  60. windowSize: math.MaxInt32,
  61. shiftFn: shift,
  62. ctxs: make(map[int]ml.Context),
  63. keys: make(map[int]ml.Tensor),
  64. values: make(map[int]ml.Tensor),
  65. }
  66. }
  67. func NewSWACache(windowSize int32, shift shiftFn) *Causal {
  68. return &Causal{
  69. windowSize: windowSize,
  70. shiftFn: shift,
  71. ctxs: make(map[int]ml.Context),
  72. keys: make(map[int]ml.Tensor),
  73. values: make(map[int]ml.Tensor),
  74. }
  75. }
  76. func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
  77. if c.config == nil {
  78. var config ml.CacheConfig
  79. if cc, ok := backend.(ml.BackendCacheConfig); ok {
  80. config = cc.CacheConfig()
  81. }
  82. c.config = &config
  83. }
  84. if c.config.CachePadding == 0 {
  85. c.config.CachePadding = 1
  86. }
  87. if c.config.MaskBatchPadding == 0 {
  88. c.config.MaskBatchPadding = 1
  89. }
  90. if c.config.MaskDType == ml.DTypeOther {
  91. c.config.MaskDType = ml.DTypeF32
  92. }
  93. var cacheSize int
  94. if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
  95. cacheSize = maxSequences * capacity
  96. } else {
  97. cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
  98. }
  99. cacheSize = roundUp(cacheSize, c.config.CachePadding)
  100. c.cells = make([]cacheCell, cacheSize)
  101. c.DType = dtype
  102. c.cellRanges = make(map[int]cellRange)
  103. c.backend = backend
  104. }
  105. func (c *Causal) SetConfig(config ml.CacheConfig) {
  106. if c.config != nil {
  107. panic("config cannot be changed after being previously set, either by the model or backend")
  108. }
  109. c.config = &config
  110. }
  111. func (c *Causal) Close() {
  112. for _, ctx := range c.ctxs {
  113. ctx.Close()
  114. }
  115. }
  116. func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
  117. c.curBatchSize = len(batch.Positions)
  118. c.curSequences = batch.Sequences
  119. c.curPositions = batch.Positions
  120. c.opts.Except = nil
  121. c.updateSlidingWindow()
  122. var err error
  123. c.curLoc, err = c.findStartLoc()
  124. if errors.Is(err, ErrKvCacheFull) {
  125. c.defrag()
  126. c.curLoc, err = c.findStartLoc()
  127. }
  128. if err != nil {
  129. return err
  130. }
  131. c.curCellRange = newRange()
  132. for i, pos := range batch.Positions {
  133. seq := batch.Sequences[i]
  134. c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
  135. seqRange, ok := c.cellRanges[seq]
  136. if !ok {
  137. seqRange = newRange()
  138. }
  139. if c.curLoc+i > seqRange.max {
  140. seqRange.max = c.curLoc + i
  141. }
  142. if seqRange.max > c.curCellRange.max {
  143. c.curCellRange.max = seqRange.max
  144. }
  145. if c.curLoc+i < seqRange.min {
  146. seqRange.min = c.curLoc + i
  147. }
  148. if seqRange.min < c.curCellRange.min {
  149. c.curCellRange.min = seqRange.min
  150. }
  151. c.cellRanges[seq] = seqRange
  152. }
  153. c.curMask, err = c.buildMask(ctx)
  154. return err
  155. }
  156. func newRange() cellRange {
  157. return cellRange{
  158. min: math.MaxInt,
  159. max: 0,
  160. }
  161. }
  162. // Find the first contiguous block of at least curBatchSize
  163. func (c *Causal) findStartLoc() (int, error) {
  164. var start, count int
  165. for i := range c.cells {
  166. if len(c.cells[i].sequences) == 0 {
  167. count++
  168. if count >= c.curBatchSize {
  169. return start, nil
  170. }
  171. } else {
  172. start = i + 1
  173. count = 0
  174. }
  175. }
  176. return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
  177. }
  178. func (c *Causal) updateSlidingWindow() {
  179. if c.windowSize == math.MaxInt32 {
  180. return
  181. }
  182. // create a map of unique sequences to the lowest position in that sequence
  183. lowestPos := make(map[int]int32)
  184. for i := range c.curPositions {
  185. seq := c.curSequences[i]
  186. pos, ok := lowestPos[seq]
  187. if !ok {
  188. pos = c.curPositions[i]
  189. } else if c.curPositions[i] < pos {
  190. pos = c.curPositions[i]
  191. }
  192. lowestPos[seq] = pos
  193. }
  194. // delete any entries that are beyond the window of the oldest position in the sequence
  195. for seq, pos := range lowestPos {
  196. oldRange, ok := c.cellRanges[seq]
  197. if !ok {
  198. continue
  199. }
  200. newRange := newRange()
  201. for i := oldRange.min; i <= oldRange.max; i++ {
  202. if slices.Contains(c.cells[i].sequences, seq) {
  203. if c.cells[i].pos < pos-c.windowSize {
  204. c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
  205. } else {
  206. newRange.min = min(newRange.min, i)
  207. newRange.max = max(newRange.max, i)
  208. }
  209. }
  210. }
  211. c.cellRanges[seq] = newRange
  212. }
  213. }
  214. func roundDown(length, pad int) int {
  215. return (length / pad) * pad
  216. }
  217. func roundUp(length, pad int) int {
  218. return ((length + pad - 1) / pad) * pad
  219. }
  220. // Builds a mask of history x batch indicating whether for each token in the batch the
  221. // token in the history should apply. This is based on both the sequence and causality (the
  222. // position of the history is not ahead of the token in the batch).
  223. func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
  224. // Align and pad the two dimensions as required by the backend
  225. batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
  226. c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
  227. c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
  228. length := c.curCellRange.max - c.curCellRange.min + 1
  229. mask := make([]float32, batchSize*length)
  230. for i := range c.curBatchSize {
  231. enabled := !slices.Contains(c.opts.Except, i)
  232. for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
  233. if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
  234. (enabled && c.cells[j].pos > c.curPositions[i]) ||
  235. c.cells[j].pos < c.curPositions[i]-c.windowSize {
  236. mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
  237. }
  238. }
  239. }
  240. // Mask out any padding tokens we added. For padding that we added to the cache history, this
  241. // has already been masked out because the sequence doesn't match.
  242. for i := c.curBatchSize * length; i < len(mask); i++ {
  243. mask[i] = float32(math.Inf(-1))
  244. }
  245. maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
  246. if err != nil {
  247. return nil, err
  248. }
  249. if c.config.MaskDType != ml.DTypeF32 {
  250. out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
  251. ctx.Forward(maskTensor.Copy(ctx, out))
  252. maskTensor = out
  253. }
  254. return maskTensor, nil
  255. }
  256. func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
  257. for i, key := range c.keys {
  258. if key == nil {
  259. continue
  260. }
  261. kHeadDim := key.Dim(0)
  262. numKVHeads := key.Dim(1)
  263. rowSize := key.Stride(2)
  264. kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
  265. kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
  266. value := c.values[i]
  267. var vSrcView, vDstView ml.Tensor
  268. if c.config.PermutedV {
  269. vHeadDim := value.Dim(1)
  270. elemSize := value.Stride(0)
  271. vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
  272. vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
  273. } else {
  274. vHeadDim := value.Dim(0)
  275. rowSize := value.Stride(2)
  276. vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
  277. vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
  278. }
  279. ctx.Forward(
  280. kSrcView.Copy(ctx, kDstView),
  281. vSrcView.Copy(ctx, vDstView),
  282. )
  283. }
  284. }
  285. func (c *Causal) defrag() {
  286. slog.Debug("defragmenting kv cache")
  287. // Defrag strategy:
  288. // - Search for empty holes at the beginning of the cache,
  289. // filling them with active data starting at the end
  290. // - If there are contiguous elements that need to be moved,
  291. // combine them into a single operation by holding new moves
  292. // until we see that the next one is non-contiguous
  293. // - Fill up the context with the maximum number of operations it
  294. // can hold then compute that and continue with a new context
  295. //
  296. // We could try to optimize placement by grouping blocks from
  297. // the same sequences together but most likely the next forward
  298. // pass will disrupt this anyways, so the real world benefit
  299. // seems limited as this time.
  300. ctx := c.backend.NewContext()
  301. // For every move, 6 tensors are required per layer (2 views and a
  302. // copy for each of k and v). We also need to refer to the original
  303. // k and v cache tensors - once per layer, not per move.
  304. layers := 0
  305. for _, key := range c.keys {
  306. if key == nil {
  307. continue
  308. }
  309. layers++
  310. }
  311. maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
  312. moves := 0
  313. var pendingSrc, pendingDst, pendingLen int
  314. src := len(c.cells) - 1
  315. for dst := 0; dst < src; dst++ {
  316. if len(c.cells[dst].sequences) == 0 {
  317. for ; src > dst; src-- {
  318. if len(c.cells[src].sequences) != 0 {
  319. c.cells[dst] = c.cells[src]
  320. c.cells[src] = cacheCell{}
  321. if pendingLen > 0 {
  322. if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
  323. pendingSrc = src
  324. pendingLen++
  325. break
  326. } else {
  327. c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
  328. moves++
  329. }
  330. }
  331. pendingSrc = src
  332. pendingDst = dst
  333. pendingLen = 1
  334. break
  335. }
  336. }
  337. }
  338. if moves >= maxMoves {
  339. ctx.Compute()
  340. ctx.Close()
  341. ctx = c.backend.NewContext()
  342. moves = 0
  343. }
  344. }
  345. if pendingLen > 0 {
  346. c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
  347. moves++
  348. }
  349. if moves > 0 {
  350. ctx.Compute()
  351. }
  352. ctx.Close()
  353. // Reset range metadata
  354. for seq := range c.cellRanges {
  355. seqRange := newRange()
  356. for i, cell := range c.cells {
  357. if slices.Contains(cell.sequences, seq) {
  358. if i < seqRange.min {
  359. seqRange.min = i
  360. }
  361. if i > seqRange.max {
  362. seqRange.max = i
  363. }
  364. }
  365. }
  366. c.cellRanges[seq] = seqRange
  367. }
  368. }
  369. func (c *Causal) SetLayer(layer int) {
  370. c.curLayer = layer
  371. }
  372. type CausalOptions struct {
  373. // Enabled controls whether the causal mask is generated for a particular index in a batch
  374. Except []int
  375. }
  376. // SetCausal disables causal mask generation for a particular range of indicies in
  377. // the current batch for subsequent calls to Get. The state resets for the next forward pass.
  378. func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
  379. if !slices.Equal(c.opts.Except, opts.Except) {
  380. c.opts = opts
  381. if ctx != nil {
  382. var err error
  383. c.curMask, err = c.buildMask(ctx)
  384. if err != nil {
  385. // This error should never occur because we have previously built a mask with the same shape
  386. panic(fmt.Errorf("SetCausal: %w", err))
  387. }
  388. }
  389. }
  390. }
  391. func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
  392. key := c.keys[c.curLayer]
  393. value := c.values[c.curLayer]
  394. kHeadDim := key.Dim(0)
  395. numKVHeads := key.Dim(1)
  396. rowSize := key.Stride(2)
  397. cachedSize := c.curMask.Dim(0)
  398. key = key.View(ctx, rowSize*c.curCellRange.min,
  399. kHeadDim, key.Stride(1),
  400. numKVHeads, key.Stride(2),
  401. cachedSize,
  402. )
  403. if c.config.PermutedV {
  404. vHeadDim := value.Dim(1)
  405. elemSize := value.Stride(0)
  406. value = value.View(ctx, elemSize*c.curCellRange.min,
  407. cachedSize, value.Stride(1),
  408. vHeadDim, value.Stride(2),
  409. numKVHeads,
  410. )
  411. } else {
  412. vHeadDim := value.Dim(0)
  413. rowSize := value.Stride(2)
  414. value = value.View(ctx, rowSize*c.curCellRange.min,
  415. vHeadDim, value.Stride(1),
  416. numKVHeads, value.Stride(2),
  417. cachedSize,
  418. )
  419. }
  420. return key, value, c.curMask
  421. }
  422. func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
  423. kHeadDim := key.Dim(0)
  424. vHeadDim := value.Dim(0)
  425. numKVHeads := key.Dim(1)
  426. batchSize := key.Dim(2)
  427. if c.curBatchSize != batchSize {
  428. panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
  429. }
  430. if _, ok := c.ctxs[c.curLayer]; !ok {
  431. c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
  432. }
  433. if _, ok := c.keys[c.curLayer]; !ok {
  434. c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
  435. }
  436. if _, ok := c.values[c.curLayer]; !ok {
  437. if c.config.PermutedV {
  438. c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
  439. } else {
  440. c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
  441. }
  442. }
  443. rowSize := c.keys[c.curLayer].Stride(2)
  444. ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
  445. if c.config.PermutedV {
  446. elemSize := c.values[c.curLayer].Stride(0)
  447. value = value.Permute(ctx, 1, 2, 0, 3)
  448. ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
  449. } else {
  450. rowSize := c.values[c.curLayer].Stride(2)
  451. ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
  452. }
  453. }
  454. func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
  455. seqRange := newRange()
  456. for i := range c.cells {
  457. // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
  458. if slices.Contains(c.cells[i].sequences, dstSeq) {
  459. c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
  460. }
  461. if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
  462. c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
  463. if i < seqRange.min {
  464. seqRange.min = i
  465. }
  466. if i > seqRange.max {
  467. seqRange.max = i
  468. }
  469. }
  470. }
  471. c.cellRanges[dstSeq] = seqRange
  472. }
  473. func (c *Causal) shift(seq int, beginIndex, offset int32) error {
  474. if c.shiftFn == nil {
  475. return ErrNotSupported
  476. }
  477. ctx := c.backend.NewContext()
  478. defer ctx.Close()
  479. seqRange := c.cellRanges[seq]
  480. size := seqRange.max - seqRange.min + 1
  481. offsets := make([]int32, size)
  482. for i := range offsets {
  483. cell := c.cells[seqRange.min+i]
  484. if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
  485. offsets[i] = offset
  486. }
  487. }
  488. kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets))
  489. if err != nil {
  490. return err
  491. }
  492. for i, key := range c.keys {
  493. if key == nil {
  494. continue
  495. }
  496. kHeadDim := key.Dim(0)
  497. numKVHeads := key.Dim(1)
  498. rowSize := key.Stride(2)
  499. key = key.View(ctx, rowSize*seqRange.min,
  500. kHeadDim, key.Stride(1),
  501. numKVHeads, key.Stride(2),
  502. size,
  503. )
  504. roped, err := c.shiftFn(ctx, i, key, kShift)
  505. if err != nil {
  506. return err
  507. }
  508. ctx.Forward(roped.Copy(ctx, key))
  509. }
  510. ctx.Compute()
  511. return nil
  512. }
  513. func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
  514. var offset int32
  515. if endIndex != math.MaxInt32 {
  516. offset = beginIndex - endIndex
  517. }
  518. seqRange := newRange()
  519. for i := range c.cells {
  520. if slices.Contains(c.cells[i].sequences, seq) {
  521. if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
  522. c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
  523. } else {
  524. if c.cells[i].pos >= endIndex {
  525. if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
  526. // TODO(jessegross): Need to be careful about data shared between sequences
  527. return errors.New("shifting on cells shared by multiple sequences not yet implemented")
  528. }
  529. c.cells[i].pos += offset
  530. }
  531. if i < seqRange.min {
  532. seqRange.min = i
  533. }
  534. if i > seqRange.max {
  535. seqRange.max = i
  536. }
  537. }
  538. }
  539. }
  540. if seqRange == newRange() {
  541. delete(c.cellRanges, seq)
  542. return nil
  543. }
  544. c.cellRanges[seq] = seqRange
  545. if endIndex != math.MaxInt32 {
  546. err := c.shift(seq, endIndex+offset, offset)
  547. if err != nil {
  548. return err
  549. }
  550. }
  551. return nil
  552. }