123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- package openai
- import (
- "bytes"
- "encoding/json"
- "io"
- "net/http"
- "net/http/httptest"
- "testing"
- "time"
- "github.com/gin-gonic/gin"
- "github.com/ollama/ollama/api"
- "github.com/stretchr/testify/assert"
- )
- func TestMiddleware(t *testing.T) {
- type testCase struct {
- Name string
- Method string
- Path string
- TestPath string
- Handler func() gin.HandlerFunc
- Endpoint func(c *gin.Context)
- Setup func(t *testing.T, req *http.Request)
- Expected func(t *testing.T, resp *httptest.ResponseRecorder)
- }
- testCases := []testCase{
- {
- Name: "chat handler",
- Method: http.MethodPost,
- Path: "/api/chat",
- TestPath: "/api/chat",
- Handler: ChatMiddleware,
- Endpoint: func(c *gin.Context) {
- var chatReq api.ChatRequest
- if err := c.ShouldBindJSON(&chatReq); err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
- return
- }
- userMessage := chatReq.Messages[0].Content
- var assistantMessage string
- switch userMessage {
- case "Hello":
- assistantMessage = "Hello!"
- default:
- assistantMessage = "I'm not sure how to respond to that."
- }
- c.JSON(http.StatusOK, api.ChatResponse{
- Message: api.Message{
- Role: "assistant",
- Content: assistantMessage,
- },
- })
- },
- Setup: func(t *testing.T, req *http.Request) {
- body := ChatCompletionRequest{
- Model: "test-model",
- Messages: []Message{{Role: "user", Content: "Hello"}},
- }
- bodyBytes, _ := json.Marshal(body)
- req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
- req.Header.Set("Content-Type", "application/json")
- },
- Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
- var chatResp ChatCompletion
- if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
- t.Fatal(err)
- }
- if chatResp.Object != "chat.completion" {
- t.Fatalf("expected chat.completion, got %s", chatResp.Object)
- }
- if chatResp.Choices[0].Message.Content != "Hello!" {
- t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
- }
- },
- },
- {
- Name: "list handler",
- Method: http.MethodGet,
- Path: "/api/tags",
- TestPath: "/api/tags",
- Handler: ListMiddleware,
- Endpoint: func(c *gin.Context) {
- c.JSON(http.StatusOK, api.ListResponse{
- Models: []api.ListModelResponse{
- {
- Name: "Test Model",
- },
- },
- })
- },
- Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
- var listResp ListCompletion
- if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
- t.Fatal(err)
- }
- if listResp.Object != "list" {
- t.Fatalf("expected list, got %s", listResp.Object)
- }
- if len(listResp.Data) != 1 {
- t.Fatalf("expected 1, got %d", len(listResp.Data))
- }
- if listResp.Data[0].Id != "Test Model" {
- t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
- }
- },
- },
- {
- Name: "retrieve model",
- Method: http.MethodGet,
- Path: "/api/show/:model",
- TestPath: "/api/show/test-model",
- Handler: RetrieveMiddleware,
- Endpoint: func(c *gin.Context) {
- c.JSON(http.StatusOK, api.ShowResponse{
- ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
- })
- },
- Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
- var retrieveResp Model
- if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
- t.Fatal(err)
- }
- if retrieveResp.Object != "model" {
- t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
- }
- if retrieveResp.Id != "test-model" {
- t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
- }
- },
- },
- }
- gin.SetMode(gin.TestMode)
- router := gin.New()
- for _, tc := range testCases {
- t.Run(tc.Name, func(t *testing.T) {
- router = gin.New()
- router.Use(tc.Handler())
- router.Handle(tc.Method, tc.Path, tc.Endpoint)
- req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
- if tc.Setup != nil {
- tc.Setup(t, req)
- }
- resp := httptest.NewRecorder()
- router.ServeHTTP(resp, req)
- assert.Equal(t, http.StatusOK, resp.Code)
- tc.Expected(t, resp)
- })
- }
- }
|