瀏覽代碼

just for fun

Roy Han 9 月之前
父節點
當前提交
42009d2974
共有 3 個文件被更改,包括 256 次插入140 次删除
  1. 256 0
      cmd/cmd.go
  2. 0 140
      cmd/interactive.go
  3. 二進制
      plot.png

+ 256 - 0
cmd/cmd.go

@@ -2,6 +2,7 @@ package cmd
 
 import (
 	"archive/zip"
+	"bufio"
 	"bytes"
 	"context"
 	"crypto/ed25519"
@@ -16,6 +17,7 @@ import (
 	"net"
 	"net/http"
 	"os"
+	"os/exec"
 	"os/signal"
 	"path/filepath"
 	"regexp"
@@ -31,6 +33,11 @@ import (
 	"github.com/spf13/cobra"
 	"golang.org/x/crypto/ssh"
 	"golang.org/x/term"
+	"gonum.org/v1/gonum/mat"
+	"gonum.org/v1/gonum/stat"
+	"gonum.org/v1/plot"
+	"gonum.org/v1/plot/plotter"
+	"gonum.org/v1/plot/vg"
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/auth"
@@ -370,6 +377,90 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 	return generate(cmd, opts)
 }
 
+func EmbedHandler(cmd *cobra.Command, args []string) error {
+	interactive := true
+
+	opts := runOptions{
+		Model:    args[0],
+		WordWrap: os.Getenv("TERM") == "xterm-256color",
+		Options:  map[string]interface{}{},
+	}
+
+	format, err := cmd.Flags().GetString("format")
+	if err != nil {
+		return err
+	}
+	opts.Format = format
+
+	keepAlive, err := cmd.Flags().GetString("keepalive")
+	if err != nil {
+		return err
+	}
+	if keepAlive != "" {
+		d, err := time.ParseDuration(keepAlive)
+		if err != nil {
+			return err
+		}
+		opts.KeepAlive = &api.Duration{Duration: d}
+	}
+
+	prompts := args[1:]
+	// prepend stdin to the prompt if provided
+	if !term.IsTerminal(int(os.Stdin.Fd())) {
+		in, err := io.ReadAll(os.Stdin)
+		if err != nil {
+			return err
+		}
+
+		prompts = append([]string{string(in)}, prompts...)
+		opts.WordWrap = false
+		interactive = false
+	}
+	opts.Prompt = strings.Join(prompts, " ")
+	if len(prompts) > 0 {
+		interactive = false
+	}
+
+	nowrap, err := cmd.Flags().GetBool("nowordwrap")
+	if err != nil {
+		return err
+	}
+	opts.WordWrap = !nowrap
+
+	// Fill out the rest of the options based on information about the
+	// model.
+	client, err := api.ClientFromEnvironment()
+	if err != nil {
+		return err
+	}
+
+	name := args[0]
+	info, err := func() (*api.ShowResponse, error) {
+		showReq := &api.ShowRequest{Name: name}
+		info, err := client.Show(cmd.Context(), showReq)
+		var se api.StatusError
+		if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
+			if err := PullHandler(cmd, []string{name}); err != nil {
+				return nil, err
+			}
+			return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
+		}
+		return info, err
+	}()
+	if err != nil {
+		return err
+	}
+
+	opts.MultiModal = slices.Contains(info.Details.Families, "clip")
+	opts.ParentModel = info.Details.ParentModel
+	opts.Messages = append(opts.Messages, info.Messages...)
+
+	if interactive {
+		return generateInteractive(cmd, opts)
+	}
+	return embed(cmd, opts)
+}
+
 func errFromUnknownKey(unknownKeyErr error) error {
 	// find SSH public key in the error message
 	sshKeyPattern := `ssh-\w+ [^\s"]+`
@@ -979,6 +1070,154 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
 	return &api.Message{Role: role, Content: fullResponse.String()}, nil
 }
 
+func embed(cmd *cobra.Command, opts runOptions) error {
+	line := opts.Prompt
+	client, err := api.ClientFromEnvironment()
+	if err != nil {
+		fmt.Println("error: couldn't connect to ollama server")
+		return err
+	}
+
+	inputs := strings.Split(line, "\n\n")
+
+	req := &api.EmbedRequest{
+		Model: opts.Model,
+		Input: inputs,
+	}
+
+	resp, err := client.Embed(cmd.Context(), req)
+	if err != nil {
+		fmt.Println("error: couldn't get embeddings")
+		return err
+	}
+
+	embeddings := resp.Embeddings
+
+	r, c := len(embeddings), len(embeddings[0])
+	data := make([]float64, r*c)
+	for i := range r {
+		for j := range c {
+			data[i*c+j] = float64(embeddings[i][j])
+		}
+	}
+
+	X := mat.NewDense(r, c, data)
+
+	// Initialize PCA
+	var pca stat.PC
+
+	// Perform PCA
+	if !pca.PrincipalComponents(X, nil) {
+		return fmt.Errorf("PCA failed")
+	}
+
+	// Extract principal component vectors
+	var vectors mat.Dense
+	pca.VectorsTo(&vectors)
+
+	// // Extract variances of the principal components
+	// var variances []float64
+	// variances = pca.VarsTo(variances)
+
+	W := vectors.Slice(0, c, 0, 2).(*mat.Dense)
+
+	// Perform PCA reduction
+	var reducedData mat.Dense
+	reducedData.Mul(X, W)
+
+	for i, s := range inputs {
+		row := reducedData.RowView(i)
+		fmt.Print(i+1, ". ", s, "\n")
+		fmt.Printf("[%v, %v]\n\n", row.AtVec(0), row.AtVec(1))
+	}
+
+	points := make(plotter.XYs, reducedData.RawMatrix().Rows)
+	for i := range len(points) {
+		row := reducedData.RowView(i)
+		points[i].X = row.AtVec(0)
+		points[i].Y = row.AtVec(1)
+	}
+
+	// Create a new plot
+	p := plot.New()
+
+	// Set plot title and axis labels
+	p.Title.Text = "Embedding Map"
+
+	// Create a scatter plot of the points
+	s, err := plotter.NewScatter(points)
+	if err != nil {
+		panic(err)
+	}
+	p.Add(s)
+
+	/// Create labels plotter and add it to the plot
+
+	labels := make([]string, reducedData.RawMatrix().Rows)
+	for i := range len(labels) {
+		labels[i] = fmt.Sprintf("%d", i+1)
+	}
+
+	// plotter := plotter
+
+	l, err := plotter.NewLabels(plotter.XYLabels{XYs: points, Labels: labels})
+	if err != nil {
+		panic(err)
+	}
+	p.Add(l)
+
+	// Make the grid square
+	p.X.Min = -1
+	p.X.Max = 1
+	p.Y.Min = -1
+	p.Y.Max = 1
+
+	// Set the aspect ratio to be 1:1
+	p.X.Tick.Marker = plot.ConstantTicks([]plot.Tick{
+		{Value: -1, Label: "-1"},
+		{Value: -0.5, Label: "-0.5"},
+		{Value: 0, Label: "0"},
+		{Value: 0.5, Label: "0.5"},
+		{Value: 1, Label: "1"},
+	})
+	p.Y.Tick.Marker = plot.ConstantTicks([]plot.Tick{
+		{Value: -1, Label: "-1"},
+		{Value: -0.5, Label: "-0.5"},
+		{Value: 0, Label: "0"},
+		{Value: 0.5, Label: "0.5"},
+		{Value: 1, Label: "1"},
+	})
+
+	// Save the plot to a svg file
+	if err := p.Save(6*vg.Inch, 6*vg.Inch, "plot.svg"); err != nil {
+		panic(err)
+	}
+
+	// open the plot
+	open := exec.Command("open", "plot.svg")
+	err = open.Run()
+	if err != nil {
+		fmt.Println("error: couldn't open plot")
+		return err
+	}
+
+	// Wait for Enter key press
+	fmt.Print("Press 'Enter' to continue")
+	reader := bufio.NewReader(os.Stdin)
+	_, _ = reader.ReadString('\n')
+
+	// close and delete the plot (defer this)
+	defer func() {
+		delete := exec.Command("rm", "plot.svg")
+		err = delete.Run()
+		if err != nil {
+			fmt.Println("error: couldn't delete plot")
+		}
+	}()
+
+	return nil
+}
+
 func generate(cmd *cobra.Command, opts runOptions) error {
 	client, err := api.ClientFromEnvironment()
 	if err != nil {
@@ -1247,11 +1486,26 @@ func NewCLI() *cobra.Command {
 		RunE:    RunHandler,
 	}
 
+	embedCmd := &cobra.Command{
+		Use:     "embed MODEL [PROMPT]",
+		Short:   "Embed a model",
+		Args:    cobra.MinimumNArgs(1),
+		PreRunE: checkServerHeartbeat,
+		RunE:    EmbedHandler,
+	}
+
 	runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
 	runCmd.Flags().Bool("verbose", false, "Show timings for response")
 	runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
 	runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
 	runCmd.Flags().String("format", "", "Response format (e.g. json)")
+
+	embedCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
+	embedCmd.Flags().Bool("verbose", false, "Show timings for response")
+	embedCmd.Flags().Bool("insecure", false, "Use an insecure registry")
+	embedCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
+	embedCmd.Flags().String("format", "", "Response format (e.g. json)")
+
 	serveCmd := &cobra.Command{
 		Use:     "serve",
 		Aliases: []string{"start"},
@@ -1326,6 +1580,7 @@ func NewCLI() *cobra.Command {
 		copyCmd,
 		deleteCmd,
 		serveCmd,
+		embedCmd,
 	} {
 		switch cmd {
 		case runCmd:
@@ -1361,6 +1616,7 @@ func NewCLI() *cobra.Command {
 		psCmd,
 		copyCmd,
 		deleteCmd,
+		embedCmd,
 	)
 
 	return rootCmd

+ 0 - 140
cmd/interactive.go

@@ -1,7 +1,6 @@
 package cmd
 
 import (
-	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
@@ -14,11 +13,6 @@ import (
 	"strings"
 
 	"github.com/spf13/cobra"
-	"gonum.org/v1/gonum/mat"
-	"gonum.org/v1/gonum/stat"
-	"gonum.org/v1/plot"
-	"gonum.org/v1/plot/plotter"
-	"gonum.org/v1/plot/vg"
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/envconfig"
@@ -453,140 +447,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 			}
 		case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
 			return nil
-		case strings.HasPrefix(line, "/embed"):
-			line = strings.TrimPrefix(line, "/embed")
-			client, err := api.ClientFromEnvironment()
-			if err != nil {
-				fmt.Println("error: couldn't connect to ollama server")
-				return err
-			}
-
-			var strArray []string
-			fmt.Printf("line is %s\n", line)
-			err = json.Unmarshal([]byte(line), &strArray)
-			if err != nil {
-				fmt.Println("error: couldn't parse input")
-				return err
-			}
-
-			for i, s := range strArray {
-				fmt.Printf("strArray[%d] is %s\n", i, s)
-			}
-
-			req := &api.EmbedRequest{
-				Model: opts.Model,
-				Input: strArray,
-			}
-
-			resp, err := client.Embed(cmd.Context(), req)
-			if err != nil {
-				fmt.Println("error: couldn't get embeddings")
-				return err
-			}
-
-			embeddings := resp.Embeddings
-
-			r, c := len(embeddings), len(embeddings[0])
-			data := make([]float64, r*c)
-			for i := 0; i < r; i++ {
-				for j := 0; j < c; j++ {
-					data[i*c+j] = float64(embeddings[i][j])
-				}
-			}
-
-			X := mat.NewDense(r, c, data)
-
-			// Initialize PCA
-			var pca stat.PC
-
-			// Perform PCA
-			if !pca.PrincipalComponents(X, nil) {
-				return fmt.Errorf("PCA failed")
-			}
-
-			// Extract principal component vectors
-			var vectors mat.Dense
-			pca.VectorsTo(&vectors)
-
-			// // Extract variances of the principal components
-			// var variances []float64
-			// variances = pca.VarsTo(variances)
-
-			W := vectors.Slice(0, c, 0, 2).(*mat.Dense)
-
-			// Perform PCA reduction
-			var reducedData mat.Dense
-			reducedData.Mul(X, W)
-
-			// Print the projected 2D points
-			fmt.Println("Reduced embeddings to 2D:")
-			for i := 0; i < reducedData.RawMatrix().Rows; i++ {
-				row := reducedData.RowView(i)
-				fmt.Printf("[%v, %v]\n", row.AtVec(0), row.AtVec(1))
-			}
-
-			points := make(plotter.XYs, reducedData.RawMatrix().Rows)
-			for i := 0; i < reducedData.RawMatrix().Rows; i++ {
-				row := reducedData.RowView(i)
-				points[i].X = row.AtVec(0)
-				points[i].Y = row.AtVec(1)
-			}
-
-			// Create a new plot
-			p := plot.New()
-
-			// Set plot title and axis labels
-			p.Title.Text = "2D Data Plot"
-			p.X.Label.Text = "X"
-			p.Y.Label.Text = "Y"
-
-			// Create a scatter plot of the points
-			s, err := plotter.NewScatter(points)
-			if err != nil {
-				panic(err)
-			}
-			p.Add(s)
-
-			/// Create labels plotter and add it to the plot
-
-			labels := make([]string, reducedData.RawMatrix().Rows)
-			for i := 0; i < reducedData.RawMatrix().Rows; i++ {
-				labels[i] = fmt.Sprintf("%d", i+1)
-			}
-
-			l, err := plotter.NewLabels(plotter.XYLabels{XYs: points, Labels: labels})
-			if err != nil {
-				panic(err)
-			}
-			p.Add(l)
-
-			// Make the grid square
-			p.X.Min = -1
-			p.X.Max = 1
-			p.Y.Min = -1
-			p.Y.Max = 1
-
-			// Set the aspect ratio to be 1:1
-			p.X.Tick.Marker = plot.ConstantTicks([]plot.Tick{
-				{Value: -1, Label: "-1"},
-				{Value: -0.5, Label: "-0.5"},
-				{Value: 0, Label: "0"},
-				{Value: 0.5, Label: "0.5"},
-				{Value: 1, Label: "1"},
-			})
-			p.Y.Tick.Marker = plot.ConstantTicks([]plot.Tick{
-				{Value: -1, Label: "-1"},
-				{Value: -0.5, Label: "-0.5"},
-				{Value: 0, Label: "0"},
-				{Value: 0.5, Label: "0.5"},
-				{Value: 1, Label: "1"},
-			})
-
-			// Save the plot to a PNG file
-			if err := p.Save(6*vg.Inch, 6*vg.Inch, "plot.png"); err != nil {
-				panic(err)
-			}
-
 		case strings.HasPrefix(line, "/"):
 			args := strings.Fields(line)
 			isFile := false

二進制
plot.png