|
@@ -17,7 +17,9 @@ import (
|
|
|
"os/exec"
|
|
|
"os/signal"
|
|
|
"path/filepath"
|
|
|
+ "regexp"
|
|
|
"runtime"
|
|
|
+ "slices"
|
|
|
"strings"
|
|
|
"syscall"
|
|
|
"time"
|
|
@@ -36,6 +38,8 @@ import (
|
|
|
"github.com/jmorganca/ollama/version"
|
|
|
)
|
|
|
|
|
|
+type ImageData []byte
|
|
|
+
|
|
|
func CreateHandler(cmd *cobra.Command, args []string) error {
|
|
|
filename, _ := cmd.Flags().GetString("file")
|
|
|
filename, err := filepath.Abs(filename)
|
|
@@ -418,6 +422,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
|
|
|
Model: args[0],
|
|
|
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
|
|
Options: map[string]interface{}{},
|
|
|
+ Images: []ImageData{},
|
|
|
}
|
|
|
|
|
|
format, err := cmd.Flags().GetString("format")
|
|
@@ -427,7 +432,6 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
|
|
|
opts.Format = format
|
|
|
|
|
|
prompts := args[1:]
|
|
|
-
|
|
|
// prepend stdin to the prompt if provided
|
|
|
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
|
|
in, err := io.ReadAll(os.Stdin)
|
|
@@ -466,6 +470,7 @@ type generateOptions struct {
|
|
|
Format string
|
|
|
System string
|
|
|
Template string
|
|
|
+ Images []ImageData
|
|
|
Options map[string]interface{}
|
|
|
}
|
|
|
|
|
@@ -551,6 +556,10 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+ images := make([]api.ImageData, 0)
|
|
|
+ for _, i := range opts.Images {
|
|
|
+ images = append(images, api.ImageData(i))
|
|
|
+ }
|
|
|
request := api.GenerateRequest{
|
|
|
Model: opts.Model,
|
|
|
Prompt: opts.Prompt,
|
|
@@ -559,6 +568,7 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
|
|
System: opts.System,
|
|
|
Template: opts.Template,
|
|
|
Options: opts.Options,
|
|
|
+ Images: images,
|
|
|
}
|
|
|
|
|
|
if err := client.Generate(ctx, &request, fn); err != nil {
|
|
@@ -585,7 +595,9 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
|
|
latest.Summary()
|
|
|
}
|
|
|
|
|
|
- cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context))
|
|
|
+ ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
|
|
|
+ cmd.SetContext(ctx)
|
|
|
+
|
|
|
return nil
|
|
|
}
|
|
|
|
|
@@ -598,11 +610,31 @@ const (
|
|
|
MultilineTemplate
|
|
|
)
|
|
|
|
|
|
+func modelIsMultiModal(cmd *cobra.Command, name string) bool {
|
|
|
+ // get model details
|
|
|
+ client, err := api.ClientFromEnvironment()
|
|
|
+ if err != nil {
|
|
|
+ fmt.Println("error: couldn't connect to ollama server")
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ req := api.ShowRequest{Name: name}
|
|
|
+ resp, err := client.Show(cmd.Context(), &req)
|
|
|
+ if err != nil {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ return slices.Contains(resp.Details.Families, "clip")
|
|
|
+}
|
|
|
+
|
|
|
func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
|
|
+ multiModal := modelIsMultiModal(cmd, opts.Model)
|
|
|
+
|
|
|
// load the model
|
|
|
loadOpts := generateOptions{
|
|
|
Model: opts.Model,
|
|
|
Prompt: "",
|
|
|
+ Images: []ImageData{},
|
|
|
}
|
|
|
if err := generate(cmd, loadOpts); err != nil {
|
|
|
return err
|
|
@@ -902,6 +934,26 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
|
|
|
|
|
if len(prompt) > 0 && multiline == MultilineNone {
|
|
|
opts.Prompt = prompt
|
|
|
+ if multiModal {
|
|
|
+ newPrompt, images, err := extractFileNames(prompt)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ opts.Prompt = newPrompt
|
|
|
+
|
|
|
+ // reset the context if we find another image
|
|
|
+ if len(images) > 0 {
|
|
|
+ opts.Images = images
|
|
|
+ ctx := cmd.Context()
|
|
|
+ ctx = context.WithValue(ctx, generateContextKey("context"), []int{})
|
|
|
+ cmd.SetContext(ctx)
|
|
|
+ }
|
|
|
+ if len(opts.Images) == 0 {
|
|
|
+ fmt.Println("This model requires you to add a jpeg, png, or svg image.\n")
|
|
|
+ prompt = ""
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ }
|
|
|
if err := generate(cmd, opts); err != nil {
|
|
|
return err
|
|
|
}
|
|
@@ -911,6 +963,57 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func normalizeFilePath(fp string) string {
|
|
|
+ // Define a map of escaped characters and their replacements
|
|
|
+ replacements := map[string]string{
|
|
|
+ "\\ ": " ", // Escaped space
|
|
|
+ "\\(": "(", // Escaped left parenthesis
|
|
|
+ "\\)": ")", // Escaped right parenthesis
|
|
|
+ "\\[": "[", // Escaped left square bracket
|
|
|
+ "\\]": "]", // Escaped right square bracket
|
|
|
+ "\\{": "{", // Escaped left curly brace
|
|
|
+ "\\}": "}", // Escaped right curly brace
|
|
|
+ "\\$": "$", // Escaped dollar sign
|
|
|
+ "\\&": "&", // Escaped ampersand
|
|
|
+ "\\;": ";", // Escaped semicolon
|
|
|
+ "\\'": "'", // Escaped single quote
|
|
|
+ "\\\\": "\\", // Escaped backslash
|
|
|
+ "\\*": "*", // Escaped asterisk
|
|
|
+ "\\?": "?", // Escaped question mark
|
|
|
+ }
|
|
|
+
|
|
|
+ for escaped, actual := range replacements {
|
|
|
+ fp = strings.ReplaceAll(fp, escaped, actual)
|
|
|
+ }
|
|
|
+ return fp
|
|
|
+}
|
|
|
+
|
|
|
+func extractFileNames(input string) (string, []ImageData, error) {
|
|
|
+ // Regex to match file paths starting with / or ./ and include escaped spaces (\ or %20)
|
|
|
+ // and followed by more characters and a file extension
|
|
|
+ regexPattern := `(?:\./|/)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b`
|
|
|
+ re := regexp.MustCompile(regexPattern)
|
|
|
+
|
|
|
+ filePaths := re.FindAllString(input, -1)
|
|
|
+ var imgs []ImageData
|
|
|
+
|
|
|
+ for _, fp := range filePaths {
|
|
|
+ nfp := normalizeFilePath(fp)
|
|
|
+ data, err := getImageData(nfp)
|
|
|
+ if err != nil {
|
|
|
+ if os.IsNotExist(err) {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ fmt.Printf("Couldn't process image: %q\n", err)
|
|
|
+ return "", imgs, err
|
|
|
+ }
|
|
|
+ fmt.Printf("Added image '%s'\n", nfp)
|
|
|
+ input = strings.ReplaceAll(input, fp, "")
|
|
|
+ imgs = append(imgs, data)
|
|
|
+ }
|
|
|
+ return input, imgs, nil
|
|
|
+}
|
|
|
+
|
|
|
func RunServer(cmd *cobra.Command, _ []string) error {
|
|
|
host, port, err := net.SplitHostPort(os.Getenv("OLLAMA_HOST"))
|
|
|
if err != nil {
|
|
@@ -937,6 +1040,50 @@ func RunServer(cmd *cobra.Command, _ []string) error {
|
|
|
return server.Serve(ln, origins)
|
|
|
}
|
|
|
|
|
|
+func getImageData(filePath string) ([]byte, error) {
|
|
|
+ file, err := os.Open(filePath)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ defer file.Close()
|
|
|
+
|
|
|
+ buf := make([]byte, 512)
|
|
|
+ _, err = file.Read(buf)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ contentType := http.DetectContentType(buf)
|
|
|
+ allowedTypes := []string{"image/jpeg", "image/jpg", "image/svg+xml", "image/png"}
|
|
|
+ if !slices.Contains(allowedTypes, contentType) {
|
|
|
+ return nil, fmt.Errorf("invalid image type: %s", contentType)
|
|
|
+ }
|
|
|
+
|
|
|
+ info, err := file.Stat()
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ // Check if the file size exceeds 100MB
|
|
|
+ var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
|
|
|
+ if info.Size() > maxSize {
|
|
|
+ return nil, fmt.Errorf("file size exceeds maximum limit (100MB).")
|
|
|
+ }
|
|
|
+
|
|
|
+ buf = make([]byte, info.Size())
|
|
|
+ _, err = file.Seek(0, 0)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ _, err = io.ReadFull(file, buf)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return buf, nil
|
|
|
+}
|
|
|
+
|
|
|
func initializeKeypair() error {
|
|
|
home, err := os.UserHomeDir()
|
|
|
if err != nil {
|