openai_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. package openai
  2. import (
  3. "encoding/base64"
  4. "encoding/json"
  5. "io"
  6. "net/http"
  7. "net/http/httptest"
  8. "strings"
  9. "testing"
  10. "time"
  11. "github.com/gin-gonic/gin"
  12. "github.com/google/go-cmp/cmp"
  13. "github.com/ollama/ollama/api"
  14. )
  15. func capture(req any) gin.HandlerFunc {
  16. return func(c *gin.Context) {
  17. body, _ := io.ReadAll(c.Request.Body)
  18. _ = json.Unmarshal(body, req)
  19. c.Next()
  20. }
  21. }
  22. func TestChatMiddleware(t *testing.T) {
  23. type test struct {
  24. name string
  25. body string
  26. req api.ChatRequest
  27. err ErrorResponse
  28. }
  29. tests := []test{
  30. {
  31. name: "chat handler",
  32. body: `{
  33. "model": "test-model",
  34. "messages": [
  35. {"role": "user", "content": "Hello"}
  36. ]
  37. }`,
  38. req: api.ChatRequest{
  39. Model: "test-model",
  40. Messages: []api.Message{
  41. {
  42. Role: "user",
  43. Content: "Hello",
  44. },
  45. },
  46. Options: map[string]any{
  47. "temperature": 1.0,
  48. "top_p": 1.0,
  49. },
  50. Stream: func() *bool { f := false; return &f }(),
  51. },
  52. },
  53. {
  54. name: "chat handler with large context",
  55. body: `{
  56. "model": "test-model",
  57. "messages": [
  58. {"role": "user", "content": "Hello"}
  59. ],
  60. "max_tokens": 16384
  61. }`,
  62. req: api.ChatRequest{
  63. Model: "test-model",
  64. Messages: []api.Message{
  65. {
  66. Role: "user",
  67. Content: "Hello",
  68. },
  69. },
  70. Options: map[string]any{
  71. "temperature": 1.0,
  72. "top_p": 1.0,
  73. // TODO (jmorganca): because we use a map[string]any for options
  74. // the values need to be floats for the test comparison to work.
  75. "num_predict": 16384.0,
  76. "num_ctx": 16384.0,
  77. },
  78. Stream: func() *bool { f := false; return &f }(),
  79. },
  80. },
  81. {
  82. name: "chat handler with image content",
  83. body: `{
  84. "model": "test-model",
  85. "messages": [
  86. {
  87. "role": "user",
  88. "content": [
  89. {
  90. "type": "text",
  91. "text": "Hello"
  92. },
  93. {
  94. "type": "image_url",
  95. "image_url": {
  96. "url": "data:image/jpeg;base64,ZGF0YQo="
  97. }
  98. }
  99. ]
  100. }
  101. ]
  102. }`,
  103. req: api.ChatRequest{
  104. Model: "test-model",
  105. Messages: []api.Message{
  106. {
  107. Role: "user",
  108. Content: "Hello",
  109. },
  110. {
  111. Role: "user",
  112. Images: []api.ImageData{
  113. func() []byte {
  114. img, _ := base64.StdEncoding.DecodeString("ZGF0YQo=")
  115. return img
  116. }(),
  117. },
  118. },
  119. },
  120. Options: map[string]any{
  121. "temperature": 1.0,
  122. "top_p": 1.0,
  123. },
  124. Stream: func() *bool { f := false; return &f }(),
  125. },
  126. },
  127. {
  128. name: "chat handler with tools",
  129. body: `{
  130. "model": "test-model",
  131. "messages": [
  132. {"role": "user", "content": "What's the weather like in Paris Today?"},
  133. {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
  134. ]
  135. }`,
  136. req: api.ChatRequest{
  137. Model: "test-model",
  138. Messages: []api.Message{
  139. {
  140. Role: "user",
  141. Content: "What's the weather like in Paris Today?",
  142. },
  143. {
  144. Role: "assistant",
  145. ToolCalls: []api.ToolCall{
  146. {
  147. Function: api.ToolCallFunction{
  148. Name: "get_current_weather",
  149. Arguments: map[string]interface{}{
  150. "location": "Paris, France",
  151. "format": "celsius",
  152. },
  153. },
  154. },
  155. },
  156. },
  157. },
  158. Options: map[string]any{
  159. "temperature": 1.0,
  160. "top_p": 1.0,
  161. },
  162. Stream: func() *bool { f := false; return &f }(),
  163. },
  164. },
  165. {
  166. name: "chat handler error forwarding",
  167. body: `{
  168. "model": "test-model",
  169. "messages": [
  170. {"role": "user", "content": 2}
  171. ]
  172. }`,
  173. err: ErrorResponse{
  174. Error: Error{
  175. Message: "invalid message content type: float64",
  176. Type: "invalid_request_error",
  177. },
  178. },
  179. },
  180. }
  181. gin.SetMode(gin.TestMode)
  182. for _, tt := range tests {
  183. var req api.ChatRequest
  184. router := gin.New()
  185. router.Use(ChatMiddleware(), capture(&req))
  186. router.Handle(http.MethodPost, "/api/chat", func(c *gin.Context) {
  187. c.Status(http.StatusOK)
  188. })
  189. t.Run(tt.name, func(t *testing.T) {
  190. r, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tt.body))
  191. r.Header.Set("Content-Type", "application/json")
  192. resp := httptest.NewRecorder()
  193. router.ServeHTTP(resp, r)
  194. var err ErrorResponse
  195. if resp.Code != http.StatusOK {
  196. if err := json.Unmarshal(resp.Body.Bytes(), &err); err != nil {
  197. t.Fatal(err)
  198. }
  199. }
  200. if diff := cmp.Diff(tt.req, req); diff != "" {
  201. t.Errorf("mismatch (-want +got):\n%s", diff)
  202. }
  203. if diff := cmp.Diff(tt.err, err); diff != "" {
  204. t.Errorf("mismatch (-want +got):\n%s", diff)
  205. }
  206. })
  207. }
  208. }
  209. func TestCompletionsMiddleware(t *testing.T) {
  210. type test struct {
  211. name string
  212. body string
  213. req api.GenerateRequest
  214. err ErrorResponse
  215. }
  216. tests := []test{
  217. {
  218. name: "completions handler",
  219. body: `{
  220. "model": "test-model",
  221. "prompt": "Hello",
  222. "temperature": 0.8,
  223. "stop": ["\n", "stop"],
  224. "suffix": "suffix"
  225. }`,
  226. req: api.GenerateRequest{
  227. Model: "test-model",
  228. Prompt: "Hello",
  229. Options: map[string]any{
  230. "frequency_penalty": 0.0,
  231. "presence_penalty": 0.0,
  232. "temperature": 1.6,
  233. "top_p": 1.0,
  234. "stop": []any{"\n", "stop"},
  235. },
  236. Suffix: "suffix",
  237. Stream: func() *bool { f := false; return &f }(),
  238. },
  239. },
  240. {
  241. name: "completions handler error forwarding",
  242. body: `{
  243. "model": "test-model",
  244. "prompt": "Hello",
  245. "temperature": null,
  246. "stop": [1, 2],
  247. "suffix": "suffix"
  248. }`,
  249. err: ErrorResponse{
  250. Error: Error{
  251. Message: "invalid type for 'stop' field: float64",
  252. Type: "invalid_request_error",
  253. },
  254. },
  255. },
  256. }
  257. gin.SetMode(gin.TestMode)
  258. for _, tt := range tests {
  259. t.Run(tt.name, func(t *testing.T) {
  260. var req api.GenerateRequest
  261. router := gin.New()
  262. router.Use(CompletionsMiddleware(), capture(&req))
  263. router.Handle(http.MethodPost, "/api/generate", func(c *gin.Context) {
  264. c.Status(http.StatusOK)
  265. })
  266. r, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tt.body))
  267. r.Header.Set("Content-Type", "application/json")
  268. res := httptest.NewRecorder()
  269. router.ServeHTTP(res, r)
  270. var errResp ErrorResponse
  271. if res.Code != http.StatusOK {
  272. if err := json.Unmarshal(res.Body.Bytes(), &errResp); err != nil {
  273. t.Fatal(err)
  274. }
  275. }
  276. if !cmp.Equal(tt.req, req) {
  277. t.Fatalf("requests did not match:\n%s", cmp.Diff(tt.req, req))
  278. }
  279. if !cmp.Equal(tt.err, errResp) {
  280. t.Fatalf("errors did not match:\n%s", cmp.Diff(tt.err, errResp))
  281. }
  282. })
  283. }
  284. }
  285. func TestEmbeddingsMiddleware(t *testing.T) {
  286. type test struct {
  287. name string
  288. body string
  289. req api.EmbedRequest
  290. err ErrorResponse
  291. }
  292. tests := []test{
  293. {
  294. name: "embed handler single input",
  295. body: `{
  296. "input": "Hello",
  297. "model": "test-model"
  298. }`,
  299. req: api.EmbedRequest{
  300. Input: "Hello",
  301. Model: "test-model",
  302. },
  303. },
  304. {
  305. name: "embed handler batch input",
  306. body: `{
  307. "input": ["Hello", "World"],
  308. "model": "test-model"
  309. }`,
  310. req: api.EmbedRequest{
  311. Input: []any{"Hello", "World"},
  312. Model: "test-model",
  313. },
  314. },
  315. {
  316. name: "embed handler error forwarding",
  317. body: `{
  318. "model": "test-model"
  319. }`,
  320. err: ErrorResponse{
  321. Error: Error{
  322. Message: "invalid input",
  323. Type: "invalid_request_error",
  324. },
  325. },
  326. },
  327. }
  328. endpoint := func(c *gin.Context) {
  329. c.Status(http.StatusOK)
  330. }
  331. gin.SetMode(gin.TestMode)
  332. for _, tt := range tests {
  333. var req api.EmbedRequest
  334. router := gin.New()
  335. router.Use(EmbeddingsMiddleware(), capture(&req))
  336. router.Handle(http.MethodPost, "/api/embed", endpoint)
  337. t.Run(tt.name, func(t *testing.T) {
  338. r, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tt.body))
  339. r.Header.Set("Content-Type", "application/json")
  340. resp := httptest.NewRecorder()
  341. router.ServeHTTP(resp, r)
  342. var errResp ErrorResponse
  343. if resp.Code != http.StatusOK {
  344. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  345. t.Fatal(err)
  346. }
  347. }
  348. if diff := cmp.Diff(tt.req, req); diff != "" {
  349. t.Errorf("request mismatch (-want +got):\n%s", diff)
  350. }
  351. if diff := cmp.Diff(tt.err, errResp); diff != "" {
  352. t.Errorf("error mismatch (-want +got):\n%s", diff)
  353. }
  354. })
  355. }
  356. }
  357. func TestListMiddleware(t *testing.T) {
  358. type test struct {
  359. name string
  360. handler gin.HandlerFunc
  361. body string
  362. }
  363. tests := []test{
  364. {
  365. name: "list handler",
  366. handler: func(c *gin.Context) {
  367. c.JSON(http.StatusOK, api.ListResponse{
  368. Models: []api.ListModelResponse{
  369. {
  370. Name: "test-model",
  371. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  372. },
  373. },
  374. })
  375. },
  376. body: `{
  377. "object": "list",
  378. "data": [
  379. {
  380. "id": "test-model",
  381. "object": "model",
  382. "created": 1686935002,
  383. "owned_by": "library"
  384. }
  385. ]
  386. }`,
  387. },
  388. {
  389. name: "list handler empty output",
  390. handler: func(c *gin.Context) {
  391. c.JSON(http.StatusOK, api.ListResponse{
  392. Models: []api.ListModelResponse{},
  393. })
  394. },
  395. body: `{
  396. "object": "list",
  397. "data": null
  398. }`,
  399. },
  400. }
  401. gin.SetMode(gin.TestMode)
  402. for _, tt := range tests {
  403. router := gin.New()
  404. router.Use(ListMiddleware())
  405. router.Handle(http.MethodGet, "/api/tags", tt.handler)
  406. req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
  407. resp := httptest.NewRecorder()
  408. router.ServeHTTP(resp, req)
  409. var expected, actual map[string]any
  410. err := json.Unmarshal([]byte(tt.body), &expected)
  411. if err != nil {
  412. t.Fatalf("failed to unmarshal expected response: %v", err)
  413. }
  414. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  415. if err != nil {
  416. t.Fatalf("failed to unmarshal actual response: %v", err)
  417. }
  418. if diff := cmp.Diff(expected, actual); diff != "" {
  419. t.Errorf("responses did not match (-want +got):\n%s", diff)
  420. }
  421. }
  422. }
  423. func TestRetrieveMiddleware(t *testing.T) {
  424. type test struct {
  425. name string
  426. handler gin.HandlerFunc
  427. body string
  428. }
  429. tests := []test{
  430. {
  431. name: "retrieve handler",
  432. handler: func(c *gin.Context) {
  433. c.JSON(http.StatusOK, api.ShowResponse{
  434. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  435. })
  436. },
  437. body: `{
  438. "id":"test-model",
  439. "object":"model",
  440. "created":1686935002,
  441. "owned_by":"library"}
  442. `,
  443. },
  444. {
  445. name: "retrieve handler error forwarding",
  446. handler: func(c *gin.Context) {
  447. c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
  448. },
  449. body: `{
  450. "error": {
  451. "code": null,
  452. "message": "model not found",
  453. "param": null,
  454. "type": "invalid_request_error"
  455. }
  456. }`,
  457. },
  458. }
  459. gin.SetMode(gin.TestMode)
  460. for _, tt := range tests {
  461. router := gin.New()
  462. router.Use(RetrieveMiddleware())
  463. router.Handle(http.MethodGet, "/api/show/:model", tt.handler)
  464. req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
  465. resp := httptest.NewRecorder()
  466. router.ServeHTTP(resp, req)
  467. var expected, actual map[string]any
  468. err := json.Unmarshal([]byte(tt.body), &expected)
  469. if err != nil {
  470. t.Fatalf("failed to unmarshal expected response: %v", err)
  471. }
  472. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  473. if err != nil {
  474. t.Fatalf("failed to unmarshal actual response: %v", err)
  475. }
  476. if diff := cmp.Diff(expected, actual); diff != "" {
  477. t.Errorf("responses did not match (-want +got):\n%s", diff)
  478. }
  479. }
  480. }