Browse Source

model flexibility

Roy Han 8 tháng trước cách đây
mục cha
commit
2a9feb0707
4 tập tin đã thay đổi với 31 bổ sung14 xóa
  1. 2 0
      api/types.go
  2. 14 0
      docs/whisper.md
  3. 13 12
      server/routes.go
  4. 2 2
      server/sched.go

+ 2 - 0
api/types.go

@@ -81,6 +81,8 @@ type GenerateRequest struct {
 	// set through this field, if the model supports it.
 	Options map[string]interface{} `json:"options"`
 
+	WhisperModel string `json:"whisper_model,omitempty"`
+
 	Audio string `json:"audio,omitempty"`
 
 	Transcribe bool `json:"transcribe,omitempty"`

+ 14 - 0
docs/whisper.md

@@ -0,0 +1,14 @@
+# Whisper Prototype
+
+### To run
+`make {/path/to/whisper.cpp/server}`
+
+### Update routes.go
+- replace `whisperServer` with path to server
+
+## api/generate
+### Request fields
+    - "audio" (required): path to audio file
+    - "whisper_model" (required): path to whisper model
+    - "transcribe" (optional): if true, will transcribe and return the audio file
+    - "prompt" (optional): if not null, passed in with the transcribed audio

+ 13 - 12
server/routes.go

@@ -109,11 +109,11 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
 	return runner.llama, model, &opts, nil
 }
 
-func (s *Server) runWhisperServer(c *gin.Context, portCh chan int) {
+func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, modelPath string) {
 	s.sched.whisperMu.Lock()
-	if s.sched.whisperPort != nil {
-		slog.Info("whisper server already running", "port", *s.sched.whisperPort)
-		portCh <- *s.sched.whisperPort
+	if s.sched.whisperLoaded[modelPath] != nil {
+		slog.Info("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath])
+		portCh <- *s.sched.whisperLoaded[modelPath]
 		s.sched.whisperMu.Unlock()
 		return
 	}
@@ -134,7 +134,7 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int) {
 		slog.Debug("ResolveTCPAddr failed")
 		port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
 	}
-	finalParams := append(params, "--port", strconv.Itoa(port), "--model", "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin")
+	finalParams := append(params, "--port", strconv.Itoa(port), "--model", modelPath)
 
 	cmd := exec.Command(whisperServer, finalParams...)
 	slog.Info("starting whisper server", "cmd", cmd.String())
@@ -146,6 +146,7 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int) {
 		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to start whisper server"})
 	}
 
+	// Wait for server connection
 	retries := 10
 	for range retries {
 		time.Sleep(25 * time.Millisecond)
@@ -162,7 +163,7 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int) {
 	}
 
 	portCh <- port
-	s.sched.whisperPort = &port
+	s.sched.whisperLoaded[modelPath] = &port
 
 	s.sched.whisperMu.Unlock()
 
@@ -170,12 +171,11 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int) {
 	defer func() {
 		err = cmd.Wait()
 		if err != nil {
-			c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "whisper server exited"})
-		}
-		err := cmd.Process.Kill()
-		if err != nil {
-			slog.Error("failed to kill whisper server", "error", err)
+			c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
 		}
+		s.sched.whisperMu.Lock()
+		delete(s.sched.whisperLoaded, modelPath)
+		s.sched.whisperMu.Unlock()
 	}()
 }
 
@@ -279,7 +279,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 
 	if req.Audio != "" {
 		port := make(chan int, 1)
-		go s.runWhisperServer(c, port)
+		go s.runWhisperServer(c, port, req.WhisperModel)
 
 		w, err := whisperInference(c, req.Audio, <-port)
 		if err != nil {
@@ -295,6 +295,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 				Done:       true,
 				DoneReason: "stop",
 			})
+			return
 		}
 
 		req.Prompt += w.Text

+ 2 - 2
server/sched.go

@@ -47,8 +47,8 @@ type Scheduler struct {
 	getCpuFn     func() gpu.GpuInfoList
 	reschedDelay time.Duration
 
-	whisperPort *int
-	whisperMu   sync.Mutex
+	whisperLoaded map[string]*int
+	whisperMu     sync.Mutex
 }
 
 // Default automatic value for number of models we allow per GPU