Selaa lähdekoodia

openai: increase context window when max_tokens is provided

jmorganca 8 kuukautta sitten
vanhempi
commit
9899f18e18
2 muutettua tiedostoa jossa 134 lisäystä ja 117 poistoa
  1. 5 0
      openai/openai.go
  2. 129 117
      openai/openai_test.go

+ 5 - 0
openai/openai.go

@@ -449,6 +449,11 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
 
 	if r.MaxTokens != nil {
 		options["num_predict"] = *r.MaxTokens
+
+		// Increase context size up to max_tokens
+		if *r.MaxTokens > 2048 {
+			options["num_ctx"] = *r.MaxTokens
+		}
 	}
 
 	if r.Temperature != nil {

+ 129 - 117
openai/openai_test.go

@@ -1,7 +1,6 @@
 package openai
 
 import (
-	"bytes"
 	"encoding/base64"
 	"encoding/json"
 	"io"
@@ -13,40 +12,28 @@ import (
 	"time"
 
 	"github.com/gin-gonic/gin"
+	"github.com/google/go-cmp/cmp"
 
 	"github.com/ollama/ollama/api"
 )
 
-const (
-	prefix = `data:image/jpeg;base64,`
-	image  = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
-)
-
-var False = false
-
-func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
+func capture(req any) gin.HandlerFunc {
 	return func(c *gin.Context) {
-		bodyBytes, _ := io.ReadAll(c.Request.Body)
-		c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
-		err := json.Unmarshal(bodyBytes, capturedRequest)
-		if err != nil {
-			c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
-		}
+		body, _ := io.ReadAll(c.Request.Body)
+		json.Unmarshal(body, req)
 		c.Next()
 	}
 }
 
 func TestChatMiddleware(t *testing.T) {
-	type testCase struct {
+	type test struct {
 		name string
 		body string
 		req  api.ChatRequest
 		err  ErrorResponse
 	}
 
-	var capturedRequest *api.ChatRequest
-
-	testCases := []testCase{
+	tests := []test{
 		{
 			name: "chat handler",
 			body: `{
@@ -67,7 +54,36 @@ func TestChatMiddleware(t *testing.T) {
 					"temperature": 1.0,
 					"top_p":       1.0,
 				},
-				Stream: &False,
+				Stream: func() *bool { f := false; return &f }(),
+			},
+		},
+		{
+			name: "chat handler with large context",
+			body: `{
+				"model": "test-model",
+				"messages": [
+					{"role": "user", "content": "Hello"}
+				],
+				"max_tokens": 16384
+			}`,
+			req: api.ChatRequest{
+				Model: "test-model",
+				Messages: []api.Message{
+					{
+						Role:    "user",
+						Content: "Hello",
+					},
+				},
+				Options: map[string]any{
+					"temperature": 1.0,
+					"top_p":       1.0,
+
+					// TODO (jmorganca): because we use a map[string]any for options
+					// the values need to be floats for the test comparison to work.
+					"num_predict": 16384.0,
+					"num_ctx":     16384.0,
+				},
+				Stream: func() *bool { f := false; return &f }(),
 			},
 		},
 		{
@@ -85,7 +101,7 @@ func TestChatMiddleware(t *testing.T) {
 							{
 								"type": "image_url",
 								"image_url": {
-									"url": "` + prefix + image + `"
+									"url": "data:image/jpeg;base64,ZGF0YQo="
 								}
 							}
 						]
@@ -103,7 +119,7 @@ func TestChatMiddleware(t *testing.T) {
 						Role: "user",
 						Images: []api.ImageData{
 							func() []byte {
-								img, _ := base64.StdEncoding.DecodeString(image)
+								img, _ := base64.StdEncoding.DecodeString("ZGF0YQo=")
 								return img
 							}(),
 						},
@@ -113,7 +129,7 @@ func TestChatMiddleware(t *testing.T) {
 					"temperature": 1.0,
 					"top_p":       1.0,
 				},
-				Stream: &False,
+				Stream: func() *bool { f := false; return &f }(),
 			},
 		},
 		{
@@ -151,7 +167,7 @@ func TestChatMiddleware(t *testing.T) {
 					"temperature": 1.0,
 					"top_p":       1.0,
 				},
-				Stream: &False,
+				Stream: func() *bool { f := false; return &f }(),
 			},
 		},
 
@@ -172,52 +188,50 @@ func TestChatMiddleware(t *testing.T) {
 		},
 	}
 
-	endpoint := func(c *gin.Context) {
-		c.Status(http.StatusOK)
-	}
-
 	gin.SetMode(gin.TestMode)
-	router := gin.New()
-	router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
-	router.Handle(http.MethodPost, "/api/chat", endpoint)
 
-	for _, tc := range testCases {
-		t.Run(tc.name, func(t *testing.T) {
-			req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
-			req.Header.Set("Content-Type", "application/json")
+	for _, tt := range tests {
+		var req api.ChatRequest
+
+		router := gin.New()
+		router.Use(ChatMiddleware(), capture(&req))
+		router.Handle(http.MethodPost, "/api/chat", func(c *gin.Context) {
+			c.Status(http.StatusOK)
+		})
 
+		t.Run(tt.name, func(t *testing.T) {
+			r, _ := http.NewRequest("POST", "/api/chat", strings.NewReader(tt.body))
+			r.Header.Set("Content-Type", "application/json")
 			resp := httptest.NewRecorder()
-			router.ServeHTTP(resp, req)
+			router.ServeHTTP(resp, r)
 
-			var errResp ErrorResponse
+			var err ErrorResponse
 			if resp.Code != http.StatusOK {
-				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
+				if err := json.Unmarshal(resp.Body.Bytes(), &err); err != nil {
 					t.Fatal(err)
 				}
 			}
-			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
-				t.Fatal("requests did not match")
+
+			if diff := cmp.Diff(tt.req, req); diff != "" {
+				t.Errorf("mismatch (-want +got):\n%s", diff)
 			}
 
-			if !reflect.DeepEqual(tc.err, errResp) {
-				t.Fatal("errors did not match")
+			if diff := cmp.Diff(tt.err, err); diff != "" {
+				t.Errorf("mismatch (-want +got):\n%s", diff)
 			}
-			capturedRequest = nil
 		})
 	}
 }
 
 func TestCompletionsMiddleware(t *testing.T) {
-	type testCase struct {
+	type test struct {
 		name string
 		body string
 		req  api.GenerateRequest
 		err  ErrorResponse
 	}
 
-	var capturedRequest *api.GenerateRequest
-
-	testCases := []testCase{
+	tests := []test{
 		{
 			name: "completions handler",
 			body: `{
@@ -238,7 +252,7 @@ func TestCompletionsMiddleware(t *testing.T) {
 					"stop":              []any{"\n", "stop"},
 				},
 				Suffix: "suffix",
-				Stream: &False,
+				Stream: func() *bool { f := false; return &f }(),
 			},
 		},
 		{
@@ -259,54 +273,51 @@ func TestCompletionsMiddleware(t *testing.T) {
 		},
 	}
 
-	endpoint := func(c *gin.Context) {
-		c.Status(http.StatusOK)
-	}
-
 	gin.SetMode(gin.TestMode)
-	router := gin.New()
-	router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
-	router.Handle(http.MethodPost, "/api/generate", endpoint)
 
-	for _, tc := range testCases {
-		t.Run(tc.name, func(t *testing.T) {
-			req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
-			req.Header.Set("Content-Type", "application/json")
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			var req api.GenerateRequest
 
-			resp := httptest.NewRecorder()
-			router.ServeHTTP(resp, req)
+			router := gin.New()
+			router.Use(CompletionsMiddleware(), capture(&req))
+			router.Handle(http.MethodPost, "/api/generate", func(c *gin.Context) {
+				c.Status(http.StatusOK)
+			})
+
+			r, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tt.body))
+			r.Header.Set("Content-Type", "application/json")
+
+			res := httptest.NewRecorder()
+			router.ServeHTTP(res, r)
 
 			var errResp ErrorResponse
-			if resp.Code != http.StatusOK {
-				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
+			if res.Code != http.StatusOK {
+				if err := json.Unmarshal(res.Body.Bytes(), &errResp); err != nil {
 					t.Fatal(err)
 				}
 			}
 
-			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
-				t.Fatal("requests did not match")
+			if !cmp.Equal(tt.req, req) {
+				t.Fatalf("requests did not match:\n%s", cmp.Diff(tt.req, req))
 			}
 
-			if !reflect.DeepEqual(tc.err, errResp) {
-				t.Fatal("errors did not match")
+			if !cmp.Equal(tt.err, errResp) {
+				t.Fatalf("errors did not match:\n%s", cmp.Diff(tt.err, errResp))
 			}
-
-			capturedRequest = nil
 		})
 	}
 }
 
 func TestEmbeddingsMiddleware(t *testing.T) {
-	type testCase struct {
+	type test struct {
 		name string
 		body string
 		req  api.EmbedRequest
 		err  ErrorResponse
 	}
 
-	var capturedRequest *api.EmbedRequest
-
-	testCases := []testCase{
+	tests := []test{
 		{
 			name: "embed handler single input",
 			body: `{
@@ -348,17 +359,20 @@ func TestEmbeddingsMiddleware(t *testing.T) {
 	}
 
 	gin.SetMode(gin.TestMode)
-	router := gin.New()
-	router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
-	router.Handle(http.MethodPost, "/api/embed", endpoint)
 
-	for _, tc := range testCases {
-		t.Run(tc.name, func(t *testing.T) {
-			req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
-			req.Header.Set("Content-Type", "application/json")
+	for _, tt := range tests {
+		var req api.EmbedRequest
+
+		router := gin.New()
+		router.Use(EmbeddingsMiddleware(), capture(&req))
+		router.Handle(http.MethodPost, "/api/embed", endpoint)
+
+		t.Run(tt.name, func(t *testing.T) {
+			r, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tt.body))
+			r.Header.Set("Content-Type", "application/json")
 
 			resp := httptest.NewRecorder()
-			router.ServeHTTP(resp, req)
+			router.ServeHTTP(resp, r)
 
 			var errResp ErrorResponse
 			if resp.Code != http.StatusOK {
@@ -366,41 +380,37 @@ func TestEmbeddingsMiddleware(t *testing.T) {
 					t.Fatal(err)
 				}
 			}
-
-			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
-				t.Fatal("requests did not match")
+			if diff := cmp.Diff(tt.req, req); diff != "" {
+				t.Errorf("request mismatch (-want +got):\n%s", diff)
 			}
 
-			if !reflect.DeepEqual(tc.err, errResp) {
-				t.Fatal("errors did not match")
+			if diff := cmp.Diff(tt.err, errResp); diff != "" {
+				t.Errorf("error mismatch (-want +got):\n%s", diff)
 			}
-
-			capturedRequest = nil
 		})
 	}
 }
 
 func TestListMiddleware(t *testing.T) {
-	type testCase struct {
-		name     string
-		endpoint func(c *gin.Context)
-		resp     string
+	type test struct {
+		name    string
+		handler gin.HandlerFunc
+		body    string
 	}
 
-	testCases := []testCase{
+	tests := []test{
 		{
 			name: "list handler",
-			endpoint: func(c *gin.Context) {
+			handler: func(c *gin.Context) {
 				c.JSON(http.StatusOK, api.ListResponse{
 					Models: []api.ListModelResponse{
 						{
 							Name:       "test-model",
 							ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
 						},
-					},
-				})
+					}})
 			},
-			resp: `{
+			body: `{
 				"object": "list",
 				"data": [
 					{
@@ -414,10 +424,12 @@ func TestListMiddleware(t *testing.T) {
 		},
 		{
 			name: "list handler empty output",
-			endpoint: func(c *gin.Context) {
-				c.JSON(http.StatusOK, api.ListResponse{})
+			handler: func(c *gin.Context) {
+				c.JSON(http.StatusOK, api.ListResponse{
+					Models: []api.ListModelResponse{},
+				})
 			},
-			resp: `{
+			body: `{
 				"object": "list",
 				"data": null
 			}`,
@@ -426,17 +438,17 @@ func TestListMiddleware(t *testing.T) {
 
 	gin.SetMode(gin.TestMode)
 
-	for _, tc := range testCases {
+	for _, tt := range tests {
 		router := gin.New()
 		router.Use(ListMiddleware())
-		router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
+		router.Handle(http.MethodGet, "/api/tags", tt.handler)
 		req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
 
 		resp := httptest.NewRecorder()
 		router.ServeHTTP(resp, req)
 
 		var expected, actual map[string]any
-		err := json.Unmarshal([]byte(tc.resp), &expected)
+		err := json.Unmarshal([]byte(tt.body), &expected)
 		if err != nil {
 			t.Fatalf("failed to unmarshal expected response: %v", err)
 		}
@@ -446,28 +458,28 @@ func TestListMiddleware(t *testing.T) {
 			t.Fatalf("failed to unmarshal actual response: %v", err)
 		}
 
-		if !reflect.DeepEqual(expected, actual) {
-			t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
+		if diff := cmp.Diff(expected, actual); diff != "" {
+			t.Errorf("responses did not match (-want +got):\n%s", diff)
 		}
 	}
 }
 
 func TestRetrieveMiddleware(t *testing.T) {
-	type testCase struct {
-		name     string
-		endpoint func(c *gin.Context)
-		resp     string
+	type test struct {
+		name    string
+		handler gin.HandlerFunc
+		body    string
 	}
 
-	testCases := []testCase{
+	tests := []test{
 		{
 			name: "retrieve handler",
-			endpoint: func(c *gin.Context) {
+			handler: func(c *gin.Context) {
 				c.JSON(http.StatusOK, api.ShowResponse{
 					ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
 				})
 			},
-			resp: `{
+			body: `{
 				"id":"test-model",
 				"object":"model",
 				"created":1686935002,
@@ -476,10 +488,10 @@ func TestRetrieveMiddleware(t *testing.T) {
 		},
 		{
 			name: "retrieve handler error forwarding",
-			endpoint: func(c *gin.Context) {
+			handler: func(c *gin.Context) {
 				c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
 			},
-			resp: `{
+			body: `{
 				"error": {
 				  "code": null,
 				  "message": "model not found",
@@ -492,17 +504,17 @@ func TestRetrieveMiddleware(t *testing.T) {
 
 	gin.SetMode(gin.TestMode)
 
-	for _, tc := range testCases {
+	for _, tt := range tests {
 		router := gin.New()
 		router.Use(RetrieveMiddleware())
-		router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
+		router.Handle(http.MethodGet, "/api/show/:model", tt.handler)
 		req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
 
 		resp := httptest.NewRecorder()
 		router.ServeHTTP(resp, req)
 
 		var expected, actual map[string]any
-		err := json.Unmarshal([]byte(tc.resp), &expected)
+		err := json.Unmarshal([]byte(tt.body), &expected)
 		if err != nil {
 			t.Fatalf("failed to unmarshal expected response: %v", err)
 		}