浏览代码

more fixes for mllama

Patrick Devine 7 月之前
父节点
当前提交
c48e2cfc0d
共有 6 个文件被更改,包括 85 次插入64 次删除
  1. 0 1
      llm/server.go
  2. 1 5
      server/imageproc/images.go
  3. 53 26
      server/prompt.go
  4. 25 13
      server/routes.go
  5. 6 6
      server/routes_generate_test.go
  6. 0 13
      template/template.go

+ 0 - 1
llm/server.go

@@ -675,7 +675,6 @@ const maxBufferSize = 512 * format.KiloByte
 type ImageData struct {
 type ImageData struct {
 	Data          []byte    `json:"data"`
 	Data          []byte    `json:"data"`
 	ID            int       `json:"id"`
 	ID            int       `json:"id"`
-	ImageData     []float32 `json:"image_data"`
 	AspectRatioID int       `json:"aspect_ratio_id"`
 	AspectRatioID int       `json:"aspect_ratio_id"`
 }
 }
 
 

+ 1 - 5
server/imageproc/images.go

@@ -159,11 +159,7 @@ func PadImage(img image.Image, outputSize, aspectRatio image.Point) image.Image
 	}
 	}
 
 
 	dst := image.NewRGBA(image.Rect(0, 0, paddedSize.X, paddedSize.Y))
 	dst := image.NewRGBA(image.Rect(0, 0, paddedSize.X, paddedSize.Y))
-	centerX := (paddedSize.X - img.Bounds().Max.X) / 2
-	centerY := (paddedSize.Y - img.Bounds().Max.Y) / 2
-	pos := image.Rect(centerX, centerY, centerX+img.Bounds().Max.X, centerY+img.Bounds().Max.Y)
-
-	draw.Draw(dst, pos, img, image.Point{0, 0}, draw.Over)
+	draw.Draw(dst, img.Bounds(), img, image.Point{0, 0}, draw.Over)
 
 
 	return dst
 	return dst
 }
 }

+ 53 - 26
server/prompt.go

@@ -3,7 +3,10 @@ package server
 import (
 import (
 	"bytes"
 	"bytes"
 	"context"
 	"context"
+	"encoding/binary"
+	"fmt"
 	"log/slog"
 	"log/slog"
+	"strings"
 
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/llm"
@@ -18,6 +21,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
 // latest message and 2) system messages
 // latest message and 2) system messages
 func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
 func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
 	var system []api.Message
 	var system []api.Message
+
 	// always include the last message
 	// always include the last message
 	n := len(msgs) - 1
 	n := len(msgs) - 1
 	// in reverse, find all messages that fit into context window
 	// in reverse, find all messages that fit into context window
@@ -39,16 +43,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 			return "", nil, err
 			return "", nil, err
 		}
 		}
 
 
-		c := len(s)
+		ctxLen := len(s)
 		if m.ProjectorPaths != nil {
 		if m.ProjectorPaths != nil {
 			for _, m := range msgs[i:] {
 			for _, m := range msgs[i:] {
 				// images are represented as 768 sized embeddings
 				// images are represented as 768 sized embeddings
 				// TODO: get embedding length from project metadata
 				// TODO: get embedding length from project metadata
-				c += 768 * len(m.Images)
+				ctxLen += 768 * len(m.Images)
 			}
 			}
 		}
 		}
 
 
-		if c > opts.NumCtx {
+		if ctxLen > opts.NumCtx {
 			slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
 			slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
 			break
 			break
 		} else {
 		} else {
@@ -56,35 +60,58 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 		}
 		}
 	}
 	}
 
 
-	// truncate any messages that do not fit into the context window
-	var b bytes.Buffer
-	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
-		return "", nil, err
+	currMsgIdx := n
+
+	if checkMllamaModelFamily(m) {
+		lastMsgIdx := len(msgs) - 1
+		if len(msgs[lastMsgIdx].Images) == 1 {
+			data, aspectRatioID, err := imageproc.Preprocess(msgs[lastMsgIdx].Images[0])
+			if err != nil {
+				return "", nil, err
+			}
+
+			buf := new(bytes.Buffer)
+			err = binary.Write(buf, binary.LittleEndian, data)
+			if err != nil {
+				return "", nil, err
+			}
+
+			imgData := llm.ImageData{
+				Data:          buf.Bytes(),
+				AspectRatioID: aspectRatioID,
+			}
+
+			msgs[lastMsgIdx].Content = strings.TrimSpace("<|image|>" + msgs[lastMsgIdx].Content)
+			images = append(images, imgData)
+		}
 	}
 	}
 
 
-	preprocess := checkMllamaModelFamily(m)
-
-	for _, m := range msgs[n:] {
-		for _, i := range m.Images {
-			if preprocess {
-				data, aspectRatioID, err := imageproc.Preprocess(i)
-				if err != nil {
-					return "", nil, err
-				}
-				images = append(images, llm.ImageData{
-					ID:            len(images),
-					ImageData:     data,
-					AspectRatioID: aspectRatioID,
-				})
-			} else {
-				images = append(images, llm.ImageData{
-					ID:   len(images),
-					Data: i,
-				})
+	for cnt, msg := range msgs[currMsgIdx:] {
+		for _, i := range msg.Images {
+			imgData := llm.ImageData{
+				ID:   len(images),
+				Data: i,
 			}
 			}
+
+			imageTag := fmt.Sprintf("[img-%d]", imgData.ID)
+			prompt := msg.Content
+
+			if !strings.Contains(prompt, "[img]") {
+				prompt = strings.TrimSpace("[img] " + prompt)
+			}
+			prompt = strings.Replace(prompt, "[img]", imageTag, 1)
+			msgs[currMsgIdx+cnt].Content = prompt
+
+			images = append(images, imgData)
 		}
 		}
 	}
 	}
 
 
