|
@@ -119,10 +119,10 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
|
|
}
|
|
|
|
|
|
var cacheSize int
|
|
|
- if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
|
|
|
+ if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) {
|
|
|
cacheSize = maxSequences * capacity
|
|
|
} else {
|
|
|
- cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
|
|
|
+ cacheSize = (maxSequences * int(c.windowSize)) + maxBatch
|
|
|
}
|
|
|
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
|
|
c.cells = make([]cacheCell, cacheSize)
|