Roy Han 8 miesięcy temu
rodzic
commit
d503f04b32
5 zmienionych plików z 150 dodań i 58 usunięć
  1. 9 6
      api/types.go
  2. 73 0
      docs/speech.md
  3. 0 20
      docs/whisper.md
  4. 50 20
      server/routes.go
  5. 18 12
      server/sched.go

+ 9 - 6
api/types.go

@@ -36,6 +36,13 @@ func (e StatusError) Error() string {
 // ImageData represents the raw binary data of an image file.
 type ImageData []byte
 
+type WhisperRequest struct {
+	Model      string    `json:"model"`
+	Audio      string    `json:"audio,omitempty"`
+	Transcribe bool      `json:"transcribe,omitempty"`
+	KeepAlive  *Duration `json:"keep_alive,omitempty"`
+}
+
 // GenerateRequest describes a request sent by [Client.Generate]. While you
 // have to specify the Model and Prompt fields, all the other fields have
 // reasonable defaults for basic uses.
@@ -81,11 +88,7 @@ type GenerateRequest struct {
 	// set through this field, if the model supports it.
 	Options map[string]interface{} `json:"options"`
 
-	WhisperModel string `json:"whisper_model,omitempty"`
-
-	Audio string `json:"audio,omitempty"`
-
-	Transcribe bool `json:"transcribe,omitempty"`
+	Speech *WhisperRequest `json:"speech,omitempty"`
 }
 
 // ChatRequest describes a request sent by [Client.Chat].
@@ -112,7 +115,7 @@ type ChatRequest struct {
 	// Options lists model-specific options.
 	Options map[string]interface{} `json:"options"`
 
-	WhisperModel string `json:"whisper_model,omitempty"`
+	Speech *WhisperRequest `json:"speech,omitempty"`
 }
 
 type Tools []Tool

+ 73 - 0
docs/speech.md

@@ -0,0 +1,73 @@
+# Speech to Text Prototype
+
+### To run
+`make {/path/to/whisper.cpp/server}`
+
+### Update routes.go
+- replace `whisperServer` with path to server
+
+## api/generate
+### Request fields
+- `speech` (required):
+    - `audio` (required): path to audio file
+    - `model` (required): path to whisper model
+    - `transcribe` (optional): if true, will transcribe and return the audio file
+    - `keep_alive`: (optional): sets how long the model is stored in memory (default: `5m`)
+- `prompt` (optional): if not null, passed in with the transcribed audio
+
+#### Transcription
+```
+curl http://localhost:11434/api/generate -d '{
+    "speech": {
+        "model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
+        "audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav",
+        "transcribe": true,
+        "keep_alive": "1m"
+    },
+    "stream": false
+}' | jq
+```
+
+#### Response Generation
+```
+curl http://localhost:11434/api/generate -d '{
+    "model": "llama3",
+    "prompt": "What do you think about this quote?",
+    "speech": {
+        "model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
+        "audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav",
+        "keep_alive": "1m"
+    },
+    "stream": false
+}' | jq
+```
+
+## api/chat
+### Request fields
+- `model` (required): language model to chat with
+- `speech` (required):
+    - `model` (required): path to whisper model
+    - `keep_alive`: (optional): sets how long the model is stored in memory (default: `5m`)
+- `messages`/`message`/`audio` (required): path to audio file
+
+```
+curl http://localhost:11434/api/chat -d '{
+    "model": "llama3",
+    "speech": {
+        "model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
+        "keep_alive": "10m"
+    },
+    "messages": [
+        {
+            "role": "system",
+            "content": "You are a Canadian Nationalist"
+        },
+        {
+            "role": "user",
+            "content": "What do you think about this quote?",
+            "audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav"
+        }
+    ],
+    "stream": false
+}' | jq
+```

+ 0 - 20
docs/whisper.md

@@ -1,20 +0,0 @@
-# Whisper Prototype
-
-### To run
-`make {/path/to/whisper.cpp/server}`
-
-### Update routes.go
-- replace `whisperServer` with path to server
-
-## api/generate
-### Request fields
-    - "audio" (required): path to audio file
-    - "whisper_model" (required): path to whisper model
-    - "transcribe" (optional): if true, will transcribe and return the audio file
-    - "prompt" (optional): if not null, passed in with the transcribed audio
-
-## api/chat
-### Request fields
-    - "whisper_model" (required): path to whisper model
-    - "message" object
-        - "audio" (required): contains path to audio file

+ 50 - 20
server/routes.go

@@ -109,11 +109,23 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
 	return runner.llama, model, &opts, nil
 }
 
-func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, modelPath string) {
+func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, speech *api.WhisperRequest) {
+	modelPath := speech.Model
+
+	// default to 5 minutes
+	var sessionDuration time.Duration
+	if speech.KeepAlive != nil {
+		sessionDuration = speech.KeepAlive.Duration
+	} else {
+		sessionDuration = 5 * time.Minute
+	}
+
 	s.sched.whisperMu.Lock()
 	if s.sched.whisperLoaded[modelPath] != nil {
 		slog.Info(fmt.Sprintf("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath]))
 		portCh <- *s.sched.whisperLoaded[modelPath]
+		// Renew the expiration time
+		s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration)
 		s.sched.whisperMu.Unlock()
 		return
 	}
@@ -149,36 +161,52 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan er
 
 	// Wait for server connection
 	retries := 10
+	var connErr error
 	for range retries {
 		time.Sleep(50 * time.Millisecond)
 		conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), time.Second)
 		if err == nil {
 			conn.Close()
+			connErr = nil
 			break
 		}
+		connErr = err
 	}
 
