openai_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641
  1. package openai
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/json"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "reflect"
  10. "strings"
  11. "testing"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. "github.com/google/go-cmp/cmp"
  15. "github.com/ollama/ollama/api"
  16. )
  17. const (
  18. prefix = `data:image/jpeg;base64,`
  19. image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
  20. )
  21. var (
  22. False = false
  23. True = true
  24. )
  25. func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
  26. return func(c *gin.Context) {
  27. bodyBytes, _ := io.ReadAll(c.Request.Body)
  28. c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  29. err := json.Unmarshal(bodyBytes, capturedRequest)
  30. if err != nil {
  31. c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
  32. }
  33. c.Next()
  34. }
  35. }
  36. func TestChatMiddleware(t *testing.T) {
  37. type testCase struct {
  38. name string
  39. body string
  40. req api.ChatRequest
  41. err ErrorResponse
  42. }
  43. var capturedRequest *api.ChatRequest
  44. testCases := []testCase{
  45. {
  46. name: "chat handler",
  47. body: `{
  48. "model": "test-model",
  49. "messages": [
  50. {"role": "user", "content": "Hello"}
  51. ]
  52. }`,
  53. req: api.ChatRequest{
  54. Model: "test-model",
  55. Messages: []api.Message{
  56. {
  57. Role: "user",
  58. Content: "Hello",
  59. },
  60. },
  61. Options: map[string]any{
  62. "temperature": 1.0,
  63. "top_p": 1.0,
  64. },
  65. Stream: &False,
  66. },
  67. },
  68. {
  69. name: "chat handler with options",
  70. body: `{
  71. "model": "test-model",
  72. "messages": [
  73. {"role": "user", "content": "Hello"}
  74. ],
  75. "stream": true,
  76. "max_tokens": 999,
  77. "seed": 123,
  78. "stop": ["\n", "stop"],
  79. "temperature": 3.0,
  80. "frequency_penalty": 4.0,
  81. "presence_penalty": 5.0,
  82. "top_p": 6.0,
  83. "response_format": {"type": "json_object"}
  84. }`,
  85. req: api.ChatRequest{
  86. Model: "test-model",
  87. Messages: []api.Message{
  88. {
  89. Role: "user",
  90. Content: "Hello",
  91. },
  92. },
  93. Options: map[string]any{
  94. "num_predict": 999.0, // float because JSON doesn't distinguish between float and int
  95. "seed": 123.0,
  96. "stop": []any{"\n", "stop"},
  97. "temperature": 3.0,
  98. "frequency_penalty": 4.0,
  99. "presence_penalty": 5.0,
  100. "top_p": 6.0,
  101. },
  102. Format: json.RawMessage(`"json"`),
  103. Stream: &True,
  104. },
  105. },
  106. {
  107. name: "chat handler with image content",
  108. body: `{
  109. "model": "test-model",
  110. "messages": [
  111. {
  112. "role": "user",
  113. "content": [
  114. {
  115. "type": "text",
  116. "text": "Hello"
  117. },
  118. {
  119. "type": "image_url",
  120. "image_url": {
  121. "url": "` + prefix + image + `"
  122. }
  123. }
  124. ]
  125. }
  126. ]
  127. }`,
  128. req: api.ChatRequest{
  129. Model: "test-model",
  130. Messages: []api.Message{
  131. {
  132. Role: "user",
  133. Content: "Hello",
  134. },
  135. {
  136. Role: "user",
  137. Images: []api.ImageData{
  138. func() []byte {
  139. img, _ := base64.StdEncoding.DecodeString(image)
  140. return img
  141. }(),
  142. },
  143. },
  144. },
  145. Options: map[string]any{
  146. "temperature": 1.0,
  147. "top_p": 1.0,
  148. },
  149. Stream: &False,
  150. },
  151. },
  152. {
  153. name: "chat handler with tools",
  154. body: `{
  155. "model": "test-model",
  156. "messages": [
  157. {"role": "user", "content": "What's the weather like in Paris Today?"},
  158. {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
  159. ]
  160. }`,
  161. req: api.ChatRequest{
  162. Model: "test-model",
  163. Messages: []api.Message{
  164. {
  165. Role: "user",
  166. Content: "What's the weather like in Paris Today?",
  167. },
  168. {
  169. Role: "assistant",
  170. ToolCalls: []api.ToolCall{
  171. {
  172. Function: api.ToolCallFunction{
  173. Name: "get_current_weather",
  174. Arguments: map[string]interface{}{
  175. "location": "Paris, France",
  176. "format": "celsius",
  177. },
  178. },
  179. },
  180. },
  181. },
  182. },
  183. Options: map[string]any{
  184. "temperature": 1.0,
  185. "top_p": 1.0,
  186. },
  187. Stream: &False,
  188. },
  189. },
  190. {
  191. name: "chat handler with streaming tools",
  192. body: `{
  193. "model": "test-model",
  194. "messages": [
  195. {"role": "user", "content": "What's the weather like in Paris?"}
  196. ],
  197. "stream": true,
  198. "tools": [{
  199. "type": "function",
  200. "function": {
  201. "name": "get_weather",
  202. "description": "Get the current weather",
  203. "parameters": {
  204. "type": "object",
  205. "required": ["location"],
  206. "properties": {
  207. "location": {
  208. "type": "string",
  209. "description": "The city and state"
  210. },
  211. "unit": {
  212. "type": "string",
  213. "enum": ["celsius", "fahrenheit"]
  214. }
  215. }
  216. }
  217. }
  218. }]
  219. }`,
  220. req: api.ChatRequest{
  221. Model: "test-model",
  222. Messages: []api.Message{
  223. {
  224. Role: "user",
  225. Content: "What's the weather like in Paris?",
  226. },
  227. },
  228. Tools: []api.Tool{
  229. {
  230. Type: "function",
  231. Function: api.ToolFunction{
  232. Name: "get_weather",
  233. Description: "Get the current weather",
  234. Parameters: struct {
  235. Type string `json:"type"`
  236. Required []string `json:"required"`
  237. Properties map[string]struct {
  238. Type string `json:"type"`
  239. Description string `json:"description"`
  240. Enum []string `json:"enum,omitempty"`
  241. } `json:"properties"`
  242. }{
  243. Type: "object",
  244. Required: []string{"location"},
  245. Properties: map[string]struct {
  246. Type string `json:"type"`
  247. Description string `json:"description"`
  248. Enum []string `json:"enum,omitempty"`
  249. }{
  250. "location": {
  251. Type: "string",
  252. Description: "The city and state",
  253. },
  254. "unit": {
  255. Type: "string",
  256. Enum: []string{"celsius", "fahrenheit"},
  257. },
  258. },
  259. },
  260. },
  261. },
  262. },
  263. Options: map[string]any{
  264. "temperature": 1.0,
  265. "top_p": 1.0,
  266. },
  267. Stream: &True,
  268. },
  269. },
  270. {
  271. name: "chat handler error forwarding",
  272. body: `{
  273. "model": "test-model",
  274. "messages": [
  275. {"role": "user", "content": 2}
  276. ]
  277. }`,
  278. err: ErrorResponse{
  279. Error: Error{
  280. Message: "invalid message content type: float64",
  281. Type: "invalid_request_error",
  282. },
  283. },
  284. },
  285. }
  286. endpoint := func(c *gin.Context) {
  287. c.Status(http.StatusOK)
  288. }
  289. gin.SetMode(gin.TestMode)
  290. router := gin.New()
  291. router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
  292. router.Handle(http.MethodPost, "/api/chat", endpoint)
  293. for _, tc := range testCases {
  294. t.Run(tc.name, func(t *testing.T) {
  295. req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
  296. req.Header.Set("Content-Type", "application/json")
  297. defer func() { capturedRequest = nil }()
  298. resp := httptest.NewRecorder()
  299. router.ServeHTTP(resp, req)
  300. var errResp ErrorResponse
  301. if resp.Code != http.StatusOK {
  302. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  303. t.Fatal(err)
  304. }
  305. return
  306. }
  307. if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
  308. t.Fatalf("requests did not match: %+v", diff)
  309. }
  310. if diff := cmp.Diff(tc.err, errResp); diff != "" {
  311. t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
  312. }
  313. })
  314. }
  315. }
  316. func TestCompletionsMiddleware(t *testing.T) {
  317. type testCase struct {
  318. name string
  319. body string
  320. req api.GenerateRequest
  321. err ErrorResponse
  322. }
  323. var capturedRequest *api.GenerateRequest
  324. testCases := []testCase{
  325. {
  326. name: "completions handler",
  327. body: `{
  328. "model": "test-model",
  329. "prompt": "Hello",
  330. "temperature": 0.8,
  331. "stop": ["\n", "stop"],
  332. "suffix": "suffix"
  333. }`,
  334. req: api.GenerateRequest{
  335. Model: "test-model",
  336. Prompt: "Hello",
  337. Options: map[string]any{
  338. "frequency_penalty": 0.0,
  339. "presence_penalty": 0.0,
  340. "temperature": 0.8,
  341. "top_p": 1.0,
  342. "stop": []any{"\n", "stop"},
  343. },
  344. Suffix: "suffix",
  345. Stream: &False,
  346. },
  347. },
  348. {
  349. name: "completions handler error forwarding",
  350. body: `{
  351. "model": "test-model",
  352. "prompt": "Hello",
  353. "temperature": null,
  354. "stop": [1, 2],
  355. "suffix": "suffix"
  356. }`,
  357. err: ErrorResponse{
  358. Error: Error{
  359. Message: "invalid type for 'stop' field: float64",
  360. Type: "invalid_request_error",
  361. },
  362. },
  363. },
  364. }
  365. endpoint := func(c *gin.Context) {
  366. c.Status(http.StatusOK)
  367. }
  368. gin.SetMode(gin.TestMode)
  369. router := gin.New()
  370. router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
  371. router.Handle(http.MethodPost, "/api/generate", endpoint)
  372. for _, tc := range testCases {
  373. t.Run(tc.name, func(t *testing.T) {
  374. req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
  375. req.Header.Set("Content-Type", "application/json")
  376. resp := httptest.NewRecorder()
  377. router.ServeHTTP(resp, req)
  378. var errResp ErrorResponse
  379. if resp.Code != http.StatusOK {
  380. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  381. t.Fatal(err)
  382. }
  383. }
  384. if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
  385. t.Fatal("requests did not match")
  386. }
  387. if !reflect.DeepEqual(tc.err, errResp) {
  388. t.Fatal("errors did not match")
  389. }
  390. capturedRequest = nil
  391. })
  392. }
  393. }
  394. func TestEmbeddingsMiddleware(t *testing.T) {
  395. type testCase struct {
  396. name string
  397. body string
  398. req api.EmbedRequest
  399. err ErrorResponse
  400. }
  401. var capturedRequest *api.EmbedRequest
  402. testCases := []testCase{
  403. {
  404. name: "embed handler single input",
  405. body: `{
  406. "input": "Hello",
  407. "model": "test-model"
  408. }`,
  409. req: api.EmbedRequest{
  410. Input: "Hello",
  411. Model: "test-model",
  412. },
  413. },
  414. {
  415. name: "embed handler batch input",
  416. body: `{
  417. "input": ["Hello", "World"],
  418. "model": "test-model"
  419. }`,
  420. req: api.EmbedRequest{
  421. Input: []any{"Hello", "World"},
  422. Model: "test-model",
  423. },
  424. },
  425. {
  426. name: "embed handler error forwarding",
  427. body: `{
  428. "model": "test-model"
  429. }`,
  430. err: ErrorResponse{
  431. Error: Error{
  432. Message: "invalid input",
  433. Type: "invalid_request_error",
  434. },
  435. },
  436. },
  437. }
  438. endpoint := func(c *gin.Context) {
  439. c.Status(http.StatusOK)
  440. }
  441. gin.SetMode(gin.TestMode)
  442. router := gin.New()
  443. router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
  444. router.Handle(http.MethodPost, "/api/embed", endpoint)
  445. for _, tc := range testCases {
  446. t.Run(tc.name, func(t *testing.T) {
  447. req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
  448. req.Header.Set("Content-Type", "application/json")
  449. resp := httptest.NewRecorder()
  450. router.ServeHTTP(resp, req)
  451. var errResp ErrorResponse
  452. if resp.Code != http.StatusOK {
  453. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  454. t.Fatal(err)
  455. }
  456. }
  457. if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
  458. t.Fatal("requests did not match")
  459. }
  460. if !reflect.DeepEqual(tc.err, errResp) {
  461. t.Fatal("errors did not match")
  462. }
  463. capturedRequest = nil
  464. })
  465. }
  466. }
  467. func TestListMiddleware(t *testing.T) {
  468. type testCase struct {
  469. name string
  470. endpoint func(c *gin.Context)
  471. resp string
  472. }
  473. testCases := []testCase{
  474. {
  475. name: "list handler",
  476. endpoint: func(c *gin.Context) {
  477. c.JSON(http.StatusOK, api.ListResponse{
  478. Models: []api.ListModelResponse{
  479. {
  480. Name: "test-model",
  481. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  482. },
  483. },
  484. })
  485. },
  486. resp: `{
  487. "object": "list",
  488. "data": [
  489. {
  490. "id": "test-model",
  491. "object": "model",
  492. "created": 1686935002,
  493. "owned_by": "library"
  494. }
  495. ]
  496. }`,
  497. },
  498. {
  499. name: "list handler empty output",
  500. endpoint: func(c *gin.Context) {
  501. c.JSON(http.StatusOK, api.ListResponse{})
  502. },
  503. resp: `{
  504. "object": "list",
  505. "data": null
  506. }`,
  507. },
  508. }
  509. gin.SetMode(gin.TestMode)
  510. for _, tc := range testCases {
  511. router := gin.New()
  512. router.Use(ListMiddleware())
  513. router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
  514. req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
  515. resp := httptest.NewRecorder()
  516. router.ServeHTTP(resp, req)
  517. var expected, actual map[string]any
  518. err := json.Unmarshal([]byte(tc.resp), &expected)
  519. if err != nil {
  520. t.Fatalf("failed to unmarshal expected response: %v", err)
  521. }
  522. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  523. if err != nil {
  524. t.Fatalf("failed to unmarshal actual response: %v", err)
  525. }
  526. if !reflect.DeepEqual(expected, actual) {
  527. t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
  528. }
  529. }
  530. }
  531. func TestRetrieveMiddleware(t *testing.T) {
  532. type testCase struct {
  533. name string
  534. endpoint func(c *gin.Context)
  535. resp string
  536. }
  537. testCases := []testCase{
  538. {
  539. name: "retrieve handler",
  540. endpoint: func(c *gin.Context) {
  541. c.JSON(http.StatusOK, api.ShowResponse{
  542. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  543. })
  544. },
  545. resp: `{
  546. "id":"test-model",
  547. "object":"model",
  548. "created":1686935002,
  549. "owned_by":"library"}
  550. `,
  551. },
  552. {
  553. name: "retrieve handler error forwarding",
  554. endpoint: func(c *gin.Context) {
  555. c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
  556. },
  557. resp: `{
  558. "error": {
  559. "code": null,
  560. "message": "model not found",
  561. "param": null,
  562. "type": "api_error"
  563. }
  564. }`,
  565. },
  566. }
  567. gin.SetMode(gin.TestMode)
  568. for _, tc := range testCases {
  569. router := gin.New()
  570. router.Use(RetrieveMiddleware())
  571. router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
  572. req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
  573. resp := httptest.NewRecorder()
  574. router.ServeHTTP(resp, req)
  575. var expected, actual map[string]any
  576. err := json.Unmarshal([]byte(tc.resp), &expected)
  577. if err != nil {
  578. t.Fatalf("failed to unmarshal expected response: %v", err)
  579. }
  580. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  581. if err != nil {
  582. t.Fatalf("failed to unmarshal actual response: %v", err)
  583. }
  584. if !reflect.DeepEqual(expected, actual) {
  585. t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
  586. }
  587. }
  588. }