Explorar el Código

fix(mllama): sync backend between batches

Michael Yang hace 5 meses
padre
commit
5b3393b6a2
Se han modificado 2 ficheros con 11 adiciones y 0 borrados
  1. 4 0
      llama/llama.go
  2. 7 0
      llama/runner/runner.go

+ 4 - 0
llama/llama.go

@@ -598,6 +598,10 @@ func (c *Context) SetCrossAttention(state bool) {
 	C.llama_set_cross_attention(c.c, C.bool(state))
 	C.llama_set_cross_attention(c.c, C.bool(state))
 }
 }
 
 
+func (c *Context) Synchronize() {
+	C.llama_synchronize(c.c)
+}
+
 // sampling
 // sampling
 // TODO: this is a temporary wrapper to allow calling C++ code from CGo
 // TODO: this is a temporary wrapper to allow calling C++ code from CGo
 type SamplingContext struct {
 type SamplingContext struct {

+ 7 - 0
llama/runner/runner.go

@@ -427,6 +427,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 		return
 		return
 	}
 	}
 
 
+	if crossAttention {
+		// synchronize state to ensure the cross attention batch is complete.
+		// needed specifically for multi-GPU systems otherwise an inflight
+		// task may be incorrectly invalidated causing a crash
+		s.lc.Synchronize()
+	}
+
 	for i, seq := range s.seqs {
 	for i, seq := range s.seqs {
 		if seq == nil {
 		if seq == nil {
 			continue
 			continue