Roy Han 8 months ago
parent
commit
89f3bae306
7 changed files with 228 additions and 6 deletions
  1. 3 1
      api/types.go
  2. 39 0
      cmd/cmd.go
  3. 35 0
      cmd/interactive.go
  4. 1 0
      go.mod
  5. 2 0
      go.sum
  6. 137 0
      recorder/recorder.go
  7. 11 5
      server/routes.go

+ 3 - 1
api/types.go

@@ -37,7 +37,7 @@ func (e StatusError) Error() string {
 type ImageData []byte
 
 type WhisperRequest struct {
-	Model      string    `json:"model"`
+	Model      string    `json:"model,omitempty"`
 	Audio      string    `json:"audio,omitempty"`
 	Transcribe bool      `json:"transcribe,omitempty"`
 	KeepAlive  *Duration `json:"keep_alive,omitempty"`
@@ -116,6 +116,8 @@ type ChatRequest struct {
 	Options map[string]interface{} `json:"options"`
 
 	Speech *WhisperRequest `json:"speech,omitempty"`
+
+	RunSpeech bool `json:"run_speech,omitempty"`
 }
 
 type Tools []Tool

+ 39 - 0
cmd/cmd.go

@@ -38,6 +38,7 @@ import (
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/progress"
+	"github.com/ollama/ollama/recorder"
 	"github.com/ollama/ollama/server"
 	"github.com/ollama/ollama/types/errtypes"
 	"github.com/ollama/ollama/types/model"
@@ -380,6 +381,14 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 			}
 		}
 
+		speech, err := cmd.Flags().GetBool("speech")
+		if err != nil {
+			return err
+		}
+
+		if speech {
+			return generateInteractiveAudio(cmd, opts)
+		}
 		return generateInteractive(cmd, opts)
 	}
 	return generate(cmd, opts)
@@ -862,6 +871,7 @@ type runOptions struct {
 	Options     map[string]interface{}
 	MultiModal  bool
 	KeepAlive   *api.Duration
+	Audio       bool
 }
 
 type displayResponseState struct {
@@ -970,6 +980,10 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
 		req.KeepAlive = opts.KeepAlive
 	}
 
