openai_test.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729
  1. package openai
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/json"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "reflect"
  10. "strings"
  11. "testing"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. "github.com/google/go-cmp/cmp"
  15. "github.com/ollama/ollama/api"
  16. )
  17. const (
  18. prefix = `data:image/jpeg;base64,`
  19. image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
  20. )
  21. var (
  22. False = false
  23. True = true
  24. )
  25. func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
  26. return func(c *gin.Context) {
  27. bodyBytes, _ := io.ReadAll(c.Request.Body)
  28. c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  29. err := json.Unmarshal(bodyBytes, capturedRequest)
  30. if err != nil {
  31. c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
  32. }
  33. c.Next()
  34. }
  35. }
  36. func TestChatMiddleware(t *testing.T) {
  37. type testCase struct {
  38. name string
  39. body string
  40. req api.ChatRequest
  41. err ErrorResponse
  42. }
  43. var capturedRequest *api.ChatRequest
  44. testCases := []testCase{
  45. {
  46. name: "chat handler",
  47. body: `{
  48. "model": "test-model",
  49. "messages": [
  50. {"role": "user", "content": "Hello"}
  51. ]
  52. }`,
  53. req: api.ChatRequest{
  54. Model: "test-model",
  55. Messages: []api.Message{
  56. {
  57. Role: "user",
  58. Content: "Hello",
  59. },
  60. },
  61. Options: map[string]any{
  62. "temperature": 1.0,
  63. "top_p": 1.0,
  64. },
  65. Stream: &False,
  66. },
  67. },
  68. {
  69. name: "chat handler with options",
  70. body: `{
  71. "model": "test-model",
  72. "messages": [
  73. {"role": "user", "content": "Hello"}
  74. ],
  75. "stream": true,
  76. "max_tokens": 999,
  77. "seed": 123,
  78. "stop": ["\n", "stop"],
  79. "temperature": 3.0,
  80. "frequency_penalty": 4.0,
  81. "presence_penalty": 5.0,
  82. "top_p": 6.0,
  83. "response_format": {"type": "json_object"}
  84. }`,
  85. req: api.ChatRequest{
  86. Model: "test-model",
  87. Messages: []api.Message{
  88. {
  89. Role: "user",
  90. Content: "Hello",
  91. },
  92. },
  93. Options: map[string]any{
  94. "num_predict": 999.0, // float because JSON doesn't distinguish between float and int
  95. "seed": 123.0,
  96. "stop": []any{"\n", "stop"},
  97. "temperature": 3.0,
  98. "frequency_penalty": 4.0,
  99. "presence_penalty": 5.0,
  100. "top_p": 6.0,
  101. },
  102. Format: json.RawMessage(`"json"`),
  103. Stream: &True,
  104. },
  105. },
  106. {
  107. name: "chat handler with streaming usage",
  108. body: `{
  109. "model": "test-model",
  110. "messages": [
  111. {"role": "user", "content": "Hello"}
  112. ],
  113. "stream": true,
  114. "stream_options": {"include_usage": true},
  115. "max_tokens": 999,
  116. "seed": 123,
  117. "stop": ["\n", "stop"],
  118. "temperature": 3.0,
  119. "frequency_penalty": 4.0,
  120. "presence_penalty": 5.0,
  121. "top_p": 6.0,
  122. "response_format": {"type": "json_object"}
  123. }`,
  124. req: api.ChatRequest{
  125. Model: "test-model",
  126. Messages: []api.Message{
  127. {
  128. Role: "user",
  129. Content: "Hello",
  130. },
  131. },
  132. Options: map[string]any{
  133. "num_predict": 999.0, // float because JSON doesn't distinguish between float and int
  134. "seed": 123.0,
  135. "stop": []any{"\n", "stop"},
  136. "temperature": 3.0,
  137. "frequency_penalty": 4.0,
  138. "presence_penalty": 5.0,
  139. "top_p": 6.0,
  140. },
  141. Format: json.RawMessage(`"json"`),
  142. Stream: &True,
  143. },
  144. },
  145. {
  146. name: "chat handler with image content",
  147. body: `{
  148. "model": "test-model",
  149. "messages": [
  150. {
  151. "role": "user",
  152. "content": [
  153. {
  154. "type": "text",
  155. "text": "Hello"
  156. },
  157. {
  158. "type": "image_url",
  159. "image_url": {
  160. "url": "` + prefix + image + `"
  161. }
  162. }
  163. ]
  164. }
  165. ]
  166. }`,
  167. req: api.ChatRequest{
  168. Model: "test-model",
  169. Messages: []api.Message{
  170. {
  171. Role: "user",
  172. Content: "Hello",
  173. },
  174. {
  175. Role: "user",
  176. Images: []api.ImageData{
  177. func() []byte {
  178. img, _ := base64.StdEncoding.DecodeString(image)
  179. return img
  180. }(),
  181. },
  182. },
  183. },
  184. Options: map[string]any{
  185. "temperature": 1.0,
  186. "top_p": 1.0,
  187. },
  188. Stream: &False,
  189. },
  190. },
  191. {
  192. name: "chat handler with tools",
  193. body: `{
  194. "model": "test-model",
  195. "messages": [
  196. {"role": "user", "content": "What's the weather like in Paris Today?"},
  197. {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
  198. ]
  199. }`,
  200. req: api.ChatRequest{
  201. Model: "test-model",
  202. Messages: []api.Message{
  203. {
  204. Role: "user",
  205. Content: "What's the weather like in Paris Today?",
  206. },
  207. {
  208. Role: "assistant",
  209. ToolCalls: []api.ToolCall{
  210. {
  211. Function: api.ToolCallFunction{
  212. Name: "get_current_weather",
  213. Arguments: map[string]interface{}{
  214. "location": "Paris, France",
  215. "format": "celsius",
  216. },
  217. },
  218. },
  219. },
  220. },
  221. },
  222. Options: map[string]any{
  223. "temperature": 1.0,
  224. "top_p": 1.0,
  225. },
  226. Stream: &False,
  227. },
  228. },
  229. {
  230. name: "chat handler with streaming tools",
  231. body: `{
  232. "model": "test-model",
  233. "messages": [
  234. {"role": "user", "content": "What's the weather like in Paris?"}
  235. ],
  236. "stream": true,
  237. "tools": [{
  238. "type": "function",
  239. "function": {
  240. "name": "get_weather",
  241. "description": "Get the current weather",
  242. "parameters": {
  243. "type": "object",
  244. "required": ["location"],
  245. "properties": {
  246. "location": {
  247. "type": "string",
  248. "description": "The city and state"
  249. },
  250. "unit": {
  251. "type": "string",
  252. "enum": ["celsius", "fahrenheit"]
  253. }
  254. }
  255. }
  256. }
  257. }]
  258. }`,
  259. req: api.ChatRequest{
  260. Model: "test-model",
  261. Messages: []api.Message{
  262. {
  263. Role: "user",
  264. Content: "What's the weather like in Paris?",
  265. },
  266. },
  267. Tools: []api.Tool{
  268. {
  269. Type: "function",
  270. Function: api.ToolFunction{
  271. Name: "get_weather",
  272. Description: "Get the current weather",
  273. Parameters: struct {
  274. Type string `json:"type"`
  275. Required []string `json:"required"`
  276. Properties map[string]struct {
  277. Type string `json:"type"`
  278. Description string `json:"description"`
  279. Enum []string `json:"enum,omitempty"`
  280. } `json:"properties"`
  281. }{
  282. Type: "object",
  283. Required: []string{"location"},
  284. Properties: map[string]struct {
  285. Type string `json:"type"`
  286. Description string `json:"description"`
  287. Enum []string `json:"enum,omitempty"`
  288. }{
  289. "location": {
  290. Type: "string",
  291. Description: "The city and state",
  292. },
  293. "unit": {
  294. Type: "string",
  295. Enum: []string{"celsius", "fahrenheit"},
  296. },
  297. },
  298. },
  299. },
  300. },
  301. },
  302. Options: map[string]any{
  303. "temperature": 1.0,
  304. "top_p": 1.0,
  305. },
  306. Stream: &True,
  307. },
  308. },
  309. {
  310. name: "chat handler error forwarding",
  311. body: `{
  312. "model": "test-model",
  313. "messages": [
  314. {"role": "user", "content": 2}
  315. ]
  316. }`,
  317. err: ErrorResponse{
  318. Error: Error{
  319. Message: "invalid message content type: float64",
  320. Type: "invalid_request_error",
  321. },
  322. },
  323. },
  324. }
  325. endpoint := func(c *gin.Context) {
  326. c.Status(http.StatusOK)
  327. }
  328. gin.SetMode(gin.TestMode)
  329. router := gin.New()
  330. router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
  331. router.Handle(http.MethodPost, "/api/chat", endpoint)
  332. for _, tc := range testCases {
  333. t.Run(tc.name, func(t *testing.T) {
  334. req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
  335. req.Header.Set("Content-Type", "application/json")
  336. defer func() { capturedRequest = nil }()
  337. resp := httptest.NewRecorder()
  338. router.ServeHTTP(resp, req)
  339. var errResp ErrorResponse
  340. if resp.Code != http.StatusOK {
  341. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  342. t.Fatal(err)
  343. }
  344. return
  345. }
  346. if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
  347. t.Fatalf("requests did not match: %+v", diff)
  348. }
  349. if diff := cmp.Diff(tc.err, errResp); diff != "" {
  350. t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
  351. }
  352. })
  353. }
  354. }
  355. func TestCompletionsMiddleware(t *testing.T) {
  356. type testCase struct {
  357. name string
  358. body string
  359. req api.GenerateRequest
  360. err ErrorResponse
  361. }
  362. var capturedRequest *api.GenerateRequest
  363. testCases := []testCase{
  364. {
  365. name: "completions handler",
  366. body: `{
  367. "model": "test-model",
  368. "prompt": "Hello",
  369. "temperature": 0.8,
  370. "stop": ["\n", "stop"],
  371. "suffix": "suffix"
  372. }`,
  373. req: api.GenerateRequest{
  374. Model: "test-model",
  375. Prompt: "Hello",
  376. Options: map[string]any{
  377. "frequency_penalty": 0.0,
  378. "presence_penalty": 0.0,
  379. "temperature": 0.8,
  380. "top_p": 1.0,
  381. "stop": []any{"\n", "stop"},
  382. },
  383. Suffix: "suffix",
  384. Stream: &False,
  385. },
  386. },
  387. {
  388. name: "completions handler stream",
  389. body: `{
  390. "model": "test-model",
  391. "prompt": "Hello",
  392. "stream": true,
  393. "temperature": 0.8,
  394. "stop": ["\n", "stop"],
  395. "suffix": "suffix"
  396. }`,
  397. req: api.GenerateRequest{
  398. Model: "test-model",
  399. Prompt: "Hello",
  400. Options: map[string]any{
  401. "frequency_penalty": 0.0,
  402. "presence_penalty": 0.0,
  403. "temperature": 0.8,
  404. "top_p": 1.0,
  405. "stop": []any{"\n", "stop"},
  406. },
  407. Suffix: "suffix",
  408. Stream: &True,
  409. },
  410. },
  411. {
  412. name: "completions handler stream with usage",
  413. body: `{
  414. "model": "test-model",
  415. "prompt": "Hello",
  416. "stream": true,
  417. "stream_options": {"include_usage": true},
  418. "temperature": 0.8,
  419. "stop": ["\n", "stop"],
  420. "suffix": "suffix"
  421. }`,
  422. req: api.GenerateRequest{
  423. Model: "test-model",
  424. Prompt: "Hello",
  425. Options: map[string]any{
  426. "frequency_penalty": 0.0,
  427. "presence_penalty": 0.0,
  428. "temperature": 0.8,
  429. "top_p": 1.0,
  430. "stop": []any{"\n", "stop"},
  431. },
  432. Suffix: "suffix",
  433. Stream: &True,
  434. },
  435. },
  436. {
  437. name: "completions handler error forwarding",
  438. body: `{
  439. "model": "test-model",
  440. "prompt": "Hello",
  441. "temperature": null,
  442. "stop": [1, 2],
  443. "suffix": "suffix"
  444. }`,
  445. err: ErrorResponse{
  446. Error: Error{
  447. Message: "invalid type for 'stop' field: float64",
  448. Type: "invalid_request_error",
  449. },
  450. },
  451. },
  452. }
  453. endpoint := func(c *gin.Context) {
  454. c.Status(http.StatusOK)
  455. }
  456. gin.SetMode(gin.TestMode)
  457. router := gin.New()
  458. router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
  459. router.Handle(http.MethodPost, "/api/generate", endpoint)
  460. for _, tc := range testCases {
  461. t.Run(tc.name, func(t *testing.T) {
  462. req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
  463. req.Header.Set("Content-Type", "application/json")
  464. resp := httptest.NewRecorder()
  465. router.ServeHTTP(resp, req)
  466. var errResp ErrorResponse
  467. if resp.Code != http.StatusOK {
  468. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  469. t.Fatal(err)
  470. }
  471. }
  472. if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
  473. t.Fatal("requests did not match")
  474. }
  475. if !reflect.DeepEqual(tc.err, errResp) {
  476. t.Fatal("errors did not match")
  477. }
  478. capturedRequest = nil
  479. })
  480. }
  481. }
  482. func TestEmbeddingsMiddleware(t *testing.T) {
  483. type testCase struct {
  484. name string
  485. body string
  486. req api.EmbedRequest
  487. err ErrorResponse
  488. }
  489. var capturedRequest *api.EmbedRequest
  490. testCases := []testCase{
  491. {
  492. name: "embed handler single input",
  493. body: `{
  494. "input": "Hello",
  495. "model": "test-model"
  496. }`,
  497. req: api.EmbedRequest{
  498. Input: "Hello",
  499. Model: "test-model",
  500. },
  501. },
  502. {
  503. name: "embed handler batch input",
  504. body: `{
  505. "input": ["Hello", "World"],
  506. "model": "test-model"
  507. }`,
  508. req: api.EmbedRequest{
  509. Input: []any{"Hello", "World"},
  510. Model: "test-model",
  511. },
  512. },
  513. {
  514. name: "embed handler error forwarding",
  515. body: `{
  516. "model": "test-model"
  517. }`,
  518. err: ErrorResponse{
  519. Error: Error{
  520. Message: "invalid input",
  521. Type: "invalid_request_error",
  522. },
  523. },
  524. },
  525. }
  526. endpoint := func(c *gin.Context) {
  527. c.Status(http.StatusOK)
  528. }
  529. gin.SetMode(gin.TestMode)
  530. router := gin.New()
  531. router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
  532. router.Handle(http.MethodPost, "/api/embed", endpoint)
  533. for _, tc := range testCases {
  534. t.Run(tc.name, func(t *testing.T) {
  535. req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
  536. req.Header.Set("Content-Type", "application/json")
  537. resp := httptest.NewRecorder()
  538. router.ServeHTTP(resp, req)
  539. var errResp ErrorResponse
  540. if resp.Code != http.StatusOK {
  541. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  542. t.Fatal(err)
  543. }
  544. }
  545. if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
  546. t.Fatal("requests did not match")
  547. }
  548. if !reflect.DeepEqual(tc.err, errResp) {
  549. t.Fatal("errors did not match")
  550. }
  551. capturedRequest = nil
  552. })
  553. }
  554. }
  555. func TestListMiddleware(t *testing.T) {
  556. type testCase struct {
  557. name string
  558. endpoint func(c *gin.Context)
  559. resp string
  560. }
  561. testCases := []testCase{
  562. {
  563. name: "list handler",
  564. endpoint: func(c *gin.Context) {
  565. c.JSON(http.StatusOK, api.ListResponse{
  566. Models: []api.ListModelResponse{
  567. {
  568. Name: "test-model",
  569. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  570. },
  571. },
  572. })
  573. },
  574. resp: `{
  575. "object": "list",
  576. "data": [
  577. {
  578. "id": "test-model",
  579. "object": "model",
  580. "created": 1686935002,
  581. "owned_by": "library"
  582. }
  583. ]
  584. }`,
  585. },
  586. {
  587. name: "list handler empty output",
  588. endpoint: func(c *gin.Context) {
  589. c.JSON(http.StatusOK, api.ListResponse{})
  590. },
  591. resp: `{
  592. "object": "list",
  593. "data": null
  594. }`,
  595. },
  596. }
  597. gin.SetMode(gin.TestMode)
  598. for _, tc := range testCases {
  599. router := gin.New()
  600. router.Use(ListMiddleware())
  601. router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
  602. req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
  603. resp := httptest.NewRecorder()
  604. router.ServeHTTP(resp, req)
  605. var expected, actual map[string]any
  606. err := json.Unmarshal([]byte(tc.resp), &expected)
  607. if err != nil {
  608. t.Fatalf("failed to unmarshal expected response: %v", err)
  609. }
  610. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  611. if err != nil {
  612. t.Fatalf("failed to unmarshal actual response: %v", err)
  613. }
  614. if !reflect.DeepEqual(expected, actual) {
  615. t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
  616. }
  617. }
  618. }
  619. func TestRetrieveMiddleware(t *testing.T) {
  620. type testCase struct {
  621. name string
  622. endpoint func(c *gin.Context)
  623. resp string
  624. }
  625. testCases := []testCase{
  626. {
  627. name: "retrieve handler",
  628. endpoint: func(c *gin.Context) {
  629. c.JSON(http.StatusOK, api.ShowResponse{
  630. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  631. })
  632. },
  633. resp: `{
  634. "id":"test-model",
  635. "object":"model",
  636. "created":1686935002,
  637. "owned_by":"library"}
  638. `,
  639. },
  640. {
  641. name: "retrieve handler error forwarding",
  642. endpoint: func(c *gin.Context) {
  643. c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
  644. },
  645. resp: `{
  646. "error": {
  647. "code": null,
  648. "message": "model not found",
  649. "param": null,
  650. "type": "api_error"
  651. }
  652. }`,
  653. },
  654. }
  655. gin.SetMode(gin.TestMode)
  656. for _, tc := range testCases {
  657. router := gin.New()
  658. router.Use(RetrieveMiddleware())
  659. router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
  660. req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
  661. resp := httptest.NewRecorder()
  662. router.ServeHTTP(resp, req)
  663. var expected, actual map[string]any
  664. err := json.Unmarshal([]byte(tc.resp), &expected)
  665. if err != nil {
  666. t.Fatalf("failed to unmarshal expected response: %v", err)
  667. }
  668. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  669. if err != nil {
  670. t.Fatalf("failed to unmarshal actual response: %v", err)
  671. }
  672. if !reflect.DeepEqual(expected, actual) {
  673. t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
  674. }
  675. }
  676. }