+	// truncate any messages that do not fit into the context window
+	var b bytes.Buffer
+	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
+		return "", nil, err
+	}
+
 	return b.String(), images, nil
 	return b.String(), images, nil
 }
 }
 
 

+ 25 - 13
server/routes.go

@@ -119,20 +119,21 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
+	model, err := GetModel(req.Model)
+	if err != nil {
+		switch {
+		case os.IsNotExist(err):
+			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
+		case err.Error() == "invalid model name":
+			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		default:
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		}
+		return
+	}
+
 	// expire the runner
 	// expire the runner
 	if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
 	if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
-		model, err := GetModel(req.Model)
-		if err != nil {
-			switch {
-			case os.IsNotExist(err):
-				c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
-			case err.Error() == "invalid model name":
-				c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-			default:
-				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-			}
-			return
-		}
 		s.sched.expireRunner(model)
 		s.sched.expireRunner(model)
 
 
 		c.JSON(http.StatusOK, api.GenerateResponse{
 		c.JSON(http.StatusOK, api.GenerateResponse{
@@ -169,6 +170,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 
 
 	checkpointLoaded := time.Now()
 	checkpointLoaded := time.Now()
 
 
+	// load the model
 	if req.Prompt == "" {
 	if req.Prompt == "" {
 		c.JSON(http.StatusOK, api.GenerateResponse{
 		c.JSON(http.StatusOK, api.GenerateResponse{
 			Model:      req.Model,
 			Model:      req.Model,
@@ -179,6 +181,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
+	isMllama := checkMllamaModelFamily(model)
+	if isMllama && len(req.Images) > 1 {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
+		return
+	}
+
 	images := make([]llm.ImageData, len(req.Images))
 	images := make([]llm.ImageData, len(req.Images))
 	for i := range req.Images {
 	for i := range req.Images {
 		images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
 		images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
@@ -212,7 +220,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 			}
 			}
 
 
 			for _, i := range images {
 			for _, i := range images {
-				msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
+				if isMllama {
+					msgs = append(msgs, api.Message{Role: "user", Content: "<|image|>"})
+				} else {
+					msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
+				}
 			}
 			}
 
 
 			values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
 			values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})

+ 6 - 6
server/routes_generate_test.go

@@ -421,22 +421,22 @@ func TestGenerate(t *testing.T) {
 
 
 	t.Run("missing body", func(t *testing.T) {
 	t.Run("missing body", func(t *testing.T) {
 		w := createRequest(t, s.GenerateHandler, nil)
 		w := createRequest(t, s.GenerateHandler, nil)
-		if w.Code != http.StatusBadRequest {
-			t.Errorf("expected status 400, got %d", w.Code)
+		if w.Code != http.StatusNotFound {
+			t.Errorf("expected status 404, got %d", w.Code)
 		}
 		}
 
 
-		if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
+		if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
 			t.Errorf("mismatch (-got +want):\n%s", diff)
 			t.Errorf("mismatch (-got +want):\n%s", diff)
 		}
 		}
 	})
 	})
 
 
 	t.Run("missing model", func(t *testing.T) {
 	t.Run("missing model", func(t *testing.T) {
 		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
 		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
-		if w.Code != http.StatusBadRequest {
-			t.Errorf("expected status 400, got %d", w.Code)
+		if w.Code != http.StatusNotFound {
+			t.Errorf("expected status 404, got %d", w.Code)
 		}
 		}
 
 
-		if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
+		if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
 			t.Errorf("mismatch (-got +want):\n%s", diff)
 			t.Errorf("mismatch (-got +want):\n%s", diff)
 		}
 		}
 	})
 	})

+ 0 - 13
template/template.go

@@ -5,7 +5,6 @@ import (
 	"embed"
 	"embed"
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
-	"fmt"
 	"io"
 	"io"
 	"math"
 	"math"
 	"slices"
 	"slices"
@@ -302,22 +301,10 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 // into a single message. collate also collects and returns all system messages.
 // into a single message. collate also collects and returns all system messages.
 // collate mutates message content adding image tags ([img-%d]) as needed
 // collate mutates message content adding image tags ([img-%d]) as needed
 func collate(msgs []api.Message) (string, []*api.Message) {
 func collate(msgs []api.Message) (string, []*api.Message) {
-	var n int
-
 	var system []string
 	var system []string
 	var collated []*api.Message
 	var collated []*api.Message
 	for i := range msgs {
 	for i := range msgs {
 		msg := msgs[i]
 		msg := msgs[i]
-		for range msg.Images {
-			imageTag := fmt.Sprintf("[img-%d]", n)
-			if !strings.Contains(msg.Content, "[img]") {
-				msg.Content = strings.TrimSpace("[img] " + msg.Content)
-			}
-
-			msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
-			n++
-		}
-
 		if msg.Role == "system" {
 		if msg.Role == "system" {
 			system = append(system, msg.Content)
 			system = append(system, msg.Content)
 		}
 		}