openai_test.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. package openai
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/json"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/ollama/ollama/api"
  14. "github.com/stretchr/testify/assert"
  15. )
  16. const prefix = `data:image/jpeg;base64,`
  17. const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
  18. const imageURL = prefix + image
  19. func TestMiddlewareRequests(t *testing.T) {
  20. type testCase struct {
  21. Name string
  22. Method string
  23. Path string
  24. Handler func() gin.HandlerFunc
  25. Setup func(t *testing.T, req *http.Request)
  26. Expected func(t *testing.T, req *http.Request)
  27. }
  28. var capturedRequest *http.Request
  29. captureRequestMiddleware := func() gin.HandlerFunc {
  30. return func(c *gin.Context) {
  31. bodyBytes, _ := io.ReadAll(c.Request.Body)
  32. c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  33. capturedRequest = c.Request
  34. c.Next()
  35. }
  36. }
  37. testCases := []testCase{
  38. {
  39. Name: "chat handler",
  40. Method: http.MethodPost,
  41. Path: "/api/chat",
  42. Handler: ChatMiddleware,
  43. Setup: func(t *testing.T, req *http.Request) {
  44. body := ChatCompletionRequest{
  45. Model: "test-model",
  46. Messages: []Message{{Role: "user", Content: "Hello"}},
  47. }
  48. bodyBytes, _ := json.Marshal(body)
  49. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  50. req.Header.Set("Content-Type", "application/json")
  51. },
  52. Expected: func(t *testing.T, req *http.Request) {
  53. var chatReq api.ChatRequest
  54. if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
  55. t.Fatal(err)
  56. }
  57. if chatReq.Messages[0].Role != "user" {
  58. t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
  59. }
  60. if chatReq.Messages[0].Content != "Hello" {
  61. t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
  62. }
  63. },
  64. },
  65. {
  66. Name: "completions handler",
  67. Method: http.MethodPost,
  68. Path: "/api/generate",
  69. Handler: CompletionsMiddleware,
  70. Setup: func(t *testing.T, req *http.Request) {
  71. temp := float32(0.8)
  72. body := CompletionRequest{
  73. Model: "test-model",
  74. Prompt: "Hello",
  75. Temperature: &temp,
  76. Stop: []string{"\n", "stop"},
  77. }
  78. bodyBytes, _ := json.Marshal(body)
  79. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  80. req.Header.Set("Content-Type", "application/json")
  81. },
  82. Expected: func(t *testing.T, req *http.Request) {
  83. var genReq api.GenerateRequest
  84. if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
  85. t.Fatal(err)
  86. }
  87. if genReq.Prompt != "Hello" {
  88. t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
  89. }
  90. if genReq.Options["temperature"] != 1.6 {
  91. t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
  92. }
  93. stopTokens, ok := genReq.Options["stop"].([]any)
  94. if !ok {
  95. t.Fatalf("expected stop tokens to be a list")
  96. }
  97. if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
  98. t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
  99. }
  100. },
  101. },
  102. {
  103. Name: "chat handler with image content",
  104. Method: http.MethodPost,
  105. Path: "/api/chat",
  106. Handler: ChatMiddleware,
  107. Setup: func(t *testing.T, req *http.Request) {
  108. body := ChatCompletionRequest{
  109. Model: "test-model",
  110. Messages: []Message{
  111. {
  112. Role: "user", Content: []map[string]any{
  113. {"type": "text", "text": "Hello"},
  114. {"type": "image_url", "image_url": map[string]string{"url": imageURL}},
  115. },
  116. },
  117. },
  118. }
  119. bodyBytes, _ := json.Marshal(body)
  120. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  121. req.Header.Set("Content-Type", "application/json")
  122. },
  123. Expected: func(t *testing.T, req *http.Request) {
  124. var chatReq api.ChatRequest
  125. if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
  126. t.Fatal(err)
  127. }
  128. if chatReq.Messages[0].Role != "user" {
  129. t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
  130. }
  131. if chatReq.Messages[0].Content != "Hello" {
  132. t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
  133. }
  134. img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
  135. if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
  136. t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
  137. }
  138. },
  139. },
  140. {
  141. Name: "embed handler single input",
  142. Method: http.MethodPost,
  143. Path: "/api/embed",
  144. Handler: EmbeddingsMiddleware,
  145. Setup: func(t *testing.T, req *http.Request) {
  146. body := EmbedRequest{
  147. Input: "Hello",
  148. Model: "test-model",
  149. }
  150. bodyBytes, _ := json.Marshal(body)
  151. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  152. req.Header.Set("Content-Type", "application/json")
  153. },
  154. Expected: func(t *testing.T, req *http.Request) {
  155. var embedReq api.EmbedRequest
  156. if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
  157. t.Fatal(err)
  158. }
  159. if embedReq.Input != "Hello" {
  160. t.Fatalf("expected 'Hello', got %s", embedReq.Input)
  161. }
  162. if embedReq.Model != "test-model" {
  163. t.Fatalf("expected 'test-model', got %s", embedReq.Model)
  164. }
  165. },
  166. },
  167. {
  168. Name: "embed handler batch input",
  169. Method: http.MethodPost,
  170. Path: "/api/embed",
  171. Handler: EmbeddingsMiddleware,
  172. Setup: func(t *testing.T, req *http.Request) {
  173. body := EmbedRequest{
  174. Input: []string{"Hello", "World"},
  175. Model: "test-model",
  176. }
  177. bodyBytes, _ := json.Marshal(body)
  178. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  179. req.Header.Set("Content-Type", "application/json")
  180. },
  181. Expected: func(t *testing.T, req *http.Request) {
  182. var embedReq api.EmbedRequest
  183. if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
  184. t.Fatal(err)
  185. }
  186. input, ok := embedReq.Input.([]any)
  187. if !ok {
  188. t.Fatalf("expected input to be a list")
  189. }
  190. if input[0].(string) != "Hello" {
  191. t.Fatalf("expected 'Hello', got %s", input[0])
  192. }
  193. if input[1].(string) != "World" {
  194. t.Fatalf("expected 'World', got %s", input[1])
  195. }
  196. if embedReq.Model != "test-model" {
  197. t.Fatalf("expected 'test-model', got %s", embedReq.Model)
  198. }
  199. },
  200. },
  201. }
  202. gin.SetMode(gin.TestMode)
  203. router := gin.New()
  204. endpoint := func(c *gin.Context) {
  205. c.Status(http.StatusOK)
  206. }
  207. for _, tc := range testCases {
  208. t.Run(tc.Name, func(t *testing.T) {
  209. router = gin.New()
  210. router.Use(captureRequestMiddleware())
  211. router.Use(tc.Handler())
  212. router.Handle(tc.Method, tc.Path, endpoint)
  213. req, _ := http.NewRequest(tc.Method, tc.Path, nil)
  214. if tc.Setup != nil {
  215. tc.Setup(t, req)
  216. }
  217. resp := httptest.NewRecorder()
  218. router.ServeHTTP(resp, req)
  219. tc.Expected(t, capturedRequest)
  220. })
  221. }
  222. }
  223. func TestMiddlewareResponses(t *testing.T) {
  224. type testCase struct {
  225. Name string
  226. Method string
  227. Path string
  228. TestPath string
  229. Handler func() gin.HandlerFunc
  230. Endpoint func(c *gin.Context)
  231. Setup func(t *testing.T, req *http.Request)
  232. Expected func(t *testing.T, resp *httptest.ResponseRecorder)
  233. }
  234. testCases := []testCase{
  235. {
  236. Name: "completions handler error forwarding",
  237. Method: http.MethodPost,
  238. Path: "/api/generate",
  239. TestPath: "/api/generate",
  240. Handler: CompletionsMiddleware,
  241. Endpoint: func(c *gin.Context) {
  242. c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
  243. },
  244. Setup: func(t *testing.T, req *http.Request) {
  245. body := CompletionRequest{
  246. Model: "test-model",
  247. Prompt: "Hello",
  248. }
  249. bodyBytes, _ := json.Marshal(body)
  250. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  251. req.Header.Set("Content-Type", "application/json")
  252. },
  253. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  254. if resp.Code != http.StatusBadRequest {
  255. t.Fatalf("expected 400, got %d", resp.Code)
  256. }
  257. if !strings.Contains(resp.Body.String(), `"invalid request"`) {
  258. t.Fatalf("error was not forwarded")
  259. }
  260. },
  261. },
  262. {
  263. Name: "list handler",
  264. Method: http.MethodGet,
  265. Path: "/api/tags",
  266. TestPath: "/api/tags",
  267. Handler: ListMiddleware,
  268. Endpoint: func(c *gin.Context) {
  269. c.JSON(http.StatusOK, api.ListResponse{
  270. Models: []api.ListModelResponse{
  271. {
  272. Name: "Test Model",
  273. },
  274. },
  275. })
  276. },
  277. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  278. assert.Equal(t, http.StatusOK, resp.Code)
  279. var listResp ListCompletion
  280. if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
  281. t.Fatal(err)
  282. }
  283. if listResp.Object != "list" {
  284. t.Fatalf("expected list, got %s", listResp.Object)
  285. }
  286. if len(listResp.Data) != 1 {
  287. t.Fatalf("expected 1, got %d", len(listResp.Data))
  288. }
  289. if listResp.Data[0].Id != "Test Model" {
  290. t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
  291. }
  292. },
  293. },
  294. {
  295. Name: "retrieve model",
  296. Method: http.MethodGet,
  297. Path: "/api/show/:model",
  298. TestPath: "/api/show/test-model",
  299. Handler: RetrieveMiddleware,
  300. Endpoint: func(c *gin.Context) {
  301. c.JSON(http.StatusOK, api.ShowResponse{
  302. ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
  303. })
  304. },
  305. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  306. var retrieveResp Model
  307. if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
  308. t.Fatal(err)
  309. }
  310. if retrieveResp.Object != "model" {
  311. t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
  312. }
  313. if retrieveResp.Id != "test-model" {
  314. t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
  315. }
  316. },
  317. },
  318. }
  319. gin.SetMode(gin.TestMode)
  320. router := gin.New()
  321. for _, tc := range testCases {
  322. t.Run(tc.Name, func(t *testing.T) {
  323. router = gin.New()
  324. router.Use(tc.Handler())
  325. router.Handle(tc.Method, tc.Path, tc.Endpoint)
  326. req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
  327. if tc.Setup != nil {
  328. tc.Setup(t, req)
  329. }
  330. resp := httptest.NewRecorder()
  331. router.ServeHTTP(resp, req)
  332. tc.Expected(t, resp)
  333. })
  334. }
  335. }