|
@@ -118,7 +118,12 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
|
|
c.config.MaskDType = ml.DTypeF32
|
|
|
}
|
|
|
|
|
|
- cacheSize := maxSequences * capacity
|
|
|
+ var cacheSize int
|
|
|
+ if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
|
|
|
+ cacheSize = maxSequences * capacity
|
|
|
+ } else {
|
|
|
+ cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
|
|
|
+ }
|
|
|
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
|
|
c.cells = make([]cacheCell, cacheSize)
|
|
|
|
|
@@ -147,6 +152,8 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
|
|
|
c.curPositions = batch.Positions
|
|
|
c.opts.Except = nil
|
|
|
|
|
|
+ c.updateSlidingWindow()
|
|
|
+
|
|
|
var err error
|
|
|
c.curLoc, err = c.findStartLoc()
|
|
|
if errors.Is(err, ErrKvCacheFull) {
|
|
@@ -214,6 +221,50 @@ func (c *Causal) findStartLoc() (int, error) {
|
|
|
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
|
|
|
}
|
|
|
|
|
|
+func (c *Causal) updateSlidingWindow() {
|
|
|
+ if c.windowSize == math.MaxInt32 {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // create a map of unique sequences to the lowest position in that sequence
|
|
|
+ lowestPos := make(map[int]int32)
|
|
|
+ for i := range c.curPositions {
|
|
|
+ seq := c.curSequences[i]
|
|
|
+
|
|
|
+ pos, ok := lowestPos[seq]
|
|
|
+ if !ok {
|
|
|
+ pos = c.curPositions[i]
|
|
|
+ } else if c.curPositions[i] < pos {
|
|
|
+ pos = c.curPositions[i]
|
|
|
+ }
|
|
|
+
|
|
|
+ lowestPos[seq] = pos
|
|
|
+ }
|
|
|
+
|
|
|
+ // delete any entries that are beyond the window of the oldest position in the sequence
|
|
|
+ for seq, pos := range lowestPos {
|
|
|
+ oldRange, ok := c.cellRanges[seq]
|
|
|
+ if !ok {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ newRange := newRange()
|
|
|
+
|
|
|
+ for i := oldRange.min; i <= oldRange.max; i++ {
|
|
|
+ if slices.Contains(c.cells[i].sequences, seq) {
|
|
|
+ if c.cells[i].pos < pos-c.windowSize {
|
|
|
+ c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
|
|
+ } else {
|
|
|
+ newRange.min = min(newRange.min, i)
|
|
|
+ newRange.max = max(newRange.max, i)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ c.cellRanges[seq] = newRange
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func roundDown(length, pad int) int {
|
|
|
return (length / pad) * pad
|
|
|
}
|