routes_test.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "net/http/httptest"
  10. "os"
  11. "strings"
  12. "testing"
  13. "github.com/stretchr/testify/assert"
  14. "github.com/jmorganca/ollama/api"
  15. "github.com/jmorganca/ollama/parser"
  16. )
  17. func setupServer(t *testing.T) (*Server, error) {
  18. t.Helper()
  19. return NewServer()
  20. }
  21. func Test_Routes(t *testing.T) {
  22. type testCase struct {
  23. Name string
  24. Method string
  25. Path string
  26. Setup func(t *testing.T, req *http.Request)
  27. Expected func(t *testing.T, resp *http.Response)
  28. }
  29. var tempModelFile string
  30. createTestModel := func(t *testing.T, name string) {
  31. f, err := os.CreateTemp("", "ollama-model")
  32. assert.Nil(t, err)
  33. defer os.RemoveAll(f.Name())
  34. modelfile := strings.NewReader(fmt.Sprintf("FROM %s", f.Name()))
  35. commands, err := parser.Parse(modelfile)
  36. assert.Nil(t, err)
  37. fn := func(resp api.ProgressResponse) {}
  38. err = CreateModel(context.TODO(), name, "", commands, fn)
  39. assert.Nil(t, err)
  40. }
  41. testCases := []testCase{
  42. {
  43. Name: "Version Handler",
  44. Method: http.MethodGet,
  45. Path: "/api/version",
  46. Setup: func(t *testing.T, req *http.Request) {
  47. },
  48. Expected: func(t *testing.T, resp *http.Response) {
  49. contentType := resp.Header.Get("Content-Type")
  50. assert.Equal(t, contentType, "application/json; charset=utf-8")
  51. body, err := io.ReadAll(resp.Body)
  52. assert.Nil(t, err)
  53. assert.Equal(t, `{"version":"0.0.0"}`, string(body))
  54. },
  55. },
  56. {
  57. Name: "Tags Handler (no tags)",
  58. Method: http.MethodGet,
  59. Path: "/api/tags",
  60. Expected: func(t *testing.T, resp *http.Response) {
  61. contentType := resp.Header.Get("Content-Type")
  62. assert.Equal(t, contentType, "application/json; charset=utf-8")
  63. body, err := io.ReadAll(resp.Body)
  64. assert.Nil(t, err)
  65. var modelList api.ListResponse
  66. err = json.Unmarshal(body, &modelList)
  67. assert.Nil(t, err)
  68. assert.Equal(t, 0, len(modelList.Models))
  69. },
  70. },
  71. {
  72. Name: "Tags Handler (yes tags)",
  73. Method: http.MethodGet,
  74. Path: "/api/tags",
  75. Setup: func(t *testing.T, req *http.Request) {
  76. createTestModel(t, "test-model")
  77. },
  78. Expected: func(t *testing.T, resp *http.Response) {
  79. contentType := resp.Header.Get("Content-Type")
  80. assert.Equal(t, contentType, "application/json; charset=utf-8")
  81. body, err := io.ReadAll(resp.Body)
  82. assert.Nil(t, err)
  83. var modelList api.ListResponse
  84. err = json.Unmarshal(body, &modelList)
  85. assert.Nil(t, err)
  86. assert.Equal(t, 1, len(modelList.Models))
  87. assert.Equal(t, modelList.Models[0].Name, "test-model:latest")
  88. },
  89. },
  90. {
  91. Name: "Create Model Handler",
  92. Method: http.MethodPost,
  93. Path: "/api/create",
  94. Setup: func(t *testing.T, req *http.Request) {
  95. f, err := os.CreateTemp("", "ollama-model")
  96. assert.Nil(t, err)
  97. tempModelFile = f.Name()
  98. stream := false
  99. createReq := api.CreateRequest{
  100. Name: "t-bone",
  101. Modelfile: fmt.Sprintf("FROM %s", f.Name()),
  102. Stream: &stream,
  103. }
  104. jsonData, err := json.Marshal(createReq)
  105. assert.Nil(t, err)
  106. req.Body = io.NopCloser(bytes.NewReader(jsonData))
  107. },
  108. Expected: func(t *testing.T, resp *http.Response) {
  109. os.RemoveAll(tempModelFile)
  110. contentType := resp.Header.Get("Content-Type")
  111. assert.Equal(t, "application/json", contentType)
  112. _, err := io.ReadAll(resp.Body)
  113. assert.Nil(t, err)
  114. assert.Equal(t, resp.StatusCode, 200)
  115. model, err := GetModel("t-bone")
  116. assert.Nil(t, err)
  117. assert.Equal(t, "t-bone:latest", model.ShortName)
  118. },
  119. },
  120. {
  121. Name: "Copy Model Handler",
  122. Method: http.MethodPost,
  123. Path: "/api/copy",
  124. Setup: func(t *testing.T, req *http.Request) {
  125. createTestModel(t, "hamshank")
  126. copyReq := api.CopyRequest{
  127. Source: "hamshank",
  128. Destination: "beefsteak",
  129. }
  130. jsonData, err := json.Marshal(copyReq)
  131. assert.Nil(t, err)
  132. req.Body = io.NopCloser(bytes.NewReader(jsonData))
  133. },
  134. Expected: func(t *testing.T, resp *http.Response) {
  135. model, err := GetModel("beefsteak")
  136. assert.Nil(t, err)
  137. assert.Equal(t, "beefsteak:latest", model.ShortName)
  138. },
  139. },
  140. }
  141. s, err := setupServer(t)
  142. assert.Nil(t, err)
  143. router := s.GenerateRoutes()
  144. httpSrv := httptest.NewServer(router)
  145. t.Cleanup(httpSrv.Close)
  146. workDir, err := os.MkdirTemp("", "ollama-test")
  147. assert.Nil(t, err)
  148. defer os.RemoveAll(workDir)
  149. os.Setenv("OLLAMA_MODELS", workDir)
  150. for _, tc := range testCases {
  151. u := httpSrv.URL + tc.Path
  152. req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
  153. assert.Nil(t, err)
  154. if tc.Setup != nil {
  155. tc.Setup(t, req)
  156. }
  157. resp, err := httpSrv.Client().Do(req)
  158. defer resp.Body.Close()
  159. assert.Nil(t, err)
  160. if tc.Expected != nil {
  161. tc.Expected(t, resp)
  162. }
  163. }
  164. }