Sfoglia il codice sorgente

relay CUDA errors to the client (#825)

Bruce MacDonald 1 anno fa
parent
commit
565648f3f7
1 ha cambiato i file con 35 aggiunte e 12 eliminazioni
  1. 35 12
      llm/llama.go

+ 35 - 12
llm/llama.go

@@ -183,12 +183,12 @@ type llamaHyperparameters struct {
 }
 
 type Running struct {
-	Port     int
-	Cmd      *exec.Cmd
-	Cancel   context.CancelFunc
-	exitOnce sync.Once
-	exitCh   chan error // channel to receive the exit status of the subprocess
-	exitErr  error      // error returned by the subprocess
+	Port          int
+	Cmd           *exec.Cmd
+	Cancel        context.CancelFunc
+	exitOnce      sync.Once
+	exitCh        chan error // channel to receive the exit status of the subprocess
+	*StatusWriter            // captures error messages from the llama runner process
 }
 
 type llama struct {
@@ -259,7 +259,8 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
 
 // StatusWriter is a writer that captures error messages from the llama runner process
 type StatusWriter struct {
-	ErrCh chan error
+	ErrCh      chan error
+	LastErrMsg string
 }
 
 func NewStatusWriter() *StatusWriter {
@@ -269,9 +270,18 @@ func NewStatusWriter() *StatusWriter {
 }
 
 func (w *StatusWriter) Write(b []byte) (int, error) {
+	var errMsg string
 	if _, after, ok := bytes.Cut(b, []byte("error:")); ok {
-		w.ErrCh <- fmt.Errorf("llama runner: %s", bytes.TrimSpace(after))
+		errMsg = string(bytes.TrimSpace(after))
+	} else if _, after, ok := bytes.Cut(b, []byte("CUDA error")); ok {
+		errMsg = string(bytes.TrimSpace(after))
 	}
+
+	if errMsg != "" {
+		w.LastErrMsg = errMsg
+		w.ErrCh <- fmt.Errorf("llama runner: %s", errMsg)
+	}
+
 	return os.Stderr.Write(b)
 }
 
@@ -359,7 +369,13 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
 		// monitor the llama runner process and signal when it exits
 		go func() {
 			err := llm.Cmd.Wait()
-			llm.exitErr = err
+			// default to printing the exit message of the command process, it will probably just say 'exit staus 1'
+			errMsg := err.Error()
+			// try to set a better error message if llama runner logs captured an error
+			if statusWriter.LastErrMsg != "" {
+				errMsg = statusWriter.LastErrMsg
+			}
+			log.Println(errMsg)
 			// llm.Cmd.Wait() can only be called once, use this exit channel to signal that the process has exited
 			llm.exitOnce.Do(func() {
 				close(llm.exitCh)
@@ -429,10 +445,9 @@ func (llm *llama) Close() {
 
 	// wait for the command to exit to prevent race conditions with the next run
 	<-llm.exitCh
-	err := llm.exitErr
 
-	if err != nil {
-		log.Printf("llama runner stopped with error: %v", err)
+	if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
+		log.Printf("llama runner stopped with error: %v", llm.StatusWriter.LastErrMsg)
 	} else {
 		log.Print("llama runner stopped successfully")
 	}
@@ -569,6 +584,14 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
 	}
 
 	if err := scanner.Err(); err != nil {
+		if strings.Contains(err.Error(), "unexpected EOF") {
+			// this means the llama runner subprocess crashed
+			llm.Close()
+			if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
+				return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg)
+			}
+			return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model")
+		}
 		return fmt.Errorf("error reading llm response: %v", err)
 	}