server_test.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. package registry
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "io"
  7. "io/fs"
  8. "net"
  9. "net/http"
  10. "net/http/httptest"
  11. "os"
  12. "regexp"
  13. "strings"
  14. "sync"
  15. "testing"
  16. "github.com/ollama/ollama/server/internal/cache/blob"
  17. "github.com/ollama/ollama/server/internal/client/ollama"
  18. "github.com/ollama/ollama/server/internal/testutil"
  19. "golang.org/x/tools/txtar"
  20. _ "embed"
  21. )
  22. type panicTransport struct{}
  23. func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
  24. panic("unexpected RoundTrip call")
  25. }
  26. var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
  27. // bytesResetter is an interface for types that can be reset and return a byte
  28. // slice, only. This is to prevent inadvertent use of bytes.Buffer.Read/Write
  29. // etc for the purpose of checking logs.
  30. type bytesResetter interface {
  31. Bytes() []byte
  32. Reset()
  33. }
  34. func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
  35. t.Helper()
  36. dir := t.TempDir()
  37. err := os.CopyFS(dir, os.DirFS("testdata/models"))
  38. if err != nil {
  39. t.Fatal(err)
  40. }
  41. c, err := blob.Open(dir)
  42. if err != nil {
  43. t.Fatal(err)
  44. }
  45. client := panicOnRoundTrip
  46. if upstreamRegistry != nil {
  47. s := httptest.NewTLSServer(upstreamRegistry)
  48. t.Cleanup(s.Close)
  49. tr := s.Client().Transport.(*http.Transport).Clone()
  50. tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
  51. var d net.Dialer
  52. return d.DialContext(ctx, "tcp", s.Listener.Addr().String())
  53. }
  54. client = &http.Client{Transport: tr}
  55. }
  56. rc := &ollama.Registry{
  57. Cache: c,
  58. HTTPClient: client,
  59. Mask: "example.com/library/_:latest",
  60. }
  61. l := &Local{
  62. Client: rc,
  63. Logger: testutil.Slogger(t),
  64. }
  65. return l
  66. }
  67. func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
  68. t.Helper()
  69. req := httptest.NewRequestWithContext(t.Context(), method, path, strings.NewReader(body))
  70. return s.sendRequest(t, req)
  71. }
  72. func (s *Local) sendRequest(t *testing.T, req *http.Request) *httptest.ResponseRecorder {
  73. t.Helper()
  74. w := httptest.NewRecorder()
  75. s.ServeHTTP(w, req)
  76. return w
  77. }
  78. type invalidReader struct{}
  79. func (r *invalidReader) Read(p []byte) (int, error) {
  80. return 0, os.ErrInvalid
  81. }
  82. // captureLogs is a helper to capture logs from the server. It returns a
  83. // shallow copy of the server with a new logger and a bytesResetter for the
  84. // logs.
  85. func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
  86. t.Helper()
  87. log, logs := testutil.SlogBuffer()
  88. l := *s // shallow copy
  89. l.Logger = log
  90. return &l, logs
  91. }
  92. func TestServerDelete(t *testing.T) {
  93. check := testutil.Checker(t)
  94. s := newTestServer(t, nil)
  95. _, err := s.Client.ResolveLocal("smol")
  96. check(err)
  97. got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
  98. if got.Code != 200 {
  99. t.Fatalf("Code = %d; want 200", got.Code)
  100. }
  101. _, err = s.Client.ResolveLocal("smol")
  102. if err == nil {
  103. t.Fatal("expected smol to have been deleted")
  104. }
  105. got = s.send(t, "DELETE", "/api/delete", `!`)
  106. checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
  107. got = s.send(t, "GET", "/api/delete", `{"model": "smol"}`)
  108. checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
  109. got = s.send(t, "DELETE", "/api/delete", ``)
  110. checkErrorResponse(t, got, 400, "bad_request", "empty request body")
  111. got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
  112. checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
  113. got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
  114. checkErrorResponse(t, got, 404, "not_found", "not found")
  115. s, logs := captureLogs(t, s)
  116. req := httptest.NewRequestWithContext(t.Context(), "DELETE", "/api/delete", &invalidReader{})
  117. got = s.sendRequest(t, req)
  118. checkErrorResponse(t, got, 500, "internal_error", "internal server error")
  119. ok, err := regexp.Match(`ERROR.*error="invalid argument"`, logs.Bytes())
  120. check(err)
  121. if !ok {
  122. t.Logf("logs:\n%s", logs)
  123. t.Fatalf("expected log to contain ERROR with invalid argument")
  124. }
  125. }
  126. //go:embed testdata/registry.txt
  127. var registryTXT []byte
  128. var registryFS = sync.OnceValue(func() fs.FS {
  129. // Txtar gets hung up on \r\n line endings, so we need to convert them
  130. // to \n when parsing the txtar on Windows.
  131. data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
  132. a := txtar.Parse(data)
  133. fsys, err := txtar.FS(a)
  134. if err != nil {
  135. panic(err)
  136. }
  137. return fsys
  138. })
  139. func TestServerPull(t *testing.T) {
  140. modelsHandler := http.FileServerFS(registryFS())
  141. s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
  142. switch r.URL.Path {
  143. case "/v2/library/BOOM/manifests/latest":
  144. w.WriteHeader(999)
  145. io.WriteString(w, `{"error": "boom"}`)
  146. case "/v2/library/unknown/manifests/latest":
  147. w.WriteHeader(404)
  148. io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
  149. default:
  150. t.Logf("serving blob: %s", r.URL.Path)
  151. modelsHandler.ServeHTTP(w, r)
  152. }
  153. })
  154. checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
  155. t.Helper()
  156. if got.Code != 200 {
  157. t.Errorf("Code = %d; want 200", got.Code)
  158. }
  159. gotlines := got.Body.String()
  160. t.Logf("got:\n%s", gotlines)
  161. for want := range strings.Lines(wantlines) {
  162. want = strings.TrimSpace(want)
  163. want, unwanted := strings.CutPrefix(want, "!")
  164. want = strings.TrimSpace(want)
  165. if !unwanted && !strings.Contains(gotlines, want) {
  166. t.Errorf("! missing %q in body", want)
  167. }
  168. if unwanted && strings.Contains(gotlines, want) {
  169. t.Errorf("! unexpected %q in body", want)
  170. }
  171. }
  172. }
  173. got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
  174. checkResponse(got, `
  175. {"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
  176. `)
  177. got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
  178. checkResponse(got, `
  179. {"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
  180. {"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
  181. {"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
  182. {"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
  183. `)
  184. got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
  185. checkResponse(got, `
  186. {"status":"error: model \"unknown\" not found"}
  187. `)
  188. got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`)
  189. checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
  190. got = s.send(t, "POST", "/api/pull", `!`)
  191. checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
  192. got = s.send(t, "POST", "/api/pull", ``)
  193. checkErrorResponse(t, got, 400, "bad_request", "empty request body")
  194. got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
  195. checkResponse(got, `
  196. {"status":"error: invalid or missing name: \"\""}
  197. `)
  198. // Non-streaming pulls
  199. got = s.send(t, "POST", "/api/pull", `{"model": "://", "stream": false}`)
  200. checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
  201. got = s.send(t, "POST", "/api/pull", `{"model": "smol", "stream": false}`)
  202. checkResponse(got, `
  203. {"status":"success"}
  204. !digest
  205. !total
  206. !completed
  207. `)
  208. got = s.send(t, "POST", "/api/pull", `{"model": "unknown", "stream": false}`)
  209. checkErrorResponse(t, got, 404, "not_found", "model not found")
  210. }
  211. func TestServerUnknownPath(t *testing.T) {
  212. s := newTestServer(t, nil)
  213. got := s.send(t, "DELETE", "/api/unknown", `{}`)
  214. checkErrorResponse(t, got, 404, "not_found", "not found")
  215. var fellback bool
  216. s.Fallback = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  217. fellback = true
  218. })
  219. got = s.send(t, "DELETE", "/api/unknown", `{}`)
  220. if !fellback {
  221. t.Fatal("expected Fallback to be called")
  222. }
  223. if got.Code != 200 {
  224. t.Fatalf("Code = %d; want 200", got.Code)
  225. }
  226. }
  227. func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
  228. t.Helper()
  229. var printedBody bool
  230. errorf := func(format string, args ...any) {
  231. t.Helper()
  232. if !printedBody {
  233. t.Logf("BODY:\n%s", got.Body.String())
  234. printedBody = true
  235. }
  236. t.Errorf(format, args...)
  237. }
  238. if got.Code != status {
  239. errorf("Code = %d; want %d", got.Code, status)
  240. }
  241. // unmarshal the error as *ollama.Error (proving *serverError is an *ollama.Error)
  242. var e *ollama.Error
  243. if err := json.Unmarshal(got.Body.Bytes(), &e); err != nil {
  244. errorf("unmarshal error: %v", err)
  245. t.FailNow()
  246. }
  247. if e.Code != code {
  248. errorf("Code = %q; want %q", e.Code, code)
  249. }
  250. if !strings.Contains(e.Message, msg) {
  251. errorf("Message = %q; want to contain %q", e.Message, msg)
  252. }
  253. }