server_test.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. package registry
  2. import (
  3. "encoding/json"
  4. "net/http"
  5. "net/http/httptest"
  6. "os"
  7. "regexp"
  8. "strings"
  9. "testing"
  10. "github.com/ollama/ollama/server/internal/cache/blob"
  11. "github.com/ollama/ollama/server/internal/client/ollama"
  12. "github.com/ollama/ollama/server/internal/testutil"
  13. )
  14. type panicTransport struct{}
  15. func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
  16. panic("unexpected RoundTrip call")
  17. }
  18. var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
  19. // bytesResetter is an interface for types that can be reset and return a byte
  20. // slice, only. This is to prevent inadvertent use of bytes.Buffer.Read/Write
  21. // etc for the purpose of checking logs.
  22. type bytesResetter interface {
  23. Bytes() []byte
  24. Reset()
  25. }
  26. func newTestServer(t *testing.T) *Local {
  27. t.Helper()
  28. dir := t.TempDir()
  29. err := os.CopyFS(dir, os.DirFS("testdata/models"))
  30. if err != nil {
  31. t.Fatal(err)
  32. }
  33. c, err := blob.Open(dir)
  34. if err != nil {
  35. t.Fatal(err)
  36. }
  37. rc := &ollama.Registry{
  38. HTTPClient: panicOnRoundTrip,
  39. }
  40. l := &Local{
  41. Cache: c,
  42. Client: rc,
  43. Logger: testutil.Slogger(t),
  44. }
  45. return l
  46. }
  47. func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
  48. t.Helper()
  49. req := httptest.NewRequestWithContext(t.Context(), method, path, strings.NewReader(body))
  50. return s.sendRequest(t, req)
  51. }
  52. func (s *Local) sendRequest(t *testing.T, req *http.Request) *httptest.ResponseRecorder {
  53. t.Helper()
  54. w := httptest.NewRecorder()
  55. s.ServeHTTP(w, req)
  56. return w
  57. }
  58. type invalidReader struct{}
  59. func (r *invalidReader) Read(p []byte) (int, error) {
  60. return 0, os.ErrInvalid
  61. }
  62. // captureLogs is a helper to capture logs from the server. It returns a
  63. // shallow copy of the server with a new logger and a bytesResetter for the
  64. // logs.
  65. func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
  66. t.Helper()
  67. log, logs := testutil.SlogBuffer()
  68. l := *s // shallow copy
  69. l.Logger = log
  70. return &l, logs
  71. }
  72. func TestServerDelete(t *testing.T) {
  73. check := testutil.Checker(t)
  74. s := newTestServer(t)
  75. _, err := s.Client.ResolveLocal(s.Cache, "smol")
  76. check(err)
  77. got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
  78. if got.Code != 200 {
  79. t.Fatalf("Code = %d; want 200", got.Code)
  80. }
  81. _, err = s.Client.ResolveLocal(s.Cache, "smol")
  82. if err == nil {
  83. t.Fatal("expected smol to have been deleted")
  84. }
  85. got = s.send(t, "DELETE", "/api/delete", `!`)
  86. checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
  87. got = s.send(t, "GET", "/api/delete", `{"model": "smol"}`)
  88. checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
  89. got = s.send(t, "DELETE", "/api/delete", ``)
  90. checkErrorResponse(t, got, 400, "bad_request", "empty request body")
  91. got = s.send(t, "DELETE", "/api/delete", `{"model": "!"}`)
  92. checkErrorResponse(t, got, 404, "manifest_not_found", "not found")
  93. got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
  94. checkErrorResponse(t, got, 400, "bad_request", "invalid name")
  95. got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
  96. checkErrorResponse(t, got, 404, "not_found", "not found")
  97. s, logs := captureLogs(t, s)
  98. req := httptest.NewRequestWithContext(t.Context(), "DELETE", "/api/delete", &invalidReader{})
  99. got = s.sendRequest(t, req)
  100. checkErrorResponse(t, got, 500, "internal_error", "internal server error")
  101. ok, err := regexp.Match(`ERROR.*error="invalid argument"`, logs.Bytes())
  102. check(err)
  103. if !ok {
  104. t.Logf("logs:\n%s", logs)
  105. t.Fatalf("expected log to contain ERROR with invalid argument")
  106. }
  107. }
  108. func TestServerUnknownPath(t *testing.T) {
  109. s := newTestServer(t)
  110. got := s.send(t, "DELETE", "/api/unknown", `{}`)
  111. checkErrorResponse(t, got, 404, "not_found", "not found")
  112. }
  113. func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
  114. t.Helper()
  115. var printedBody bool
  116. errorf := func(format string, args ...any) {
  117. t.Helper()
  118. if !printedBody {
  119. t.Logf("BODY:\n%s", got.Body.String())
  120. printedBody = true
  121. }
  122. t.Errorf(format, args...)
  123. }
  124. if got.Code != status {
  125. errorf("Code = %d; want %d", got.Code, status)
  126. }
  127. // unmarshal the error as *ollama.Error (proving *serverError is an *ollama.Error)
  128. var e *ollama.Error
  129. if err := json.Unmarshal(got.Body.Bytes(), &e); err != nil {
  130. errorf("unmarshal error: %v", err)
  131. t.FailNow()
  132. }
  133. if e.Code != code {
  134. errorf("Code = %q; want %q", e.Code, code)
  135. }
  136. if !strings.Contains(e.Message, msg) {
  137. errorf("Message = %q; want to contain %q", e.Message, msg)
  138. }
  139. }