openai_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. package openai
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/json"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "reflect"
  10. "strings"
  11. "testing"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. "github.com/ollama/ollama/api"
  15. )
  16. const (
  17. prefix = `data:image/jpeg;base64,`
  18. image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
  19. )
  20. var (
  21. False = false
  22. True = true
  23. )
  24. func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
  25. return func(c *gin.Context) {
  26. bodyBytes, _ := io.ReadAll(c.Request.Body)
  27. c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  28. err := json.Unmarshal(bodyBytes, capturedRequest)
  29. if err != nil {
  30. c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
  31. }
  32. c.Next()
  33. }
  34. }
  35. func TestChatMiddleware(t *testing.T) {
  36. type testCase struct {
  37. name string
  38. body string
  39. req api.ChatRequest
  40. err ErrorResponse
  41. }
  42. var capturedRequest *api.ChatRequest
  43. testCases := []testCase{
  44. {
  45. name: "chat handler",
  46. body: `{
  47. "model": "test-model",
  48. "messages": [
  49. {"role": "user", "content": "Hello"}
  50. ]
  51. }`,
  52. req: api.ChatRequest{
  53. Model: "test-model",
  54. Messages: []api.Message{
  55. {
  56. Role: "user",
  57. Content: "Hello",
  58. },
  59. },
  60. Options: map[string]any{
  61. "temperature": 1.0,
  62. "top_p": 1.0,
  63. },
  64. Stream: &False,
  65. },
  66. },
  67. {
  68. name: "chat handler with options",
  69. body: `{
  70. "model": "test-model",
  71. "messages": [
  72. {"role": "user", "content": "Hello"}
  73. ],
  74. "stream": true,
  75. "max_tokens": 999,
  76. "seed": 123,
  77. "stop": ["\n", "stop"],
  78. "temperature": 3.0,
  79. "frequency_penalty": 4.0,
  80. "presence_penalty": 5.0,
  81. "top_p": 6.0,
  82. "response_format": {"type": "json_object"}
  83. }`,
  84. req: api.ChatRequest{
  85. Model: "test-model",
  86. Messages: []api.Message{
  87. {
  88. Role: "user",
  89. Content: "Hello",
  90. },
  91. },
  92. Options: map[string]any{
  93. "num_predict": 999.0, // float because JSON doesn't distinguish between float and int
  94. "seed": 123.0,
  95. "stop": []any{"\n", "stop"},
  96. "temperature": 3.0,
  97. "frequency_penalty": 4.0,
  98. "presence_penalty": 5.0,
  99. "top_p": 6.0,
  100. },
  101. Format: "json",
  102. Stream: &True,
  103. },
  104. },
  105. {
  106. name: "chat handler with image content",
  107. body: `{
  108. "model": "test-model",
  109. "messages": [
  110. {
  111. "role": "user",
  112. "content": [
  113. {
  114. "type": "text",
  115. "text": "Hello"
  116. },
  117. {
  118. "type": "image_url",
  119. "image_url": {
  120. "url": "` + prefix + image + `"
  121. }
  122. }
  123. ]
  124. }
  125. ]
  126. }`,
  127. req: api.ChatRequest{
  128. Model: "test-model",
  129. Messages: []api.Message{
  130. {
  131. Role: "user",
  132. Content: "Hello",
  133. },
  134. {
  135. Role: "user",
  136. Images: []api.ImageData{
  137. func() []byte {
  138. img, _ := base64.StdEncoding.DecodeString(image)
  139. return img
  140. }(),
  141. },
  142. },
  143. },
  144. Options: map[string]any{
  145. "temperature": 1.0,
  146. "top_p": 1.0,
  147. },
  148. Stream: &False,
  149. },
  150. },
  151. {
  152. name: "chat handler with tools",
  153. body: `{
  154. "model": "test-model",
  155. "messages": [
  156. {"role": "user", "content": "What's the weather like in Paris Today?"},
  157. {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
  158. ]
  159. }`,
  160. req: api.ChatRequest{
  161. Model: "test-model",
  162. Messages: []api.Message{
  163. {
  164. Role: "user",
  165. Content: "What's the weather like in Paris Today?",
  166. },
  167. {
  168. Role: "assistant",
  169. ToolCalls: []api.ToolCall{
  170. {
  171. Function: api.ToolCallFunction{
  172. Name: "get_current_weather",
  173. Arguments: map[string]interface{}{
  174. "location": "Paris, France",
  175. "format": "celsius",
  176. },
  177. },
  178. },
  179. },
  180. },
  181. },
  182. Options: map[string]any{
  183. "temperature": 1.0,
  184. "top_p": 1.0,
  185. },
  186. Stream: &False,
  187. },
  188. },
  189. {
  190. name: "chat handler with streaming tools",
  191. body: `{
  192. "model": "test-model",
  193. "messages": [
  194. {"role": "user", "content": "What's the weather like in Paris?"}
  195. ],
  196. "stream": true,
  197. "tools": [{
  198. "type": "function",
  199. "function": {
  200. "name": "get_weather",
  201. "description": "Get the current weather",
  202. "parameters": {
  203. "type": "object",
  204. "required": ["location"],
  205. "properties": {
  206. "location": {
  207. "type": "string",
  208. "description": "The city and state"
  209. },
  210. "unit": {
  211. "type": "string",
  212. "enum": ["celsius", "fahrenheit"]
  213. }
  214. }
  215. }
  216. }
  217. }]
  218. }`,
  219. req: api.ChatRequest{
  220. Model: "test-model",
  221. Messages: []api.Message{
  222. {
  223. Role: "user",
  224. Content: "What's the weather like in Paris?",
  225. },
  226. },
  227. Tools: []api.Tool{
  228. {
  229. Type: "function",
  230. Function: api.ToolFunction{
  231. Name: "get_weather",
  232. Description: "Get the current weather",
  233. Parameters: struct {
  234. Type string `json:"type"`
  235. Required []string `json:"required"`
  236. Properties map[string]struct {
  237. Type string `json:"type"`
  238. Description string `json:"description"`
  239. Enum []string `json:"enum,omitempty"`
  240. } `json:"properties"`
  241. }{
  242. Type: "object",
  243. Required: []string{"location"},
  244. Properties: map[string]struct {
  245. Type string `json:"type"`
  246. Description string `json:"description"`
  247. Enum []string `json:"enum,omitempty"`
  248. }{
  249. "location": {
  250. Type: "string",
  251. Description: "The city and state",
  252. },
  253. "unit": {
  254. Type: "string",
  255. Enum: []string{"celsius", "fahrenheit"},
  256. },
  257. },
  258. },
  259. },
  260. },
  261. },
  262. Options: map[string]any{
  263. "temperature": 1.0,
  264. "top_p": 1.0,
  265. },
  266. Stream: &True,
  267. },
  268. },
  269. {
  270. name: "chat handler error forwarding",
  271. body: `{
  272. "model": "test-model",
  273. "messages": [
  274. {"role": "user", "content": 2}
  275. ]
  276. }`,
  277. err: ErrorResponse{
  278. Error: Error{
  279. Message: "invalid message content type: float64",
  280. Type: "invalid_request_error",
  281. },
  282. },
  283. },
  284. }
  285. endpoint := func(c *gin.Context) {
  286. c.Status(http.StatusOK)
  287. }
  288. gin.SetMode(gin.TestMode)
  289. router := gin.New()
  290. router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
  291. router.Handle(http.MethodPost, "/api/chat", endpoint)
  292. for _, tc := range testCases {
  293. t.Run(tc.name, func(t *testing.T) {
  294. req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
  295. req.Header.Set("Content-Type", "application/json")
  296. defer func() { capturedRequest = nil }()
  297. resp := httptest.NewRecorder()
  298. router.ServeHTTP(resp, req)
  299. var errResp ErrorResponse
  300. if resp.Code != http.StatusOK {
  301. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  302. t.Fatal(err)
  303. }
  304. }
  305. if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
  306. t.Fatal("requests did not match")
  307. }
  308. if !reflect.DeepEqual(tc.err, errResp) {
  309. t.Fatal("errors did not match")
  310. }
  311. })
  312. }
  313. }
  314. func TestCompletionsMiddleware(t *testing.T) {
  315. type testCase struct {
  316. name string
  317. body string
  318. req api.GenerateRequest
  319. err ErrorResponse
  320. }
  321. var capturedRequest *api.GenerateRequest
  322. testCases := []testCase{
  323. {
  324. name: "completions handler",
  325. body: `{
  326. "model": "test-model",
  327. "prompt": "Hello",
  328. "temperature": 0.8,
  329. "stop": ["\n", "stop"],
  330. "suffix": "suffix"
  331. }`,
  332. req: api.GenerateRequest{
  333. Model: "test-model",
  334. Prompt: "Hello",
  335. Options: map[string]any{
  336. "frequency_penalty": 0.0,
  337. "presence_penalty": 0.0,
  338. "temperature": 0.8,
  339. "top_p": 1.0,
  340. "stop": []any{"\n", "stop"},
  341. },
  342. Suffix: "suffix",
  343. Stream: &False,
  344. },
  345. },
  346. {
  347. name: "completions handler error forwarding",
  348. body: `{
  349. "model": "test-model",
  350. "prompt": "Hello",
  351. "temperature": null,
  352. "stop": [1, 2],
  353. "suffix": "suffix"
  354. }`,
  355. err: ErrorResponse{
  356. Error: Error{
  357. Message: "invalid type for 'stop' field: float64",
  358. Type: "invalid_request_error",
  359. },
  360. },
  361. },
  362. }
  363. endpoint := func(c *gin.Context) {
  364. c.Status(http.StatusOK)
  365. }
  366. gin.SetMode(gin.TestMode)
  367. router := gin.New()
  368. router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
  369. router.Handle(http.MethodPost, "/api/generate", endpoint)
  370. for _, tc := range testCases {
  371. t.Run(tc.name, func(t *testing.T) {
  372. req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
  373. req.Header.Set("Content-Type", "application/json")
  374. resp := httptest.NewRecorder()
  375. router.ServeHTTP(resp, req)
  376. var errResp ErrorResponse
  377. if resp.Code != http.StatusOK {
  378. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  379. t.Fatal(err)
  380. }
  381. }
  382. if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
  383. t.Fatal("requests did not match")
  384. }
  385. if !reflect.DeepEqual(tc.err, errResp) {
  386. t.Fatal("errors did not match")
  387. }
  388. capturedRequest = nil
  389. })
  390. }
  391. }
  392. func TestEmbeddingsMiddleware(t *testing.T) {
  393. type testCase struct {
  394. name string
  395. body string
  396. req api.EmbedRequest
  397. err ErrorResponse
  398. }
  399. var capturedRequest *api.EmbedRequest
  400. testCases := []testCase{
  401. {
  402. name: "embed handler single input",
  403. body: `{
  404. "input": "Hello",
  405. "model": "test-model"
  406. }`,
  407. req: api.EmbedRequest{
  408. Input: "Hello",
  409. Model: "test-model",
  410. },
  411. },
  412. {
  413. name: "embed handler batch input",
  414. body: `{
  415. "input": ["Hello", "World"],
  416. "model": "test-model"
  417. }`,
  418. req: api.EmbedRequest{
  419. Input: []any{"Hello", "World"},
  420. Model: "test-model",
  421. },
  422. },
  423. {
  424. name: "embed handler error forwarding",
  425. body: `{
  426. "model": "test-model"
  427. }`,
  428. err: ErrorResponse{
  429. Error: Error{
  430. Message: "invalid input",
  431. Type: "invalid_request_error",
  432. },
  433. },
  434. },
  435. }
  436. endpoint := func(c *gin.Context) {
  437. c.Status(http.StatusOK)
  438. }
  439. gin.SetMode(gin.TestMode)
  440. router := gin.New()
  441. router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
  442. router.Handle(http.MethodPost, "/api/embed", endpoint)
  443. for _, tc := range testCases {
  444. t.Run(tc.name, func(t *testing.T) {
  445. req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
  446. req.Header.Set("Content-Type", "application/json")
  447. resp := httptest.NewRecorder()
  448. router.ServeHTTP(resp, req)
  449. var errResp ErrorResponse
  450. if resp.Code != http.StatusOK {
  451. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  452. t.Fatal(err)
  453. }
  454. }
  455. if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
  456. t.Fatal("requests did not match")
  457. }
  458. if !reflect.DeepEqual(tc.err, errResp) {
  459. t.Fatal("errors did not match")
  460. }
  461. capturedRequest = nil
  462. })
  463. }
  464. }
  465. func TestListMiddleware(t *testing.T) {
  466. type testCase struct {
  467. name string
  468. endpoint func(c *gin.Context)
  469. resp string
  470. }
  471. testCases := []testCase{
  472. {
  473. name: "list handler",
  474. endpoint: func(c *gin.Context) {
  475. c.JSON(http.StatusOK, api.ListResponse{
  476. Models: []api.ListModelResponse{
  477. {
  478. Name: "test-model",
  479. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  480. },
  481. },
  482. })
  483. },
  484. resp: `{
  485. "object": "list",
  486. "data": [
  487. {
  488. "id": "test-model",
  489. "object": "model",
  490. "created": 1686935002,
  491. "owned_by": "library"
  492. }
  493. ]
  494. }`,
  495. },
  496. {
  497. name: "list handler empty output",
  498. endpoint: func(c *gin.Context) {
  499. c.JSON(http.StatusOK, api.ListResponse{})
  500. },
  501. resp: `{
  502. "object": "list",
  503. "data": null
  504. }`,
  505. },
  506. }
  507. gin.SetMode(gin.TestMode)
  508. for _, tc := range testCases {
  509. router := gin.New()
  510. router.Use(ListMiddleware())
  511. router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
  512. req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
  513. resp := httptest.NewRecorder()
  514. router.ServeHTTP(resp, req)
  515. var expected, actual map[string]any
  516. err := json.Unmarshal([]byte(tc.resp), &expected)
  517. if err != nil {
  518. t.Fatalf("failed to unmarshal expected response: %v", err)
  519. }
  520. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  521. if err != nil {
  522. t.Fatalf("failed to unmarshal actual response: %v", err)
  523. }
  524. if !reflect.DeepEqual(expected, actual) {
  525. t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
  526. }
  527. }
  528. }
  529. func TestRetrieveMiddleware(t *testing.T) {
  530. type testCase struct {
  531. name string
  532. endpoint func(c *gin.Context)
  533. resp string
  534. }
  535. testCases := []testCase{
  536. {
  537. name: "retrieve handler",
  538. endpoint: func(c *gin.Context) {
  539. c.JSON(http.StatusOK, api.ShowResponse{
  540. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  541. })
  542. },
  543. resp: `{
  544. "id":"test-model",
  545. "object":"model",
  546. "created":1686935002,
  547. "owned_by":"library"}
  548. `,
  549. },
  550. {
  551. name: "retrieve handler error forwarding",
  552. endpoint: func(c *gin.Context) {
  553. c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
  554. },
  555. resp: `{
  556. "error": {
  557. "code": null,
  558. "message": "model not found",
  559. "param": null,
  560. "type": "api_error"
  561. }
  562. }`,
  563. },
  564. }
  565. gin.SetMode(gin.TestMode)
  566. for _, tc := range testCases {
  567. router := gin.New()
  568. router.Use(RetrieveMiddleware())
  569. router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
  570. req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
  571. resp := httptest.NewRecorder()
  572. router.ServeHTTP(resp, req)
  573. var expected, actual map[string]any
  574. err := json.Unmarshal([]byte(tc.resp), &expected)
  575. if err != nil {
  576. t.Fatalf("failed to unmarshal expected response: %v", err)
  577. }
  578. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  579. if err != nil {
  580. t.Fatalf("failed to unmarshal actual response: %v", err)
  581. }
  582. if !reflect.DeepEqual(expected, actual) {
  583. t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
  584. }
  585. }
  586. }