causal.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. package kvcache
  2. import (
  3. "errors"
  4. "fmt"
  5. "log/slog"
  6. "math"
  7. "slices"
  8. "github.com/ollama/ollama/ml"
  9. )
  10. type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
  11. // Causal cache stores K and V tensors according to their position in the
  12. // sequence. Returns the history and a mask for attending to past tokens
  13. //
  14. // The tensors are of shape embed dim, kv heads, batch size
  15. // The mask is of shape history size, batch size
  16. type Causal struct {
  17. DType ml.DType
  18. Capacity int32
  19. windowSize int32
  20. // ** current forward pass **
  21. // the active layer for Get and Put
  22. curLayer int
  23. // starting location for data storage for this batch
  24. curLoc int
  25. // size of the current batch
  26. curBatchSize int
  27. // mask of the cache as used by this batch
  28. curMask ml.Tensor
  29. // locations in the cache that are needed for this batch
  30. curCellRange cellRange
  31. // ** cache metadata **
  32. // for each possible location in the cache, stores the position and set of sequences
  33. // that reference the data there
  34. cells []cacheCell
  35. // maps from sequence to the range of locations where it is stored in the cache
  36. cellRanges map[int]cellRange
  37. // ** cache data storage **
  38. shiftFn shiftFn
  39. backend ml.Backend
  40. cacheCtx ml.Context
  41. keys, values []ml.Tensor
  42. }
  43. type cacheCell struct {
  44. pos int32
  45. sequences []int
  46. }
  47. type cellRange struct {
  48. min int
  49. max int
  50. }
  51. func NewCausalCache(shift shiftFn) *Causal {
  52. return &Causal{windowSize: math.MaxInt32, shiftFn: shift}
  53. }
  54. func NewSWACache(windowSize int32, shift shiftFn) *Causal {
  55. return &Causal{windowSize: windowSize, shiftFn: shift}
  56. }
  57. func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
  58. c.DType = dtype
  59. c.Capacity = capacity
  60. c.cells = make([]cacheCell, capacity)
  61. c.cellRanges = make(map[int]cellRange)
  62. c.backend = backend
  63. c.cacheCtx = backend.NewContext()
  64. }
  65. func (c *Causal) Close() {
  66. c.cacheCtx.Close()
  67. }
  68. func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
  69. c.curBatchSize = len(positions)
  70. var err error
  71. c.curLoc, err = c.findStartLoc()
  72. if errors.Is(err, ErrKvCacheFull) {
  73. c.defrag()
  74. c.curLoc, err = c.findStartLoc()
  75. }
  76. if err != nil {
  77. return err
  78. }
  79. c.curCellRange = newRange()
  80. for i, pos := range positions {
  81. seq := seqs[i]
  82. c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
  83. seqRange, ok := c.cellRanges[seq]
  84. if !ok {
  85. seqRange = newRange()
  86. }
  87. if c.curLoc+i > seqRange.max {
  88. seqRange.max = c.curLoc + i
  89. }
  90. if seqRange.max > c.curCellRange.max {
  91. c.curCellRange.max = seqRange.max
  92. }
  93. if c.curLoc+i < seqRange.min {
  94. seqRange.min = c.curLoc + i
  95. }
  96. if seqRange.min < c.curCellRange.min {
  97. c.curCellRange.min = seqRange.min
  98. }
  99. c.cellRanges[seq] = seqRange
  100. }
  101. c.curMask, err = c.buildMask(ctx, positions, seqs)
  102. return err
  103. }
  104. func newRange() cellRange {
  105. return cellRange{
  106. min: math.MaxInt,
  107. max: 0,
  108. }
  109. }
  110. // Find the first contiguous block of at least curBatchSize
  111. func (c *Causal) findStartLoc() (int, error) {
  112. var start, count int
  113. for i := range c.cells {
  114. if len(c.cells[i].sequences) == 0 {
  115. count++
  116. if count >= c.curBatchSize {
  117. return start, nil
  118. }
  119. } else {
  120. start = i + 1
  121. count = 0
  122. }
  123. }
  124. return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
  125. }
  126. // Builds a mask of history x batch indicating whether for each token in the batch the
  127. // token in the history should apply. This is based on both the sequence and causality (the
  128. // position of the history is not ahead of the token in the batch).
  129. func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
  130. // TODO(jessegross): This does not do padding, which is required for flash attention
  131. len := c.curCellRange.max - c.curCellRange.min + 1
  132. mask := make([]float32, c.curBatchSize*len)
  133. for i := range c.curBatchSize {
  134. for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
  135. if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
  136. c.cells[j].pos < positions[i]-c.windowSize {
  137. mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
  138. }
  139. }
  140. }
  141. return ctx.FromFloatSlice(mask, len, c.curBatchSize)
  142. }
  143. func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
  144. for _, obj := range objs {
  145. if obj == nil {
  146. continue
  147. }
  148. srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len)
  149. dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
  150. ctx.Forward(srcView.Copy(ctx, dstView))
  151. }
  152. }
  153. func (c *Causal) defrag() {
  154. slog.Debug("defragmenting kv cache")
  155. // Defrag strategy:
  156. // - Search for empty holes at the beginning of the cache,
  157. // filling them with active data starting at the end
  158. // - If there are contiguous elements that need to be moved,
  159. // combine them into a single operation by holding new moves
  160. // until we see that the next one is non-contiguous
  161. // - Fill up the context with the maximum number of operations it
  162. // can hold then compute that and continue with a new context
  163. //
  164. // We could try to optimize placement by grouping blocks from
  165. // the same sequences together but most likely the next forward
  166. // pass will disrupt this anyways, so the real world benefit
  167. // seems limited as this time.
  168. ctx := c.backend.NewContext()
  169. // For every move, 6 tensors are required per layer (2 views and a
  170. // copy for each of k and v).
  171. layers := 0
  172. for _, key := range c.keys {
  173. if key == nil {
  174. continue
  175. }
  176. layers++
  177. }
  178. maxMoves := ctx.MaxTensors() / (6 * layers)
  179. moves := 0
  180. var pendingSrc, pendingDst, pendingLen int
  181. src := len(c.cells) - 1
  182. for dst := 0; dst < src; dst++ {
  183. if len(c.cells[dst].sequences) == 0 {
  184. for ; src > dst; src-- {
  185. if len(c.cells[src].sequences) != 0 {
  186. c.cells[dst] = c.cells[src]
  187. c.cells[src] = cacheCell{}
  188. if pendingLen > 0 {
  189. if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
  190. pendingSrc = src
  191. pendingLen++
  192. break
  193. } else {
  194. moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
  195. moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
  196. moves++
  197. }
  198. }
  199. pendingSrc = src
  200. pendingDst = dst
  201. pendingLen = 1
  202. break
  203. }
  204. }
  205. }
  206. if moves >= maxMoves {
  207. ctx.Compute()
  208. ctx.Close()
  209. ctx = c.backend.NewContext()
  210. moves = 0
  211. }
  212. }
  213. if pendingLen > 0 {
  214. moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
  215. moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
  216. moves++
  217. }
  218. if moves > 0 {
  219. ctx.Compute()
  220. }
  221. ctx.Close()
  222. // Reset range metadata
  223. for seq := range c.cellRanges {
  224. seqRange := newRange()
  225. for i, cell := range c.cells {
  226. if slices.Contains(cell.sequences, seq) {
  227. if i < seqRange.min {
  228. seqRange.min = i
  229. }
  230. if i > seqRange.max {
  231. seqRange.max = i
  232. }
  233. }
  234. }
  235. c.cellRanges[seq] = seqRange
  236. }
  237. }
  238. func (c *Causal) SetLayer(layer int) {
  239. if layer >= len(c.keys) {
  240. c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
  241. c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
  242. }
  243. c.curLayer = layer
  244. }
  245. func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
  246. key := c.keys[c.curLayer]
  247. value := c.values[c.curLayer]
  248. key = key.View(ctx, key.Stride(2)*c.curCellRange.min,
  249. key.Dim(0), key.Stride(1),
  250. key.Dim(1), key.Stride(2),
  251. c.curMask.Dim(0),
  252. )
  253. value = value.View(ctx, key.Stride(2)*c.curCellRange.min,
  254. value.Dim(0), value.Stride(1),
  255. value.Dim(1), value.Stride(2),
  256. c.curMask.Dim(0),
  257. )
  258. return key, value, c.curMask
  259. }
  260. func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
  261. if c.curBatchSize != key.Dim(2) {
  262. panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2)))
  263. }
  264. if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
  265. c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity))
  266. c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
  267. }
  268. ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))))
  269. ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))))
  270. }
  271. func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
  272. seqRange := newRange()
  273. for i := range c.cells {
  274. // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
  275. if slices.Contains(c.cells[i].sequences, dstSeq) {
  276. c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
  277. }
  278. if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
  279. c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
  280. if i < seqRange.min {
  281. seqRange.min = i
  282. }
  283. if i > seqRange.max {
  284. seqRange.max = i
  285. }
  286. }
  287. }
  288. c.cellRanges[dstSeq] = seqRange
  289. }
  290. func (c *Causal) shift(seq int, beginIndex, offset int32) error {
  291. if c.shiftFn == nil {
  292. return ErrNotSupported
  293. }
  294. ctx := c.backend.NewContext()
  295. defer ctx.Close()
  296. seqRange := c.cellRanges[seq]
  297. size := seqRange.max - seqRange.min + 1
  298. offsets := make([]int32, size)
  299. for i := range offsets {
  300. cell := c.cells[seqRange.min+i]
  301. if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
  302. offsets[i] = offset
  303. }
  304. }
  305. kShift, err := ctx.FromIntSlice(offsets, len(offsets))
  306. if err != nil {
  307. return err
  308. }
  309. for i, key := range c.keys {
  310. if key == nil {
  311. continue
  312. }
  313. key = key.View(ctx, key.Stride(2)*seqRange.min,
  314. key.Dim(0), key.Stride(1),
  315. key.Dim(1), key.Stride(2),
  316. size,
  317. )
  318. roped, err := c.shiftFn(ctx, i, key, kShift)
  319. if err != nil {
  320. return err
  321. }
  322. ctx.Forward(roped.Copy(ctx, key))
  323. }
  324. ctx.Compute()
  325. return nil
  326. }
  327. func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
  328. var offset int32
  329. if endIndex != math.MaxInt32 {
  330. offset = beginIndex - endIndex
  331. }
  332. seqRange := newRange()
  333. for i := range c.cells {
  334. if slices.Contains(c.cells[i].sequences, seq) {
  335. if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
  336. c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
  337. } else {
  338. if c.cells[i].pos >= endIndex {
  339. if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
  340. // TODO(jessegross): Need to be careful about data shared between sequences
  341. return errors.New("shifting on cells shared by multiple sequences not yet implemented")
  342. }
  343. c.cells[i].pos += offset
  344. }
  345. if i < seqRange.min {
  346. seqRange.min = i
  347. }
  348. if i > seqRange.max {
  349. seqRange.max = i
  350. }
  351. }
  352. }
  353. }
  354. if seqRange == newRange() {
  355. delete(c.cellRanges, seq)
  356. return nil
  357. }
  358. c.cellRanges[seq] = seqRange
  359. if endIndex != math.MaxInt32 {
  360. err := c.shift(seq, endIndex+offset, offset)
  361. if err != nil {
  362. return err
  363. }
  364. }
  365. return nil
  366. }