|
@@ -1,7 +1,6 @@
|
|
|
package openai
|
|
|
|
|
|
import (
|
|
|
- "bytes"
|
|
|
"encoding/base64"
|
|
|
"encoding/json"
|
|
|
"io"
|
|
@@ -13,40 +12,28 @@ import (
|
|
|
"time"
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
+ "github.com/google/go-cmp/cmp"
|
|
|
|
|
|
"github.com/ollama/ollama/api"
|
|
|
)
|
|
|
|
|
|
-const (
|
|
|
- prefix = `data:image/jpeg;base64,`
|
|
|
- image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
|
|
-)
|
|
|
-
|
|
|
-var False = false
|
|
|
-
|
|
|
-func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
|
|
+func capture(req any) gin.HandlerFunc {
|
|
|
return func(c *gin.Context) {
|
|
|
- bodyBytes, _ := io.ReadAll(c.Request.Body)
|
|
|
- c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
- err := json.Unmarshal(bodyBytes, capturedRequest)
|
|
|
- if err != nil {
|
|
|
- c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
|
|
|
- }
|
|
|
+ body, _ := io.ReadAll(c.Request.Body)
|
|
|
+ json.Unmarshal(body, req)
|
|
|
c.Next()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func TestChatMiddleware(t *testing.T) {
|
|
|
- type testCase struct {
|
|
|
+ type test struct {
|
|
|
name string
|
|
|
body string
|
|
|
req api.ChatRequest
|
|
|
err ErrorResponse
|
|
|
}
|
|
|
|
|
|
- var capturedRequest *api.ChatRequest
|
|
|
-
|
|
|
- testCases := []testCase{
|
|
|
+ tests := []test{
|
|
|
{
|
|
|
name: "chat handler",
|
|
|
body: `{
|
|
@@ -67,7 +54,36 @@ func TestChatMiddleware(t *testing.T) {
|
|
|
"temperature": 1.0,
|
|
|
"top_p": 1.0,
|
|
|
},
|
|
|
- Stream: &False,
|
|
|
+ Stream: func() *bool { f := false; return &f }(),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "chat handler with large context",
|
|
|
+ body: `{
|
|
|
+ "model": "test-model",
|
|
|
+ "messages": [
|
|
|
+ {"role": "user", "content": "Hello"}
|
|
|
+ ],
|
|
|
+ "max_tokens": 16384
|
|
|
+ }`,
|
|
|
+ req: api.ChatRequest{
|
|
|
+ Model: "test-model",
|
|
|
+ Messages: []api.Message{
|
|
|
+ {
|
|
|
+ Role: "user",
|
|
|
+ Content: "Hello",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ Options: map[string]any{
|
|
|
+ "temperature": 1.0,
|
|
|
+ "top_p": 1.0,
|
|
|
+
|
|
|
+ // TODO (jmorganca): because we use a map[string]any for options
|
|
|
+ // the values need to be floats for the test comparison to work.
|
|
|
+ "num_predict": 16384.0,
|
|
|
+ "num_ctx": 16384.0,
|
|
|
+ },
|
|
|
+ Stream: func() *bool { f := false; return &f }(),
|
|
|
},
|
|
|
},
|
|
|
{
|
|
@@ -85,7 +101,7 @@ func TestChatMiddleware(t *testing.T) {
|
|
|
{
|
|
|
"type": "image_url",
|
|
|
"image_url": {
|
|
|
- "url": "` + prefix + image + `"
|
|
|
+ "url": "data:image/jpeg;base64,ZGF0YQo="
|
|
|
}
|
|
|
}
|
|
|
]
|
|
@@ -103,7 +119,7 @@ func TestChatMiddleware(t *testing.T) {
|
|
|
Role: "user",
|
|
|
Images: []api.ImageData{
|
|
|
func() []byte {
|
|
|
- img, _ := base64.StdEncoding.DecodeString(image)
|
|
|
+ img, _ := base64.StdEncoding.DecodeString("ZGF0YQo=")
|
|
|
return img
|
|
|
}(),
|
|
|
},
|
|
@@ -113,7 +129,7 @@ func TestChatMiddleware(t *testing.T) {
|
|
|
"temperature": 1.0,
|
|
|
"top_p": 1.0,
|
|
|
},
|
|
|
- Stream: &False,
|
|
|
+ Stream: func() *bool { f := false; return &f }(),
|
|
|
},
|
|
|
},
|
|
|
{
|
|
@@ -151,7 +167,7 @@ func TestChatMiddleware(t *testing.T) {
|
|
|
"temperature": 1.0,
|
|
|
"top_p": 1.0,
|
|
|
},
|
|
|
- Stream: &False,
|
|
|
+ Stream: func() *bool { f := false; return &f }(),
|
|
|
},
|
|
|
},
|
|
|
|
|
@@ -172,52 +188,50 @@ func TestChatMiddleware(t *testing.T) {
|
|
|
},
|
|
|
}
|
|
|
|
|
|
- endpoint := func(c *gin.Context) {
|
|
|
- c.Status(http.StatusOK)
|
|
|
- }
|
|
|
-
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
- router := gin.New()
|
|
|
- router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
|
|
|
- router.Handle(http.MethodPost, "/api/chat", endpoint)
|
|
|
|
|
|
- for _, tc := range testCases {
|
|
|
- t.Run(tc.name, func(t *testing.T) {
|
|
|
- req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
|
|
|
- req.Header.Set("Content-Type", "application/json")
|
|
|
+ for _, tt := range tests {
|
|
|
+ var req api.ChatRequest
|
|
|
+
|
|
|
+ router := gin.New()
|
|
|
+ router.Use(ChatMiddleware(), capture(&req))
|
|
|
+ router.Handle(http.MethodPost, "/api/chat", func(c *gin.Context) {
|
|
|
+ c.Status(http.StatusOK)
|
|
|
+ })
|
|
|
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ r, _ := http.NewRequest("POST", "/api/chat", strings.NewReader(tt.body))
|
|
|
+ r.Header.Set("Content-Type", "application/json")
|
|
|
resp := httptest.NewRecorder()
|
|
|
- router.ServeHTTP(resp, req)
|
|
|
+ router.ServeHTTP(resp, r)
|
|
|
|
|
|
- var errResp ErrorResponse
|
|
|
+ var err ErrorResponse
|
|
|
if resp.Code != http.StatusOK {
|
|
|
- if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
|
|
+ if err := json.Unmarshal(resp.Body.Bytes(), &err); err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
}
|
|
|
- if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
|
|
- t.Fatal("requests did not match")
|
|
|
+
|
|
|
+ if diff := cmp.Diff(tt.req, req); diff != "" {
|
|
|
+ t.Errorf("mismatch (-want +got):\n%s", diff)
|
|
|
}
|
|
|
|
|
|
- if !reflect.DeepEqual(tc.err, errResp) {
|
|
|
- t.Fatal("errors did not match")
|
|
|
+ if diff := cmp.Diff(tt.err, err); diff != "" {
|
|
|
+ t.Errorf("mismatch (-want +got):\n%s", diff)
|
|
|
}
|
|
|
- capturedRequest = nil
|
|
|
})
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func TestCompletionsMiddleware(t *testing.T) {
|
|
|
- type testCase struct {
|
|
|
+ type test struct {
|
|
|
name string
|
|
|
body string
|
|
|
req api.GenerateRequest
|
|
|
err ErrorResponse
|
|
|
}
|
|
|
|
|
|
- var capturedRequest *api.GenerateRequest
|
|
|
-
|
|
|
- testCases := []testCase{
|
|
|
+ tests := []test{
|
|
|
{
|
|
|
name: "completions handler",
|
|
|
body: `{
|
|
@@ -238,7 +252,7 @@ func TestCompletionsMiddleware(t *testing.T) {
|
|
|
"stop": []any{"\n", "stop"},
|
|
|
},
|
|
|
Suffix: "suffix",
|
|
|
- Stream: &False,
|
|
|
+ Stream: func() *bool { f := false; return &f }(),
|
|
|
},
|
|
|
},
|
|
|
{
|
|
@@ -259,54 +273,51 @@ func TestCompletionsMiddleware(t *testing.T) {
|
|
|
},
|
|
|
}
|
|
|
|
|
|
- endpoint := func(c *gin.Context) {
|
|
|
- c.Status(http.StatusOK)
|
|
|
- }
|
|
|
-
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
- router := gin.New()
|
|
|
- router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
|
|
- router.Handle(http.MethodPost, "/api/generate", endpoint)
|
|
|
|
|
|
- for _, tc := range testCases {
|
|
|
- t.Run(tc.name, func(t *testing.T) {
|
|
|
- req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
|
|
- req.Header.Set("Content-Type", "application/json")
|
|
|
+ for _, tt := range tests {
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ var req api.GenerateRequest
|
|
|
|
|
|
- resp := httptest.NewRecorder()
|
|
|
- router.ServeHTTP(resp, req)
|
|
|
+ router := gin.New()
|
|
|
+ router.Use(CompletionsMiddleware(), capture(&req))
|
|
|
+ router.Handle(http.MethodPost, "/api/generate", func(c *gin.Context) {
|
|
|
+ c.Status(http.StatusOK)
|
|
|
+ })
|
|
|
+
|
|
|
+ r, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tt.body))
|
|
|
+ r.Header.Set("Content-Type", "application/json")
|
|
|
+
|
|
|
+ res := httptest.NewRecorder()
|
|
|
+ router.ServeHTTP(res, r)
|
|
|
|
|
|
var errResp ErrorResponse
|
|
|
- if resp.Code != http.StatusOK {
|
|
|
- if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
|
|
+ if res.Code != http.StatusOK {
|
|
|
+ if err := json.Unmarshal(res.Body.Bytes(), &errResp); err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
|
|
- t.Fatal("requests did not match")
|
|
|
+ if !cmp.Equal(tt.req, req) {
|
|
|
+ t.Fatalf("requests did not match:\n%s", cmp.Diff(tt.req, req))
|
|
|
}
|
|
|
|
|
|
- if !reflect.DeepEqual(tc.err, errResp) {
|
|
|
- t.Fatal("errors did not match")
|
|
|
+ if !cmp.Equal(tt.err, errResp) {
|
|
|
+ t.Fatalf("errors did not match:\n%s", cmp.Diff(tt.err, errResp))
|
|
|
}
|
|
|
-
|
|
|
- capturedRequest = nil
|
|
|
})
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func TestEmbeddingsMiddleware(t *testing.T) {
|
|
|
- type testCase struct {
|
|
|
+ type test struct {
|
|
|
name string
|
|
|
body string
|
|
|
req api.EmbedRequest
|
|
|
err ErrorResponse
|
|
|
}
|
|
|
|
|
|
- var capturedRequest *api.EmbedRequest
|
|
|
-
|
|
|
- testCases := []testCase{
|
|
|
+ tests := []test{
|
|
|
{
|
|
|
name: "embed handler single input",
|
|
|
body: `{
|
|
@@ -348,17 +359,20 @@ func TestEmbeddingsMiddleware(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
- router := gin.New()
|
|
|
- router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
|
|
- router.Handle(http.MethodPost, "/api/embed", endpoint)
|
|
|
|
|
|
- for _, tc := range testCases {
|
|
|
- t.Run(tc.name, func(t *testing.T) {
|
|
|
- req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
|
|
|
- req.Header.Set("Content-Type", "application/json")
|
|
|
+ for _, tt := range tests {
|
|
|
+ var req api.EmbedRequest
|
|
|
+
|
|
|
+ router := gin.New()
|
|
|
+ router.Use(EmbeddingsMiddleware(), capture(&req))
|
|
|
+ router.Handle(http.MethodPost, "/api/embed", endpoint)
|
|
|
+
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ r, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tt.body))
|
|
|
+ r.Header.Set("Content-Type", "application/json")
|
|
|
|
|
|
resp := httptest.NewRecorder()
|
|
|
- router.ServeHTTP(resp, req)
|
|
|
+ router.ServeHTTP(resp, r)
|
|
|
|
|
|
var errResp ErrorResponse
|
|
|
if resp.Code != http.StatusOK {
|
|
@@ -366,41 +380,37 @@ func TestEmbeddingsMiddleware(t *testing.T) {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
- if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
|
|
- t.Fatal("requests did not match")
|
|
|
+ if diff := cmp.Diff(tt.req, req); diff != "" {
|
|
|
+ t.Errorf("request mismatch (-want +got):\n%s", diff)
|
|
|
}
|
|
|
|
|
|
- if !reflect.DeepEqual(tc.err, errResp) {
|
|
|
- t.Fatal("errors did not match")
|
|
|
+ if diff := cmp.Diff(tt.err, errResp); diff != "" {
|
|
|
+ t.Errorf("error mismatch (-want +got):\n%s", diff)
|
|
|
}
|
|
|
-
|
|
|
- capturedRequest = nil
|
|
|
})
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func TestListMiddleware(t *testing.T) {
|
|
|
- type testCase struct {
|
|
|
- name string
|
|
|
- endpoint func(c *gin.Context)
|
|
|
- resp string
|
|
|
+ type test struct {
|
|
|
+ name string
|
|
|
+ handler gin.HandlerFunc
|
|
|
+ body string
|
|
|
}
|
|
|
|
|
|
- testCases := []testCase{
|
|
|
+ tests := []test{
|
|
|
{
|
|
|
name: "list handler",
|
|
|
- endpoint: func(c *gin.Context) {
|
|
|
+ handler: func(c *gin.Context) {
|
|
|
c.JSON(http.StatusOK, api.ListResponse{
|
|
|
Models: []api.ListModelResponse{
|
|
|
{
|
|
|
Name: "test-model",
|
|
|
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
|
|
},
|
|
|
- },
|
|
|
- })
|
|
|
+ }})
|
|
|
},
|
|
|
- resp: `{
|
|
|
+ body: `{
|
|
|
"object": "list",
|
|
|
"data": [
|
|
|
{
|
|
@@ -414,10 +424,12 @@ func TestListMiddleware(t *testing.T) {
|
|
|
},
|
|
|
{
|
|
|
name: "list handler empty output",
|
|
|
- endpoint: func(c *gin.Context) {
|
|
|
- c.JSON(http.StatusOK, api.ListResponse{})
|
|
|
+ handler: func(c *gin.Context) {
|
|
|
+ c.JSON(http.StatusOK, api.ListResponse{
|
|
|
+ Models: []api.ListModelResponse{},
|
|
|
+ })
|
|
|
},
|
|
|
- resp: `{
|
|
|
+ body: `{
|
|
|
"object": "list",
|
|
|
"data": null
|
|
|
}`,
|
|
@@ -426,17 +438,17 @@ func TestListMiddleware(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
- for _, tc := range testCases {
|
|
|
+ for _, tt := range tests {
|
|
|
router := gin.New()
|
|
|
router.Use(ListMiddleware())
|
|
|
- router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
|
|
|
+ router.Handle(http.MethodGet, "/api/tags", tt.handler)
|
|
|
req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
|
|
|
|
|
|
resp := httptest.NewRecorder()
|
|
|
router.ServeHTTP(resp, req)
|
|
|
|
|
|
var expected, actual map[string]any
|
|
|
- err := json.Unmarshal([]byte(tc.resp), &expected)
|
|
|
+ err := json.Unmarshal([]byte(tt.body), &expected)
|
|
|
if err != nil {
|
|
|
t.Fatalf("failed to unmarshal expected response: %v", err)
|
|
|
}
|
|
@@ -446,28 +458,28 @@ func TestListMiddleware(t *testing.T) {
|
|
|
t.Fatalf("failed to unmarshal actual response: %v", err)
|
|
|
}
|
|
|
|
|
|
- if !reflect.DeepEqual(expected, actual) {
|
|
|
- t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
|
|
+ if diff := cmp.Diff(expected, actual); diff != "" {
|
|
|
+ t.Errorf("responses did not match (-want +got):\n%s", diff)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func TestRetrieveMiddleware(t *testing.T) {
|
|
|
- type testCase struct {
|
|
|
- name string
|
|
|
- endpoint func(c *gin.Context)
|
|
|
- resp string
|
|
|
+ type test struct {
|
|
|
+ name string
|
|
|
+ handler gin.HandlerFunc
|
|
|
+ body string
|
|
|
}
|
|
|
|
|
|
- testCases := []testCase{
|
|
|
+ tests := []test{
|
|
|
{
|
|
|
name: "retrieve handler",
|
|
|
- endpoint: func(c *gin.Context) {
|
|
|
+ handler: func(c *gin.Context) {
|
|
|
c.JSON(http.StatusOK, api.ShowResponse{
|
|
|
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
|
|
})
|
|
|
},
|
|
|
- resp: `{
|
|
|
+ body: `{
|
|
|
"id":"test-model",
|
|
|
"object":"model",
|
|
|
"created":1686935002,
|
|
@@ -476,10 +488,10 @@ func TestRetrieveMiddleware(t *testing.T) {
|
|
|
},
|
|
|
{
|
|
|
name: "retrieve handler error forwarding",
|
|
|
- endpoint: func(c *gin.Context) {
|
|
|
+ handler: func(c *gin.Context) {
|
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
|
|
|
},
|
|
|
- resp: `{
|
|
|
+ body: `{
|
|
|
"error": {
|
|
|
"code": null,
|
|
|
"message": "model not found",
|
|
@@ -492,17 +504,17 @@ func TestRetrieveMiddleware(t *testing.T) {
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
|
|
- for _, tc := range testCases {
|
|
|
+ for _, tt := range tests {
|
|
|
router := gin.New()
|
|
|
router.Use(RetrieveMiddleware())
|
|
|
- router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
|
|
|
+ router.Handle(http.MethodGet, "/api/show/:model", tt.handler)
|
|
|
req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
|
|
|
|
|
|
resp := httptest.NewRecorder()
|
|
|
router.ServeHTTP(resp, req)
|
|
|
|
|
|
var expected, actual map[string]any
|
|
|
- err := json.Unmarshal([]byte(tc.resp), &expected)
|
|
|
+ err := json.Unmarshal([]byte(tt.body), &expected)
|
|
|
if err != nil {
|
|
|
t.Fatalf("failed to unmarshal expected response: %v", err)
|
|
|
}
|