openai_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. package openai
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/json"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/ollama/ollama/api"
  14. "github.com/stretchr/testify/assert"
  15. )
  16. const prefix = `data:image/jpeg;base64,`
  17. const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
  18. const imageURL = prefix + image
  19. func TestMiddlewareRequests(t *testing.T) {
  20. type testCase struct {
  21. Name string
  22. Method string
  23. Path string
  24. Handler func() gin.HandlerFunc
  25. Setup func(t *testing.T, req *http.Request)
  26. Expected func(t *testing.T, req *http.Request)
  27. }
  28. var capturedRequest *http.Request
  29. captureRequestMiddleware := func() gin.HandlerFunc {
  30. return func(c *gin.Context) {
  31. bodyBytes, _ := io.ReadAll(c.Request.Body)
  32. c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  33. capturedRequest = c.Request
  34. c.Next()
  35. }
  36. }
  37. testCases := []testCase{
  38. {
  39. Name: "chat handler",
  40. Method: http.MethodPost,
  41. Path: "/api/chat",
  42. Handler: ChatMiddleware,
  43. Setup: func(t *testing.T, req *http.Request) {
  44. body := ChatCompletionRequest{
  45. Model: "test-model",
  46. Messages: []Message{{Role: "user", Content: "Hello"}},
  47. }
  48. bodyBytes, _ := json.Marshal(body)
  49. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  50. req.Header.Set("Content-Type", "application/json")
  51. },
  52. Expected: func(t *testing.T, req *http.Request) {
  53. var chatReq api.ChatRequest
  54. if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
  55. t.Fatal(err)
  56. }
  57. if chatReq.Messages[0].Role != "user" {
  58. t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
  59. }
  60. if chatReq.Messages[0].Content != "Hello" {
  61. t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
  62. }
  63. },
  64. },
  65. {
  66. Name: "completions handler",
  67. Method: http.MethodPost,
  68. Path: "/api/generate",
  69. Handler: CompletionsMiddleware,
  70. Setup: func(t *testing.T, req *http.Request) {
  71. temp := float32(0.8)
  72. body := CompletionRequest{
  73. Model: "test-model",
  74. Prompt: "Hello",
  75. Temperature: &temp,
  76. Stop: []string{"\n", "stop"},
  77. Suffix: "suffix",
  78. }
  79. bodyBytes, _ := json.Marshal(body)
  80. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  81. req.Header.Set("Content-Type", "application/json")
  82. },
  83. Expected: func(t *testing.T, req *http.Request) {
  84. var genReq api.GenerateRequest
  85. if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
  86. t.Fatal(err)
  87. }
  88. if genReq.Prompt != "Hello" {
  89. t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
  90. }
  91. if genReq.Options["temperature"] != 1.6 {
  92. t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
  93. }
  94. stopTokens, ok := genReq.Options["stop"].([]any)
  95. if !ok {
  96. t.Fatalf("expected stop tokens to be a list")
  97. }
  98. if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
  99. t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
  100. }
  101. if genReq.Suffix != "suffix" {
  102. t.Fatalf("expected 'suffix', got %s", genReq.Suffix)
  103. }
  104. },
  105. },
  106. {
  107. Name: "chat handler with image content",
  108. Method: http.MethodPost,
  109. Path: "/api/chat",
  110. Handler: ChatMiddleware,
  111. Setup: func(t *testing.T, req *http.Request) {
  112. body := ChatCompletionRequest{
  113. Model: "test-model",
  114. Messages: []Message{
  115. {
  116. Role: "user", Content: []map[string]any{
  117. {"type": "text", "text": "Hello"},
  118. {"type": "image_url", "image_url": map[string]string{"url": imageURL}},
  119. },
  120. },
  121. },
  122. }
  123. bodyBytes, _ := json.Marshal(body)
  124. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  125. req.Header.Set("Content-Type", "application/json")
  126. },
  127. Expected: func(t *testing.T, req *http.Request) {
  128. var chatReq api.ChatRequest
  129. if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
  130. t.Fatal(err)
  131. }
  132. if chatReq.Messages[0].Role != "user" {
  133. t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
  134. }
  135. if chatReq.Messages[0].Content != "Hello" {
  136. t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
  137. }
  138. img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
  139. if chatReq.Messages[1].Role != "user" {
  140. t.Fatalf("expected 'user', got %s", chatReq.Messages[1].Role)
  141. }
  142. if !bytes.Equal(chatReq.Messages[1].Images[0], img) {
  143. t.Fatalf("expected image encoding, got %s", chatReq.Messages[1].Images[0])
  144. }
  145. },
  146. },
  147. {
  148. Name: "embed handler single input",
  149. Method: http.MethodPost,
  150. Path: "/api/embed",
  151. Handler: EmbeddingsMiddleware,
  152. Setup: func(t *testing.T, req *http.Request) {
  153. body := EmbedRequest{
  154. Input: "Hello",
  155. Model: "test-model",
  156. }
  157. bodyBytes, _ := json.Marshal(body)
  158. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  159. req.Header.Set("Content-Type", "application/json")
  160. },
  161. Expected: func(t *testing.T, req *http.Request) {
  162. var embedReq api.EmbedRequest
  163. if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
  164. t.Fatal(err)
  165. }
  166. if embedReq.Input != "Hello" {
  167. t.Fatalf("expected 'Hello', got %s", embedReq.Input)
  168. }
  169. if embedReq.Model != "test-model" {
  170. t.Fatalf("expected 'test-model', got %s", embedReq.Model)
  171. }
  172. },
  173. },
  174. {
  175. Name: "embed handler batch input",
  176. Method: http.MethodPost,
  177. Path: "/api/embed",
  178. Handler: EmbeddingsMiddleware,
  179. Setup: func(t *testing.T, req *http.Request) {
  180. body := EmbedRequest{
  181. Input: []string{"Hello", "World"},
  182. Model: "test-model",
  183. }
  184. bodyBytes, _ := json.Marshal(body)
  185. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  186. req.Header.Set("Content-Type", "application/json")
  187. },
  188. Expected: func(t *testing.T, req *http.Request) {
  189. var embedReq api.EmbedRequest
  190. if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
  191. t.Fatal(err)
  192. }
  193. input, ok := embedReq.Input.([]any)
  194. if !ok {
  195. t.Fatalf("expected input to be a list")
  196. }
  197. if input[0].(string) != "Hello" {
  198. t.Fatalf("expected 'Hello', got %s", input[0])
  199. }
  200. if input[1].(string) != "World" {
  201. t.Fatalf("expected 'World', got %s", input[1])
  202. }
  203. if embedReq.Model != "test-model" {
  204. t.Fatalf("expected 'test-model', got %s", embedReq.Model)
  205. }
  206. },
  207. },
  208. }
  209. gin.SetMode(gin.TestMode)
  210. router := gin.New()
  211. endpoint := func(c *gin.Context) {
  212. c.Status(http.StatusOK)
  213. }
  214. for _, tc := range testCases {
  215. t.Run(tc.Name, func(t *testing.T) {
  216. router = gin.New()
  217. router.Use(captureRequestMiddleware())
  218. router.Use(tc.Handler())
  219. router.Handle(tc.Method, tc.Path, endpoint)
  220. req, _ := http.NewRequest(tc.Method, tc.Path, nil)
  221. if tc.Setup != nil {
  222. tc.Setup(t, req)
  223. }
  224. resp := httptest.NewRecorder()
  225. router.ServeHTTP(resp, req)
  226. tc.Expected(t, capturedRequest)
  227. })
  228. }
  229. }
  230. func TestMiddlewareResponses(t *testing.T) {
  231. type testCase struct {
  232. Name string
  233. Method string
  234. Path string
  235. TestPath string
  236. Handler func() gin.HandlerFunc
  237. Endpoint func(c *gin.Context)
  238. Setup func(t *testing.T, req *http.Request)
  239. Expected func(t *testing.T, resp *httptest.ResponseRecorder)
  240. }
  241. testCases := []testCase{
  242. {
  243. Name: "completions handler error forwarding",
  244. Method: http.MethodPost,
  245. Path: "/api/generate",
  246. TestPath: "/api/generate",
  247. Handler: CompletionsMiddleware,
  248. Endpoint: func(c *gin.Context) {
  249. c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
  250. },
  251. Setup: func(t *testing.T, req *http.Request) {
  252. body := CompletionRequest{
  253. Model: "test-model",
  254. Prompt: "Hello",
  255. }
  256. bodyBytes, _ := json.Marshal(body)
  257. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  258. req.Header.Set("Content-Type", "application/json")
  259. },
  260. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  261. if resp.Code != http.StatusBadRequest {
  262. t.Fatalf("expected 400, got %d", resp.Code)
  263. }
  264. if !strings.Contains(resp.Body.String(), `"invalid request"`) {
  265. t.Fatalf("error was not forwarded")
  266. }
  267. },
  268. },
  269. {
  270. Name: "list handler",
  271. Method: http.MethodGet,
  272. Path: "/api/tags",
  273. TestPath: "/api/tags",
  274. Handler: ListMiddleware,
  275. Endpoint: func(c *gin.Context) {
  276. c.JSON(http.StatusOK, api.ListResponse{
  277. Models: []api.ListModelResponse{
  278. {
  279. Name: "Test Model",
  280. },
  281. },
  282. })
  283. },
  284. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  285. assert.Equal(t, http.StatusOK, resp.Code)
  286. var listResp ListCompletion
  287. if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
  288. t.Fatal(err)
  289. }
  290. if listResp.Object != "list" {
  291. t.Fatalf("expected list, got %s", listResp.Object)
  292. }
  293. if len(listResp.Data) != 1 {
  294. t.Fatalf("expected 1, got %d", len(listResp.Data))
  295. }
  296. if listResp.Data[0].Id != "Test Model" {
  297. t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
  298. }
  299. },
  300. },
  301. {
  302. Name: "retrieve model",
  303. Method: http.MethodGet,
  304. Path: "/api/show/:model",
  305. TestPath: "/api/show/test-model",
  306. Handler: RetrieveMiddleware,
  307. Endpoint: func(c *gin.Context) {
  308. c.JSON(http.StatusOK, api.ShowResponse{
  309. ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
  310. })
  311. },
  312. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  313. var retrieveResp Model
  314. if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
  315. t.Fatal(err)
  316. }
  317. if retrieveResp.Object != "model" {
  318. t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
  319. }
  320. if retrieveResp.Id != "test-model" {
  321. t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
  322. }
  323. },
  324. },
  325. }
  326. gin.SetMode(gin.TestMode)
  327. router := gin.New()
  328. for _, tc := range testCases {
  329. t.Run(tc.Name, func(t *testing.T) {
  330. router = gin.New()
  331. router.Use(tc.Handler())
  332. router.Handle(tc.Method, tc.Path, tc.Endpoint)
  333. req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
  334. if tc.Setup != nil {
  335. tc.Setup(t, req)
  336. }
  337. resp := httptest.NewRecorder()
  338. router.ServeHTTP(resp, req)
  339. tc.Expected(t, resp)
  340. })
  341. }
  342. }