openai_test.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. package openai
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "io"
  6. "net/http"
  7. "net/http/httptest"
  8. "testing"
  9. "time"
  10. "github.com/gin-gonic/gin"
  11. "github.com/ollama/ollama/api"
  12. "github.com/stretchr/testify/assert"
  13. )
  14. func TestMiddleware(t *testing.T) {
  15. type testCase struct {
  16. Name string
  17. Method string
  18. Path string
  19. TestPath string
  20. Handler func() gin.HandlerFunc
  21. Endpoint func(c *gin.Context)
  22. Setup func(t *testing.T, req *http.Request)
  23. Expected func(t *testing.T, resp *httptest.ResponseRecorder)
  24. }
  25. testCases := []testCase{
  26. {
  27. Name: "chat handler",
  28. Method: http.MethodPost,
  29. Path: "/api/chat",
  30. TestPath: "/api/chat",
  31. Handler: ChatMiddleware,
  32. Endpoint: func(c *gin.Context) {
  33. var chatReq api.ChatRequest
  34. if err := c.ShouldBindJSON(&chatReq); err != nil {
  35. c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
  36. return
  37. }
  38. userMessage := chatReq.Messages[0].Content
  39. var assistantMessage string
  40. switch userMessage {
  41. case "Hello":
  42. assistantMessage = "Hello!"
  43. default:
  44. assistantMessage = "I'm not sure how to respond to that."
  45. }
  46. c.JSON(http.StatusOK, api.ChatResponse{
  47. Message: api.Message{
  48. Role: "assistant",
  49. Content: assistantMessage,
  50. },
  51. })
  52. },
  53. Setup: func(t *testing.T, req *http.Request) {
  54. body := ChatCompletionRequest{
  55. Model: "test-model",
  56. Messages: []Message{{Role: "user", Content: "Hello"}},
  57. }
  58. bodyBytes, _ := json.Marshal(body)
  59. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  60. req.Header.Set("Content-Type", "application/json")
  61. },
  62. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  63. var chatResp ChatCompletion
  64. if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
  65. t.Fatal(err)
  66. }
  67. if chatResp.Object != "chat.completion" {
  68. t.Fatalf("expected chat.completion, got %s", chatResp.Object)
  69. }
  70. if chatResp.Choices[0].Message.Content != "Hello!" {
  71. t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
  72. }
  73. },
  74. },
  75. {
  76. Name: "list handler",
  77. Method: http.MethodGet,
  78. Path: "/api/tags",
  79. TestPath: "/api/tags",
  80. Handler: ListMiddleware,
  81. Endpoint: func(c *gin.Context) {
  82. c.JSON(http.StatusOK, api.ListResponse{
  83. Models: []api.ListModelResponse{
  84. {
  85. Name: "Test Model",
  86. },
  87. },
  88. })
  89. },
  90. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  91. var listResp ListCompletion
  92. if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
  93. t.Fatal(err)
  94. }
  95. if listResp.Object != "list" {
  96. t.Fatalf("expected list, got %s", listResp.Object)
  97. }
  98. if len(listResp.Data) != 1 {
  99. t.Fatalf("expected 1, got %d", len(listResp.Data))
  100. }
  101. if listResp.Data[0].Id != "Test Model" {
  102. t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
  103. }
  104. },
  105. },
  106. {
  107. Name: "retrieve model",
  108. Method: http.MethodGet,
  109. Path: "/api/show/:model",
  110. TestPath: "/api/show/test-model",
  111. Handler: RetrieveMiddleware,
  112. Endpoint: func(c *gin.Context) {
  113. c.JSON(http.StatusOK, api.ShowResponse{
  114. ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
  115. })
  116. },
  117. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  118. var retrieveResp Model
  119. if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
  120. t.Fatal(err)
  121. }
  122. if retrieveResp.Object != "model" {
  123. t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
  124. }
  125. if retrieveResp.Id != "test-model" {
  126. t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
  127. }
  128. },
  129. },
  130. }
  131. gin.SetMode(gin.TestMode)
  132. router := gin.New()
  133. for _, tc := range testCases {
  134. t.Run(tc.Name, func(t *testing.T) {
  135. router = gin.New()
  136. router.Use(tc.Handler())
  137. router.Handle(tc.Method, tc.Path, tc.Endpoint)
  138. req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
  139. if tc.Setup != nil {
  140. tc.Setup(t, req)
  141. }
  142. resp := httptest.NewRecorder()
  143. router.ServeHTTP(resp, req)
  144. assert.Equal(t, http.StatusOK, resp.Code)
  145. tc.Expected(t, resp)
  146. })
  147. }
  148. }