Browse Source

working poc

Roy Han 9 months ago
parent
commit
65483180b9
2 changed files with 147 additions and 0 deletions
  1. 6 0
      api/types.go
  2. 141 0
      server/routes.go

+ 6 - 0
api/types.go

@@ -80,6 +80,8 @@ type GenerateRequest struct {
 	// Options lists model-specific options. For example, temperature can be
 	// set through this field, if the model supports it.
 	Options map[string]interface{} `json:"options"`
+
+	Audio string `json:"audio,omitempty"`
 }
 
 // ChatRequest describes a request sent by [Client.Chat].
@@ -450,6 +452,10 @@ type GenerateResponse struct {
 	Metrics
 }
 
+type WhisperCompletion struct {
+	Text string `json:"text"`
+}
+
 // ModelDetails provides details about a model.
 type ModelDetails struct {
 	ParentModel       string   `json:"parent_model"`

+ 141 - 0
server/routes.go

@@ -10,13 +10,17 @@ import (
 	"io"
 	"log/slog"
 	"math"
+	"math/rand"
+	"mime/multipart"
 	"net"
 	"net/http"
 	"net/netip"
 	"os"
+	"os/exec"
 	"os/signal"
 	"path/filepath"
 	"slices"
+	"strconv"
 	"strings"
 	"syscall"
 	"time"
@@ -105,7 +109,131 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
 	return runner.llama, model, &opts, nil
 }
 
+func runWhisperServer(c *gin.Context, portCh chan int) {
+	whisperServer := "/Users/royhan-ollama/ollama/llm/whisper.cpp/server"
+
+	// Find an available port for whisper
+	port := 0
+	params := []string{}
+	if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
+		var l *net.TCPListener
+		if l, err = net.ListenTCP("tcp", a); err == nil {
+			port = l.Addr().(*net.TCPAddr).Port
+			l.Close()
+		}
+	}
+	if port == 0 {
+		slog.Debug("ResolveTCPAddr failed")
+		port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
+	}
+	finalParams := append(params, "--port", strconv.Itoa(port), "--model", "/Users/royhan-ollama/ollama/llm/whisper.cpp/models/ggml-base.en.bin")
+
+	cmd := exec.Command(whisperServer, finalParams...)
+	slog.Info("starting whisper server", "cmd", cmd.String())
+	cmd.Stdout = os.Stdout
+	cmd.Stderr = os.Stderr
+	err := cmd.Start()
+	if err != nil {
+		slog.Error("failed to start whisper server", "error", err)
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to start whisper server"})
+	}
+
+	// wait for server to start
+	time.Sleep(250 * time.Millisecond)
+
+	portCh <- port
+
+	// Wait for the whisper server to exit
+	err = cmd.Wait()
+	if err != nil {
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "whisper server exited"})
+	}
+
+	defer func() {
+		err := cmd.Process.Kill()
+		if err != nil {
+			slog.Error("failed to kill whisper server", "error", err)
+			c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to kill whisper server"})
+		}
+	}()
+}
+
+func whisperInference(c *gin.Context, filePath string, port int) (*api.WhisperCompletion, error) {
+	// Open the file
+	file, err := os.Open(filePath)
+	if err != nil {
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to open file"})
+		return nil, err
+	}
+	defer file.Close()
+
+	// Create a buffer to hold the multipart form data
+	buffer := &bytes.Buffer{}
+	writer := multipart.NewWriter(buffer)
+
+	// Add the file to the multipart form
+	part, err := writer.CreateFormFile("file", filepath.Base(filePath))
+	if err != nil {
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create form file"})
+		return nil, err
+	}
+
+	if _, err := io.Copy(part, file); err != nil {
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to copy file"})
+		return nil, err
+	}
+
+	// Add other fields to the form
+	if err := writer.WriteField("temperature", "0.0"); err != nil {
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to write field"})
+		return nil, err
+	}
+
+	// Close the writer to finalize the multipart form
+	if err := writer.Close(); err != nil {
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to close writer"})
+		return nil, err
+	}
+
+	endpoint := fmt.Sprintf("http://localhost:%s/inference", strconv.Itoa(port))
+
+	serverReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodPost, endpoint, buffer)
+	if err != nil {
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create request"})
+		return nil, err
+	}
+
+	serverReq.Header.Set("Content-Type", writer.FormDataContentType())
+
+	res, err := http.DefaultClient.Do(serverReq)
+	if err != nil {
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to send request"})
+		return nil, err
+	}
+	defer res.Body.Close()
+
+	body, err := io.ReadAll(res.Body)
+	if err != nil {
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to read response"})
+		return nil, err
+	}
+
+	if res.StatusCode >= 400 {
+		slog.Error("error response from whisper server", "status", res.Status, "body", string(body))
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error response from whisper server"})
+	}
+
+	var w api.WhisperCompletion
+	if err := json.Unmarshal(body, &w); err != nil {
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to unmarshal response"})
+		return nil, err
+	}
+
+	return &w, nil
+}
+
 func (s *Server) GenerateHandler(c *gin.Context) {
+	slog.Info("generate request", "method", c.Request.Method, "url", c.Request.URL.String())
 	checkpointStart := time.Now()
 	var req api.GenerateRequest
 	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
@@ -129,6 +257,19 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		caps = append(caps, CapabilityInsert)
 	}
 
+	if req.Audio != "" {
+		port := make(chan int, 1)
+		go runWhisperServer(c, port)
+
+		w, err := whisperInference(c, req.Audio, <-port)
+		if err != nil {
+			c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
+			return
+		}
+
+		req.Prompt = w.Text
+	}
+
 	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
 	if errors.Is(err, errCapabilityCompletion) {
 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})