openai_test.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768
  1. package openai
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/json"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/google/go-cmp/cmp"
  14. "github.com/ollama/ollama/api"
  15. )
  16. const (
  17. prefix = `data:image/jpeg;base64,`
  18. image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
  19. )
  20. var (
  21. False = false
  22. True = true
  23. )
  24. func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
  25. return func(c *gin.Context) {
  26. bodyBytes, _ := io.ReadAll(c.Request.Body)
  27. c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  28. err := json.Unmarshal(bodyBytes, capturedRequest)
  29. if err != nil {
  30. c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
  31. }
  32. c.Next()
  33. }
  34. }
  35. func TestChatMiddleware(t *testing.T) {
  36. type testCase struct {
  37. name string
  38. body string
  39. req api.ChatRequest
  40. err ErrorResponse
  41. }
  42. var capturedRequest *api.ChatRequest
  43. testCases := []testCase{
  44. {
  45. name: "chat handler",
  46. body: `{
  47. "model": "test-model",
  48. "messages": [
  49. {"role": "user", "content": "Hello"}
  50. ]
  51. }`,
  52. req: api.ChatRequest{
  53. Model: "test-model",
  54. Messages: []api.Message{
  55. {
  56. Role: "user",
  57. Content: "Hello",
  58. },
  59. },
  60. Options: map[string]any{
  61. "temperature": 1.0,
  62. "top_p": 1.0,
  63. },
  64. Stream: &False,
  65. },
  66. },
  67. {
  68. name: "chat handler with options",
  69. body: `{
  70. "model": "test-model",
  71. "messages": [
  72. {"role": "user", "content": "Hello"}
  73. ],
  74. "stream": true,
  75. "max_tokens": 999,
  76. "seed": 123,
  77. "stop": ["\n", "stop"],
  78. "temperature": 3.0,
  79. "frequency_penalty": 4.0,
  80. "presence_penalty": 5.0,
  81. "top_p": 6.0,
  82. "response_format": {"type": "json_object"}
  83. }`,
  84. req: api.ChatRequest{
  85. Model: "test-model",
  86. Messages: []api.Message{
  87. {
  88. Role: "user",
  89. Content: "Hello",
  90. },
  91. },
  92. Options: map[string]any{
  93. "num_predict": 999.0, // float because JSON doesn't distinguish between float and int
  94. "seed": 123.0,
  95. "stop": []any{"\n", "stop"},
  96. "temperature": 3.0,
  97. "frequency_penalty": 4.0,
  98. "presence_penalty": 5.0,
  99. "top_p": 6.0,
  100. },
  101. Format: json.RawMessage(`"json"`),
  102. Stream: &True,
  103. },
  104. },
  105. {
  106. name: "chat handler with streaming usage",
  107. body: `{
  108. "model": "test-model",
  109. "messages": [
  110. {"role": "user", "content": "Hello"}
  111. ],
  112. "stream": true,
  113. "stream_options": {"include_usage": true},
  114. "max_tokens": 999,
  115. "seed": 123,
  116. "stop": ["\n", "stop"],
  117. "temperature": 3.0,
  118. "frequency_penalty": 4.0,
  119. "presence_penalty": 5.0,
  120. "top_p": 6.0,
  121. "response_format": {"type": "json_object"}
  122. }`,
  123. req: api.ChatRequest{
  124. Model: "test-model",
  125. Messages: []api.Message{
  126. {
  127. Role: "user",
  128. Content: "Hello",
  129. },
  130. },
  131. Options: map[string]any{
  132. "num_predict": 999.0, // float because JSON doesn't distinguish between float and int
  133. "seed": 123.0,
  134. "stop": []any{"\n", "stop"},
  135. "temperature": 3.0,
  136. "frequency_penalty": 4.0,
  137. "presence_penalty": 5.0,
  138. "top_p": 6.0,
  139. },
  140. Format: json.RawMessage(`"json"`),
  141. Stream: &True,
  142. },
  143. },
  144. {
  145. name: "chat handler with image content",
  146. body: `{
  147. "model": "test-model",
  148. "messages": [
  149. {
  150. "role": "user",
  151. "content": [
  152. {
  153. "type": "text",
  154. "text": "Hello"
  155. },
  156. {
  157. "type": "image_url",
  158. "image_url": {
  159. "url": "` + prefix + image + `"
  160. }
  161. }
  162. ]
  163. }
  164. ]
  165. }`,
  166. req: api.ChatRequest{
  167. Model: "test-model",
  168. Messages: []api.Message{
  169. {
  170. Role: "user",
  171. Content: "Hello",
  172. },
  173. {
  174. Role: "user",
  175. Images: []api.ImageData{
  176. func() []byte {
  177. img, _ := base64.StdEncoding.DecodeString(image)
  178. return img
  179. }(),
  180. },
  181. },
  182. },
  183. Options: map[string]any{
  184. "temperature": 1.0,
  185. "top_p": 1.0,
  186. },
  187. Stream: &False,
  188. },
  189. },
  190. {
  191. name: "chat handler with tools",
  192. body: `{
  193. "model": "test-model",
  194. "messages": [
  195. {"role": "user", "content": "What's the weather like in Paris Today?"},
  196. {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
  197. ]
  198. }`,
  199. req: api.ChatRequest{
  200. Model: "test-model",
  201. Messages: []api.Message{
  202. {
  203. Role: "user",
  204. Content: "What's the weather like in Paris Today?",
  205. },
  206. {
  207. Role: "assistant",
  208. ToolCalls: []api.ToolCall{
  209. {
  210. Function: api.ToolCallFunction{
  211. Name: "get_current_weather",
  212. Arguments: map[string]interface{}{
  213. "location": "Paris, France",
  214. "format": "celsius",
  215. },
  216. },
  217. },
  218. },
  219. },
  220. },
  221. Options: map[string]any{
  222. "temperature": 1.0,
  223. "top_p": 1.0,
  224. },
  225. Stream: &False,
  226. },
  227. },
  228. {
  229. name: "chat handler with streaming tools",
  230. body: `{
  231. "model": "test-model",
  232. "messages": [
  233. {"role": "user", "content": "What's the weather like in Paris?"}
  234. ],
  235. "stream": true,
  236. "tools": [{
  237. "type": "function",
  238. "function": {
  239. "name": "get_weather",
  240. "description": "Get the current weather",
  241. "parameters": {
  242. "type": "object",
  243. "required": ["location"],
  244. "properties": {
  245. "location": {
  246. "type": "string",
  247. "description": "The city and state"
  248. },
  249. "unit": {
  250. "type": "string",
  251. "enum": ["celsius", "fahrenheit"]
  252. }
  253. }
  254. }
  255. }
  256. }]
  257. }`,
  258. req: api.ChatRequest{
  259. Model: "test-model",
  260. Messages: []api.Message{
  261. {
  262. Role: "user",
  263. Content: "What's the weather like in Paris?",
  264. },
  265. },
  266. Tools: []api.Tool{
  267. {
  268. Type: "function",
  269. Function: api.ToolFunction{
  270. Name: "get_weather",
  271. Description: "Get the current weather",
  272. Parameters: struct {
  273. Type string `json:"type"`
  274. Required []string `json:"required"`
  275. Properties map[string]struct {
  276. Type string `json:"type"`
  277. Description string `json:"description"`
  278. Enum []string `json:"enum,omitempty"`
  279. } `json:"properties"`
  280. }{
  281. Type: "object",
  282. Required: []string{"location"},
  283. Properties: map[string]struct {
  284. Type string `json:"type"`
  285. Description string `json:"description"`
  286. Enum []string `json:"enum,omitempty"`
  287. }{
  288. "location": {
  289. Type: "string",
  290. Description: "The city and state",
  291. },
  292. "unit": {
  293. Type: "string",
  294. Enum: []string{"celsius", "fahrenheit"},
  295. },
  296. },
  297. },
  298. },
  299. },
  300. },
  301. Options: map[string]any{
  302. "temperature": 1.0,
  303. "top_p": 1.0,
  304. },
  305. Stream: &True,
  306. },
  307. },
  308. {
  309. name: "chat handler with context_length",
  310. body: `{
  311. "model": "test-model",
  312. "messages": [{"role": "user", "content": "Hello"}],
  313. "context_length": 4096
  314. }`,
  315. req: api.ChatRequest{
  316. Model: "test-model",
  317. Messages: []api.Message{{Role: "user", Content: "Hello"}},
  318. Options: map[string]any{
  319. "num_ctx": 4096.0, // float because JSON doesn't distinguish between float and int
  320. "temperature": 1.0,
  321. "top_p": 1.0,
  322. },
  323. Stream: &False,
  324. },
  325. },
  326. {
  327. name: "chat handler with max_completion_tokens",
  328. body: `{
  329. "model": "test-model",
  330. "messages": [{"role": "user", "content": "Hello"}],
  331. "max_completion_tokens": 2
  332. }`,
  333. req: api.ChatRequest{
  334. Model: "test-model",
  335. Messages: []api.Message{{Role: "user", Content: "Hello"}},
  336. Options: map[string]any{
  337. "num_predict": 2.0, // float because JSON doesn't distinguish between float and int
  338. "temperature": 1.0,
  339. "top_p": 1.0,
  340. },
  341. Stream: &False,
  342. },
  343. },
  344. {
  345. name: "chat handler error forwarding",
  346. body: `{
  347. "model": "test-model",
  348. "messages": [
  349. {"role": "user", "content": 2}
  350. ]
  351. }`,
  352. err: ErrorResponse{
  353. Error: Error{
  354. Message: "invalid message content type: float64",
  355. Type: "invalid_request_error",
  356. },
  357. },
  358. },
  359. }
  360. endpoint := func(c *gin.Context) {
  361. c.Status(http.StatusOK)
  362. }
  363. gin.SetMode(gin.TestMode)
  364. router := gin.New()
  365. router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
  366. router.Handle(http.MethodPost, "/api/chat", endpoint)
  367. for _, tc := range testCases {
  368. t.Run(tc.name, func(t *testing.T) {
  369. req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
  370. req.Header.Set("Content-Type", "application/json")
  371. defer func() { capturedRequest = nil }()
  372. resp := httptest.NewRecorder()
  373. router.ServeHTTP(resp, req)
  374. var errResp ErrorResponse
  375. if resp.Code != http.StatusOK {
  376. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  377. t.Fatal(err)
  378. }
  379. return
  380. }
  381. if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
  382. t.Fatalf("requests did not match (-want +got):\n%s", diff)
  383. }
  384. if diff := cmp.Diff(tc.err, errResp); diff != "" {
  385. t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
  386. }
  387. })
  388. }
  389. }
  390. func TestCompletionsMiddleware(t *testing.T) {
  391. type testCase struct {
  392. name string
  393. body string
  394. req api.GenerateRequest
  395. err ErrorResponse
  396. }
  397. var capturedRequest *api.GenerateRequest
  398. testCases := []testCase{
  399. {
  400. name: "completions handler",
  401. body: `{
  402. "model": "test-model",
  403. "prompt": "Hello",
  404. "temperature": 0.8,
  405. "stop": ["\n", "stop"],
  406. "suffix": "suffix"
  407. }`,
  408. req: api.GenerateRequest{
  409. Model: "test-model",
  410. Prompt: "Hello",
  411. Options: map[string]any{
  412. "frequency_penalty": 0.0,
  413. "presence_penalty": 0.0,
  414. "temperature": 0.8,
  415. "top_p": 1.0,
  416. "stop": []any{"\n", "stop"},
  417. },
  418. Suffix: "suffix",
  419. Stream: &False,
  420. },
  421. },
  422. {
  423. name: "completions handler stream",
  424. body: `{
  425. "model": "test-model",
  426. "prompt": "Hello",
  427. "stream": true,
  428. "temperature": 0.8,
  429. "stop": ["\n", "stop"],
  430. "suffix": "suffix"
  431. }`,
  432. req: api.GenerateRequest{
  433. Model: "test-model",
  434. Prompt: "Hello",
  435. Options: map[string]any{
  436. "frequency_penalty": 0.0,
  437. "presence_penalty": 0.0,
  438. "temperature": 0.8,
  439. "top_p": 1.0,
  440. "stop": []any{"\n", "stop"},
  441. },
  442. Suffix: "suffix",
  443. Stream: &True,
  444. },
  445. },
  446. {
  447. name: "completions handler stream with usage",
  448. body: `{
  449. "model": "test-model",
  450. "prompt": "Hello",
  451. "stream": true,
  452. "stream_options": {"include_usage": true},
  453. "temperature": 0.8,
  454. "stop": ["\n", "stop"],
  455. "suffix": "suffix"
  456. }`,
  457. req: api.GenerateRequest{
  458. Model: "test-model",
  459. Prompt: "Hello",
  460. Options: map[string]any{
  461. "frequency_penalty": 0.0,
  462. "presence_penalty": 0.0,
  463. "temperature": 0.8,
  464. "top_p": 1.0,
  465. "stop": []any{"\n", "stop"},
  466. },
  467. Suffix: "suffix",
  468. Stream: &True,
  469. },
  470. },
  471. {
  472. name: "completions handler error forwarding",
  473. body: `{
  474. "model": "test-model",
  475. "prompt": "Hello",
  476. "temperature": null,
  477. "stop": [1, 2],
  478. "suffix": "suffix"
  479. }`,
  480. err: ErrorResponse{
  481. Error: Error{
  482. Message: "invalid type for 'stop' field: float64",
  483. Type: "invalid_request_error",
  484. },
  485. },
  486. },
  487. }
  488. endpoint := func(c *gin.Context) {
  489. c.Status(http.StatusOK)
  490. }
  491. gin.SetMode(gin.TestMode)
  492. router := gin.New()
  493. router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
  494. router.Handle(http.MethodPost, "/api/generate", endpoint)
  495. for _, tc := range testCases {
  496. t.Run(tc.name, func(t *testing.T) {
  497. req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
  498. req.Header.Set("Content-Type", "application/json")
  499. resp := httptest.NewRecorder()
  500. router.ServeHTTP(resp, req)
  501. var errResp ErrorResponse
  502. if resp.Code != http.StatusOK {
  503. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  504. t.Fatal(err)
  505. }
  506. }
  507. if capturedRequest != nil {
  508. if diff := cmp.Diff(tc.req, *capturedRequest); diff != "" {
  509. t.Fatalf("requests did not match (-want +got):\n%s", diff)
  510. }
  511. }
  512. if diff := cmp.Diff(tc.err, errResp); diff != "" {
  513. t.Fatalf("errors did not match (-want +got):\n%s", diff)
  514. }
  515. capturedRequest = nil
  516. })
  517. }
  518. }
  519. func TestEmbeddingsMiddleware(t *testing.T) {
  520. type testCase struct {
  521. name string
  522. body string
  523. req api.EmbedRequest
  524. err ErrorResponse
  525. }
  526. var capturedRequest *api.EmbedRequest
  527. testCases := []testCase{
  528. {
  529. name: "embed handler single input",
  530. body: `{
  531. "input": "Hello",
  532. "model": "test-model"
  533. }`,
  534. req: api.EmbedRequest{
  535. Input: "Hello",
  536. Model: "test-model",
  537. },
  538. },
  539. {
  540. name: "embed handler batch input",
  541. body: `{
  542. "input": ["Hello", "World"],
  543. "model": "test-model"
  544. }`,
  545. req: api.EmbedRequest{
  546. Input: []any{"Hello", "World"},
  547. Model: "test-model",
  548. },
  549. },
  550. {
  551. name: "embed handler error forwarding",
  552. body: `{
  553. "model": "test-model"
  554. }`,
  555. err: ErrorResponse{
  556. Error: Error{
  557. Message: "invalid input",
  558. Type: "invalid_request_error",
  559. },
  560. },
  561. },
  562. }
  563. endpoint := func(c *gin.Context) {
  564. c.Status(http.StatusOK)
  565. }
  566. gin.SetMode(gin.TestMode)
  567. router := gin.New()
  568. router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
  569. router.Handle(http.MethodPost, "/api/embed", endpoint)
  570. for _, tc := range testCases {
  571. t.Run(tc.name, func(t *testing.T) {
  572. req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
  573. req.Header.Set("Content-Type", "application/json")
  574. resp := httptest.NewRecorder()
  575. router.ServeHTTP(resp, req)
  576. var errResp ErrorResponse
  577. if resp.Code != http.StatusOK {
  578. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  579. t.Fatal(err)
  580. }
  581. }
  582. if capturedRequest != nil {
  583. if diff := cmp.Diff(tc.req, *capturedRequest); diff != "" {
  584. t.Fatalf("requests did not match (-want +got):\n%s", diff)
  585. }
  586. }
  587. if diff := cmp.Diff(tc.err, errResp); diff != "" {
  588. t.Fatalf("errors did not match (-want +got):\n%s", diff)
  589. }
  590. capturedRequest = nil
  591. })
  592. }
  593. }
  594. func TestListMiddleware(t *testing.T) {
  595. type testCase struct {
  596. name string
  597. endpoint func(c *gin.Context)
  598. resp string
  599. }
  600. testCases := []testCase{
  601. {
  602. name: "list handler",
  603. endpoint: func(c *gin.Context) {
  604. c.JSON(http.StatusOK, api.ListResponse{
  605. Models: []api.ListModelResponse{
  606. {
  607. Name: "test-model",
  608. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  609. },
  610. },
  611. })
  612. },
  613. resp: `{
  614. "object": "list",
  615. "data": [
  616. {
  617. "id": "test-model",
  618. "object": "model",
  619. "created": 1686935002,
  620. "owned_by": "library"
  621. }
  622. ]
  623. }`,
  624. },
  625. {
  626. name: "list handler empty output",
  627. endpoint: func(c *gin.Context) {
  628. c.JSON(http.StatusOK, api.ListResponse{})
  629. },
  630. resp: `{
  631. "object": "list",
  632. "data": null
  633. }`,
  634. },
  635. }
  636. gin.SetMode(gin.TestMode)
  637. for _, tc := range testCases {
  638. router := gin.New()
  639. router.Use(ListMiddleware())
  640. router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
  641. req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
  642. resp := httptest.NewRecorder()
  643. router.ServeHTTP(resp, req)
  644. var expected, actual map[string]any
  645. err := json.Unmarshal([]byte(tc.resp), &expected)
  646. if err != nil {
  647. t.Fatalf("failed to unmarshal expected response: %v", err)
  648. }
  649. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  650. if err != nil {
  651. t.Fatalf("failed to unmarshal actual response: %v", err)
  652. }
  653. if diff := cmp.Diff(expected, actual); diff != "" {
  654. t.Errorf("responses did not match (-want +got):\n%s", diff)
  655. }
  656. }
  657. }
  658. func TestRetrieveMiddleware(t *testing.T) {
  659. type testCase struct {
  660. name string
  661. endpoint func(c *gin.Context)
  662. resp string
  663. }
  664. testCases := []testCase{
  665. {
  666. name: "retrieve handler",
  667. endpoint: func(c *gin.Context) {
  668. c.JSON(http.StatusOK, api.ShowResponse{
  669. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  670. })
  671. },
  672. resp: `{
  673. "id":"test-model",
  674. "object":"model",
  675. "created":1686935002,
  676. "owned_by":"library"}
  677. `,
  678. },
  679. {
  680. name: "retrieve handler error forwarding",
  681. endpoint: func(c *gin.Context) {
  682. c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
  683. },
  684. resp: `{
  685. "error": {
  686. "code": null,
  687. "message": "model not found",
  688. "param": null,
  689. "type": "api_error"
  690. }
  691. }`,
  692. },
  693. }
  694. gin.SetMode(gin.TestMode)
  695. for _, tc := range testCases {
  696. router := gin.New()
  697. router.Use(RetrieveMiddleware())
  698. router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
  699. req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
  700. resp := httptest.NewRecorder()
  701. router.ServeHTTP(resp, req)
  702. var expected, actual map[string]any
  703. err := json.Unmarshal([]byte(tc.resp), &expected)
  704. if err != nil {
  705. t.Fatalf("failed to unmarshal expected response: %v", err)
  706. }
  707. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  708. if err != nil {
  709. t.Fatalf("failed to unmarshal actual response: %v", err)
  710. }
  711. if diff := cmp.Diff(expected, actual); diff != "" {
  712. t.Errorf("responses did not match (-want +got):\n%s", diff)
  713. }
  714. }
  715. }