openai_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. package openai
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/json"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "reflect"
  10. "strings"
  11. "testing"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. "github.com/ollama/ollama/api"
  15. )
  16. const (
  17. prefix = `data:image/jpeg;base64,`
  18. image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
  19. )
  20. var False = false
  21. func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
  22. return func(c *gin.Context) {
  23. bodyBytes, _ := io.ReadAll(c.Request.Body)
  24. c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  25. err := json.Unmarshal(bodyBytes, capturedRequest)
  26. if err != nil {
  27. c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
  28. }
  29. c.Next()
  30. }
  31. }
  32. func TestChatMiddleware(t *testing.T) {
  33. type testCase struct {
  34. name string
  35. body string
  36. req api.ChatRequest
  37. err ErrorResponse
  38. }
  39. var capturedRequest *api.ChatRequest
  40. testCases := []testCase{
  41. {
  42. name: "chat handler",
  43. body: `{
  44. "model": "test-model",
  45. "messages": [
  46. {"role": "user", "content": "Hello"}
  47. ]
  48. }`,
  49. req: api.ChatRequest{
  50. Model: "test-model",
  51. Messages: []api.Message{
  52. {
  53. Role: "user",
  54. Content: "Hello",
  55. },
  56. },
  57. Options: map[string]any{
  58. "temperature": 1.0,
  59. "top_p": 1.0,
  60. },
  61. Stream: &False,
  62. },
  63. },
  64. {
  65. name: "chat handler with image content",
  66. body: `{
  67. "model": "test-model",
  68. "messages": [
  69. {
  70. "role": "user",
  71. "content": [
  72. {
  73. "type": "text",
  74. "text": "Hello"
  75. },
  76. {
  77. "type": "image_url",
  78. "image_url": {
  79. "url": "` + prefix + image + `"
  80. }
  81. }
  82. ]
  83. }
  84. ]
  85. }`,
  86. req: api.ChatRequest{
  87. Model: "test-model",
  88. Messages: []api.Message{
  89. {
  90. Role: "user",
  91. Content: "Hello",
  92. },
  93. {
  94. Role: "user",
  95. Images: []api.ImageData{
  96. func() []byte {
  97. img, _ := base64.StdEncoding.DecodeString(image)
  98. return img
  99. }(),
  100. },
  101. },
  102. },
  103. Options: map[string]any{
  104. "temperature": 1.0,
  105. "top_p": 1.0,
  106. },
  107. Stream: &False,
  108. },
  109. },
  110. {
  111. name: "chat handler with tools",
  112. body: `{
  113. "model": "test-model",
  114. "messages": [
  115. {"role": "user", "content": "What's the weather like in Paris Today?"},
  116. {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
  117. ]
  118. }`,
  119. req: api.ChatRequest{
  120. Model: "test-model",
  121. Messages: []api.Message{
  122. {
  123. Role: "user",
  124. Content: "What's the weather like in Paris Today?",
  125. },
  126. {
  127. Role: "assistant",
  128. ToolCalls: []api.ToolCall{
  129. {
  130. Function: api.ToolCallFunction{
  131. Name: "get_current_weather",
  132. Arguments: map[string]interface{}{
  133. "location": "Paris, France",
  134. "format": "celsius",
  135. },
  136. },
  137. },
  138. },
  139. },
  140. },
  141. Options: map[string]any{
  142. "temperature": 1.0,
  143. "top_p": 1.0,
  144. },
  145. Stream: &False,
  146. },
  147. },
  148. {
  149. name: "chat handler error forwarding",
  150. body: `{
  151. "model": "test-model",
  152. "messages": [
  153. {"role": "user", "content": 2}
  154. ]
  155. }`,
  156. err: ErrorResponse{
  157. Error: Error{
  158. Message: "invalid message content type: float64",
  159. Type: "invalid_request_error",
  160. },
  161. },
  162. },
  163. }
  164. endpoint := func(c *gin.Context) {
  165. c.Status(http.StatusOK)
  166. }
  167. gin.SetMode(gin.TestMode)
  168. router := gin.New()
  169. router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
  170. router.Handle(http.MethodPost, "/api/chat", endpoint)
  171. for _, tc := range testCases {
  172. t.Run(tc.name, func(t *testing.T) {
  173. req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
  174. req.Header.Set("Content-Type", "application/json")
  175. resp := httptest.NewRecorder()
  176. router.ServeHTTP(resp, req)
  177. var errResp ErrorResponse
  178. if resp.Code != http.StatusOK {
  179. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  180. t.Fatal(err)
  181. }
  182. }
  183. if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
  184. t.Fatal("requests did not match")
  185. }
  186. if !reflect.DeepEqual(tc.err, errResp) {
  187. t.Fatal("errors did not match")
  188. }
  189. capturedRequest = nil
  190. })
  191. }
  192. }
  193. func TestCompletionsMiddleware(t *testing.T) {
  194. type testCase struct {
  195. name string
  196. body string
  197. req api.GenerateRequest
  198. err ErrorResponse
  199. }
  200. var capturedRequest *api.GenerateRequest
  201. testCases := []testCase{
  202. {
  203. name: "completions handler",
  204. body: `{
  205. "model": "test-model",
  206. "prompt": "Hello",
  207. "temperature": 0.8,
  208. "stop": ["\n", "stop"],
  209. "suffix": "suffix"
  210. }`,
  211. req: api.GenerateRequest{
  212. Model: "test-model",
  213. Prompt: "Hello",
  214. Options: map[string]any{
  215. "frequency_penalty": 0.0,
  216. "presence_penalty": 0.0,
  217. "temperature": 1.6,
  218. "top_p": 1.0,
  219. "stop": []any{"\n", "stop"},
  220. },
  221. Suffix: "suffix",
  222. Stream: &False,
  223. },
  224. },
  225. {
  226. name: "completions handler error forwarding",
  227. body: `{
  228. "model": "test-model",
  229. "prompt": "Hello",
  230. "temperature": null,
  231. "stop": [1, 2],
  232. "suffix": "suffix"
  233. }`,
  234. err: ErrorResponse{
  235. Error: Error{
  236. Message: "invalid type for 'stop' field: float64",
  237. Type: "invalid_request_error",
  238. },
  239. },
  240. },
  241. }
  242. endpoint := func(c *gin.Context) {
  243. c.Status(http.StatusOK)
  244. }
  245. gin.SetMode(gin.TestMode)
  246. router := gin.New()
  247. router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
  248. router.Handle(http.MethodPost, "/api/generate", endpoint)
  249. for _, tc := range testCases {
  250. t.Run(tc.name, func(t *testing.T) {
  251. req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
  252. req.Header.Set("Content-Type", "application/json")
  253. resp := httptest.NewRecorder()
  254. router.ServeHTTP(resp, req)
  255. var errResp ErrorResponse
  256. if resp.Code != http.StatusOK {
  257. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  258. t.Fatal(err)
  259. }
  260. }
  261. if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
  262. t.Fatal("requests did not match")
  263. }
  264. if !reflect.DeepEqual(tc.err, errResp) {
  265. t.Fatal("errors did not match")
  266. }
  267. capturedRequest = nil
  268. })
  269. }
  270. }
  271. func TestEmbeddingsMiddleware(t *testing.T) {
  272. type testCase struct {
  273. name string
  274. body string
  275. req api.EmbedRequest
  276. err ErrorResponse
  277. }
  278. var capturedRequest *api.EmbedRequest
  279. testCases := []testCase{
  280. {
  281. name: "embed handler single input",
  282. body: `{
  283. "input": "Hello",
  284. "model": "test-model"
  285. }`,
  286. req: api.EmbedRequest{
  287. Input: "Hello",
  288. Model: "test-model",
  289. },
  290. },
  291. {
  292. name: "embed handler batch input",
  293. body: `{
  294. "input": ["Hello", "World"],
  295. "model": "test-model"
  296. }`,
  297. req: api.EmbedRequest{
  298. Input: []any{"Hello", "World"},
  299. Model: "test-model",
  300. },
  301. },
  302. {
  303. name: "embed handler error forwarding",
  304. body: `{
  305. "model": "test-model"
  306. }`,
  307. err: ErrorResponse{
  308. Error: Error{
  309. Message: "invalid input",
  310. Type: "invalid_request_error",
  311. },
  312. },
  313. },
  314. }
  315. endpoint := func(c *gin.Context) {
  316. c.Status(http.StatusOK)
  317. }
  318. gin.SetMode(gin.TestMode)
  319. router := gin.New()
  320. router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
  321. router.Handle(http.MethodPost, "/api/embed", endpoint)
  322. for _, tc := range testCases {
  323. t.Run(tc.name, func(t *testing.T) {
  324. req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
  325. req.Header.Set("Content-Type", "application/json")
  326. resp := httptest.NewRecorder()
  327. router.ServeHTTP(resp, req)
  328. var errResp ErrorResponse
  329. if resp.Code != http.StatusOK {
  330. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  331. t.Fatal(err)
  332. }
  333. }
  334. if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
  335. t.Fatal("requests did not match")
  336. }
  337. if !reflect.DeepEqual(tc.err, errResp) {
  338. t.Fatal("errors did not match")
  339. }
  340. capturedRequest = nil
  341. })
  342. }
  343. }
  344. func TestListMiddleware(t *testing.T) {
  345. type testCase struct {
  346. name string
  347. endpoint func(c *gin.Context)
  348. resp string
  349. }
  350. testCases := []testCase{
  351. {
  352. name: "list handler",
  353. endpoint: func(c *gin.Context) {
  354. c.JSON(http.StatusOK, api.ListResponse{
  355. Models: []api.ListModelResponse{
  356. {
  357. Name: "test-model",
  358. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  359. },
  360. },
  361. })
  362. },
  363. resp: `{
  364. "object": "list",
  365. "data": [
  366. {
  367. "id": "test-model",
  368. "object": "model",
  369. "created": 1686935002,
  370. "owned_by": "library"
  371. }
  372. ]
  373. }`,
  374. },
  375. {
  376. name: "list handler empty output",
  377. endpoint: func(c *gin.Context) {
  378. c.JSON(http.StatusOK, api.ListResponse{})
  379. },
  380. resp: `{
  381. "object": "list",
  382. "data": null
  383. }`,
  384. },
  385. }
  386. gin.SetMode(gin.TestMode)
  387. for _, tc := range testCases {
  388. router := gin.New()
  389. router.Use(ListMiddleware())
  390. router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
  391. req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
  392. resp := httptest.NewRecorder()
  393. router.ServeHTTP(resp, req)
  394. var expected, actual map[string]any
  395. err := json.Unmarshal([]byte(tc.resp), &expected)
  396. if err != nil {
  397. t.Fatalf("failed to unmarshal expected response: %v", err)
  398. }
  399. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  400. if err != nil {
  401. t.Fatalf("failed to unmarshal actual response: %v", err)
  402. }
  403. if !reflect.DeepEqual(expected, actual) {
  404. t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
  405. }
  406. }
  407. }
  408. func TestRetrieveMiddleware(t *testing.T) {
  409. type testCase struct {
  410. name string
  411. endpoint func(c *gin.Context)
  412. resp string
  413. }
  414. testCases := []testCase{
  415. {
  416. name: "retrieve handler",
  417. endpoint: func(c *gin.Context) {
  418. c.JSON(http.StatusOK, api.ShowResponse{
  419. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  420. })
  421. },
  422. resp: `{
  423. "id":"test-model",
  424. "object":"model",
  425. "created":1686935002,
  426. "owned_by":"library"}
  427. `,
  428. },
  429. {
  430. name: "retrieve handler error forwarding",
  431. endpoint: func(c *gin.Context) {
  432. c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
  433. },
  434. resp: `{
  435. "error": {
  436. "code": null,
  437. "message": "model not found",
  438. "param": null,
  439. "type": "api_error"
  440. }
  441. }`,
  442. },
  443. }
  444. gin.SetMode(gin.TestMode)
  445. for _, tc := range testCases {
  446. router := gin.New()
  447. router.Use(RetrieveMiddleware())
  448. router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
  449. req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
  450. resp := httptest.NewRecorder()
  451. router.ServeHTTP(resp, req)
  452. var expected, actual map[string]any
  453. err := json.Unmarshal([]byte(tc.resp), &expected)
  454. if err != nil {
  455. t.Fatalf("failed to unmarshal expected response: %v", err)
  456. }
  457. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  458. if err != nil {
  459. t.Fatalf("failed to unmarshal actual response: %v", err)
  460. }
  461. if !reflect.DeepEqual(expected, actual) {
  462. t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
  463. }
  464. }
  465. }