|
@@ -2,6 +2,7 @@ package cmd
|
|
|
|
|
|
import (
|
|
|
"archive/zip"
|
|
|
+ "bufio"
|
|
|
"bytes"
|
|
|
"context"
|
|
|
"crypto/ed25519"
|
|
@@ -16,6 +17,7 @@ import (
|
|
|
"net"
|
|
|
"net/http"
|
|
|
"os"
|
|
|
+ "os/exec"
|
|
|
"os/signal"
|
|
|
"path/filepath"
|
|
|
"regexp"
|
|
@@ -31,6 +33,11 @@ import (
|
|
|
"github.com/spf13/cobra"
|
|
|
"golang.org/x/crypto/ssh"
|
|
|
"golang.org/x/term"
|
|
|
+ "gonum.org/v1/gonum/mat"
|
|
|
+ "gonum.org/v1/gonum/stat"
|
|
|
+ "gonum.org/v1/plot"
|
|
|
+ "gonum.org/v1/plot/plotter"
|
|
|
+ "gonum.org/v1/plot/vg"
|
|
|
|
|
|
"github.com/ollama/ollama/api"
|
|
|
"github.com/ollama/ollama/auth"
|
|
@@ -370,6 +377,90 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|
|
return generate(cmd, opts)
|
|
|
}
|
|
|
|
|
|
+func EmbedHandler(cmd *cobra.Command, args []string) error {
|
|
|
+ interactive := true
|
|
|
+
|
|
|
+ opts := runOptions{
|
|
|
+ Model: args[0],
|
|
|
+ WordWrap: os.Getenv("TERM") == "xterm-256color",
|
|
|
+ Options: map[string]interface{}{},
|
|
|
+ }
|
|
|
+
|
|
|
+ format, err := cmd.Flags().GetString("format")
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ opts.Format = format
|
|
|
+
|
|
|
+ keepAlive, err := cmd.Flags().GetString("keepalive")
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if keepAlive != "" {
|
|
|
+ d, err := time.ParseDuration(keepAlive)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ opts.KeepAlive = &api.Duration{Duration: d}
|
|
|
+ }
|
|
|
+
|
|
|
+ prompts := args[1:]
|
|
|
+ // prepend stdin to the prompt if provided
|
|
|
+ if !term.IsTerminal(int(os.Stdin.Fd())) {
|
|
|
+ in, err := io.ReadAll(os.Stdin)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ prompts = append([]string{string(in)}, prompts...)
|
|
|
+ opts.WordWrap = false
|
|
|
+ interactive = false
|
|
|
+ }
|
|
|
+ opts.Prompt = strings.Join(prompts, " ")
|
|
|
+ if len(prompts) > 0 {
|
|
|
+ interactive = false
|
|
|
+ }
|
|
|
+
|
|
|
+ nowrap, err := cmd.Flags().GetBool("nowordwrap")
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ opts.WordWrap = !nowrap
|
|
|
+
|
|
|
+ // Fill out the rest of the options based on information about the
|
|
|
+ // model.
|
|
|
+ client, err := api.ClientFromEnvironment()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ name := args[0]
|
|
|
+ info, err := func() (*api.ShowResponse, error) {
|
|
|
+ showReq := &api.ShowRequest{Name: name}
|
|
|
+ info, err := client.Show(cmd.Context(), showReq)
|
|
|
+ var se api.StatusError
|
|
|
+ if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
|
|
+ if err := PullHandler(cmd, []string{name}); err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
|
|
+ }
|
|
|
+ return info, err
|
|
|
+ }()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ opts.MultiModal = slices.Contains(info.Details.Families, "clip")
|
|
|
+ opts.ParentModel = info.Details.ParentModel
|
|
|
+ opts.Messages = append(opts.Messages, info.Messages...)
|
|
|
+
|
|
|
+ if interactive {
|
|
|
+ return generateInteractive(cmd, opts)
|
|
|
+ }
|
|
|
+ return embed(cmd, opts)
|
|
|
+}
|
|
|
+
|
|
|
func errFromUnknownKey(unknownKeyErr error) error {
|
|
|
// find SSH public key in the error message
|
|
|
sshKeyPattern := `ssh-\w+ [^\s"]+`
|
|
@@ -979,6 +1070,154 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
|
|
return &api.Message{Role: role, Content: fullResponse.String()}, nil
|
|
|
}
|
|
|
|
|
|
+func embed(cmd *cobra.Command, opts runOptions) error {
|
|
|
+ line := opts.Prompt
|
|
|
+ client, err := api.ClientFromEnvironment()
|
|
|
+ if err != nil {
|
|
|
+ fmt.Println("error: couldn't connect to ollama server")
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ inputs := strings.Split(line, "\n\n")
|
|
|
+
|
|
|
+ req := &api.EmbedRequest{
|
|
|
+ Model: opts.Model,
|
|
|
+ Input: inputs,
|
|
|
+ }
|
|
|
+
|
|
|
+ resp, err := client.Embed(cmd.Context(), req)
|
|
|
+ if err != nil {
|
|
|
+ fmt.Println("error: couldn't get embeddings")
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ embeddings := resp.Embeddings
|
|
|
+
|
|
|
+ r, c := len(embeddings), len(embeddings[0])
|
|
|
+ data := make([]float64, r*c)
|
|
|
+ for i := range r {
|
|
|
+ for j := range c {
|
|
|
+ data[i*c+j] = float64(embeddings[i][j])
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ X := mat.NewDense(r, c, data)
|
|
|
+
|
|
|
+ // Initialize PCA
|
|
|
+ var pca stat.PC
|
|
|
+
|
|
|
+ // Perform PCA
|
|
|
+ if !pca.PrincipalComponents(X, nil) {
|
|
|
+ return fmt.Errorf("PCA failed")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Extract principal component vectors
|
|
|
+ var vectors mat.Dense
|
|
|
+ pca.VectorsTo(&vectors)
|
|
|
+
|
|
|
+ // // Extract variances of the principal components
|
|
|
+ // var variances []float64
|
|
|
+ // variances = pca.VarsTo(variances)
|
|
|
+
|
|
|
+ W := vectors.Slice(0, c, 0, 2).(*mat.Dense)
|
|
|
+
|
|
|
+ // Perform PCA reduction
|
|
|
+ var reducedData mat.Dense
|
|
|
+ reducedData.Mul(X, W)
|
|
|
+
|
|
|
+ for i, s := range inputs {
|
|
|
+ row := reducedData.RowView(i)
|
|
|
+ fmt.Print(i+1, ". ", s, "\n")
|
|
|
+ fmt.Printf("[%v, %v]\n\n", row.AtVec(0), row.AtVec(1))
|
|
|
+ }
|
|
|
+
|
|
|
+ points := make(plotter.XYs, reducedData.RawMatrix().Rows)
|
|
|
+ for i := range len(points) {
|
|
|
+ row := reducedData.RowView(i)
|
|
|
+ points[i].X = row.AtVec(0)
|
|
|
+ points[i].Y = row.AtVec(1)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Create a new plot
|
|
|
+ p := plot.New()
|
|
|
+
|
|
|
+ // Set plot title and axis labels
|
|
|
+ p.Title.Text = "Embedding Map"
|
|
|
+
|
|
|
+ // Create a scatter plot of the points
|
|
|
+ s, err := plotter.NewScatter(points)
|
|
|
+ if err != nil {
|
|
|
+ panic(err)
|
|
|
+ }
|
|
|
+ p.Add(s)
|
|
|
+
|
|
|
+ /// Create labels plotter and add it to the plot
|
|
|
+
|
|
|
+ labels := make([]string, reducedData.RawMatrix().Rows)
|
|
|
+ for i := range len(labels) {
|
|
|
+ labels[i] = fmt.Sprintf("%d", i+1)
|
|
|
+ }
|
|
|
+
|
|
|
+ // plotter := plotter
|
|
|
+
|
|
|
+ l, err := plotter.NewLabels(plotter.XYLabels{XYs: points, Labels: labels})
|
|
|
+ if err != nil {
|
|
|
+ panic(err)
|
|
|
+ }
|
|
|
+ p.Add(l)
|
|
|
+
|
|
|
+ // Make the grid square
|
|
|
+ p.X.Min = -1
|
|
|
+ p.X.Max = 1
|
|
|
+ p.Y.Min = -1
|
|
|
+ p.Y.Max = 1
|
|
|
+
|
|
|
+ // Set the aspect ratio to be 1:1
|
|
|
+ p.X.Tick.Marker = plot.ConstantTicks([]plot.Tick{
|
|
|
+ {Value: -1, Label: "-1"},
|
|
|
+ {Value: -0.5, Label: "-0.5"},
|
|
|
+ {Value: 0, Label: "0"},
|
|
|
+ {Value: 0.5, Label: "0.5"},
|
|
|
+ {Value: 1, Label: "1"},
|
|
|
+ })
|
|
|
+ p.Y.Tick.Marker = plot.ConstantTicks([]plot.Tick{
|
|
|
+ {Value: -1, Label: "-1"},
|
|
|
+ {Value: -0.5, Label: "-0.5"},
|
|
|
+ {Value: 0, Label: "0"},
|
|
|
+ {Value: 0.5, Label: "0.5"},
|
|
|
+ {Value: 1, Label: "1"},
|
|
|
+ })
|
|
|
+
|
|
|
+ // Save the plot to a svg file
|
|
|
+ if err := p.Save(6*vg.Inch, 6*vg.Inch, "plot.svg"); err != nil {
|
|
|
+ panic(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // open the plot
|
|
|
+ open := exec.Command("open", "plot.svg")
|
|
|
+ err = open.Run()
|
|
|
+ if err != nil {
|
|
|
+ fmt.Println("error: couldn't open plot")
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ // Wait for Enter key press
|
|
|
+ fmt.Print("Press 'Enter' to continue")
|
|
|
+ reader := bufio.NewReader(os.Stdin)
|
|
|
+ _, _ = reader.ReadString('\n')
|
|
|
+
|
|
|
+ // close and delete the plot (defer this)
|
|
|
+ defer func() {
|
|
|
+ delete := exec.Command("rm", "plot.svg")
|
|
|
+ err = delete.Run()
|
|
|
+ if err != nil {
|
|
|
+ fmt.Println("error: couldn't delete plot")
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
func generate(cmd *cobra.Command, opts runOptions) error {
|
|
|
client, err := api.ClientFromEnvironment()
|
|
|
if err != nil {
|
|
@@ -1247,11 +1486,26 @@ func NewCLI() *cobra.Command {
|
|
|
RunE: RunHandler,
|
|
|
}
|
|
|
|
|
|
+ embedCmd := &cobra.Command{
|
|
|
+ Use: "embed MODEL [PROMPT]",
|
|
|
+ Short: "Embed a model",
|
|
|
+ Args: cobra.MinimumNArgs(1),
|
|
|
+ PreRunE: checkServerHeartbeat,
|
|
|
+ RunE: EmbedHandler,
|
|
|
+ }
|
|
|
+
|
|
|
runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
|
|
|
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
|
|
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
|
|
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
|
|
runCmd.Flags().String("format", "", "Response format (e.g. json)")
|
|
|
+
|
|
|
+ embedCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
|
|
|
+ embedCmd.Flags().Bool("verbose", false, "Show timings for response")
|
|
|
+ embedCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
|
|
+ embedCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
|
|
+ embedCmd.Flags().String("format", "", "Response format (e.g. json)")
|
|
|
+
|
|
|
serveCmd := &cobra.Command{
|
|
|
Use: "serve",
|
|
|
Aliases: []string{"start"},
|
|
@@ -1326,6 +1580,7 @@ func NewCLI() *cobra.Command {
|
|
|
copyCmd,
|
|
|
deleteCmd,
|
|
|
serveCmd,
|
|
|
+ embedCmd,
|
|
|
} {
|
|
|
switch cmd {
|
|
|
case runCmd:
|
|
@@ -1361,6 +1616,7 @@ func NewCLI() *cobra.Command {
|
|
|
psCmd,
|
|
|
copyCmd,
|
|
|
deleteCmd,
|
|
|
+ embedCmd,
|
|
|
)
|
|
|
|
|
|
return rootCmd
|