openai_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. package openai
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/json"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "reflect"
  10. "strings"
  11. "testing"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. "github.com/ollama/ollama/api"
  15. )
  16. const (
  17. prefix = `data:image/jpeg;base64,`
  18. image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
  19. )
  20. var (
  21. False = false
  22. True = true
  23. )
  24. func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
  25. return func(c *gin.Context) {
  26. bodyBytes, _ := io.ReadAll(c.Request.Body)
  27. c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  28. err := json.Unmarshal(bodyBytes, capturedRequest)
  29. if err != nil {
  30. c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
  31. }
  32. c.Next()
  33. }
  34. }
  35. func TestChatMiddleware(t *testing.T) {
  36. type testCase struct {
  37. name string
  38. body string
  39. req api.ChatRequest
  40. err ErrorResponse
  41. }
  42. var capturedRequest *api.ChatRequest
  43. testCases := []testCase{
  44. {
  45. name: "chat handler",
  46. body: `{
  47. "model": "test-model",
  48. "messages": [
  49. {"role": "user", "content": "Hello"}
  50. ]
  51. }`,
  52. req: api.ChatRequest{
  53. Model: "test-model",
  54. Messages: []api.Message{
  55. {
  56. Role: "user",
  57. Content: "Hello",
  58. },
  59. },
  60. Options: map[string]any{
  61. "temperature": 1.0,
  62. "top_p": 1.0,
  63. },
  64. Stream: &False,
  65. },
  66. },
  67. {
  68. name: "chat handler with options",
  69. body: `{
  70. "model": "test-model",
  71. "messages": [
  72. {"role": "user", "content": "Hello"}
  73. ],
  74. "stream": true,
  75. "max_tokens": 999,
  76. "seed": 123,
  77. "stop": ["\n", "stop"],
  78. "temperature": 3.0,
  79. "frequency_penalty": 4.0,
  80. "presence_penalty": 5.0,
  81. "top_p": 6.0,
  82. "response_format": {"type": "json_object"}
  83. }`,
  84. req: api.ChatRequest{
  85. Model: "test-model",
  86. Messages: []api.Message{
  87. {
  88. Role: "user",
  89. Content: "Hello",
  90. },
  91. },
  92. Options: map[string]any{
  93. "num_predict": 999.0, // float because JSON doesn't distinguish between float and int
  94. "seed": 123.0,
  95. "stop": []any{"\n", "stop"},
  96. "temperature": 6.0,
  97. "frequency_penalty": 8.0,
  98. "presence_penalty": 10.0,
  99. "top_p": 6.0,
  100. },
  101. Format: "json",
  102. Stream: &True,
  103. },
  104. },
  105. {
  106. name: "chat handler with image content",
  107. body: `{
  108. "model": "test-model",
  109. "messages": [
  110. {
  111. "role": "user",
  112. "content": [
  113. {
  114. "type": "text",
  115. "text": "Hello"
  116. },
  117. {
  118. "type": "image_url",
  119. "image_url": {
  120. "url": "` + prefix + image + `"
  121. }
  122. }
  123. ]
  124. }
  125. ]
  126. }`,
  127. req: api.ChatRequest{
  128. Model: "test-model",
  129. Messages: []api.Message{
  130. {
  131. Role: "user",
  132. Content: "Hello",
  133. },
  134. {
  135. Role: "user",
  136. Images: []api.ImageData{
  137. func() []byte {
  138. img, _ := base64.StdEncoding.DecodeString(image)
  139. return img
  140. }(),
  141. },
  142. },
  143. },
  144. Options: map[string]any{
  145. "temperature": 1.0,
  146. "top_p": 1.0,
  147. },
  148. Stream: &False,
  149. },
  150. },
  151. {
  152. name: "chat handler with tools",
  153. body: `{
  154. "model": "test-model",
  155. "messages": [
  156. {"role": "user", "content": "What's the weather like in Paris Today?"},
  157. {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
  158. ]
  159. }`,
  160. req: api.ChatRequest{
  161. Model: "test-model",
  162. Messages: []api.Message{
  163. {
  164. Role: "user",
  165. Content: "What's the weather like in Paris Today?",
  166. },
  167. {
  168. Role: "assistant",
  169. ToolCalls: []api.ToolCall{
  170. {
  171. Function: api.ToolCallFunction{
  172. Name: "get_current_weather",
  173. Arguments: map[string]interface{}{
  174. "location": "Paris, France",
  175. "format": "celsius",
  176. },
  177. },
  178. },
  179. },
  180. },
  181. },
  182. Options: map[string]any{
  183. "temperature": 1.0,
  184. "top_p": 1.0,
  185. },
  186. Stream: &False,
  187. },
  188. },
  189. {
  190. name: "chat handler error forwarding",
  191. body: `{
  192. "model": "test-model",
  193. "messages": [
  194. {"role": "user", "content": 2}
  195. ]
  196. }`,
  197. err: ErrorResponse{
  198. Error: Error{
  199. Message: "invalid message content type: float64",
  200. Type: "invalid_request_error",
  201. },
  202. },
  203. },
  204. }
  205. endpoint := func(c *gin.Context) {
  206. c.Status(http.StatusOK)
  207. }
  208. gin.SetMode(gin.TestMode)
  209. router := gin.New()
  210. router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
  211. router.Handle(http.MethodPost, "/api/chat", endpoint)
  212. for _, tc := range testCases {
  213. t.Run(tc.name, func(t *testing.T) {
  214. req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
  215. req.Header.Set("Content-Type", "application/json")
  216. defer func() { capturedRequest = nil }()
  217. resp := httptest.NewRecorder()
  218. router.ServeHTTP(resp, req)
  219. var errResp ErrorResponse
  220. if resp.Code != http.StatusOK {
  221. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  222. t.Fatal(err)
  223. }
  224. }
  225. if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
  226. t.Fatal("requests did not match")
  227. }
  228. if !reflect.DeepEqual(tc.err, errResp) {
  229. t.Fatal("errors did not match")
  230. }
  231. })
  232. }
  233. }
  234. func TestCompletionsMiddleware(t *testing.T) {
  235. type testCase struct {
  236. name string
  237. body string
  238. req api.GenerateRequest
  239. err ErrorResponse
  240. }
  241. var capturedRequest *api.GenerateRequest
  242. testCases := []testCase{
  243. {
  244. name: "completions handler",
  245. body: `{
  246. "model": "test-model",
  247. "prompt": "Hello",
  248. "temperature": 0.8,
  249. "stop": ["\n", "stop"],
  250. "suffix": "suffix"
  251. }`,
  252. req: api.GenerateRequest{
  253. Model: "test-model",
  254. Prompt: "Hello",
  255. Options: map[string]any{
  256. "frequency_penalty": 0.0,
  257. "presence_penalty": 0.0,
  258. "temperature": 0.8,
  259. "top_p": 1.0,
  260. "stop": []any{"\n", "stop"},
  261. },
  262. Suffix: "suffix",
  263. Stream: &False,
  264. },
  265. },
  266. {
  267. name: "completions handler error forwarding",
  268. body: `{
  269. "model": "test-model",
  270. "prompt": "Hello",
  271. "temperature": null,
  272. "stop": [1, 2],
  273. "suffix": "suffix"
  274. }`,
  275. err: ErrorResponse{
  276. Error: Error{
  277. Message: "invalid type for 'stop' field: float64",
  278. Type: "invalid_request_error",
  279. },
  280. },
  281. },
  282. }
  283. endpoint := func(c *gin.Context) {
  284. c.Status(http.StatusOK)
  285. }
  286. gin.SetMode(gin.TestMode)
  287. router := gin.New()
  288. router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
  289. router.Handle(http.MethodPost, "/api/generate", endpoint)
  290. for _, tc := range testCases {
  291. t.Run(tc.name, func(t *testing.T) {
  292. req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
  293. req.Header.Set("Content-Type", "application/json")
  294. resp := httptest.NewRecorder()
  295. router.ServeHTTP(resp, req)
  296. var errResp ErrorResponse
  297. if resp.Code != http.StatusOK {
  298. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  299. t.Fatal(err)
  300. }
  301. }
  302. if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
  303. t.Fatal("requests did not match")
  304. }
  305. if !reflect.DeepEqual(tc.err, errResp) {
  306. t.Fatal("errors did not match")
  307. }
  308. capturedRequest = nil
  309. })
  310. }
  311. }
  312. func TestEmbeddingsMiddleware(t *testing.T) {
  313. type testCase struct {
  314. name string
  315. body string
  316. req api.EmbedRequest
  317. err ErrorResponse
  318. }
  319. var capturedRequest *api.EmbedRequest
  320. testCases := []testCase{
  321. {
  322. name: "embed handler single input",
  323. body: `{
  324. "input": "Hello",
  325. "model": "test-model"
  326. }`,
  327. req: api.EmbedRequest{
  328. Input: "Hello",
  329. Model: "test-model",
  330. },
  331. },
  332. {
  333. name: "embed handler batch input",
  334. body: `{
  335. "input": ["Hello", "World"],
  336. "model": "test-model"
  337. }`,
  338. req: api.EmbedRequest{
  339. Input: []any{"Hello", "World"},
  340. Model: "test-model",
  341. },
  342. },
  343. {
  344. name: "embed handler error forwarding",
  345. body: `{
  346. "model": "test-model"
  347. }`,
  348. err: ErrorResponse{
  349. Error: Error{
  350. Message: "invalid input",
  351. Type: "invalid_request_error",
  352. },
  353. },
  354. },
  355. }
  356. endpoint := func(c *gin.Context) {
  357. c.Status(http.StatusOK)
  358. }
  359. gin.SetMode(gin.TestMode)
  360. router := gin.New()
  361. router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
  362. router.Handle(http.MethodPost, "/api/embed", endpoint)
  363. for _, tc := range testCases {
  364. t.Run(tc.name, func(t *testing.T) {
  365. req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
  366. req.Header.Set("Content-Type", "application/json")
  367. resp := httptest.NewRecorder()
  368. router.ServeHTTP(resp, req)
  369. var errResp ErrorResponse
  370. if resp.Code != http.StatusOK {
  371. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  372. t.Fatal(err)
  373. }
  374. }
  375. if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
  376. t.Fatal("requests did not match")
  377. }
  378. if !reflect.DeepEqual(tc.err, errResp) {
  379. t.Fatal("errors did not match")
  380. }
  381. capturedRequest = nil
  382. })
  383. }
  384. }
  385. func TestListMiddleware(t *testing.T) {
  386. type testCase struct {
  387. name string
  388. endpoint func(c *gin.Context)
  389. resp string
  390. }
  391. testCases := []testCase{
  392. {
  393. name: "list handler",
  394. endpoint: func(c *gin.Context) {
  395. c.JSON(http.StatusOK, api.ListResponse{
  396. Models: []api.ListModelResponse{
  397. {
  398. Name: "test-model",
  399. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  400. },
  401. },
  402. })
  403. },
  404. resp: `{
  405. "object": "list",
  406. "data": [
  407. {
  408. "id": "test-model",
  409. "object": "model",
  410. "created": 1686935002,
  411. "owned_by": "library"
  412. }
  413. ]
  414. }`,
  415. },
  416. {
  417. name: "list handler empty output",
  418. endpoint: func(c *gin.Context) {
  419. c.JSON(http.StatusOK, api.ListResponse{})
  420. },
  421. resp: `{
  422. "object": "list",
  423. "data": null
  424. }`,
  425. },
  426. }
  427. gin.SetMode(gin.TestMode)
  428. for _, tc := range testCases {
  429. router := gin.New()
  430. router.Use(ListMiddleware())
  431. router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
  432. req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
  433. resp := httptest.NewRecorder()
  434. router.ServeHTTP(resp, req)
  435. var expected, actual map[string]any
  436. err := json.Unmarshal([]byte(tc.resp), &expected)
  437. if err != nil {
  438. t.Fatalf("failed to unmarshal expected response: %v", err)
  439. }
  440. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  441. if err != nil {
  442. t.Fatalf("failed to unmarshal actual response: %v", err)
  443. }
  444. if !reflect.DeepEqual(expected, actual) {
  445. t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
  446. }
  447. }
  448. }
  449. func TestRetrieveMiddleware(t *testing.T) {
  450. type testCase struct {
  451. name string
  452. endpoint func(c *gin.Context)
  453. resp string
  454. }
  455. testCases := []testCase{
  456. {
  457. name: "retrieve handler",
  458. endpoint: func(c *gin.Context) {
  459. c.JSON(http.StatusOK, api.ShowResponse{
  460. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  461. })
  462. },
  463. resp: `{
  464. "id":"test-model",
  465. "object":"model",
  466. "created":1686935002,
  467. "owned_by":"library"}
  468. `,
  469. },
  470. {
  471. name: "retrieve handler error forwarding",
  472. endpoint: func(c *gin.Context) {
  473. c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
  474. },
  475. resp: `{
  476. "error": {
  477. "code": null,
  478. "message": "model not found",
  479. "param": null,
  480. "type": "api_error"
  481. }
  482. }`,
  483. },
  484. }
  485. gin.SetMode(gin.TestMode)
  486. for _, tc := range testCases {
  487. router := gin.New()
  488. router.Use(RetrieveMiddleware())
  489. router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
  490. req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
  491. resp := httptest.NewRecorder()
  492. router.ServeHTTP(resp, req)
  493. var expected, actual map[string]any
  494. err := json.Unmarshal([]byte(tc.resp), &expected)
  495. if err != nil {
  496. t.Fatalf("failed to unmarshal expected response: %v", err)
  497. }
  498. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  499. if err != nil {
  500. t.Fatalf("failed to unmarshal actual response: %v", err)
  501. }
  502. if !reflect.DeepEqual(expected, actual) {
  503. t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
  504. }
  505. }
  506. }