server_test.go 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. package registry
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "io/fs"
  9. "net"
  10. "net/http"
  11. "net/http/httptest"
  12. "os"
  13. "regexp"
  14. "strings"
  15. "sync"
  16. "testing"
  17. "github.com/ollama/ollama/server/internal/cache/blob"
  18. "github.com/ollama/ollama/server/internal/client/ollama"
  19. "github.com/ollama/ollama/server/internal/testutil"
  20. "golang.org/x/tools/txtar"
  21. _ "embed"
  22. )
  23. type panicTransport struct{}
  24. func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
  25. panic("unexpected RoundTrip call")
  26. }
  27. var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
  28. // bytesResetter is an interface for types that can be reset and return a byte
  29. // slice, only. This is to prevent inadvertent use of bytes.Buffer.Read/Write
  30. // etc for the purpose of checking logs.
  31. type bytesResetter interface {
  32. Bytes() []byte
  33. Reset()
  34. }
  35. func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
  36. t.Helper()
  37. dir := t.TempDir()
  38. err := os.CopyFS(dir, os.DirFS("testdata/models"))
  39. if err != nil {
  40. t.Fatal(err)
  41. }
  42. c, err := blob.Open(dir)
  43. if err != nil {
  44. t.Fatal(err)
  45. }
  46. client := panicOnRoundTrip
  47. if upstreamRegistry != nil {
  48. s := httptest.NewTLSServer(upstreamRegistry)
  49. t.Cleanup(s.Close)
  50. tr := s.Client().Transport.(*http.Transport).Clone()
  51. tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
  52. var d net.Dialer
  53. return d.DialContext(ctx, "tcp", s.Listener.Addr().String())
  54. }
  55. client = &http.Client{Transport: tr}
  56. }
  57. rc := &ollama.Registry{
  58. Cache: c,
  59. HTTPClient: client,
  60. Mask: "example.com/library/_:latest",
  61. }
  62. l := &Local{
  63. Client: rc,
  64. Logger: testutil.Slogger(t),
  65. }
  66. return l
  67. }
  68. func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
  69. t.Helper()
  70. req := httptest.NewRequestWithContext(t.Context(), method, path, strings.NewReader(body))
  71. return s.sendRequest(t, req)
  72. }
  73. func (s *Local) sendRequest(t *testing.T, req *http.Request) *httptest.ResponseRecorder {
  74. t.Helper()
  75. w := httptest.NewRecorder()
  76. s.ServeHTTP(w, req)
  77. return w
  78. }
  79. type invalidReader struct{}
  80. func (r *invalidReader) Read(p []byte) (int, error) {
  81. return 0, os.ErrInvalid
  82. }
  83. // captureLogs is a helper to capture logs from the server. It returns a
  84. // shallow copy of the server with a new logger and a bytesResetter for the
  85. // logs.
  86. func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
  87. t.Helper()
  88. log, logs := testutil.SlogBuffer()
  89. l := *s // shallow copy
  90. l.Logger = log
  91. return &l, logs
  92. }
  93. func TestServerDelete(t *testing.T) {
  94. check := testutil.Checker(t)
  95. s := newTestServer(t, nil)
  96. _, err := s.Client.ResolveLocal("smol")
  97. check(err)
  98. got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
  99. if got.Code != 200 {
  100. t.Fatalf("Code = %d; want 200", got.Code)
  101. }
  102. _, err = s.Client.ResolveLocal("smol")
  103. if err == nil {
  104. t.Fatal("expected smol to have been deleted")
  105. }
  106. got = s.send(t, "DELETE", "/api/delete", `!`)
  107. checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
  108. got = s.send(t, "GET", "/api/delete", `{"model": "smol"}`)
  109. checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
  110. got = s.send(t, "DELETE", "/api/delete", ``)
  111. checkErrorResponse(t, got, 400, "bad_request", "empty request body")
  112. got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
  113. checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
  114. got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
  115. checkErrorResponse(t, got, 404, "not_found", "not found")
  116. s, logs := captureLogs(t, s)
  117. req := httptest.NewRequestWithContext(t.Context(), "DELETE", "/api/delete", &invalidReader{})
  118. got = s.sendRequest(t, req)
  119. checkErrorResponse(t, got, 500, "internal_error", "internal server error")
  120. ok, err := regexp.Match(`ERROR.*error="invalid argument"`, logs.Bytes())
  121. check(err)
  122. if !ok {
  123. t.Logf("logs:\n%s", logs)
  124. t.Fatalf("expected log to contain ERROR with invalid argument")
  125. }
  126. }
  127. //go:embed testdata/registry.txt
  128. var registryTXT []byte
  129. var registryFS = sync.OnceValue(func() fs.FS {
  130. // Txtar gets hung up on \r\n line endings, so we need to convert them
  131. // to \n when parsing the txtar on Windows.
  132. data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
  133. a := txtar.Parse(data)
  134. fmt.Printf("%q\n", a.Comment)
  135. fsys, err := txtar.FS(a)
  136. if err != nil {
  137. panic(err)
  138. }
  139. return fsys
  140. })
  141. func TestServerPull(t *testing.T) {
  142. modelsHandler := http.FileServerFS(registryFS())
  143. s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
  144. switch r.URL.Path {
  145. case "/v2/library/BOOM/manifests/latest":
  146. w.WriteHeader(999)
  147. io.WriteString(w, `{"error": "boom"}`)
  148. case "/v2/library/unknown/manifests/latest":
  149. w.WriteHeader(404)
  150. io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
  151. default:
  152. t.Logf("serving file: %s", r.URL.Path)
  153. modelsHandler.ServeHTTP(w, r)
  154. }
  155. })
  156. checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
  157. t.Helper()
  158. if got.Code != 200 {
  159. t.Fatalf("Code = %d; want 200", got.Code)
  160. }
  161. gotlines := got.Body.String()
  162. t.Logf("got:\n%s", gotlines)
  163. for want := range strings.Lines(wantlines) {
  164. want = strings.TrimSpace(want)
  165. want, unwanted := strings.CutPrefix(want, "!")
  166. want = strings.TrimSpace(want)
  167. if !unwanted && !strings.Contains(gotlines, want) {
  168. t.Fatalf("! missing %q in body", want)
  169. }
  170. if unwanted && strings.Contains(gotlines, want) {
  171. t.Fatalf("! unexpected %q in body", want)
  172. }
  173. }
  174. }
  175. got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
  176. checkResponse(got, `
  177. {"status":"pulling manifest"}
  178. {"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
  179. `)
  180. got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
  181. checkResponse(got, `
  182. {"status":"pulling manifest"}
  183. {"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
  184. {"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
  185. {"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
  186. {"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
  187. {"status":"verifying layers"}
  188. {"status":"writing manifest"}
  189. {"status":"success"}
  190. `)
  191. got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
  192. checkResponse(got, `
  193. {"status":"pulling manifest"}
  194. {"status":"error: model \"unknown\" not found"}
  195. `)
  196. got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`)
  197. checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
  198. got = s.send(t, "POST", "/api/pull", `!`)
  199. checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
  200. got = s.send(t, "POST", "/api/pull", ``)
  201. checkErrorResponse(t, got, 400, "bad_request", "empty request body")
  202. got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
  203. checkResponse(got, `
  204. {"status":"pulling manifest"}
  205. {"status":"error: invalid or missing name: \"\""}
  206. !verifying
  207. !writing
  208. !success
  209. `)
  210. }
  211. func TestServerUnknownPath(t *testing.T) {
  212. s := newTestServer(t, nil)
  213. got := s.send(t, "DELETE", "/api/unknown", `{}`)
  214. checkErrorResponse(t, got, 404, "not_found", "not found")
  215. }
  216. func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
  217. t.Helper()
  218. var printedBody bool
  219. errorf := func(format string, args ...any) {
  220. t.Helper()
  221. if !printedBody {
  222. t.Logf("BODY:\n%s", got.Body.String())
  223. printedBody = true
  224. }
  225. t.Errorf(format, args...)
  226. }
  227. if got.Code != status {
  228. errorf("Code = %d; want %d", got.Code, status)
  229. }
  230. // unmarshal the error as *ollama.Error (proving *serverError is an *ollama.Error)
  231. var e *ollama.Error
  232. if err := json.Unmarshal(got.Body.Bytes(), &e); err != nil {
  233. errorf("unmarshal error: %v", err)
  234. t.FailNow()
  235. }
  236. if e.Code != code {
  237. errorf("Code = %q; want %q", e.Code, code)
  238. }
  239. if !strings.Contains(e.Message, msg) {
  240. errorf("Message = %q; want to contain %q", e.Message, msg)
  241. }
  242. }