openai_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. package openai
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/json"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/ollama/ollama/api"
  14. "github.com/stretchr/testify/assert"
  15. )
  16. const prefix = `data:image/jpeg;base64,`
  17. const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
  18. const imageURL = prefix + image
  19. func prepareRequest(req *http.Request, body any) {
  20. bodyBytes, _ := json.Marshal(body)
  21. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  22. req.Header.Set("Content-Type", "application/json")
  23. }
  24. func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
  25. return func(c *gin.Context) {
  26. bodyBytes, _ := io.ReadAll(c.Request.Body)
  27. c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  28. err := json.Unmarshal(bodyBytes, capturedRequest)
  29. if err != nil {
  30. c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
  31. }
  32. c.Next()
  33. }
  34. }
  35. func TestChatMiddleware(t *testing.T) {
  36. type testCase struct {
  37. Name string
  38. Setup func(t *testing.T, req *http.Request)
  39. Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
  40. }
  41. var capturedRequest *api.ChatRequest
  42. testCases := []testCase{
  43. {
  44. Name: "chat handler",
  45. Setup: func(t *testing.T, req *http.Request) {
  46. body := ChatCompletionRequest{
  47. Model: "test-model",
  48. Messages: []Message{{Role: "user", Content: "Hello"}},
  49. }
  50. prepareRequest(req, body)
  51. },
  52. Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
  53. if resp.Code != http.StatusOK {
  54. t.Fatalf("expected 200, got %d", resp.Code)
  55. }
  56. if req.Messages[0].Role != "user" {
  57. t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
  58. }
  59. if req.Messages[0].Content != "Hello" {
  60. t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
  61. }
  62. },
  63. },
  64. {
  65. Name: "chat handler with image content",
  66. Setup: func(t *testing.T, req *http.Request) {
  67. body := ChatCompletionRequest{
  68. Model: "test-model",
  69. Messages: []Message{
  70. {
  71. Role: "user", Content: []map[string]any{
  72. {"type": "text", "text": "Hello"},
  73. {"type": "image_url", "image_url": map[string]string{"url": imageURL}},
  74. },
  75. },
  76. },
  77. }
  78. prepareRequest(req, body)
  79. },
  80. Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
  81. if resp.Code != http.StatusOK {
  82. t.Fatalf("expected 200, got %d", resp.Code)
  83. }
  84. if req.Messages[0].Role != "user" {
  85. t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
  86. }
  87. if req.Messages[0].Content != "Hello" {
  88. t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
  89. }
  90. img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
  91. if req.Messages[1].Role != "user" {
  92. t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
  93. }
  94. if !bytes.Equal(req.Messages[1].Images[0], img) {
  95. t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
  96. }
  97. },
  98. },
  99. {
  100. Name: "chat handler with tools",
  101. Setup: func(t *testing.T, req *http.Request) {
  102. body := ChatCompletionRequest{
  103. Model: "test-model",
  104. Messages: []Message{
  105. {Role: "user", Content: "What's the weather like in Paris Today?"},
  106. {Role: "assistant", ToolCalls: []ToolCall{{
  107. ID: "id",
  108. Type: "function",
  109. Function: struct {
  110. Name string `json:"name"`
  111. Arguments string `json:"arguments"`
  112. }{
  113. Name: "get_current_weather",
  114. Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
  115. },
  116. }}},
  117. },
  118. }
  119. prepareRequest(req, body)
  120. },
  121. Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
  122. if resp.Code != 200 {
  123. t.Fatalf("expected 200, got %d", resp.Code)
  124. }
  125. if req.Messages[0].Content != "What's the weather like in Paris Today?" {
  126. t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
  127. }
  128. if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
  129. t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
  130. }
  131. if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
  132. t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
  133. }
  134. },
  135. },
  136. {
  137. Name: "chat handler error forwarding",
  138. Setup: func(t *testing.T, req *http.Request) {
  139. body := ChatCompletionRequest{
  140. Model: "test-model",
  141. Messages: []Message{{Role: "user", Content: 2}},
  142. }
  143. prepareRequest(req, body)
  144. },
  145. Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
  146. if resp.Code != http.StatusBadRequest {
  147. t.Fatalf("expected 400, got %d", resp.Code)
  148. }
  149. if !strings.Contains(resp.Body.String(), "invalid message content type") {
  150. t.Fatalf("error was not forwarded")
  151. }
  152. },
  153. },
  154. }
  155. endpoint := func(c *gin.Context) {
  156. c.Status(http.StatusOK)
  157. }
  158. gin.SetMode(gin.TestMode)
  159. router := gin.New()
  160. router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
  161. router.Handle(http.MethodPost, "/api/chat", endpoint)
  162. for _, tc := range testCases {
  163. t.Run(tc.Name, func(t *testing.T) {
  164. req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
  165. tc.Setup(t, req)
  166. resp := httptest.NewRecorder()
  167. router.ServeHTTP(resp, req)
  168. tc.Expected(t, capturedRequest, resp)
  169. capturedRequest = nil
  170. })
  171. }
  172. }
  173. func TestCompletionsMiddleware(t *testing.T) {
  174. type testCase struct {
  175. Name string
  176. Setup func(t *testing.T, req *http.Request)
  177. Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
  178. }
  179. var capturedRequest *api.GenerateRequest
  180. testCases := []testCase{
  181. {
  182. Name: "completions handler",
  183. Setup: func(t *testing.T, req *http.Request) {
  184. temp := float32(0.8)
  185. body := CompletionRequest{
  186. Model: "test-model",
  187. Prompt: "Hello",
  188. Temperature: &temp,
  189. Stop: []string{"\n", "stop"},
  190. Suffix: "suffix",
  191. }
  192. prepareRequest(req, body)
  193. },
  194. Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
  195. if req.Prompt != "Hello" {
  196. t.Fatalf("expected 'Hello', got %s", req.Prompt)
  197. }
  198. if req.Options["temperature"] != 1.6 {
  199. t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
  200. }
  201. stopTokens, ok := req.Options["stop"].([]any)
  202. if !ok {
  203. t.Fatalf("expected stop tokens to be a list")
  204. }
  205. if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
  206. t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
  207. }
  208. if req.Suffix != "suffix" {
  209. t.Fatalf("expected 'suffix', got %s", req.Suffix)
  210. }
  211. },
  212. },
  213. {
  214. Name: "completions handler error forwarding",
  215. Setup: func(t *testing.T, req *http.Request) {
  216. body := CompletionRequest{
  217. Model: "test-model",
  218. Prompt: "Hello",
  219. Temperature: nil,
  220. Stop: []int{1, 2},
  221. Suffix: "suffix",
  222. }
  223. prepareRequest(req, body)
  224. },
  225. Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
  226. if resp.Code != http.StatusBadRequest {
  227. t.Fatalf("expected 400, got %d", resp.Code)
  228. }
  229. if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
  230. t.Fatalf("error was not forwarded")
  231. }
  232. },
  233. },
  234. }
  235. endpoint := func(c *gin.Context) {
  236. c.Status(http.StatusOK)
  237. }
  238. gin.SetMode(gin.TestMode)
  239. router := gin.New()
  240. router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
  241. router.Handle(http.MethodPost, "/api/generate", endpoint)
  242. for _, tc := range testCases {
  243. t.Run(tc.Name, func(t *testing.T) {
  244. req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
  245. tc.Setup(t, req)
  246. resp := httptest.NewRecorder()
  247. router.ServeHTTP(resp, req)
  248. tc.Expected(t, capturedRequest, resp)
  249. capturedRequest = nil
  250. })
  251. }
  252. }
  253. func TestEmbeddingsMiddleware(t *testing.T) {
  254. type testCase struct {
  255. Name string
  256. Setup func(t *testing.T, req *http.Request)
  257. Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
  258. }
  259. var capturedRequest *api.EmbedRequest
  260. testCases := []testCase{
  261. {
  262. Name: "embed handler single input",
  263. Setup: func(t *testing.T, req *http.Request) {
  264. body := EmbedRequest{
  265. Input: "Hello",
  266. Model: "test-model",
  267. }
  268. prepareRequest(req, body)
  269. },
  270. Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
  271. if req.Input != "Hello" {
  272. t.Fatalf("expected 'Hello', got %s", req.Input)
  273. }
  274. if req.Model != "test-model" {
  275. t.Fatalf("expected 'test-model', got %s", req.Model)
  276. }
  277. },
  278. },
  279. {
  280. Name: "embed handler batch input",
  281. Setup: func(t *testing.T, req *http.Request) {
  282. body := EmbedRequest{
  283. Input: []string{"Hello", "World"},
  284. Model: "test-model",
  285. }
  286. prepareRequest(req, body)
  287. },
  288. Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
  289. input, ok := req.Input.([]any)
  290. if !ok {
  291. t.Fatalf("expected input to be a list")
  292. }
  293. if input[0].(string) != "Hello" {
  294. t.Fatalf("expected 'Hello', got %s", input[0])
  295. }
  296. if input[1].(string) != "World" {
  297. t.Fatalf("expected 'World', got %s", input[1])
  298. }
  299. if req.Model != "test-model" {
  300. t.Fatalf("expected 'test-model', got %s", req.Model)
  301. }
  302. },
  303. },
  304. {
  305. Name: "embed handler error forwarding",
  306. Setup: func(t *testing.T, req *http.Request) {
  307. body := EmbedRequest{
  308. Model: "test-model",
  309. }
  310. prepareRequest(req, body)
  311. },
  312. Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
  313. if resp.Code != http.StatusBadRequest {
  314. t.Fatalf("expected 400, got %d", resp.Code)
  315. }
  316. if !strings.Contains(resp.Body.String(), "invalid input") {
  317. t.Fatalf("error was not forwarded")
  318. }
  319. },
  320. },
  321. }
  322. endpoint := func(c *gin.Context) {
  323. c.Status(http.StatusOK)
  324. }
  325. gin.SetMode(gin.TestMode)
  326. router := gin.New()
  327. router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
  328. router.Handle(http.MethodPost, "/api/embed", endpoint)
  329. for _, tc := range testCases {
  330. t.Run(tc.Name, func(t *testing.T) {
  331. req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
  332. tc.Setup(t, req)
  333. resp := httptest.NewRecorder()
  334. router.ServeHTTP(resp, req)
  335. tc.Expected(t, capturedRequest, resp)
  336. capturedRequest = nil
  337. })
  338. }
  339. }
  340. func TestMiddlewareResponses(t *testing.T) {
  341. type testCase struct {
  342. Name string
  343. Method string
  344. Path string
  345. TestPath string
  346. Handler func() gin.HandlerFunc
  347. Endpoint func(c *gin.Context)
  348. Setup func(t *testing.T, req *http.Request)
  349. Expected func(t *testing.T, resp *httptest.ResponseRecorder)
  350. }
  351. testCases := []testCase{
  352. {
  353. Name: "list handler",
  354. Method: http.MethodGet,
  355. Path: "/api/tags",
  356. TestPath: "/api/tags",
  357. Handler: ListMiddleware,
  358. Endpoint: func(c *gin.Context) {
  359. c.JSON(http.StatusOK, api.ListResponse{
  360. Models: []api.ListModelResponse{
  361. {
  362. Name: "Test Model",
  363. },
  364. },
  365. })
  366. },
  367. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  368. var listResp ListCompletion
  369. if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
  370. t.Fatal(err)
  371. }
  372. if listResp.Object != "list" {
  373. t.Fatalf("expected list, got %s", listResp.Object)
  374. }
  375. if len(listResp.Data) != 1 {
  376. t.Fatalf("expected 1, got %d", len(listResp.Data))
  377. }
  378. if listResp.Data[0].Id != "Test Model" {
  379. t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
  380. }
  381. },
  382. },
  383. {
  384. Name: "retrieve model",
  385. Method: http.MethodGet,
  386. Path: "/api/show/:model",
  387. TestPath: "/api/show/test-model",
  388. Handler: RetrieveMiddleware,
  389. Endpoint: func(c *gin.Context) {
  390. c.JSON(http.StatusOK, api.ShowResponse{
  391. ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
  392. })
  393. },
  394. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  395. var retrieveResp Model
  396. if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
  397. t.Fatal(err)
  398. }
  399. if retrieveResp.Object != "model" {
  400. t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
  401. }
  402. if retrieveResp.Id != "test-model" {
  403. t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
  404. }
  405. },
  406. },
  407. }
  408. gin.SetMode(gin.TestMode)
  409. router := gin.New()
  410. for _, tc := range testCases {
  411. t.Run(tc.Name, func(t *testing.T) {
  412. router = gin.New()
  413. router.Use(tc.Handler())
  414. router.Handle(tc.Method, tc.Path, tc.Endpoint)
  415. req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
  416. if tc.Setup != nil {
  417. tc.Setup(t, req)
  418. }
  419. resp := httptest.NewRecorder()
  420. router.ServeHTTP(resp, req)
  421. assert.Equal(t, http.StatusOK, resp.Code)
  422. tc.Expected(t, resp)
  423. })
  424. }
  425. }