|
@@ -109,11 +109,23 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
|
|
|
return runner.llama, model, &opts, nil
|
|
|
}
|
|
|
|
|
|
-func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, modelPath string) {
|
|
|
+func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, speech *api.WhisperRequest) {
|
|
|
+ modelPath := speech.Model
|
|
|
+
|
|
|
+ // default to 5 minutes
|
|
|
+ var sessionDuration time.Duration
|
|
|
+ if speech.KeepAlive != nil {
|
|
|
+ sessionDuration = speech.KeepAlive.Duration
|
|
|
+ } else {
|
|
|
+ sessionDuration = 5 * time.Minute
|
|
|
+ }
|
|
|
+
|
|
|
s.sched.whisperMu.Lock()
|
|
|
if s.sched.whisperLoaded[modelPath] != nil {
|
|
|
slog.Info(fmt.Sprintf("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath]))
|
|
|
portCh <- *s.sched.whisperLoaded[modelPath]
|
|
|
+ // Renew the expiration time
|
|
|
+ s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration)
|
|
|
s.sched.whisperMu.Unlock()
|
|
|
return
|
|
|
}
|
|
@@ -149,36 +161,52 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan er
|
|
|
|
|
|
// Wait for server connection
|
|
|
retries := 10
|
|
|
+ var connErr error
|
|
|
for range retries {
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), time.Second)
|
|
|
if err == nil {
|
|
|
conn.Close()
|
|
|
+ connErr = nil
|
|
|
break
|
|
|
}
|
|
|
+ connErr = err
|
|
|
}
|
|
|
|
|
|
- if err != nil {
|
|
|
- slog.Error("failed to connect to whisper server", "error", err)
|
|
|
- errCh <- err
|
|
|
+ if connErr != nil {
|
|
|
+ slog.Error("failed to connect to whisper server", "error", connErr)
|
|
|
+ errCh <- connErr
|
|
|
return
|
|
|
}
|
|
|
|
|
|
portCh <- port
|
|
|
s.sched.whisperLoaded[modelPath] = &port
|
|
|
+ s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration)
|
|
|
|
|
|
s.sched.whisperMu.Unlock()
|
|
|
|
|
|
// Wait for the whisper server to exit
|
|
|
defer func() {
|
|
|
- err = cmd.Wait()
|
|
|
- if err != nil {
|
|
|
- c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
|
|
|
- return
|
|
|
+ ticker := time.NewTicker(5 * time.Second)
|
|
|
+ defer ticker.Stop()
|
|
|
+ for range ticker.C {
|
|
|
+ s.sched.whisperMu.Lock()
|
|
|
+ if time.Now().After(s.sched.whisperExpiresAt[modelPath]) {
|
|
|
+ slog.Info("exiting whisper server")
|
|
|
+ delete(s.sched.whisperLoaded, modelPath)
|
|
|
+ delete(s.sched.whisperExpiresAt, modelPath)
|
|
|
+ s.sched.whisperMu.Unlock()
|
|
|
+
|
|
|
+ if err := cmd.Process.Kill(); err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ slog.Debug("whisper server stopped")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ s.sched.whisperMu.Unlock()
|
|
|
}
|
|
|
- s.sched.whisperMu.Lock()
|
|
|
- delete(s.sched.whisperLoaded, modelPath)
|
|
|
- s.sched.whisperMu.Unlock()
|
|
|
}()
|
|
|
}
|
|
|
|
|
@@ -280,10 +308,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
caps = append(caps, CapabilityInsert)
|
|
|
}
|
|
|
|
|
|
- if req.Audio != "" {
|
|
|
+ if req.Speech != nil {
|
|
|
portCh := make(chan int, 1)
|
|
|
errCh := make(chan error, 1)
|
|
|
- go s.runWhisperServer(c, portCh, errCh, req.WhisperModel)
|
|
|
+ go s.runWhisperServer(c, portCh, errCh, req.Speech)
|
|
|
|
|
|
var port int
|
|
|
|
|
@@ -294,19 +322,19 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- w, err := whisperInference(c, req.Audio, port)
|
|
|
+ w, err := whisperInference(c, req.Speech.Audio, port)
|
|
|
if err != nil {
|
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- if req.Transcribe {
|
|
|
+ if req.Speech.Transcribe {
|
|
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
|
|
Model: req.Model,
|
|
|
CreatedAt: time.Now().UTC(),
|
|
|
Response: w.Text,
|
|
|
Done: true,
|
|
|
- DoneReason: "stop",
|
|
|
+ DoneReason: "transcribe",
|
|
|
})
|
|
|
return
|
|
|
}
|
|
@@ -1481,13 +1509,13 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
|
|
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
|
|
}
|
|
|
|
|
|
-func processAudio(c *gin.Context, s *Server, msgs []api.Message, model string) {
|
|
|
- if model == "" {
|
|
|
+func processAudio(c *gin.Context, s *Server, msgs []api.Message, req *api.WhisperRequest) {
|
|
|
+ if req.Model == "" {
|
|
|
return
|
|
|
}
|
|
|
portCh := make(chan int, 1)
|
|
|
errCh := make(chan error, 1)
|
|
|
- go s.runWhisperServer(c, portCh, errCh, model)
|
|
|
+ go s.runWhisperServer(c, portCh, errCh, req)
|
|
|
|
|
|
var port int
|
|
|
select {
|
|
@@ -1554,7 +1582,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|
|
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
|
|
|
}
|
|
|
|
|
|
- processAudio(c, s, msgs, req.WhisperModel)
|
|
|
+ if req.Speech != nil {
|
|
|
+ processAudio(c, s, msgs, req.Speech)
|
|
|
+ }
|
|
|
|
|
|
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
|
|
|
if err != nil {
|