causal.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595
  1. package kvcache
  2. import (
  3. "errors"
  4. "fmt"
  5. "log/slog"
  6. "math"
  7. "slices"
  8. "github.com/ollama/ollama/ml"
  9. )
  10. type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
  11. // Causal cache stores K and V tensors according to their position in the
  12. // sequence. Returns the history and a mask for attending to past tokens
  13. //
  14. // The tensors are of shape embed dim, kv heads, batch size
  15. // The mask is of shape history size, batch size
  16. type Causal struct {
  17. DType ml.DType
  18. Capacity int32
  19. windowSize int32
  20. // config controls mostly backend-specific optimizations
  21. config *ml.CacheConfig
  22. // ** current forward pass **
  23. // the active layer for Get and Put
  24. curLayer int
  25. // starting location for data storage for this batch
  26. curLoc int
  27. // size of the current batch
  28. curBatchSize int
  29. // mask of the cache as used by this batch
  30. curMask ml.Tensor
  31. // locations in the cache that are needed for this batch
  32. curCellRange cellRange
  33. // ** cache metadata **
  34. // for each possible location in the cache, stores the position and set of sequences
  35. // that reference the data there
  36. cells []cacheCell
  37. // maps from sequence to the range of locations where it is stored in the cache
  38. cellRanges map[int]cellRange
  39. // ** cache data storage **
  40. shiftFn shiftFn
  41. backend ml.Backend
  42. ctxs map[int]ml.Context
  43. keys, values map[int]ml.Tensor
  44. }
  45. type cacheCell struct {
  46. pos int32
  47. sequences []int
  48. }
  49. type cellRange struct {
  50. min int
  51. max int
  52. }
  53. func NewCausalCache(shift shiftFn) *Causal {
  54. return &Causal{
  55. windowSize: math.MaxInt32,
  56. shiftFn: shift,
  57. ctxs: make(map[int]ml.Context),
  58. keys: make(map[int]ml.Tensor),
  59. values: make(map[int]ml.Tensor),
  60. }
  61. }
  62. func NewSWACache(windowSize int32, shift shiftFn) *Causal {
  63. return &Causal{
  64. windowSize: windowSize,
  65. shiftFn: shift,
  66. ctxs: make(map[int]ml.Context),
  67. keys: make(map[int]ml.Tensor),
  68. values: make(map[int]ml.Tensor),
  69. }
  70. }
  71. func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
  72. if c.config == nil {
  73. var config ml.CacheConfig
  74. if cc, ok := backend.(ml.BackendCacheConfig); ok {
  75. config = cc.CacheConfig()
  76. }
  77. c.config = &config
  78. }
  79. if c.config.CachePadding == 0 {
  80. c.config.CachePadding = 1
  81. }
  82. if c.config.MaskBatchPadding == 0 {
  83. c.config.MaskBatchPadding = 1
  84. }
  85. if c.config.MaskDType == ml.DTypeOther {
  86. c.config.MaskDType = ml.DTypeF32
  87. }
  88. c.DType = dtype
  89. c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
  90. c.cells = make([]cacheCell, c.Capacity)
  91. c.cellRanges = make(map[int]cellRange)
  92. c.backend = backend
  93. }
  94. func (c *Causal) SetConfig(config ml.CacheConfig) {
  95. if c.config != nil {
  96. panic("config cannot be changed after being previously set, either by the model or backend")
  97. }
  98. c.config = &config
  99. }
  100. func (c *Causal) Close() {
  101. for _, ctx := range c.ctxs {
  102. ctx.Close()
  103. }
  104. }
  105. func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
  106. c.curBatchSize = len(positions)
  107. var err error
  108. c.curLoc, err = c.findStartLoc()
  109. if errors.Is(err, ErrKvCacheFull) {
  110. c.defrag()
  111. c.curLoc, err = c.findStartLoc()
  112. }
  113. if err != nil {
  114. return err
  115. }
  116. c.curCellRange = newRange()
  117. for i, pos := range positions {
  118. seq := seqs[i]
  119. c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
  120. seqRange, ok := c.cellRanges[seq]
  121. if !ok {
  122. seqRange = newRange()
  123. }
  124. if c.curLoc+i > seqRange.max {
  125. seqRange.max = c.curLoc + i
  126. }
  127. if seqRange.max > c.curCellRange.max {
  128. c.curCellRange.max = seqRange.max
  129. }
  130. if c.curLoc+i < seqRange.min {
  131. seqRange.min = c.curLoc + i
  132. }
  133. if seqRange.min < c.curCellRange.min {
  134. c.curCellRange.min = seqRange.min
  135. }
  136. c.cellRanges[seq] = seqRange
  137. }
  138. c.curMask, err = c.buildMask(ctx, positions, seqs)
  139. return err
  140. }
  141. func newRange() cellRange {
  142. return cellRange{
  143. min: math.MaxInt,
  144. max: 0,
  145. }
  146. }
  147. // Find the first contiguous block of at least curBatchSize
  148. func (c *Causal) findStartLoc() (int, error) {
  149. var start, count int
  150. for i := range c.cells {
  151. if len(c.cells[i].sequences) == 0 {
  152. count++
  153. if count >= c.curBatchSize {
  154. return start, nil
  155. }
  156. } else {
  157. start = i + 1
  158. count = 0
  159. }
  160. }
  161. return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
  162. }
  163. func roundDown(length, pad int) int {
  164. return (length / pad) * pad
  165. }
  166. func roundUp(length, pad int) int {
  167. return ((length + pad - 1) / pad) * pad
  168. }
  169. // Builds a mask of history x batch indicating whether for each token in the batch the
  170. // token in the history should apply. This is based on both the sequence and causality (the
  171. // position of the history is not ahead of the token in the batch).
  172. func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
  173. // Align and pad the two dimensions as required by the backend
  174. batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
  175. c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
  176. c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
  177. length := c.curCellRange.max - c.curCellRange.min + 1
  178. mask := make([]float32, batchSize*length)
  179. for i := range c.curBatchSize {
  180. for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
  181. if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
  182. c.cells[j].pos < positions[i]-c.windowSize {
  183. mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
  184. }
  185. }
  186. }
  187. // Mask out any padding tokens we added. For padding that we added to the cache history, this
  188. // has already been masked out because the sequence doesn't match.
  189. for i := c.curBatchSize * length; i < len(mask); i++ {
  190. mask[i] = float32(math.Inf(-1))
  191. }
  192. maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
  193. if err != nil {
  194. return nil, err
  195. }
  196. if c.config.MaskDType != ml.DTypeF32 {
  197. out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
  198. ctx.Forward(maskTensor.Copy(ctx, out))
  199. maskTensor = out
  200. }
  201. return maskTensor, nil
  202. }
  203. func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
  204. for i, key := range c.keys {
  205. if key == nil {
  206. continue
  207. }
  208. kHeadDim := key.Dim(0)
  209. numKVHeads := key.Dim(1)
  210. rowSize := key.Stride(2)
  211. kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
  212. kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
  213. value := c.values[i]
  214. var vSrcView, vDstView ml.Tensor
  215. if c.config.PermutedV {
  216. vHeadDim := value.Dim(1)
  217. elemSize := value.Stride(0)
  218. vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
  219. vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
  220. } else {
  221. vHeadDim := value.Dim(0)
  222. rowSize := value.Stride(2)
  223. vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
  224. vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
  225. }
  226. ctx.Forward(
  227. kSrcView.Copy(ctx, kDstView),
  228. vSrcView.Copy(ctx, vDstView),
  229. )
  230. }
  231. }
  232. func (c *Causal) defrag() {
  233. slog.Debug("defragmenting kv cache")
  234. // Defrag strategy:
  235. // - Search for empty holes at the beginning of the cache,
  236. // filling them with active data starting at the end
  237. // - If there are contiguous elements that need to be moved,
  238. // combine them into a single operation by holding new moves
  239. // until we see that the next one is non-contiguous
  240. // - Fill up the context with the maximum number of operations it
  241. // can hold then compute that and continue with a new context
  242. //
  243. // We could try to optimize placement by grouping blocks from
  244. // the same sequences together but most likely the next forward
  245. // pass will disrupt this anyways, so the real world benefit
  246. // seems limited as this time.
  247. ctx := c.backend.NewContext()
  248. // For every move, 6 tensors are required per layer (2 views and a
  249. // copy for each of k and v).
  250. layers := 0
  251. for _, key := range c.keys {
  252. if key == nil {
  253. continue
  254. }
  255. layers++
  256. }
  257. maxMoves := ctx.MaxGraphNodes() / (6 * layers)
  258. moves := 0
  259. var pendingSrc, pendingDst, pendingLen int
  260. src := len(c.cells) - 1
  261. for dst := 0; dst < src; dst++ {
  262. if len(c.cells[dst].sequences) == 0 {
  263. for ; src > dst; src-- {
  264. if len(c.cells[src].sequences) != 0 {
  265. c.cells[dst] = c.cells[src]
  266. c.cells[src] = cacheCell{}
  267. if pendingLen > 0 {
  268. if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
  269. pendingSrc = src
  270. pendingLen++
  271. break
  272. } else {
  273. c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
  274. moves++
  275. }
  276. }
  277. pendingSrc = src
  278. pendingDst = dst
  279. pendingLen = 1
  280. break
  281. }
  282. }
  283. }
  284. if moves >= maxMoves {
  285. ctx.Compute()
  286. ctx.Close()
  287. ctx = c.backend.NewContext()
  288. moves = 0
  289. }
  290. }
  291. if pendingLen > 0 {
  292. c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
  293. moves++
  294. }
  295. if moves > 0 {
  296. ctx.Compute()
  297. }
  298. ctx.Close()
  299. // Reset range metadata
  300. for seq := range c.cellRanges {
  301. seqRange := newRange()
  302. for i, cell := range c.cells {
  303. if slices.Contains(cell.sequences, seq) {
  304. if i < seqRange.min {
  305. seqRange.min = i
  306. }
  307. if i > seqRange.max {
  308. seqRange.max = i
  309. }
  310. }
  311. }
  312. c.cellRanges[seq] = seqRange
  313. }
  314. }
  315. func (c *Causal) SetLayer(layer int) {
  316. c.curLayer = layer
  317. }
  318. func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
  319. key := c.keys[c.curLayer]
  320. value := c.values[c.curLayer]
  321. kHeadDim := key.Dim(0)
  322. numKVHeads := key.Dim(1)
  323. rowSize := key.Stride(2)
  324. cachedSize := c.curMask.Dim(0)
  325. key = key.View(ctx, rowSize*c.curCellRange.min,
  326. kHeadDim, key.Stride(1),
  327. numKVHeads, key.Stride(2),
  328. cachedSize,
  329. )
  330. if c.config.PermutedV {
  331. vHeadDim := value.Dim(1)
  332. elemSize := value.Stride(0)
  333. value = value.View(ctx, elemSize*c.curCellRange.min,
  334. cachedSize, value.Stride(1),
  335. vHeadDim, value.Stride(2),
  336. numKVHeads,
  337. )
  338. } else {
  339. vHeadDim := value.Dim(0)
  340. rowSize := value.Stride(2)
  341. value = value.View(ctx, rowSize*c.curCellRange.min,
  342. vHeadDim, value.Stride(1),
  343. numKVHeads, value.Stride(2),
  344. cachedSize,
  345. )
  346. }
  347. return key, value, c.curMask
  348. }
  349. func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
  350. kHeadDim := key.Dim(0)
  351. vHeadDim := value.Dim(0)
  352. numKVHeads := key.Dim(1)
  353. batchSize := key.Dim(2)
  354. if c.curBatchSize != batchSize {
  355. panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
  356. }
  357. if _, ok := c.ctxs[c.curLayer]; !ok {
  358. c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
  359. }
  360. if _, ok := c.keys[c.curLayer]; !ok {
  361. c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
  362. }
  363. if _, ok := c.values[c.curLayer]; !ok {
  364. if c.config.PermutedV {
  365. c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
  366. } else {
  367. c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
  368. }
  369. }
  370. rowSize := c.keys[c.curLayer].Stride(2)
  371. ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
  372. if c.config.PermutedV {
  373. elemSize := c.values[c.curLayer].Stride(0)
  374. value = value.Permute(ctx, 1, 2, 0, 3)
  375. ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
  376. } else {
  377. rowSize := c.values[c.curLayer].Stride(2)
  378. ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
  379. }
  380. }
  381. func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
  382. seqRange := newRange()
  383. for i := range c.cells {
  384. // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
  385. if slices.Contains(c.cells[i].sequences, dstSeq) {
  386. c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
  387. }
  388. if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
  389. c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
  390. if i < seqRange.min {
  391. seqRange.min = i
  392. }
  393. if i > seqRange.max {
  394. seqRange.max = i
  395. }
  396. }
  397. }
  398. c.cellRanges[dstSeq] = seqRange
  399. }
  400. func (c *Causal) shift(seq int, beginIndex, offset int32) error {
  401. if c.shiftFn == nil {
  402. return ErrNotSupported
  403. }
  404. ctx := c.backend.NewContext()
  405. defer ctx.Close()
  406. seqRange := c.cellRanges[seq]
  407. size := seqRange.max - seqRange.min + 1
  408. offsets := make([]int32, size)
  409. for i := range offsets {
  410. cell := c.cells[seqRange.min+i]
  411. if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
  412. offsets[i] = offset
  413. }
  414. }
  415. kShift, err := ctx.FromIntSlice(offsets, len(offsets))
  416. if err != nil {
  417. return err
  418. }
  419. for i, key := range c.keys {
  420. if key == nil {
  421. continue
  422. }
  423. kHeadDim := key.Dim(0)
  424. numKVHeads := key.Dim(1)
  425. rowSize := key.Stride(2)
  426. key = key.View(ctx, rowSize*seqRange.min,
  427. kHeadDim, key.Stride(1),
  428. numKVHeads, key.Stride(2),
  429. size,
  430. )
  431. roped, err := c.shiftFn(ctx, i, key, kShift)
  432. if err != nil {
  433. return err
  434. }
  435. ctx.Forward(roped.Copy(ctx, key))
  436. }
  437. ctx.Compute()
  438. return nil
  439. }
  440. func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
  441. var offset int32
  442. if endIndex != math.MaxInt32 {
  443. offset = beginIndex - endIndex
  444. }
  445. seqRange := newRange()
  446. for i := range c.cells {
  447. if slices.Contains(c.cells[i].sequences, seq) {
  448. if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
  449. c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
  450. } else {
  451. if c.cells[i].pos >= endIndex {
  452. if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
  453. // TODO(jessegross): Need to be careful about data shared between sequences
  454. return errors.New("shifting on cells shared by multiple sequences not yet implemented")
  455. }
  456. c.cells[i].pos += offset
  457. }
  458. if i < seqRange.min {
  459. seqRange.min = i
  460. }
  461. if i > seqRange.max {
  462. seqRange.max = i
  463. }
  464. }
  465. }
  466. }
  467. if seqRange == newRange() {
  468. delete(c.cellRanges, seq)
  469. return nil
  470. }
  471. c.cellRanges[seq] = seqRange
  472. if endIndex != math.MaxInt32 {
  473. err := c.shift(seq, endIndex+offset, offset)
  474. if err != nil {
  475. return err
  476. }
  477. }
  478. return nil
  479. }