+	if opts.Audio {
+		req.RunSpeech = true
+	}
+
 	if err := client.Chat(cancelCtx, req, fn); err != nil {
 		if errors.Is(err, context.Canceled) {
 			return nil, nil
@@ -1055,6 +1069,30 @@ func generate(cmd *cobra.Command, opts runOptions) error {
 		KeepAlive: opts.KeepAlive,
 	}
 
+	speech, err := cmd.Flags().GetBool("speech")
+	if err != nil {
+		return err
+	}
+
+	// create temp wav file with the recorder package
+	if speech {
+		tempFile, err := os.CreateTemp("", "recording-*.wav")
+		if err != nil {
+			return err
+		}
+		defer os.Remove(tempFile.Name())
+
+		fmt.Print("Speech Mode\n\n")
+
+		err = recorder.RecordAudio(tempFile)
+		if err != nil {
+			return err
+		}
+
+		request.Speech = &api.WhisperRequest{
+			Audio: tempFile.Name(),
+		}
+	}
 	if err := client.Generate(ctx, &request, fn); err != nil {
 		if errors.Is(err, context.Canceled) {
 			return nil
@@ -1262,6 +1300,7 @@ func NewCLI() *cobra.Command {
 		RunE:    RunHandler,
 	}
 
+	runCmd.Flags().Bool("speech", false, "Speech to text mode")
 	runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
 	runCmd.Flags().Bool("verbose", false, "Show timings for response")
 	runCmd.Flags().Bool("insecure", false, "Use an insecure registry")

+ 35 - 0
cmd/interactive.go

@@ -20,6 +20,7 @@ import (
 	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/progress"
 	"github.com/ollama/ollama/readline"
+	"github.com/ollama/ollama/recorder"
 	"github.com/ollama/ollama/types/errtypes"
 )
 
@@ -51,6 +52,40 @@ func loadModel(cmd *cobra.Command, opts *runOptions) error {
 	return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil })
 }
 
+func generateInteractiveAudio(cmd *cobra.Command, opts runOptions) error {
+	for {
+		p := progress.NewProgress(os.Stderr)
+		spinner := progress.NewSpinner("")
+		p.Add("", spinner)
+
+		// create temp wav file with the recorder package
+		tempFile, err := os.CreateTemp("", "recording-*.wav")
+		if err != nil {
+			return err
+		}
+		defer os.Remove(tempFile.Name())
+
+		err = recorder.RecordAudio(tempFile)
+		if err != nil {
+			return err
+		}
+
+		p.StopAndClear()
+
+		newMessage := api.Message{Role: "user", Audio: tempFile.Name()}
+		opts.Audio = true
+		opts.Messages = append(opts.Messages, newMessage)
+
+		assistant, err := chat(cmd, opts)
+		if err != nil {
+			return err
+		}
+		if assistant != nil {
+			opts.Messages = append(opts.Messages, *assistant)
+		}
+	}
+}
+
 func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 	usage := func() {
 		fmt.Fprintln(os.Stderr, "Available Commands:")

+ 1 - 0
go.mod

@@ -19,6 +19,7 @@ require (
 	github.com/agnivade/levenshtein v1.1.1
 	github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
 	github.com/google/go-cmp v0.6.0
+	github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5
 	github.com/mattn/go-runewidth v0.0.14
 	github.com/nlpodyssey/gopickle v0.3.0
 	github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c

+ 2 - 0
go.sum

@@ -115,6 +115,8 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
 github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
 github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
+github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
 github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
 github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=

+ 137 - 0
recorder/recorder.go

@@ -0,0 +1,137 @@
+package recorder
+
+import (
+	"encoding/binary"
+	"fmt"
+	"os"
+	"os/signal"
+	"syscall"
+
+	"golang.org/x/sys/unix"
+	"golang.org/x/term"
+
+	"github.com/gordonklaus/portaudio"
+)
+
+const (
+	sampleRate    = 16000
+	numChannels   = 1
+	bitsPerSample = 16
+)
+
+func RecordAudio(f *os.File) error {
+	fmt.Print("Recording. Press any key to stop.\n\n")
+
+	sig := make(chan os.Signal, 1)
+	signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
+
+	portaudio.Initialize()
+	defer portaudio.Terminate()
+
+	in := make([]int16, 64)
+	stream, err := portaudio.OpenDefaultStream(numChannels, 0, sampleRate, len(in), in)
+	if err != nil {
+		return err
+	}
+	defer stream.Close()
+
+	err = stream.Start()
+	if err != nil {
+		return err
+	}
+
+	// Write WAV header with placeholder sizes
+	writeWavHeader(f, sampleRate, numChannels, bitsPerSample)
+
+	var totalSamples uint32
+
+	// Set up terminal input reading
+	oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
+	if err != nil {
+		return err
+	}
+	defer term.Restore(int(os.Stdin.Fd()), oldState)
+
+	// Create a channel to handle the stop signal
+	stop := make(chan struct{})
+
+	go func() {
+		_, err := unix.Read(int(os.Stdin.Fd()), make([]byte, 1))
+		if err != nil {
+			fmt.Println("Error reading from stdin:", err)
+			return
+		}
+		// Send signal to stop recording
+		stop <- struct{}{}
+	}()
+
+loop:
+	for {
+		err = stream.Read()
+		if err != nil {
+			return err
+		}
+
+		err = binary.Write(f, binary.LittleEndian, in)
+		if err != nil {
+			return err
+		}
+		totalSamples += uint32(len(in))
+
+		select {
+		case <-stop:
+			break loop
+		case <-sig:
+			break loop
+		default:
+		}
+	}
+
+	err = stream.Stop()
+	if err != nil {
+		return err
+	}
+
+	// Update WAV header with actual sizes
+	updateWavHeader(f, totalSamples, numChannels, bitsPerSample)
+
+	return nil
+}
+
+func writeWavHeader(f *os.File, sampleRate int, numChannels int, bitsPerSample int) {
+	subchunk1Size := 16
+	audioFormat := 1
+	byteRate := sampleRate * numChannels * (bitsPerSample / 8)
+	blockAlign := numChannels * (bitsPerSample / 8)
+
+	// Write the RIFF header
+	f.Write([]byte("RIFF"))
+	binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for file size
+	f.Write([]byte("WAVE"))
+
+	// Write the fmt subchunk
+	f.Write([]byte("fmt "))
+	binary.Write(f, binary.LittleEndian, uint32(subchunk1Size))
+	binary.Write(f, binary.LittleEndian, uint16(audioFormat))
+	binary.Write(f, binary.LittleEndian, uint16(numChannels))
+	binary.Write(f, binary.LittleEndian, uint32(sampleRate))
+	binary.Write(f, binary.LittleEndian, uint32(byteRate))
+	binary.Write(f, binary.LittleEndian, uint16(blockAlign))
+	binary.Write(f, binary.LittleEndian, uint16(bitsPerSample))
+
+	// Write the data subchunk header
+	f.Write([]byte("data"))
+	binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for data size
+}
+
+func updateWavHeader(f *os.File, totalSamples uint32, numChannels int, bitsPerSample int) {
+	fileSize := 36 + (totalSamples * uint32(numChannels) * uint32(bitsPerSample/8))
+	dataSize := totalSamples * uint32(numChannels) * uint32(bitsPerSample/8)
+
+	// Seek to the start of the file and write updated sizes
+	f.Seek(4, 0)
+	binary.Write(f, binary.LittleEndian, uint32(fileSize))
+
+	f.Seek(40, 0)
+	binary.Write(f, binary.LittleEndian, uint32(dataSize))
+}

+ 11 - 5
server/routes.go

@@ -110,7 +110,12 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
 }
 
 func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, speech *api.WhisperRequest) {
-	modelPath := speech.Model
+	var modelPath string
+	if speech.Model == "" {
+		modelPath = "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin"
+	} else {
+		modelPath = speech.Model
+	}
 
 	// default to 5 minutes
 	var sessionDuration time.Duration
@@ -130,7 +135,7 @@ func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan er
 		return
 	}
 
-	whisperServer := "/Users/royhan-ollama/ollama/llm/whisper.cpp/server"
+	whisperServer := "/Users/royhan-ollama/.ollama/server"
 
 	// Find an available port for whisper
 	port := 0
@@ -1510,8 +1515,9 @@ func (s *Server) ProcessHandler(c *gin.Context) {
 }
 
 func processAudio(c *gin.Context, s *Server, msgs []api.Message, req *api.WhisperRequest) error {
-	if req.Model == "" {
-		return nil
+	slog.Info("processing audio")
+	if req == nil {
+		req = &api.WhisperRequest{}
 	}
 	portCh := make(chan int, 1)
 	errCh := make(chan error, 1)
@@ -1583,7 +1589,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
 	}
 
-	if req.Speech != nil {
+	if req.Speech != nil || req.RunSpeech {
 		if err := processAudio(c, s, msgs, req.Speech); err != nil {
 			slog.Error("failed to process audio", "error", err)
 			return