openai_test.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. package openai
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/ollama/ollama/api"
  14. "github.com/stretchr/testify/assert"
  15. )
  16. func TestMiddleware(t *testing.T) {
  17. type testCase struct {
  18. Name string
  19. Method string
  20. Path string
  21. TestPath string
  22. Handler func() gin.HandlerFunc
  23. Endpoint func(c *gin.Context)
  24. Setup func(t *testing.T, req *http.Request)
  25. Expected func(t *testing.T, resp *httptest.ResponseRecorder)
  26. }
  27. testCases := []testCase{
  28. {
  29. Name: "chat handler",
  30. Method: http.MethodPost,
  31. Path: "/api/chat",
  32. TestPath: "/api/chat",
  33. Handler: ChatMiddleware,
  34. Endpoint: func(c *gin.Context) {
  35. var chatReq api.ChatRequest
  36. if err := c.ShouldBindJSON(&chatReq); err != nil {
  37. c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
  38. return
  39. }
  40. userMessage := chatReq.Messages[0].Content
  41. var assistantMessage string
  42. switch userMessage {
  43. case "Hello":
  44. assistantMessage = "Hello!"
  45. default:
  46. assistantMessage = "I'm not sure how to respond to that."
  47. }
  48. c.JSON(http.StatusOK, api.ChatResponse{
  49. Message: api.Message{
  50. Role: "assistant",
  51. Content: assistantMessage,
  52. },
  53. })
  54. },
  55. Setup: func(t *testing.T, req *http.Request) {
  56. body := ChatCompletionRequest{
  57. Model: "test-model",
  58. Messages: []Message{{Role: "user", Content: "Hello"}},
  59. }
  60. bodyBytes, _ := json.Marshal(body)
  61. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  62. req.Header.Set("Content-Type", "application/json")
  63. },
  64. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  65. assert.Equal(t, http.StatusOK, resp.Code)
  66. var chatResp ChatCompletion
  67. if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
  68. t.Fatal(err)
  69. }
  70. if chatResp.Object != "chat.completion" {
  71. t.Fatalf("expected chat.completion, got %s", chatResp.Object)
  72. }
  73. if chatResp.Choices[0].Message.Content != "Hello!" {
  74. t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
  75. }
  76. },
  77. },
  78. {
  79. Name: "completions handler",
  80. Method: http.MethodPost,
  81. Path: "/api/generate",
  82. TestPath: "/api/generate",
  83. Handler: CompletionsMiddleware,
  84. Endpoint: func(c *gin.Context) {
  85. c.JSON(http.StatusOK, api.GenerateResponse{
  86. Response: "Hello!",
  87. })
  88. },
  89. Setup: func(t *testing.T, req *http.Request) {
  90. body := CompletionRequest{
  91. Model: "test-model",
  92. Prompt: "Hello",
  93. }
  94. bodyBytes, _ := json.Marshal(body)
  95. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  96. req.Header.Set("Content-Type", "application/json")
  97. },
  98. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  99. assert.Equal(t, http.StatusOK, resp.Code)
  100. var completionResp Completion
  101. if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
  102. t.Fatal(err)
  103. }
  104. if completionResp.Object != "text_completion" {
  105. t.Fatalf("expected text_completion, got %s", completionResp.Object)
  106. }
  107. if completionResp.Choices[0].Text != "Hello!" {
  108. t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
  109. }
  110. },
  111. },
  112. {
  113. Name: "completions handler with params",
  114. Method: http.MethodPost,
  115. Path: "/api/generate",
  116. TestPath: "/api/generate",
  117. Handler: CompletionsMiddleware,
  118. Endpoint: func(c *gin.Context) {
  119. var generateReq api.GenerateRequest
  120. if err := c.ShouldBindJSON(&generateReq); err != nil {
  121. c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
  122. return
  123. }
  124. temperature := generateReq.Options["temperature"].(float64)
  125. var assistantMessage string
  126. switch temperature {
  127. case 1.6:
  128. assistantMessage = "Received temperature of 1.6"
  129. default:
  130. assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
  131. }
  132. c.JSON(http.StatusOK, api.GenerateResponse{
  133. Response: assistantMessage,
  134. })
  135. },
  136. Setup: func(t *testing.T, req *http.Request) {
  137. temp := float32(0.8)
  138. body := CompletionRequest{
  139. Model: "test-model",
  140. Prompt: "Hello",
  141. Temperature: &temp,
  142. }
  143. bodyBytes, _ := json.Marshal(body)
  144. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  145. req.Header.Set("Content-Type", "application/json")
  146. },
  147. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  148. assert.Equal(t, http.StatusOK, resp.Code)
  149. var completionResp Completion
  150. if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
  151. t.Fatal(err)
  152. }
  153. if completionResp.Object != "text_completion" {
  154. t.Fatalf("expected text_completion, got %s", completionResp.Object)
  155. }
  156. if completionResp.Choices[0].Text != "Received temperature of 1.6" {
  157. t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
  158. }
  159. },
  160. },
  161. {
  162. Name: "completions handler with error",
  163. Method: http.MethodPost,
  164. Path: "/api/generate",
  165. TestPath: "/api/generate",
  166. Handler: CompletionsMiddleware,
  167. Endpoint: func(c *gin.Context) {
  168. c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
  169. },
  170. Setup: func(t *testing.T, req *http.Request) {
  171. body := CompletionRequest{
  172. Model: "test-model",
  173. Prompt: "Hello",
  174. }
  175. bodyBytes, _ := json.Marshal(body)
  176. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  177. req.Header.Set("Content-Type", "application/json")
  178. },
  179. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  180. if resp.Code != http.StatusBadRequest {
  181. t.Fatalf("expected 400, got %d", resp.Code)
  182. }
  183. if !strings.Contains(resp.Body.String(), `"invalid request"`) {
  184. t.Fatalf("error was not forwarded")
  185. }
  186. },
  187. },
  188. {
  189. Name: "list handler",
  190. Method: http.MethodGet,
  191. Path: "/api/tags",
  192. TestPath: "/api/tags",
  193. Handler: ListMiddleware,
  194. Endpoint: func(c *gin.Context) {
  195. c.JSON(http.StatusOK, api.ListResponse{
  196. Models: []api.ListModelResponse{
  197. {
  198. Name: "Test Model",
  199. },
  200. },
  201. })
  202. },
  203. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  204. assert.Equal(t, http.StatusOK, resp.Code)
  205. var listResp ListCompletion
  206. if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
  207. t.Fatal(err)
  208. }
  209. if listResp.Object != "list" {
  210. t.Fatalf("expected list, got %s", listResp.Object)
  211. }
  212. if len(listResp.Data) != 1 {
  213. t.Fatalf("expected 1, got %d", len(listResp.Data))
  214. }
  215. if listResp.Data[0].Id != "Test Model" {
  216. t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
  217. }
  218. },
  219. },
  220. {
  221. Name: "retrieve model",
  222. Method: http.MethodGet,
  223. Path: "/api/show/:model",
  224. TestPath: "/api/show/test-model",
  225. Handler: RetrieveMiddleware,
  226. Endpoint: func(c *gin.Context) {
  227. c.JSON(http.StatusOK, api.ShowResponse{
  228. ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
  229. })
  230. },
  231. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  232. var retrieveResp Model
  233. if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
  234. t.Fatal(err)
  235. }
  236. if retrieveResp.Object != "model" {
  237. t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
  238. }
  239. if retrieveResp.Id != "test-model" {
  240. t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
  241. }
  242. },
  243. },
  244. }
  245. gin.SetMode(gin.TestMode)
  246. router := gin.New()
  247. for _, tc := range testCases {
  248. t.Run(tc.Name, func(t *testing.T) {
  249. router = gin.New()
  250. router.Use(tc.Handler())
  251. router.Handle(tc.Method, tc.Path, tc.Endpoint)
  252. req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
  253. if tc.Setup != nil {
  254. tc.Setup(t, req)
  255. }
  256. resp := httptest.NewRecorder()
  257. router.ServeHTTP(resp, req)
  258. tc.Expected(t, resp)
  259. })
  260. }
  261. }