openai_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. package openai
  2. import (
  3. "encoding/base64"
  4. "encoding/json"
  5. "io"
  6. "net/http"
  7. "net/http/httptest"
  8. "reflect"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/google/go-cmp/cmp"
  14. "github.com/ollama/ollama/api"
  15. )
  16. func capture(req any) gin.HandlerFunc {
  17. return func(c *gin.Context) {
  18. body, _ := io.ReadAll(c.Request.Body)
  19. json.Unmarshal(body, req)
  20. c.Next()
  21. }
  22. }
  23. func TestChatMiddleware(t *testing.T) {
  24. type test struct {
  25. name string
  26. body string
  27. req api.ChatRequest
  28. err ErrorResponse
  29. }
  30. tests := []test{
  31. {
  32. name: "chat handler",
  33. body: `{
  34. "model": "test-model",
  35. "messages": [
  36. {"role": "user", "content": "Hello"}
  37. ]
  38. }`,
  39. req: api.ChatRequest{
  40. Model: "test-model",
  41. Messages: []api.Message{
  42. {
  43. Role: "user",
  44. Content: "Hello",
  45. },
  46. },
  47. Options: map[string]any{
  48. "temperature": 1.0,
  49. "top_p": 1.0,
  50. },
  51. Stream: func() *bool { f := false; return &f }(),
  52. },
  53. },
  54. {
  55. name: "chat handler with large context",
  56. body: `{
  57. "model": "test-model",
  58. "messages": [
  59. {"role": "user", "content": "Hello"}
  60. ],
  61. "max_tokens": 16384
  62. }`,
  63. req: api.ChatRequest{
  64. Model: "test-model",
  65. Messages: []api.Message{
  66. {
  67. Role: "user",
  68. Content: "Hello",
  69. },
  70. },
  71. Options: map[string]any{
  72. "temperature": 1.0,
  73. "top_p": 1.0,
  74. // TODO (jmorganca): because we use a map[string]any for options
  75. // the values need to be floats for the test comparison to work.
  76. "num_predict": 16384.0,
  77. "num_ctx": 16384.0,
  78. },
  79. Stream: func() *bool { f := false; return &f }(),
  80. },
  81. },
  82. {
  83. name: "chat handler with image content",
  84. body: `{
  85. "model": "test-model",
  86. "messages": [
  87. {
  88. "role": "user",
  89. "content": [
  90. {
  91. "type": "text",
  92. "text": "Hello"
  93. },
  94. {
  95. "type": "image_url",
  96. "image_url": {
  97. "url": "data:image/jpeg;base64,ZGF0YQo="
  98. }
  99. }
  100. ]
  101. }
  102. ]
  103. }`,
  104. req: api.ChatRequest{
  105. Model: "test-model",
  106. Messages: []api.Message{
  107. {
  108. Role: "user",
  109. Content: "Hello",
  110. },
  111. {
  112. Role: "user",
  113. Images: []api.ImageData{
  114. func() []byte {
  115. img, _ := base64.StdEncoding.DecodeString("ZGF0YQo=")
  116. return img
  117. }(),
  118. },
  119. },
  120. },
  121. Options: map[string]any{
  122. "temperature": 1.0,
  123. "top_p": 1.0,
  124. },
  125. Stream: func() *bool { f := false; return &f }(),
  126. },
  127. },
  128. {
  129. name: "chat handler with tools",
  130. body: `{
  131. "model": "test-model",
  132. "messages": [
  133. {"role": "user", "content": "What's the weather like in Paris Today?"},
  134. {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
  135. ]
  136. }`,
  137. req: api.ChatRequest{
  138. Model: "test-model",
  139. Messages: []api.Message{
  140. {
  141. Role: "user",
  142. Content: "What's the weather like in Paris Today?",
  143. },
  144. {
  145. Role: "assistant",
  146. ToolCalls: []api.ToolCall{
  147. {
  148. Function: api.ToolCallFunction{
  149. Name: "get_current_weather",
  150. Arguments: map[string]interface{}{
  151. "location": "Paris, France",
  152. "format": "celsius",
  153. },
  154. },
  155. },
  156. },
  157. },
  158. },
  159. Options: map[string]any{
  160. "temperature": 1.0,
  161. "top_p": 1.0,
  162. },
  163. Stream: func() *bool { f := false; return &f }(),
  164. },
  165. },
  166. {
  167. name: "chat handler error forwarding",
  168. body: `{
  169. "model": "test-model",
  170. "messages": [
  171. {"role": "user", "content": 2}
  172. ]
  173. }`,
  174. err: ErrorResponse{
  175. Error: Error{
  176. Message: "invalid message content type: float64",
  177. Type: "invalid_request_error",
  178. },
  179. },
  180. },
  181. }
  182. gin.SetMode(gin.TestMode)
  183. for _, tt := range tests {
  184. var req api.ChatRequest
  185. router := gin.New()
  186. router.Use(ChatMiddleware(), capture(&req))
  187. router.Handle(http.MethodPost, "/api/chat", func(c *gin.Context) {
  188. c.Status(http.StatusOK)
  189. })
  190. t.Run(tt.name, func(t *testing.T) {
  191. r, _ := http.NewRequest("POST", "/api/chat", strings.NewReader(tt.body))
  192. r.Header.Set("Content-Type", "application/json")
  193. resp := httptest.NewRecorder()
  194. router.ServeHTTP(resp, r)
  195. var err ErrorResponse
  196. if resp.Code != http.StatusOK {
  197. if err := json.Unmarshal(resp.Body.Bytes(), &err); err != nil {
  198. t.Fatal(err)
  199. }
  200. }
  201. if diff := cmp.Diff(tt.req, req); diff != "" {
  202. t.Errorf("mismatch (-want +got):\n%s", diff)
  203. }
  204. if diff := cmp.Diff(tt.err, err); diff != "" {
  205. t.Errorf("mismatch (-want +got):\n%s", diff)
  206. }
  207. })
  208. }
  209. }
  210. func TestCompletionsMiddleware(t *testing.T) {
  211. type test struct {
  212. name string
  213. body string
  214. req api.GenerateRequest
  215. err ErrorResponse
  216. }
  217. tests := []test{
  218. {
  219. name: "completions handler",
  220. body: `{
  221. "model": "test-model",
  222. "prompt": "Hello",
  223. "temperature": 0.8,
  224. "stop": ["\n", "stop"],
  225. "suffix": "suffix"
  226. }`,
  227. req: api.GenerateRequest{
  228. Model: "test-model",
  229. Prompt: "Hello",
  230. Options: map[string]any{
  231. "frequency_penalty": 0.0,
  232. "presence_penalty": 0.0,
  233. "temperature": 1.6,
  234. "top_p": 1.0,
  235. "stop": []any{"\n", "stop"},
  236. },
  237. Suffix: "suffix",
  238. Stream: func() *bool { f := false; return &f }(),
  239. },
  240. },
  241. {
  242. name: "completions handler error forwarding",
  243. body: `{
  244. "model": "test-model",
  245. "prompt": "Hello",
  246. "temperature": null,
  247. "stop": [1, 2],
  248. "suffix": "suffix"
  249. }`,
  250. err: ErrorResponse{
  251. Error: Error{
  252. Message: "invalid type for 'stop' field: float64",
  253. Type: "invalid_request_error",
  254. },
  255. },
  256. },
  257. }
  258. gin.SetMode(gin.TestMode)
  259. for _, tt := range tests {
  260. t.Run(tt.name, func(t *testing.T) {
  261. var req api.GenerateRequest
  262. router := gin.New()
  263. router.Use(CompletionsMiddleware(), capture(&req))
  264. router.Handle(http.MethodPost, "/api/generate", func(c *gin.Context) {
  265. c.Status(http.StatusOK)
  266. })
  267. r, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tt.body))
  268. r.Header.Set("Content-Type", "application/json")
  269. res := httptest.NewRecorder()
  270. router.ServeHTTP(res, r)
  271. var errResp ErrorResponse
  272. if res.Code != http.StatusOK {
  273. if err := json.Unmarshal(res.Body.Bytes(), &errResp); err != nil {
  274. t.Fatal(err)
  275. }
  276. }
  277. if !cmp.Equal(tt.req, req) {
  278. t.Fatalf("requests did not match:\n%s", cmp.Diff(tt.req, req))
  279. }
  280. if !cmp.Equal(tt.err, errResp) {
  281. t.Fatalf("errors did not match:\n%s", cmp.Diff(tt.err, errResp))
  282. }
  283. })
  284. }
  285. }
  286. func TestEmbeddingsMiddleware(t *testing.T) {
  287. type test struct {
  288. name string
  289. body string
  290. req api.EmbedRequest
  291. err ErrorResponse
  292. }
  293. tests := []test{
  294. {
  295. name: "embed handler single input",
  296. body: `{
  297. "input": "Hello",
  298. "model": "test-model"
  299. }`,
  300. req: api.EmbedRequest{
  301. Input: "Hello",
  302. Model: "test-model",
  303. },
  304. },
  305. {
  306. name: "embed handler batch input",
  307. body: `{
  308. "input": ["Hello", "World"],
  309. "model": "test-model"
  310. }`,
  311. req: api.EmbedRequest{
  312. Input: []any{"Hello", "World"},
  313. Model: "test-model",
  314. },
  315. },
  316. {
  317. name: "embed handler error forwarding",
  318. body: `{
  319. "model": "test-model"
  320. }`,
  321. err: ErrorResponse{
  322. Error: Error{
  323. Message: "invalid input",
  324. Type: "invalid_request_error",
  325. },
  326. },
  327. },
  328. }
  329. endpoint := func(c *gin.Context) {
  330. c.Status(http.StatusOK)
  331. }
  332. gin.SetMode(gin.TestMode)
  333. for _, tt := range tests {
  334. var req api.EmbedRequest
  335. router := gin.New()
  336. router.Use(EmbeddingsMiddleware(), capture(&req))
  337. router.Handle(http.MethodPost, "/api/embed", endpoint)
  338. t.Run(tt.name, func(t *testing.T) {
  339. r, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tt.body))
  340. r.Header.Set("Content-Type", "application/json")
  341. resp := httptest.NewRecorder()
  342. router.ServeHTTP(resp, r)
  343. var errResp ErrorResponse
  344. if resp.Code != http.StatusOK {
  345. if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
  346. t.Fatal(err)
  347. }
  348. }
  349. if diff := cmp.Diff(tt.req, req); diff != "" {
  350. t.Errorf("request mismatch (-want +got):\n%s", diff)
  351. }
  352. if diff := cmp.Diff(tt.err, errResp); diff != "" {
  353. t.Errorf("error mismatch (-want +got):\n%s", diff)
  354. }
  355. })
  356. }
  357. }
  358. func TestListMiddleware(t *testing.T) {
  359. type test struct {
  360. name string
  361. handler gin.HandlerFunc
  362. body string
  363. }
  364. tests := []test{
  365. {
  366. name: "list handler",
  367. handler: func(c *gin.Context) {
  368. c.JSON(http.StatusOK, api.ListResponse{
  369. Models: []api.ListModelResponse{
  370. {
  371. Name: "test-model",
  372. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  373. },
  374. }})
  375. },
  376. body: `{
  377. "object": "list",
  378. "data": [
  379. {
  380. "id": "test-model",
  381. "object": "model",
  382. "created": 1686935002,
  383. "owned_by": "library"
  384. }
  385. ]
  386. }`,
  387. },
  388. {
  389. name: "list handler empty output",
  390. handler: func(c *gin.Context) {
  391. c.JSON(http.StatusOK, api.ListResponse{
  392. Models: []api.ListModelResponse{},
  393. })
  394. },
  395. body: `{
  396. "object": "list",
  397. "data": null
  398. }`,
  399. },
  400. }
  401. gin.SetMode(gin.TestMode)
  402. for _, tt := range tests {
  403. router := gin.New()
  404. router.Use(ListMiddleware())
  405. router.Handle(http.MethodGet, "/api/tags", tt.handler)
  406. req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
  407. resp := httptest.NewRecorder()
  408. router.ServeHTTP(resp, req)
  409. var expected, actual map[string]any
  410. err := json.Unmarshal([]byte(tt.body), &expected)
  411. if err != nil {
  412. t.Fatalf("failed to unmarshal expected response: %v", err)
  413. }
  414. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  415. if err != nil {
  416. t.Fatalf("failed to unmarshal actual response: %v", err)
  417. }
  418. if diff := cmp.Diff(expected, actual); diff != "" {
  419. t.Errorf("responses did not match (-want +got):\n%s", diff)
  420. }
  421. }
  422. }
  423. func TestRetrieveMiddleware(t *testing.T) {
  424. type test struct {
  425. name string
  426. handler gin.HandlerFunc
  427. body string
  428. }
  429. tests := []test{
  430. {
  431. name: "retrieve handler",
  432. handler: func(c *gin.Context) {
  433. c.JSON(http.StatusOK, api.ShowResponse{
  434. ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
  435. })
  436. },
  437. body: `{
  438. "id":"test-model",
  439. "object":"model",
  440. "created":1686935002,
  441. "owned_by":"library"}
  442. `,
  443. },
  444. {
  445. name: "retrieve handler error forwarding",
  446. handler: func(c *gin.Context) {
  447. c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
  448. },
  449. body: `{
  450. "error": {
  451. "code": null,
  452. "message": "model not found",
  453. "param": null,
  454. "type": "api_error"
  455. }
  456. }`,
  457. },
  458. }
  459. gin.SetMode(gin.TestMode)
  460. for _, tt := range tests {
  461. router := gin.New()
  462. router.Use(RetrieveMiddleware())
  463. router.Handle(http.MethodGet, "/api/show/:model", tt.handler)
  464. req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
  465. resp := httptest.NewRecorder()
  466. router.ServeHTTP(resp, req)
  467. var expected, actual map[string]any
  468. err := json.Unmarshal([]byte(tt.body), &expected)
  469. if err != nil {
  470. t.Fatalf("failed to unmarshal expected response: %v", err)
  471. }
  472. err = json.Unmarshal(resp.Body.Bytes(), &actual)
  473. if err != nil {
  474. t.Fatalf("failed to unmarshal actual response: %v", err)
  475. }
  476. if !reflect.DeepEqual(expected, actual) {
  477. t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
  478. }
  479. }
  480. }