@@ -598,6 +598,10 @@ func (c *Context) SetCrossAttention(state bool) {
C.llama_set_cross_attention(c.c, C.bool(state))
}
+func (c *Context) Synchronize() {
+ C.llama_synchronize(c.c)
+}
+
// sampling
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
type SamplingContext struct {
@@ -427,6 +427,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
return
+ if crossAttention {
+ // synchronize state to ensure the cross attention batch is complete.
+ // needed specifically for multi-GPU systems otherwise an inflight
+ // task may be incorrectly invalidated causing a crash
+ s.lc.Synchronize()
+ }
for i, seq := range s.seqs {
if seq == nil {
continue