openai_test.go 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. package openai
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/json"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/ollama/ollama/api"
  14. "github.com/stretchr/testify/assert"
  15. )
  16. const prefix = `data:image/jpeg;base64,`
  17. const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
  18. const imageURL = prefix + image
  19. func TestMiddlewareRequests(t *testing.T) {
  20. type testCase struct {
  21. Name string
  22. Method string
  23. Path string
  24. Handler func() gin.HandlerFunc
  25. Setup func(t *testing.T, req *http.Request)
  26. Expected func(t *testing.T, req *http.Request)
  27. }
  28. var capturedRequest *http.Request
  29. captureRequestMiddleware := func() gin.HandlerFunc {
  30. return func(c *gin.Context) {
  31. bodyBytes, _ := io.ReadAll(c.Request.Body)
  32. c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  33. capturedRequest = c.Request
  34. c.Next()
  35. }
  36. }
  37. testCases := []testCase{
  38. {
  39. Name: "chat handler",
  40. Method: http.MethodPost,
  41. Path: "/api/chat",
  42. Handler: ChatMiddleware,
  43. Setup: func(t *testing.T, req *http.Request) {
  44. body := ChatCompletionRequest{
  45. Model: "test-model",
  46. Messages: []Message{{Role: "user", Content: "Hello"}},
  47. }
  48. bodyBytes, _ := json.Marshal(body)
  49. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  50. req.Header.Set("Content-Type", "application/json")
  51. },
  52. Expected: func(t *testing.T, req *http.Request) {
  53. var chatReq api.ChatRequest
  54. if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
  55. t.Fatal(err)
  56. }
  57. if chatReq.Messages[0].Role != "user" {
  58. t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
  59. }
  60. if chatReq.Messages[0].Content != "Hello" {
  61. t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
  62. }
  63. },
  64. },
  65. {
  66. Name: "completions handler",
  67. Method: http.MethodPost,
  68. Path: "/api/generate",
  69. Handler: CompletionsMiddleware,
  70. Setup: func(t *testing.T, req *http.Request) {
  71. temp := float32(0.8)
  72. body := CompletionRequest{
  73. Model: "test-model",
  74. Prompt: "Hello",
  75. Temperature: &temp,
  76. Stop: []string{"\n", "stop"},
  77. Suffix: "suffix",
  78. }
  79. bodyBytes, _ := json.Marshal(body)
  80. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  81. req.Header.Set("Content-Type", "application/json")
  82. },
  83. Expected: func(t *testing.T, req *http.Request) {
  84. var genReq api.GenerateRequest
  85. if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
  86. t.Fatal(err)
  87. }
  88. if genReq.Prompt != "Hello" {
  89. t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
  90. }
  91. if genReq.Options["temperature"] != 1.6 {
  92. t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
  93. }
  94. stopTokens, ok := genReq.Options["stop"].([]any)
  95. if !ok {
  96. t.Fatalf("expected stop tokens to be a list")
  97. }
  98. if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
  99. t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
  100. }
  101. if genReq.Suffix != "suffix" {
  102. t.Fatalf("expected 'suffix', got %s", genReq.Suffix)
  103. }
  104. },
  105. },
  106. {
  107. Name: "chat handler with image content",
  108. Method: http.MethodPost,
  109. Path: "/api/chat",
  110. Handler: ChatMiddleware,
  111. Setup: func(t *testing.T, req *http.Request) {
  112. body := ChatCompletionRequest{
  113. Model: "test-model",
  114. Messages: []Message{
  115. {
  116. Role: "user", Content: []map[string]any{
  117. {"type": "text", "text": "Hello"},
  118. {"type": "image_url", "image_url": map[string]string{"url": imageURL}},
  119. },
  120. },
  121. },
  122. }
  123. bodyBytes, _ := json.Marshal(body)
  124. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  125. req.Header.Set("Content-Type", "application/json")
  126. },
  127. Expected: func(t *testing.T, req *http.Request) {
  128. var chatReq api.ChatRequest
  129. if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
  130. t.Fatal(err)
  131. }
  132. if chatReq.Messages[0].Role != "user" {
  133. t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
  134. }
  135. if chatReq.Messages[0].Content != "Hello" {
  136. t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
  137. }
  138. img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
  139. if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
  140. t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
  141. }
  142. },
  143. },
  144. {
  145. Name: "embed handler single input",
  146. Method: http.MethodPost,
  147. Path: "/api/embed",
  148. Handler: EmbeddingsMiddleware,
  149. Setup: func(t *testing.T, req *http.Request) {
  150. body := EmbedRequest{
  151. Input: "Hello",
  152. Model: "test-model",
  153. }
  154. bodyBytes, _ := json.Marshal(body)
  155. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  156. req.Header.Set("Content-Type", "application/json")
  157. },
  158. Expected: func(t *testing.T, req *http.Request) {
  159. var embedReq api.EmbedRequest
  160. if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
  161. t.Fatal(err)
  162. }
  163. if embedReq.Input != "Hello" {
  164. t.Fatalf("expected 'Hello', got %s", embedReq.Input)
  165. }
  166. if embedReq.Model != "test-model" {
  167. t.Fatalf("expected 'test-model', got %s", embedReq.Model)
  168. }
  169. },
  170. },
  171. {
  172. Name: "embed handler batch input",
  173. Method: http.MethodPost,
  174. Path: "/api/embed",
  175. Handler: EmbeddingsMiddleware,
  176. Setup: func(t *testing.T, req *http.Request) {
  177. body := EmbedRequest{
  178. Input: []string{"Hello", "World"},
  179. Model: "test-model",
  180. }
  181. bodyBytes, _ := json.Marshal(body)
  182. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  183. req.Header.Set("Content-Type", "application/json")
  184. },
  185. Expected: func(t *testing.T, req *http.Request) {
  186. var embedReq api.EmbedRequest
  187. if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
  188. t.Fatal(err)
  189. }
  190. input, ok := embedReq.Input.([]any)
  191. if !ok {
  192. t.Fatalf("expected input to be a list")
  193. }
  194. if input[0].(string) != "Hello" {
  195. t.Fatalf("expected 'Hello', got %s", input[0])
  196. }
  197. if input[1].(string) != "World" {
  198. t.Fatalf("expected 'World', got %s", input[1])
  199. }
  200. if embedReq.Model != "test-model" {
  201. t.Fatalf("expected 'test-model', got %s", embedReq.Model)
  202. }
  203. },
  204. },
  205. }
  206. gin.SetMode(gin.TestMode)
  207. router := gin.New()
  208. endpoint := func(c *gin.Context) {
  209. c.Status(http.StatusOK)
  210. }
  211. for _, tc := range testCases {
  212. t.Run(tc.Name, func(t *testing.T) {
  213. router = gin.New()
  214. router.Use(captureRequestMiddleware())
  215. router.Use(tc.Handler())
  216. router.Handle(tc.Method, tc.Path, endpoint)
  217. req, _ := http.NewRequest(tc.Method, tc.Path, nil)
  218. if tc.Setup != nil {
  219. tc.Setup(t, req)
  220. }
  221. resp := httptest.NewRecorder()
  222. router.ServeHTTP(resp, req)
  223. tc.Expected(t, capturedRequest)
  224. })
  225. }
  226. }
  227. func TestMiddlewareResponses(t *testing.T) {
  228. type testCase struct {
  229. Name string
  230. Method string
  231. Path string
  232. TestPath string
  233. Handler func() gin.HandlerFunc
  234. Endpoint func(c *gin.Context)
  235. Setup func(t *testing.T, req *http.Request)
  236. Expected func(t *testing.T, resp *httptest.ResponseRecorder)
  237. }
  238. testCases := []testCase{
  239. {
  240. Name: "completions handler error forwarding",
  241. Method: http.MethodPost,
  242. Path: "/api/generate",
  243. TestPath: "/api/generate",
  244. Handler: CompletionsMiddleware,
  245. Endpoint: func(c *gin.Context) {
  246. c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
  247. },
  248. Setup: func(t *testing.T, req *http.Request) {
  249. body := CompletionRequest{
  250. Model: "test-model",
  251. Prompt: "Hello",
  252. }
  253. bodyBytes, _ := json.Marshal(body)
  254. req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
  255. req.Header.Set("Content-Type", "application/json")
  256. },
  257. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  258. if resp.Code != http.StatusBadRequest {
  259. t.Fatalf("expected 400, got %d", resp.Code)
  260. }
  261. if !strings.Contains(resp.Body.String(), `"invalid request"`) {
  262. t.Fatalf("error was not forwarded")
  263. }
  264. },
  265. },
  266. {
  267. Name: "list handler",
  268. Method: http.MethodGet,
  269. Path: "/api/tags",
  270. TestPath: "/api/tags",
  271. Handler: ListMiddleware,
  272. Endpoint: func(c *gin.Context) {
  273. c.JSON(http.StatusOK, api.ListResponse{
  274. Models: []api.ListModelResponse{
  275. {
  276. Name: "Test Model",
  277. },
  278. },
  279. })
  280. },
  281. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  282. assert.Equal(t, http.StatusOK, resp.Code)
  283. var listResp ListCompletion
  284. if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
  285. t.Fatal(err)
  286. }
  287. if listResp.Object != "list" {
  288. t.Fatalf("expected list, got %s", listResp.Object)
  289. }
  290. if len(listResp.Data) != 1 {
  291. t.Fatalf("expected 1, got %d", len(listResp.Data))
  292. }
  293. if listResp.Data[0].Id != "Test Model" {
  294. t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
  295. }
  296. },
  297. },
  298. {
  299. Name: "retrieve model",
  300. Method: http.MethodGet,
  301. Path: "/api/show/:model",
  302. TestPath: "/api/show/test-model",
  303. Handler: RetrieveMiddleware,
  304. Endpoint: func(c *gin.Context) {
  305. c.JSON(http.StatusOK, api.ShowResponse{
  306. ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
  307. })
  308. },
  309. Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
  310. var retrieveResp Model
  311. if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
  312. t.Fatal(err)
  313. }
  314. if retrieveResp.Object != "model" {
  315. t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
  316. }
  317. if retrieveResp.Id != "test-model" {
  318. t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
  319. }
  320. },
  321. },
  322. }
  323. gin.SetMode(gin.TestMode)
  324. router := gin.New()
  325. for _, tc := range testCases {
  326. t.Run(tc.Name, func(t *testing.T) {
  327. router = gin.New()
  328. router.Use(tc.Handler())
  329. router.Handle(tc.Method, tc.Path, tc.Endpoint)
  330. req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
  331. if tc.Setup != nil {
  332. tc.Setup(t, req)
  333. }
  334. resp := httptest.NewRecorder()
  335. router.ServeHTTP(resp, req)
  336. tc.Expected(t, resp)
  337. })
  338. }
  339. }