causal.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. package kvcache
  2. import (
  3. "errors"
  4. "fmt"
  5. "log/slog"
  6. "math"
  7. "slices"
  8. "github.com/ollama/ollama/ml"
  9. "github.com/ollama/ollama/model/input"
  10. )
  11. type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
  12. // Causal cache stores K and V tensors according to their position in the
  13. // sequence. Returns the history and a mask for attending to past tokens
  14. //
  15. // The tensors are of shape embed dim, kv heads, batch size
  16. // The mask is of shape history size, batch size
  17. type Causal struct {
  18. DType ml.DType
  19. Capacity int32
  20. windowSize int32
  21. opts CausalOptions
  22. // config controls mostly backend-specific optimizations
  23. config *ml.CacheConfig
  24. // ** current forward pass **
  25. // the active layer for Get and Put
  26. curLayer int
  27. // starting location for data storage for this batch
  28. curLoc int
  29. // size of the current batch
  30. curBatchSize int
  31. // mask of the cache as used by this batch
  32. curMask ml.Tensor
  33. // locations in the cache that are needed for this batch
  34. curCellRange cellRange
  35. // curSequences is the sequences corresponding to this pass's entries in the cache
  36. curSequences []int
  37. // curPositions is the positions corresponding to this pass's entries in the cache
  38. curPositions []int32
  39. // ** cache metadata **
  40. // for each possible location in the cache, stores the position and set of sequences
  41. // that reference the data there
  42. cells []cacheCell
  43. // maps from sequence to the range of locations where it is stored in the cache
  44. cellRanges map[int]cellRange
  45. // ** cache data storage **
  46. shiftFn shiftFn
  47. backend ml.Backend
  48. ctxs map[int]ml.Context
  49. keys, values map[int]ml.Tensor
  50. }
  51. type cacheCell struct {
  52. pos int32
  53. sequences []int
  54. }
  55. type cellRange struct {
  56. min int
  57. max int
  58. }
  59. func NewCausalCache(shift shiftFn) *Causal {
  60. return &Causal{
  61. windowSize: math.MaxInt32,
  62. shiftFn: shift,
  63. ctxs: make(map[int]ml.Context),
  64. keys: make(map[int]ml.Tensor),
  65. values: make(map[int]ml.Tensor),
  66. }
  67. }
  68. func NewSWACache(windowSize int32, shift shiftFn) *Causal {
  69. return &Causal{
  70. windowSize: windowSize,
  71. shiftFn: shift,
  72. ctxs: make(map[int]ml.Context),
  73. keys: make(map[int]ml.Tensor),
  74. values: make(map[int]ml.Tensor),
  75. }
  76. }
  77. func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
  78. if c.config == nil {
  79. var config ml.CacheConfig
  80. if cc, ok := backend.(ml.BackendCacheConfig); ok {
  81. config = cc.CacheConfig()
  82. }
  83. c.config = &config
  84. }
  85. if c.config.CachePadding == 0 {
  86. c.config.CachePadding = 1
  87. }
  88. if c.config.MaskBatchPadding == 0 {
  89. c.config.MaskBatchPadding = 1
  90. }
  91. if c.config.MaskDType == ml.DTypeOther {
  92. c.config.MaskDType = ml.DTypeF32
  93. }
  94. c.DType = dtype
  95. c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
  96. c.cells = make([]cacheCell, c.Capacity)
  97. c.cellRanges = make(map[int]cellRange)
  98. c.backend = backend
  99. }
  100. func (c *Causal) SetConfig(config ml.CacheConfig) {
  101. if c.config != nil {
  102. panic("config cannot be changed after being previously set, either by the model or backend")
  103. }
  104. c.config = &config
  105. }
  106. func (c *Causal) Close() {
  107. for _, ctx := range c.ctxs {
  108. ctx.Close()
  109. }
  110. }
  111. func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
  112. c.curBatchSize = len(batch.Positions)
  113. c.curSequences = batch.Sequences
  114. c.curPositions = batch.Positions
  115. c.opts.Except = nil
  116. var err error
  117. c.curLoc, err = c.findStartLoc()
  118. if errors.Is(err, ErrKvCacheFull) {
  119. c.defrag()
  120. c.curLoc, err = c.findStartLoc()
  121. }
  122. if err != nil {
  123. return err
  124. }
  125. c.curCellRange = newRange()
  126. for i, pos := range batch.Positions {
  127. seq := batch.Sequences[i]
  128. c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
  129. seqRange, ok := c.cellRanges[seq]
  130. if !ok {
  131. seqRange = newRange()
  132. }
  133. if c.curLoc+i > seqRange.max {
  134. seqRange.max = c.curLoc + i
  135. }
  136. if seqRange.max > c.curCellRange.max {
  137. c.curCellRange.max = seqRange.max
  138. }
  139. if c.curLoc+i < seqRange.min {
  140. seqRange.min = c.curLoc + i
  141. }
  142. if seqRange.min < c.curCellRange.min {
  143. c.curCellRange.min = seqRange.min
  144. }
  145. c.cellRanges[seq] = seqRange
  146. }
  147. c.curMask, err = c.buildMask(ctx)
  148. return err
  149. }
  150. func newRange() cellRange {
  151. return cellRange{
  152. min: math.MaxInt,
  153. max: 0,
  154. }
  155. }
  156. // Find the first contiguous block of at least curBatchSize
  157. func (c *Causal) findStartLoc() (int, error) {
  158. var start, count int
  159. for i := range c.cells {
  160. if len(c.cells[i].sequences) == 0 {
  161. count++
  162. if count >= c.curBatchSize {
  163. return start, nil
  164. }
  165. } else {
  166. start = i + 1
  167. count = 0
  168. }
  169. }
  170. return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
  171. }
  172. func roundDown(length, pad int) int {
  173. return (length / pad) * pad
  174. }
  175. func roundUp(length, pad int) int {
  176. return ((length + pad - 1) / pad) * pad
  177. }
  178. // Builds a mask of history x batch indicating whether for each token in the batch the
  179. // token in the history should apply. This is based on both the sequence and causality (the
  180. // position of the history is not ahead of the token in the batch).
  181. func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
  182. // Align and pad the two dimensions as required by the backend
  183. batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
  184. c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
  185. c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
  186. length := c.curCellRange.max - c.curCellRange.min + 1
  187. mask := make([]float32, batchSize*length)
  188. for i := range c.curBatchSize {
  189. enabled := !slices.Contains(c.opts.Except, i)
  190. for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
  191. if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
  192. (enabled && c.cells[j].pos > c.curPositions[i]) ||
  193. c.cells[j].pos < c.curPositions[i]-c.windowSize {
  194. mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
  195. }
  196. }
  197. }
  198. // Mask out any padding tokens we added. For padding that we added to the cache history, this
  199. // has already been masked out because the sequence doesn't match.
  200. for i := c.curBatchSize * length; i < len(mask); i++ {
  201. mask[i] = float32(math.Inf(-1))
  202. }
  203. maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
  204. if err != nil {
  205. return nil, err
  206. }
  207. if c.config.MaskDType != ml.DTypeF32 {
  208. out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
  209. ctx.Forward(maskTensor.Copy(ctx, out))
  210. maskTensor = out
  211. }
  212. return maskTensor, nil
  213. }
  214. func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
  215. for i, key := range c.keys {
  216. if key == nil {
  217. continue
  218. }
  219. kHeadDim := key.Dim(0)
  220. numKVHeads := key.Dim(1)
  221. rowSize := key.Stride(2)
  222. kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
  223. kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
  224. value := c.values[i]
  225. var vSrcView, vDstView ml.Tensor
  226. if c.config.PermutedV {
  227. vHeadDim := value.Dim(1)
  228. elemSize := value.Stride(0)
  229. vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
  230. vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
  231. } else {
  232. vHeadDim := value.Dim(0)
  233. rowSize := value.Stride(2)
  234. vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
  235. vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
  236. }
  237. ctx.Forward(
  238. kSrcView.Copy(ctx, kDstView),
  239. vSrcView.Copy(ctx, vDstView),
  240. )
  241. }
  242. }
  243. func (c *Causal) defrag() {
  244. slog.Debug("defragmenting kv cache")
  245. // Defrag strategy:
  246. // - Search for empty holes at the beginning of the cache,
  247. // filling them with active data starting at the end
  248. // - If there are contiguous elements that need to be moved,
  249. // combine them into a single operation by holding new moves
  250. // until we see that the next one is non-contiguous
  251. // - Fill up the context with the maximum number of operations it
  252. // can hold then compute that and continue with a new context
  253. //
  254. // We could try to optimize placement by grouping blocks from
  255. // the same sequences together but most likely the next forward
  256. // pass will disrupt this anyways, so the real world benefit
  257. // seems limited as this time.
  258. ctx := c.backend.NewContext()
  259. // For every move, 6 tensors are required per layer (2 views and a
  260. // copy for each of k and v). We also need to refer to the original
  261. // k and v cache tensors - once per layer, not per move.
  262. layers := 0
  263. for _, key := range c.keys {
  264. if key == nil {
  265. continue
  266. }
  267. layers++
  268. }
  269. maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
  270. moves := 0
  271. var pendingSrc, pendingDst, pendingLen int
  272. src := len(c.cells) - 1
  273. for dst := 0; dst < src; dst++ {
  274. if len(c.cells[dst].sequences) == 0 {
  275. for ; src > dst; src-- {
  276. if len(c.cells[src].sequences) != 0 {
  277. c.cells[dst] = c.cells[src]
  278. c.cells[src] = cacheCell{}
  279. if pendingLen > 0 {
  280. if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
  281. pendingSrc = src
  282. pendingLen++
  283. break
  284. } else {
  285. c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
  286. moves++
  287. }
  288. }
  289. pendingSrc = src
  290. pendingDst = dst
  291. pendingLen = 1
  292. break
  293. }
  294. }
  295. }
  296. if moves >= maxMoves {
  297. ctx.Compute()
  298. ctx.Close()
  299. ctx = c.backend.NewContext()
  300. moves = 0
  301. }
  302. }
  303. if pendingLen > 0 {
  304. c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
  305. moves++
  306. }
  307. if moves > 0 {
  308. ctx.Compute()
  309. }
  310. ctx.Close()
  311. // Reset range metadata
  312. for seq := range c.cellRanges {
  313. seqRange := newRange()
  314. for i, cell := range c.cells {
  315. if slices.Contains(cell.sequences, seq) {
  316. if i < seqRange.min {
  317. seqRange.min = i
  318. }
  319. if i > seqRange.max {
  320. seqRange.max = i
  321. }
  322. }
  323. }
  324. c.cellRanges[seq] = seqRange
  325. }
  326. }
  327. func (c *Causal) SetLayer(layer int) {
  328. c.curLayer = layer
  329. }
  330. type CausalOptions struct {
  331. // Enabled controls whether the causal mask is generated for a particular index in a batch
  332. Except []int
  333. }
  334. // SetCausal disables causal mask generation for a particular range of indicies in
  335. // the current batch for subsequent calls to Get. The state resets for the next forward pass.
  336. func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
  337. if !slices.Equal(c.opts.Except, opts.Except) {
  338. c.opts = opts
  339. if ctx != nil {
  340. var err error
  341. c.curMask, err = c.buildMask(ctx)
  342. if err != nil {
  343. // This error should never occur because we have previously built a mask with the same shape
  344. panic(fmt.Errorf("SetCausal: %w", err))
  345. }
  346. }
  347. }
  348. }
  349. func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
  350. key := c.keys[c.curLayer]
  351. value := c.values[c.curLayer]
  352. kHeadDim := key.Dim(0)
  353. numKVHeads := key.Dim(1)
  354. rowSize := key.Stride(2)
  355. cachedSize := c.curMask.Dim(0)
  356. key = key.View(ctx, rowSize*c.curCellRange.min,
  357. kHeadDim, key.Stride(1),
  358. numKVHeads, key.Stride(2),
  359. cachedSize,
  360. )
  361. if c.config.PermutedV {
  362. vHeadDim := value.Dim(1)
  363. elemSize := value.Stride(0)
  364. value = value.View(ctx, elemSize*c.curCellRange.min,
  365. cachedSize, value.Stride(1),
  366. vHeadDim, value.Stride(2),
  367. numKVHeads,
  368. )
  369. } else {
  370. vHeadDim := value.Dim(0)
  371. rowSize := value.Stride(2)
  372. value = value.View(ctx, rowSize*c.curCellRange.min,
  373. vHeadDim, value.Stride(1),
  374. numKVHeads, value.Stride(2),
  375. cachedSize,
  376. )
  377. }
  378. return key, value, c.curMask
  379. }
  380. func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
  381. kHeadDim := key.Dim(0)
  382. vHeadDim := value.Dim(0)
  383. numKVHeads := key.Dim(1)
  384. batchSize := key.Dim(2)
  385. if c.curBatchSize != batchSize {
  386. panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
  387. }
  388. if _, ok := c.ctxs[c.curLayer]; !ok {
  389. c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
  390. }
  391. if _, ok := c.keys[c.curLayer]; !ok {
  392. c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
  393. }
  394. if _, ok := c.values[c.curLayer]; !ok {
  395. if c.config.PermutedV {
  396. c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
  397. } else {
  398. c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
  399. }
  400. }
  401. rowSize := c.keys[c.curLayer].Stride(2)
  402. ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
  403. if c.config.PermutedV {
  404. elemSize := c.values[c.curLayer].Stride(0)
  405. value = value.Permute(ctx, 1, 2, 0, 3)
  406. ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
  407. } else {
  408. rowSize := c.values[c.curLayer].Stride(2)
  409. ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
  410. }
  411. }
  412. func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
  413. seqRange := newRange()
  414. for i := range c.cells {
  415. // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
  416. if slices.Contains(c.cells[i].sequences, dstSeq) {
  417. c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
  418. }
  419. if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
  420. c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
  421. if i < seqRange.min {
  422. seqRange.min = i
  423. }
  424. if i > seqRange.max {
  425. seqRange.max = i
  426. }
  427. }
  428. }
  429. c.cellRanges[dstSeq] = seqRange
  430. }
  431. func (c *Causal) shift(seq int, beginIndex, offset int32) error {
  432. if c.shiftFn == nil {
  433. return ErrNotSupported
  434. }
  435. ctx := c.backend.NewContext()
  436. defer ctx.Close()
  437. seqRange := c.cellRanges[seq]
  438. size := seqRange.max - seqRange.min + 1
  439. offsets := make([]int32, size)
  440. for i := range offsets {
  441. cell := c.cells[seqRange.min+i]
  442. if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
  443. offsets[i] = offset
  444. }
  445. }
  446. kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets))
  447. if err != nil {
  448. return err
  449. }
  450. for i, key := range c.keys {
  451. if key == nil {
  452. continue
  453. }
  454. kHeadDim := key.Dim(0)
  455. numKVHeads := key.Dim(1)
  456. rowSize := key.Stride(2)
  457. key = key.View(ctx, rowSize*seqRange.min,
  458. kHeadDim, key.Stride(1),
  459. numKVHeads, key.Stride(2),
  460. size,
  461. )
  462. roped, err := c.shiftFn(ctx, i, key, kShift)
  463. if err != nil {
  464. return err
  465. }
  466. ctx.Forward(roped.Copy(ctx, key))
  467. }
  468. ctx.Compute()
  469. return nil
  470. }
  471. func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
  472. var offset int32
  473. if endIndex != math.MaxInt32 {
  474. offset = beginIndex - endIndex
  475. }
  476. seqRange := newRange()
  477. for i := range c.cells {
  478. if slices.Contains(c.cells[i].sequences, seq) {
  479. if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
  480. c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
  481. } else {
  482. if c.cells[i].pos >= endIndex {
  483. if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
  484. // TODO(jessegross): Need to be careful about data shared between sequences
  485. return errors.New("shifting on cells shared by multiple sequences not yet implemented")
  486. }
  487. c.cells[i].pos += offset
  488. }
  489. if i < seqRange.min {
  490. seqRange.min = i
  491. }
  492. if i > seqRange.max {
  493. seqRange.max = i
  494. }
  495. }
  496. }
  497. }
  498. if seqRange == newRange() {
  499. delete(c.cellRanges, seq)
  500. return nil
  501. }
  502. c.cellRanges[seq] = seqRange
  503. if endIndex != math.MaxInt32 {
  504. err := c.shift(seq, endIndex+offset, offset)
  505. if err != nil {
  506. return err
  507. }
  508. }
  509. return nil
  510. }