openai_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649
  1. package openai
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/json"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "reflect"
  10. "strings"
  11. "testing"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. "github.com/ollama/ollama/api"
  15. )
  16. const (
  17. prefix = `data:image/jpeg;base64,`
  18. image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
  19. )
  20. var (
  21. False = false
  22. True = true
  23. )
  24. func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
  25. return func(c *gin.Context) {
  26. bodyBytes, _ := io.ReadAll(c.Request.Body)
  27. c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  28. err := json.Unmarshal(bodyBytes, capturedRequest)
  29. if err != nil {
  30. c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
  31. }
  32. c.Next()
  33. }
  34. }
  35. func TestChatMiddleware(t *testing.T) {
  36. type testCase struct {
  37. name string
  38. body string
  39. req api.ChatRequest
  40. err ErrorResponse
  41. }
  42. var capturedRequest *api.ChatRequest
  43. testCases := []testCase{
  44. {
  45. name: "chat handler",
  46. body: `{
  47. "model": "test-model",
  48. "messages": [
  49. {"role": "user", "content": "Hello"}
  50. ]
  51. }`,
  52. req: api.ChatRequest{
  53. Model: "test-model",
  54. Messages: []api.Message{
  55. {
  56. Role: "user",
  57. Content: "Hello",
  58. },
  59. },
  60. Options: map[string]any{
  61. "temperature": 1.0,
  62. "top_p": 1.0,
  63. },
  64. Stream: &False,
  65. },
  66. },
  67. {
  68. name: "chat handler with options",
  69. body: `{
  70. "model": "test-model",
  71. "messages": [
  72. {"role": "user", "content": "Hello"}
  73. ],
  74. "stream": true,
  75. "max_tokens": 999,
  76. "seed": 123,
  77. "stop": ["\n", "stop"],
  78. "temperature": 3.0,
  79. "frequency_penalty": 4.0,
  80. "presence_penalty": 5.0,
  81. "top_p": 6.0,
  82. "response_format": {"type": "json_object"}
  83. }`,
  84. req: api.ChatRequest{
  85. Model: "test-model",
  86. Messages: []api.Message{
  87. {
  88. Role: "user",
  89. Content: "Hello",
  90. },
  91. },
  92. Options: map[string]any{
  93. "num_predict": 999.0, // float because JSON doesn't distinguish between float and int
  94. "seed": 123.0,
  95. "stop": []any{"\n", "stop"},
  96. "temperature": 3.0,
  97. "frequency_penalty": 4.0,
  98. "presence_penalty": 5.0,
  99. "top_p": 6.0,
  100. },
  101. Format: "json",
  102. Stream: &True,
  103. },
  104. },
  105. {
  106. name: "chat handler with streaming usage",
  107. body: `{
  108. "model": "test-model",
  109. "messages": [
  110. {"role": "user", "content": "Hello"}
  111. ],
  112. "stream": true,
  113. "stream_options": {"include_usage": true},
  114. "max_tokens": 999,
  115. "seed": 123,
  116. "stop": ["\n", "stop"],
  117. "temperature": 3.0,
  118. "frequency_penalty": 4.0,
  119. "presence_penalty": 5.0,
  120. "top_p": 6.0,
  121. "response_format": {"type": "json_object"}
  122. }`,
  123. req: api.ChatRequest{
  124. Model: "test-model",
  125. Messages: []api.Message{
  126. {
  127. Role: "user",
  128. Content: "Hello",
  129. },
  130. },
  131. Options: map[string]any{
  132. "num_predict": 999.0, // float because JSON doesn't distinguish between float and int
  133. "seed": 123.0,
  134. "stop": []any{"\n", "stop"},
  135. "temperature": 3.0,
  136. "frequency_penalty": 4.0,
  137. "presence_penalty": 5.0,
  138. "top_p": 6.0,
  139. },
  140. Format: "json",
  141. Stream: &True,
  142. },
  143. },
  144. {
  145. name: "chat handler with image content",
  146. body: `{
  147. "model": "test-model",
  148. "messages": [
  149. {
  150. "role": "user",
  151. "content": [
  152. {
  153. "type": "text",
  154. "text": "Hello"
  155. },
  156. {
  157. "type": "image_url",
  158. "image_url": {
  159. "url": "` + prefix + image + `"
  160. }
  161. }
  162. ]
  163. }
  164. ]
  165. }`,
  166. req: api.ChatRequest{
  167. Model: "test-model",
  168. Messages: []api.Message{
  169. {
  170. Role: "user",
  171. Content: "Hello",
  172. },
  173. {
  174. Role: "user",
  175. Images: []api.ImageData{
  176. func() []byte {
  177. img, _ := base64.StdEncoding.DecodeString(image)
  178. return img
  179. }(),
  180. },
  181. },
  182. },
  183. Options: map[string]any{
  184. "temperature": 1.0,
  185. "top_p": 1.0,
  186. },
  187. Stream: &False,
  188. },
  189. },
  190. {
  191. name: "chat handler with tools",
  192. body: `{
  193. "model": "test-model",
  194. "messages": [
  195. {"role": "user", "content": "What's the weather like in Paris Today?"},
  196. {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
  197. ]
  198. }`,
  199. req: api.ChatRequest{
  200. Model: "test-model",
  201. Messages: []api.Message{
  202. {
  203. Role: "user",
  204. Content: "What's the weather like in Paris Today?",
  205. },
  206. {
  207. Role: "assistant",
  208. ToolCalls: []api.ToolCall{
  209. {
  210. Function: api.ToolCallFunction{
  211. Name: "get_current_weather",
  212. Arguments: map[string]interface{}{
  213. "location": "Paris, France",
  214. "format": "celsius",
  215. },
  216. },
  217. },
  218. },
  219. },
  220. },
  221. Options: map[string]any{
  222. "temperature": 1.0,
  223. "top_p": 1.0,
  224. },
  225. Stream: &False,
  226. },
  227. },
  228. {
  229. name: "chat handler error forwarding",
  230. body: `{
  231. "model": "test-model",
  232. "messages": [
  233. {"role": "user", "content": 2}
  234. ]
  235. }`,
  236. err: ErrorResponse{
  237. Error: Error{
  238. Message: "invalid message content type: float64",
  239. Type: "invalid_request_error",
  240. },
  241. },
  242. },
  243. }
  244. endpoint := func(c *gin.Context) {
  245. c.Status(http.StatusOK)
  246. }
  247. gin.SetMode(gin.TestMode)
  248. router := gin.New()
  249. router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
  250. router.Handle(http.MethodPost, "/api/chat", endpoint)
  251. for _, tc := range testCases {
  252. t.Run(tc.name, func(t *testing.T) {
  253. req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
  254. req.Header.Set("Content-Type", "application/json")
  255. defer func() { capturedRequest = nil }()
  256. resp := httptest.NewRecorder()
  257. router.ServeHTTP(resp, req)
  258. var errResp ErrorResponse
  259. if resp.Code != http.StatusOK {
  260. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  261. t.Fatal(err)
  262. }
  263. }
  264. if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
  265. t.Fatal("requests did not match")
  266. }
  267. if !reflect.DeepEqual(tc.err, errResp) {
  268. t.Fatal("errors did not match")
  269. }
  270. })
  271. }
  272. }
  273. func TestCompletionsMiddleware(t *testing.T) {
  274. type testCase struct {
  275. name string
  276. body string
  277. req api.GenerateRequest
  278. err ErrorResponse
  279. }
  280. var capturedRequest *api.GenerateRequest
  281. testCases := []testCase{
  282. {
  283. name: "completions handler",
  284. body: `{
  285. "model": "test-model",
  286. "prompt": "Hello",
  287. "temperature": 0.8,
  288. "stop": ["\n", "stop"],
  289. "suffix": "suffix"
  290. }`,
  291. req: api.GenerateRequest{
  292. Model: "test-model",
  293. Prompt: "Hello",
  294. Options: map[string]any{
  295. "frequency_penalty": 0.0,
  296. "presence_penalty": 0.0,
  297. "temperature": 0.8,
  298. "top_p": 1.0,
  299. "stop": []any{"\n", "stop"},
  300. },
  301. Suffix: "suffix",
  302. Stream: &False,
  303. },
  304. },
  305. {
  306. name: "completions handler stream",
  307. body: `{
  308. "model": "test-model",
  309. "prompt": "Hello",
  310. "stream": true,
  311. "temperature": 0.8,
  312. "stop": ["\n", "stop"],
  313. "suffix": "suffix"
  314. }`,
  315. req: api.GenerateRequest{
  316. Model: "test-model",
  317. Prompt: "Hello",
  318. Options: map[string]any{
  319. "frequency_penalty": 0.0,
  320. "presence_penalty": 0.0,
  321. "temperature": 0.8,
  322. "top_p": 1.0,
  323. "stop": []any{"\n", "stop"},
  324. },
  325. Suffix: "suffix",
  326. Stream: &True,
  327. },
  328. },
  329. {
  330. name: "completions handler stream with usage",
  331. body: `{
  332. "model": "test-model",
  333. "prompt": "Hello",
  334. "stream": true,
  335. "stream_options": {"include_usage": true},
  336. "temperature": 0.8,
  337. "stop": ["\n", "stop"],
  338. "suffix": "suffix"
  339. }`,
  340. req: api.GenerateRequest{
  341. Model: "test-model",
  342. Prompt: "Hello",
  343. Options: map[string]any{
  344. "frequency_penalty": 0.0,
  345. "presence_penalty": 0.0,
  346. "temperature": 0.8,
  347. "top_p": 1.0,
  348. "stop": []any{"\n", "stop"},
  349. },
  350. Suffix: "suffix",
  351. Stream: &True,
  352. },
  353. },
  354. {
  355. name: "completions handler error forwarding",
  356. body: `{
  357. "model": "test-model",
  358. "prompt": "Hello",
  359. "temperature": null,
  360. "stop": [1, 2],
  361. "suffix": "suffix"
  362. }`,
  363. err: ErrorResponse{
  364. Error: Error{
  365. Message: "invalid type for 'stop' field: float64",
  366. Type: "invalid_request_error",
  367. },
  368. },
  369. },
  370. }
  371. endpoint := func(c *gin.Context) {
  372. c.Status(http.StatusOK)
  373. }
  374. gin.SetMode(gin.TestMode)
  375. router := gin.New()
  376. router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
  377. router.Handle(http.MethodPost, "/api/generate", endpoint)
  378. for _, tc := range testCases {
  379. t.Run(tc.name, func(t *testing.T) {
  380. req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
  381. req.Header.Set("Content-Type", "application/json")
  382. resp := httptest.NewRecorder()
  383. router.ServeHTTP(resp, req)
  384. var errResp ErrorResponse
  385. if resp.Code != http.StatusOK {
  386. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  387. t.Fatal(err)
  388. }
  389. }
  390. if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
  391. t.Fatal("requests did not match")
  392. }
  393. if !reflect.DeepEqual(tc.err, errResp) {
  394. t.Fatal("errors did not match")
  395. }
  396. capturedRequest = nil
  397. })
  398. }
  399. }
  400. func TestEmbeddingsMiddleware(t *testing.T) {
  401. type testCase struct {
  402. name string
  403. body string
  404. req api.EmbedRequest
  405. err ErrorResponse
  406. }
  407. var capturedRequest *api.EmbedRequest
  408. testCases := []testCase{
  409. {
  410. name: "embed handler single input",
  411. body: `{
  412. "input": "Hello",
  413. "model": "test-model"
  414. }`,
  415. req: api.EmbedRequest{
  416. Input: "Hello",
  417. Model: "test-model",
  418. },
  419. },
  420. {
  421. name: "embed handler batch input",
  422. body: `{
  423. "input": ["Hello", "World"],
  424. "model": "test-model"
  425. }`,
  426. req: api.EmbedRequest{
  427. Input: []any{"Hello", "World"},
  428. Model: "test-model",
  429. },
  430. },
  431. {
  432. name: "embed handler error forwarding",
  433. body: `{
  434. "model": "test-model"
  435. }`,
  436. err: ErrorResponse{
  437. Error: Error{
  438. Message: "invalid input",
  439. Type: "invalid_request_error",
  440. },
  441. },
  442. },
  443. }
  444. endpoint := func(c *gin.Context) {
  445. c.Status(http.StatusOK)
  446. }
  447. gin.SetMode(gin.TestMode)
  448. router := gin.New()
  449. router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
  450. router.Handle(http.MethodPost, "/api/embed", endpoint)
  451. for _, tc := range testCases {
  452. t.Run(tc.name, func(t *testing.T) {
  453. req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
  454. req.Header.Set("Content-Type", "application/json")
  455. resp := httptest.NewRecorder()
  456. router.ServeHTTP(resp, req)
  457. var errResp ErrorResponse
  458. if resp.Code != http.StatusOK {
  459. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  460. t.Fatal(err)
  461. }
  462. }
  463. if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
  464. t.Fatal("requests did not match")
  465. }
  466. if !reflect.DeepEqual(tc.err, errResp) {
  467. t.Fatal("errors did not match")
  468. }
  469. capturedRequest = nil
  470. })
  471. }
  472. }
  473. func TestListMiddleware(t *testing.T) {
  474. type testCase struct {
  475. name string
  476. endpoint func(c *gin.Context)
  477. resp string
  478. }
  479. testCases := []testCase{
  480. {
  481. name: "list handler",
  482. endpoint: func(c *gin.Context) {
  483. c.JSON(http.StatusOK, api.ListResponse{
  484. Models: []api.ListModelResponse{
  485. {
  486. Name: "test-model",
  487. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  488. },
  489. },
  490. })
  491. },
  492. resp: `{
  493. "object": "list",
  494. "data": [
  495. {
  496. "id": "test-model",
  497. "object": "model",
  498. "created": 1686935002,
  499. "owned_by": "library"
  500. }
  501. ]
  502. }`,
  503. },
  504. {
  505. name: "list handler empty output",
  506. endpoint: func(c *gin.Context) {
  507. c.JSON(http.StatusOK, api.ListResponse{})
  508. },
  509. resp: `{
  510. "object": "list",
  511. "data": null
  512. }`,
  513. },
  514. }
  515. gin.SetMode(gin.TestMode)
  516. for _, tc := range testCases {
  517. router := gin.New()
  518. router.Use(ListMiddleware())
  519. router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
  520. req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
  521. resp := httptest.NewRecorder()
  522. router.ServeHTTP(resp, req)
  523. var expected, actual map[string]any
  524. err := json.Unmarshal([]byte(tc.resp), &expected)
  525. if err != nil {
  526. t.Fatalf("failed to unmarshal expected response: %v", err)
  527. }
  528. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  529. if err != nil {
  530. t.Fatalf("failed to unmarshal actual response: %v", err)
  531. }
  532. if !reflect.DeepEqual(expected, actual) {
  533. t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
  534. }
  535. }
  536. }
  537. func TestRetrieveMiddleware(t *testing.T) {
  538. type testCase struct {
  539. name string
  540. endpoint func(c *gin.Context)
  541. resp string
  542. }
  543. testCases := []testCase{
  544. {
  545. name: "retrieve handler",
  546. endpoint: func(c *gin.Context) {
  547. c.JSON(http.StatusOK, api.ShowResponse{
  548. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  549. })
  550. },
  551. resp: `{
  552. "id":"test-model",
  553. "object":"model",
  554. "created":1686935002,
  555. "owned_by":"library"}
  556. `,
  557. },
  558. {
  559. name: "retrieve handler error forwarding",
  560. endpoint: func(c *gin.Context) {
  561. c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
  562. },
  563. resp: `{
  564. "error": {
  565. "code": null,
  566. "message": "model not found",
  567. "param": null,
  568. "type": "api_error"
  569. }
  570. }`,
  571. },
  572. }
  573. gin.SetMode(gin.TestMode)
  574. for _, tc := range testCases {
  575. router := gin.New()
  576. router.Use(RetrieveMiddleware())
  577. router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
  578. req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
  579. resp := httptest.NewRecorder()
  580. router.ServeHTTP(resp, req)
  581. var expected, actual map[string]any
  582. err := json.Unmarshal([]byte(tc.resp), &expected)
  583. if err != nil {
  584. t.Fatalf("failed to unmarshal expected response: %v", err)
  585. }
  586. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  587. if err != nil {
  588. t.Fatalf("failed to unmarshal actual response: %v", err)
  589. }
  590. if !reflect.DeepEqual(expected, actual) {
  591. t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
  592. }
  593. }
  594. }