|
@@ -10,13 +10,17 @@ import (
|
|
|
"io"
|
|
|
"log/slog"
|
|
|
"math"
|
|
|
+ "math/rand"
|
|
|
+ "mime/multipart"
|
|
|
"net"
|
|
|
"net/http"
|
|
|
"net/netip"
|
|
|
"os"
|
|
|
+ "os/exec"
|
|
|
"os/signal"
|
|
|
"path/filepath"
|
|
|
"slices"
|
|
|
+ "strconv"
|
|
|
"strings"
|
|
|
"syscall"
|
|
|
"time"
|
|
@@ -105,7 +109,131 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
|
|
|
return runner.llama, model, &opts, nil
|
|
|
}
|
|
|
|
|
|
+func runWhisperServer(c *gin.Context, portCh chan int) {
|
|
|
+ whisperServer := "/Users/royhan-ollama/ollama/llm/whisper.cpp/server"
|
|
|
+
|
|
|
+ // Find an available port for whisper
|
|
|
+ port := 0
|
|
|
+ params := []string{}
|
|
|
+ if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
|
|
+ var l *net.TCPListener
|
|
|
+ if l, err = net.ListenTCP("tcp", a); err == nil {
|
|
|
+ port = l.Addr().(*net.TCPAddr).Port
|
|
|
+ l.Close()
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if port == 0 {
|
|
|
+ slog.Debug("ResolveTCPAddr failed")
|
|
|
+ port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
|
|
+ }
|
|
|
+ finalParams := append(params, "--port", strconv.Itoa(port), "--model", "/Users/royhan-ollama/ollama/llm/whisper.cpp/models/ggml-base.en.bin")
|
|
|
+
|
|
|
+ cmd := exec.Command(whisperServer, finalParams...)
|
|
|
+ slog.Info("starting whisper server", "cmd", cmd.String())
|
|
|
+ cmd.Stdout = os.Stdout
|
|
|
+ cmd.Stderr = os.Stderr
|
|
|
+ err := cmd.Start()
|
|
|
+ if err != nil {
|
|
|
+ slog.Error("failed to start whisper server", "error", err)
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to start whisper server"})
|
|
|
+ }
|
|
|
+
|
|
|
+ // wait for server to start
|
|
|
+ time.Sleep(250 * time.Millisecond)
|
|
|
+
|
|
|
+ portCh <- port
|
|
|
+
|
|
|
+ // Wait for the whisper server to exit
|
|
|
+ err = cmd.Wait()
|
|
|
+ if err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "whisper server exited"})
|
|
|
+ }
|
|
|
+
|
|
|
+ defer func() {
|
|
|
+ err := cmd.Process.Kill()
|
|
|
+ if err != nil {
|
|
|
+ slog.Error("failed to kill whisper server", "error", err)
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to kill whisper server"})
|
|
|
+ }
|
|
|
+ }()
|
|
|
+}
|
|
|
+
|
|
|
+func whisperInference(c *gin.Context, filePath string, port int) (*api.WhisperCompletion, error) {
|
|
|
+ // Open the file
|
|
|
+ file, err := os.Open(filePath)
|
|
|
+ if err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to open file"})
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ defer file.Close()
|
|
|
+
|
|
|
+ // Create a buffer to hold the multipart form data
|
|
|
+ buffer := &bytes.Buffer{}
|
|
|
+ writer := multipart.NewWriter(buffer)
|
|
|
+
|
|
|
+ // Add the file to the multipart form
|
|
|
+ part, err := writer.CreateFormFile("file", filepath.Base(filePath))
|
|
|
+ if err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create form file"})
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ if _, err := io.Copy(part, file); err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to copy file"})
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ // Add other fields to the form
|
|
|
+ if err := writer.WriteField("temperature", "0.0"); err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to write field"})
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ // Close the writer to finalize the multipart form
|
|
|
+ if err := writer.Close(); err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to close writer"})
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ endpoint := fmt.Sprintf("http://localhost:%s/inference", strconv.Itoa(port))
|
|
|
+
|
|
|
+ serverReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodPost, endpoint, buffer)
|
|
|
+ if err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create request"})
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ serverReq.Header.Set("Content-Type", writer.FormDataContentType())
|
|
|
+
|
|
|
+ res, err := http.DefaultClient.Do(serverReq)
|
|
|
+ if err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to send request"})
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ defer res.Body.Close()
|
|
|
+
|
|
|
+ body, err := io.ReadAll(res.Body)
|
|
|
+ if err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to read response"})
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ if res.StatusCode >= 400 {
|
|
|
+ slog.Error("error response from whisper server", "status", res.Status, "body", string(body))
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error response from whisper server"})
|
|
|
+ }
|
|
|
+
|
|
|
+ var w api.WhisperCompletion
|
|
|
+ if err := json.Unmarshal(body, &w); err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to unmarshal response"})
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return &w, nil
|
|
|
+}
|
|
|
+
|
|
|
func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
+ slog.Info("generate request", "method", c.Request.Method, "url", c.Request.URL.String())
|
|
|
checkpointStart := time.Now()
|
|
|
var req api.GenerateRequest
|
|
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
|
@@ -129,6 +257,19 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
caps = append(caps, CapabilityInsert)
|
|
|
}
|
|
|
|
|
|
+ if req.Audio != "" {
|
|
|
+ port := make(chan int, 1)
|
|
|
+ go runWhisperServer(c, port)
|
|
|
+
|
|
|
+ w, err := whisperInference(c, req.Audio, <-port)
|
|
|
+ if err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ req.Prompt = w.Text
|
|
|
+ }
|
|
|
+
|
|
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
|
|
if errors.Is(err, errCapabilityCompletion) {
|
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|