123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499 |
- package openai
- import (
- "bytes"
- "encoding/base64"
- "encoding/json"
- "io"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "time"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/assert"
- "github.com/ollama/ollama/api"
- )
- const (
- prefix = `data:image/jpeg;base64,`
- image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
- imageURL = prefix + image
- )
- func prepareRequest(req *http.Request, body any) {
- bodyBytes, _ := json.Marshal(body)
- req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
- req.Header.Set("Content-Type", "application/json")
- }
- func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
- return func(c *gin.Context) {
- bodyBytes, _ := io.ReadAll(c.Request.Body)
- c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
- err := json.Unmarshal(bodyBytes, capturedRequest)
- if err != nil {
- c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
- }
- c.Next()
- }
- }
- func TestChatMiddleware(t *testing.T) {
- type testCase struct {
- Name string
- Setup func(t *testing.T, req *http.Request)
- Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
- }
- var capturedRequest *api.ChatRequest
- testCases := []testCase{
- {
- Name: "chat handler",
- Setup: func(t *testing.T, req *http.Request) {
- body := ChatCompletionRequest{
- Model: "test-model",
- Messages: []Message{{Role: "user", Content: "Hello"}},
- }
- prepareRequest(req, body)
- },
- Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
- if resp.Code != http.StatusOK {
- t.Fatalf("expected 200, got %d", resp.Code)
- }
- if req.Messages[0].Role != "user" {
- t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
- }
- if req.Messages[0].Content != "Hello" {
- t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
- }
- },
- },
- {
- Name: "chat handler with image content",
- Setup: func(t *testing.T, req *http.Request) {
- body := ChatCompletionRequest{
- Model: "test-model",
- Messages: []Message{
- {
- Role: "user", Content: []map[string]any{
- {"type": "text", "text": "Hello"},
- {"type": "image_url", "image_url": map[string]string{"url": imageURL}},
- },
- },
- },
- }
- prepareRequest(req, body)
- },
- Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
- if resp.Code != http.StatusOK {
- t.Fatalf("expected 200, got %d", resp.Code)
- }
- if req.Messages[0].Role != "user" {
- t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
- }
- if req.Messages[0].Content != "Hello" {
- t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
- }
- img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
- if req.Messages[1].Role != "user" {
- t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
- }
- if !bytes.Equal(req.Messages[1].Images[0], img) {
- t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
- }
- },
- },
- {
- Name: "chat handler with tools",
- Setup: func(t *testing.T, req *http.Request) {
- body := ChatCompletionRequest{
- Model: "test-model",
- Messages: []Message{
- {Role: "user", Content: "What's the weather like in Paris Today?"},
- {Role: "assistant", ToolCalls: []ToolCall{{
- ID: "id",
- Type: "function",
- Function: struct {
- Name string `json:"name"`
- Arguments string `json:"arguments"`
- }{
- Name: "get_current_weather",
- Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
- },
- }}},
- },
- }
- prepareRequest(req, body)
- },
- Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
- if resp.Code != 200 {
- t.Fatalf("expected 200, got %d", resp.Code)
- }
- if req.Messages[0].Content != "What's the weather like in Paris Today?" {
- t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
- }
- if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
- t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
- }
- if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
- t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
- }
- },
- },
- {
- Name: "chat handler error forwarding",
- Setup: func(t *testing.T, req *http.Request) {
- body := ChatCompletionRequest{
- Model: "test-model",
- Messages: []Message{{Role: "user", Content: 2}},
- }
- prepareRequest(req, body)
- },
- Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
- if resp.Code != http.StatusBadRequest {
- t.Fatalf("expected 400, got %d", resp.Code)
- }
- if !strings.Contains(resp.Body.String(), "invalid message content type") {
- t.Fatalf("error was not forwarded")
- }
- },
- },
- }
- endpoint := func(c *gin.Context) {
- c.Status(http.StatusOK)
- }
- gin.SetMode(gin.TestMode)
- router := gin.New()
- router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
- router.Handle(http.MethodPost, "/api/chat", endpoint)
- for _, tc := range testCases {
- t.Run(tc.Name, func(t *testing.T) {
- req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
- tc.Setup(t, req)
- resp := httptest.NewRecorder()
- router.ServeHTTP(resp, req)
- tc.Expected(t, capturedRequest, resp)
- capturedRequest = nil
- })
- }
- }
- func TestCompletionsMiddleware(t *testing.T) {
- type testCase struct {
- Name string
- Setup func(t *testing.T, req *http.Request)
- Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
- }
- var capturedRequest *api.GenerateRequest
- testCases := []testCase{
- {
- Name: "completions handler",
- Setup: func(t *testing.T, req *http.Request) {
- temp := float32(0.8)
- body := CompletionRequest{
- Model: "test-model",
- Prompt: "Hello",
- Temperature: &temp,
- Stop: []string{"\n", "stop"},
- Suffix: "suffix",
- }
- prepareRequest(req, body)
- },
- Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
- if req.Prompt != "Hello" {
- t.Fatalf("expected 'Hello', got %s", req.Prompt)
- }
- if req.Options["temperature"] != 1.6 {
- t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
- }
- stopTokens, ok := req.Options["stop"].([]any)
- if !ok {
- t.Fatalf("expected stop tokens to be a list")
- }
- if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
- t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
- }
- if req.Suffix != "suffix" {
- t.Fatalf("expected 'suffix', got %s", req.Suffix)
- }
- },
- },
- {
- Name: "completions handler error forwarding",
- Setup: func(t *testing.T, req *http.Request) {
- body := CompletionRequest{
- Model: "test-model",
- Prompt: "Hello",
- Temperature: nil,
- Stop: []int{1, 2},
- Suffix: "suffix",
- }
- prepareRequest(req, body)
- },
- Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
- if resp.Code != http.StatusBadRequest {
- t.Fatalf("expected 400, got %d", resp.Code)
- }
- if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
- t.Fatalf("error was not forwarded")
- }
- },
- },
- }
- endpoint := func(c *gin.Context) {
- c.Status(http.StatusOK)
- }
- gin.SetMode(gin.TestMode)
- router := gin.New()
- router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
- router.Handle(http.MethodPost, "/api/generate", endpoint)
- for _, tc := range testCases {
- t.Run(tc.Name, func(t *testing.T) {
- req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
- tc.Setup(t, req)
- resp := httptest.NewRecorder()
- router.ServeHTTP(resp, req)
- tc.Expected(t, capturedRequest, resp)
- capturedRequest = nil
- })
- }
- }
- func TestEmbeddingsMiddleware(t *testing.T) {
- type testCase struct {
- Name string
- Setup func(t *testing.T, req *http.Request)
- Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
- }
- var capturedRequest *api.EmbedRequest
- testCases := []testCase{
- {
- Name: "embed handler single input",
- Setup: func(t *testing.T, req *http.Request) {
- body := EmbedRequest{
- Input: "Hello",
- Model: "test-model",
- }
- prepareRequest(req, body)
- },
- Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
- if req.Input != "Hello" {
- t.Fatalf("expected 'Hello', got %s", req.Input)
- }
- if req.Model != "test-model" {
- t.Fatalf("expected 'test-model', got %s", req.Model)
- }
- },
- },
- {
- Name: "embed handler batch input",
- Setup: func(t *testing.T, req *http.Request) {
- body := EmbedRequest{
- Input: []string{"Hello", "World"},
- Model: "test-model",
- }
- prepareRequest(req, body)
- },
- Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
- input, ok := req.Input.([]any)
- if !ok {
- t.Fatalf("expected input to be a list")
- }
- if input[0].(string) != "Hello" {
- t.Fatalf("expected 'Hello', got %s", input[0])
- }
- if input[1].(string) != "World" {
- t.Fatalf("expected 'World', got %s", input[1])
- }
- if req.Model != "test-model" {
- t.Fatalf("expected 'test-model', got %s", req.Model)
- }
- },
- },
- {
- Name: "embed handler error forwarding",
- Setup: func(t *testing.T, req *http.Request) {
- body := EmbedRequest{
- Model: "test-model",
- }
- prepareRequest(req, body)
- },
- Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
- if resp.Code != http.StatusBadRequest {
- t.Fatalf("expected 400, got %d", resp.Code)
- }
- if !strings.Contains(resp.Body.String(), "invalid input") {
- t.Fatalf("error was not forwarded")
- }
- },
- },
- }
- endpoint := func(c *gin.Context) {
- c.Status(http.StatusOK)
- }
- gin.SetMode(gin.TestMode)
- router := gin.New()
- router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
- router.Handle(http.MethodPost, "/api/embed", endpoint)
- for _, tc := range testCases {
- t.Run(tc.Name, func(t *testing.T) {
- req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
- tc.Setup(t, req)
- resp := httptest.NewRecorder()
- router.ServeHTTP(resp, req)
- tc.Expected(t, capturedRequest, resp)
- capturedRequest = nil
- })
- }
- }
- func TestMiddlewareResponses(t *testing.T) {
- type testCase struct {
- Name string
- Method string
- Path string
- TestPath string
- Handler func() gin.HandlerFunc
- Endpoint func(c *gin.Context)
- Setup func(t *testing.T, req *http.Request)
- Expected func(t *testing.T, resp *httptest.ResponseRecorder)
- }
- testCases := []testCase{
- {
- Name: "list handler",
- Method: http.MethodGet,
- Path: "/api/tags",
- TestPath: "/api/tags",
- Handler: ListMiddleware,
- Endpoint: func(c *gin.Context) {
- c.JSON(http.StatusOK, api.ListResponse{
- Models: []api.ListModelResponse{
- {
- Name: "Test Model",
- },
- },
- })
- },
- Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
- var listResp ListCompletion
- if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
- t.Fatal(err)
- }
- if listResp.Object != "list" {
- t.Fatalf("expected list, got %s", listResp.Object)
- }
- if len(listResp.Data) != 1 {
- t.Fatalf("expected 1, got %d", len(listResp.Data))
- }
- if listResp.Data[0].Id != "Test Model" {
- t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
- }
- },
- },
- {
- Name: "retrieve model",
- Method: http.MethodGet,
- Path: "/api/show/:model",
- TestPath: "/api/show/test-model",
- Handler: RetrieveMiddleware,
- Endpoint: func(c *gin.Context) {
- c.JSON(http.StatusOK, api.ShowResponse{
- ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
- })
- },
- Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
- var retrieveResp Model
- if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
- t.Fatal(err)
- }
- if retrieveResp.Object != "model" {
- t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
- }
- if retrieveResp.Id != "test-model" {
- t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
- }
- },
- },
- }
- gin.SetMode(gin.TestMode)
- router := gin.New()
- for _, tc := range testCases {
- t.Run(tc.Name, func(t *testing.T) {
- router = gin.New()
- router.Use(tc.Handler())
- router.Handle(tc.Method, tc.Path, tc.Endpoint)
- req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
- if tc.Setup != nil {
- tc.Setup(t, req)
- }
- resp := httptest.NewRecorder()
- router.ServeHTTP(resp, req)
- assert.Equal(t, http.StatusOK, resp.Code)
- tc.Expected(t, resp)
- })
- }
- }
|