-	if err != nil {
-		slog.Error("failed to connect to whisper server", "error", err)
-		errCh <- err
+	if connErr != nil {
+		slog.Error("failed to connect to whisper server", "error", connErr)
+		errCh <- connErr
 		return
 	}
 
 	portCh <- port
 	s.sched.whisperLoaded[modelPath] = &port
+	s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration)
 
 	s.sched.whisperMu.Unlock()
 
 	// Wait for the whisper server to exit
 	defer func() {
-		err = cmd.Wait()
-		if err != nil {
-			c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
-			return
+		ticker := time.NewTicker(5 * time.Second)
+		defer ticker.Stop()
+		for range ticker.C {
+			s.sched.whisperMu.Lock()
+			if time.Now().After(s.sched.whisperExpiresAt[modelPath]) {
+				slog.Info("exiting whisper server")
+				delete(s.sched.whisperLoaded, modelPath)
+				delete(s.sched.whisperExpiresAt, modelPath)
+				s.sched.whisperMu.Unlock()
+
+				if err := cmd.Process.Kill(); err != nil {
+					c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
+					return
+				}
+
+				slog.Debug("whisper server stopped")
+				return
+			}
+			s.sched.whisperMu.Unlock()
 		}
-		s.sched.whisperMu.Lock()
-		delete(s.sched.whisperLoaded, modelPath)
-		s.sched.whisperMu.Unlock()
 	}()
 }
 
@@ -280,10 +308,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		caps = append(caps, CapabilityInsert)
 	}
 
-	if req.Audio != "" {
+	if req.Speech != nil {
 		portCh := make(chan int, 1)
 		errCh := make(chan error, 1)
-		go s.runWhisperServer(c, portCh, errCh, req.WhisperModel)
+		go s.runWhisperServer(c, portCh, errCh, req.Speech)
 
 		var port int
 
@@ -294,19 +322,19 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 			return
 		}
 
-		w, err := whisperInference(c, req.Audio, port)
+		w, err := whisperInference(c, req.Speech.Audio, port)
 		if err != nil {
 			c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
 			return
 		}
 
-		if req.Transcribe {
+		if req.Speech.Transcribe {
 			c.JSON(http.StatusOK, api.GenerateResponse{
 				Model:      req.Model,
 				CreatedAt:  time.Now().UTC(),
 				Response:   w.Text,
 				Done:       true,
-				DoneReason: "stop",
+				DoneReason: "transcribe",
 			})
 			return
 		}
@@ -1481,13 +1509,13 @@ func (s *Server) ProcessHandler(c *gin.Context) {
 	c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
 }
 
-func processAudio(c *gin.Context, s *Server, msgs []api.Message, model string) {
-	if model == "" {
+func processAudio(c *gin.Context, s *Server, msgs []api.Message, req *api.WhisperRequest) {
+	if req.Model == "" {
 		return
 	}
 	portCh := make(chan int, 1)
 	errCh := make(chan error, 1)
-	go s.runWhisperServer(c, portCh, errCh, model)
+	go s.runWhisperServer(c, portCh, errCh, req)
 
 	var port int
 	select {
@@ -1554,7 +1582,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
 	}
 
-	processAudio(c, s, msgs, req.WhisperModel)
+	if req.Speech != nil {
+		processAudio(c, s, msgs, req.Speech)
+	}
 
 	prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
 	if err != nil {

+ 18 - 12
server/sched.go

@@ -47,8 +47,9 @@ type Scheduler struct {
 	getCpuFn     func() gpu.GpuInfoList
 	reschedDelay time.Duration
 
-	whisperLoaded map[string]*int
-	whisperMu     sync.Mutex
+	whisperLoaded    map[string]*int
+	whisperExpiresAt map[string]time.Time
+	whisperMu        sync.Mutex
 }
 
 // Default automatic value for number of models we allow per GPU
@@ -66,16 +67,17 @@ var ErrMaxQueue = errors.New("server busy, please try again.  maximum pending re
 func InitScheduler(ctx context.Context) *Scheduler {
 	maxQueue := envconfig.MaxQueue()
 	sched := &Scheduler{
-		pendingReqCh:  make(chan *LlmRequest, maxQueue),
-		finishedReqCh: make(chan *LlmRequest, maxQueue),
-		expiredCh:     make(chan *runnerRef, maxQueue),
-		unloadedCh:    make(chan interface{}, maxQueue),
-		loaded:        make(map[string]*runnerRef),
-		newServerFn:   llm.NewLlamaServer,
-		getGpuFn:      gpu.GetGPUInfo,
-		getCpuFn:      gpu.GetCPUInfo,
-		reschedDelay:  250 * time.Millisecond,
-		whisperLoaded: make(map[string]*int),
+		pendingReqCh:     make(chan *LlmRequest, maxQueue),
+		finishedReqCh:    make(chan *LlmRequest, maxQueue),
+		expiredCh:        make(chan *runnerRef, maxQueue),
+		unloadedCh:       make(chan interface{}, maxQueue),
+		loaded:           make(map[string]*runnerRef),
+		newServerFn:      llm.NewLlamaServer,
+		getGpuFn:         gpu.GetGPUInfo,
+		getCpuFn:         gpu.GetCPUInfo,
+		reschedDelay:     250 * time.Millisecond,
+		whisperLoaded:    make(map[string]*int),
+		whisperExpiresAt: make(map[string]time.Time),
 	}
 	sched.loadFn = sched.load
 	return sched
@@ -114,6 +116,10 @@ func (s *Scheduler) Run(ctx context.Context) {
 	go func() {
 		s.processCompleted(ctx)
 	}()
+
+	// go func() {
+	// 	could clean up whisper servers in init thread
+	// }
 }
 
 func (s *Scheduler) processPending(ctx context.Context) {