123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640 |
- package openai
- import (
- "bytes"
- "encoding/base64"
- "encoding/json"
- "io"
- "net/http"
- "net/http/httptest"
- "reflect"
- "strings"
- "testing"
- "time"
- "github.com/gin-gonic/gin"
- "github.com/ollama/ollama/api"
- )
- const (
- prefix = `data:image/jpeg;base64,`
- image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
- )
- var (
- False = false
- True = true
- )
- func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
- return func(c *gin.Context) {
- bodyBytes, _ := io.ReadAll(c.Request.Body)
- c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
- err := json.Unmarshal(bodyBytes, capturedRequest)
- if err != nil {
- c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
- }
- c.Next()
- }
- }
- func TestChatMiddleware(t *testing.T) {
- type testCase struct {
- name string
- body string
- req api.ChatRequest
- err ErrorResponse
- }
- var capturedRequest *api.ChatRequest
- testCases := []testCase{
- {
- name: "chat handler",
- body: `{
- "model": "test-model",
- "messages": [
- {"role": "user", "content": "Hello"}
- ]
- }`,
- req: api.ChatRequest{
- Model: "test-model",
- Messages: []api.Message{
- {
- Role: "user",
- Content: "Hello",
- },
- },
- Options: map[string]any{
- "temperature": 1.0,
- "top_p": 1.0,
- },
- Stream: &False,
- },
- },
- {
- name: "chat handler with options",
- body: `{
- "model": "test-model",
- "messages": [
- {"role": "user", "content": "Hello"}
- ],
- "stream": true,
- "max_tokens": 999,
- "seed": 123,
- "stop": ["\n", "stop"],
- "temperature": 3.0,
- "frequency_penalty": 4.0,
- "presence_penalty": 5.0,
- "top_p": 6.0,
- "response_format": {"type": "json_object"}
- }`,
- req: api.ChatRequest{
- Model: "test-model",
- Messages: []api.Message{
- {
- Role: "user",
- Content: "Hello",
- },
- },
- Options: map[string]any{
- "num_predict": 999.0, // float because JSON doesn't distinguish between float and int
- "seed": 123.0,
- "stop": []any{"\n", "stop"},
- "temperature": 3.0,
- "frequency_penalty": 4.0,
- "presence_penalty": 5.0,
- "top_p": 6.0,
- },
- Format: "json",
- Stream: &True,
- },
- },
- {
- name: "chat handler with image content",
- body: `{
- "model": "test-model",
- "messages": [
- {
- "role": "user",
- "content": [
- {
- "type": "text",
- "text": "Hello"
- },
- {
- "type": "image_url",
- "image_url": {
- "url": "` + prefix + image + `"
- }
- }
- ]
- }
- ]
- }`,
- req: api.ChatRequest{
- Model: "test-model",
- Messages: []api.Message{
- {
- Role: "user",
- Content: "Hello",
- },
- {
- Role: "user",
- Images: []api.ImageData{
- func() []byte {
- img, _ := base64.StdEncoding.DecodeString(image)
- return img
- }(),
- },
- },
- },
- Options: map[string]any{
- "temperature": 1.0,
- "top_p": 1.0,
- },
- Stream: &False,
- },
- },
- {
- name: "chat handler with tools",
- body: `{
- "model": "test-model",
- "messages": [
- {"role": "user", "content": "What's the weather like in Paris Today?"},
- {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
- ]
- }`,
- req: api.ChatRequest{
- Model: "test-model",
- Messages: []api.Message{
- {
- Role: "user",
- Content: "What's the weather like in Paris Today?",
- },
- {
- Role: "assistant",
- ToolCalls: []api.ToolCall{
- {
- Function: api.ToolCallFunction{
- Name: "get_current_weather",
- Arguments: map[string]interface{}{
- "location": "Paris, France",
- "format": "celsius",
- },
- },
- },
- },
- },
- },
- Options: map[string]any{
- "temperature": 1.0,
- "top_p": 1.0,
- },
- Stream: &False,
- },
- },
- {
- name: "chat handler with streaming tools",
- body: `{
- "model": "test-model",
- "messages": [
- {"role": "user", "content": "What's the weather like in Paris?"}
- ],
- "stream": true,
- "tools": [{
- "type": "function",
- "function": {
- "name": "get_weather",
- "description": "Get the current weather",
- "parameters": {
- "type": "object",
- "required": ["location"],
- "properties": {
- "location": {
- "type": "string",
- "description": "The city and state"
- },
- "unit": {
- "type": "string",
- "enum": ["celsius", "fahrenheit"]
- }
- }
- }
- }
- }]
- }`,
- req: api.ChatRequest{
- Model: "test-model",
- Messages: []api.Message{
- {
- Role: "user",
- Content: "What's the weather like in Paris?",
- },
- },
- Tools: []api.Tool{
- {
- Type: "function",
- Function: api.ToolFunction{
- Name: "get_weather",
- Description: "Get the current weather",
- Parameters: struct {
- Type string `json:"type"`
- Required []string `json:"required"`
- Properties map[string]struct {
- Type string `json:"type"`
- Description string `json:"description"`
- Enum []string `json:"enum,omitempty"`
- } `json:"properties"`
- }{
- Type: "object",
- Required: []string{"location"},
- Properties: map[string]struct {
- Type string `json:"type"`
- Description string `json:"description"`
- Enum []string `json:"enum,omitempty"`
- }{
- "location": {
- Type: "string",
- Description: "The city and state",
- },
- "unit": {
- Type: "string",
- Enum: []string{"celsius", "fahrenheit"},
- },
- },
- },
- },
- },
- },
- Options: map[string]any{
- "temperature": 1.0,
- "top_p": 1.0,
- },
- Stream: &True,
- },
- },
- {
- name: "chat handler error forwarding",
- body: `{
- "model": "test-model",
- "messages": [
- {"role": "user", "content": 2}
- ]
- }`,
- err: ErrorResponse{
- Error: Error{
- Message: "invalid message content type: float64",
- Type: "invalid_request_error",
- },
- },
- },
- }
- endpoint := func(c *gin.Context) {
- c.Status(http.StatusOK)
- }
- gin.SetMode(gin.TestMode)
- router := gin.New()
- router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
- router.Handle(http.MethodPost, "/api/chat", endpoint)
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
- req.Header.Set("Content-Type", "application/json")
- defer func() { capturedRequest = nil }()
- resp := httptest.NewRecorder()
- router.ServeHTTP(resp, req)
- var errResp ErrorResponse
- if resp.Code != http.StatusOK {
- if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
- t.Fatal(err)
- }
- }
- if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
- t.Fatal("requests did not match")
- }
- if !reflect.DeepEqual(tc.err, errResp) {
- t.Fatal("errors did not match")
- }
- })
- }
- }
- func TestCompletionsMiddleware(t *testing.T) {
- type testCase struct {
- name string
- body string
- req api.GenerateRequest
- err ErrorResponse
- }
- var capturedRequest *api.GenerateRequest
- testCases := []testCase{
- {
- name: "completions handler",
- body: `{
- "model": "test-model",
- "prompt": "Hello",
- "temperature": 0.8,
- "stop": ["\n", "stop"],
- "suffix": "suffix"
- }`,
- req: api.GenerateRequest{
- Model: "test-model",
- Prompt: "Hello",
- Options: map[string]any{
- "frequency_penalty": 0.0,
- "presence_penalty": 0.0,
- "temperature": 0.8,
- "top_p": 1.0,
- "stop": []any{"\n", "stop"},
- },
- Suffix: "suffix",
- Stream: &False,
- },
- },
- {
- name: "completions handler error forwarding",
- body: `{
- "model": "test-model",
- "prompt": "Hello",
- "temperature": null,
- "stop": [1, 2],
- "suffix": "suffix"
- }`,
- err: ErrorResponse{
- Error: Error{
- Message: "invalid type for 'stop' field: float64",
- Type: "invalid_request_error",
- },
- },
- },
- }
- endpoint := func(c *gin.Context) {
- c.Status(http.StatusOK)
- }
- gin.SetMode(gin.TestMode)
- router := gin.New()
- router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
- router.Handle(http.MethodPost, "/api/generate", endpoint)
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
- req.Header.Set("Content-Type", "application/json")
- resp := httptest.NewRecorder()
- router.ServeHTTP(resp, req)
- var errResp ErrorResponse
- if resp.Code != http.StatusOK {
- if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
- t.Fatal(err)
- }
- }
- if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
- t.Fatal("requests did not match")
- }
- if !reflect.DeepEqual(tc.err, errResp) {
- t.Fatal("errors did not match")
- }
- capturedRequest = nil
- })
- }
- }
- func TestEmbeddingsMiddleware(t *testing.T) {
- type testCase struct {
- name string
- body string
- req api.EmbedRequest
- err ErrorResponse
- }
- var capturedRequest *api.EmbedRequest
- testCases := []testCase{
- {
- name: "embed handler single input",
- body: `{
- "input": "Hello",
- "model": "test-model"
- }`,
- req: api.EmbedRequest{
- Input: "Hello",
- Model: "test-model",
- },
- },
- {
- name: "embed handler batch input",
- body: `{
- "input": ["Hello", "World"],
- "model": "test-model"
- }`,
- req: api.EmbedRequest{
- Input: []any{"Hello", "World"},
- Model: "test-model",
- },
- },
- {
- name: "embed handler error forwarding",
- body: `{
- "model": "test-model"
- }`,
- err: ErrorResponse{
- Error: Error{
- Message: "invalid input",
- Type: "invalid_request_error",
- },
- },
- },
- }
- endpoint := func(c *gin.Context) {
- c.Status(http.StatusOK)
- }
- gin.SetMode(gin.TestMode)
- router := gin.New()
- router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
- router.Handle(http.MethodPost, "/api/embed", endpoint)
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
- req.Header.Set("Content-Type", "application/json")
- resp := httptest.NewRecorder()
- router.ServeHTTP(resp, req)
- var errResp ErrorResponse
- if resp.Code != http.StatusOK {
- if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
- t.Fatal(err)
- }
- }
- if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
- t.Fatal("requests did not match")
- }
- if !reflect.DeepEqual(tc.err, errResp) {
- t.Fatal("errors did not match")
- }
- capturedRequest = nil
- })
- }
- }
- func TestListMiddleware(t *testing.T) {
- type testCase struct {
- name string
- endpoint func(c *gin.Context)
- resp string
- }
- testCases := []testCase{
- {
- name: "list handler",
- endpoint: func(c *gin.Context) {
- c.JSON(http.StatusOK, api.ListResponse{
- Models: []api.ListModelResponse{
- {
- Name: "test-model",
- ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
- },
- },
- })
- },
- resp: `{
- "object": "list",
- "data": [
- {
- "id": "test-model",
- "object": "model",
- "created": 1686935002,
- "owned_by": "library"
- }
- ]
- }`,
- },
- {
- name: "list handler empty output",
- endpoint: func(c *gin.Context) {
- c.JSON(http.StatusOK, api.ListResponse{})
- },
- resp: `{
- "object": "list",
- "data": null
- }`,
- },
- }
- gin.SetMode(gin.TestMode)
- for _, tc := range testCases {
- router := gin.New()
- router.Use(ListMiddleware())
- router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
- req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
- resp := httptest.NewRecorder()
- router.ServeHTTP(resp, req)
- var expected, actual map[string]any
- err := json.Unmarshal([]byte(tc.resp), &expected)
- if err != nil {
- t.Fatalf("failed to unmarshal expected response: %v", err)
- }
- err = json.Unmarshal(resp.Body.Bytes(), &actual)
- if err != nil {
- t.Fatalf("failed to unmarshal actual response: %v", err)
- }
- if !reflect.DeepEqual(expected, actual) {
- t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
- }
- }
- }
- func TestRetrieveMiddleware(t *testing.T) {
- type testCase struct {
- name string
- endpoint func(c *gin.Context)
- resp string
- }
- testCases := []testCase{
- {
- name: "retrieve handler",
- endpoint: func(c *gin.Context) {
- c.JSON(http.StatusOK, api.ShowResponse{
- ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
- })
- },
- resp: `{
- "id":"test-model",
- "object":"model",
- "created":1686935002,
- "owned_by":"library"}
- `,
- },
- {
- name: "retrieve handler error forwarding",
- endpoint: func(c *gin.Context) {
- c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
- },
- resp: `{
- "error": {
- "code": null,
- "message": "model not found",
- "param": null,
- "type": "api_error"
- }
- }`,
- },
- }
- gin.SetMode(gin.TestMode)
- for _, tc := range testCases {
- router := gin.New()
- router.Use(RetrieveMiddleware())
- router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
- req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
- resp := httptest.NewRecorder()
- router.ServeHTTP(resp, req)
- var expected, actual map[string]any
- err := json.Unmarshal([]byte(tc.resp), &expected)
- if err != nil {
- t.Fatalf("failed to unmarshal expected response: %v", err)
- }
- err = json.Unmarshal(resp.Body.Bytes(), &actual)
- if err != nil {
- t.Fatalf("failed to unmarshal actual response: %v", err)
- }
- if !reflect.DeepEqual(expected, actual) {
- t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
- }
- }
- }
|