server_test.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. package registry
  2. import (
  3. "encoding/json"
  4. "net/http"
  5. "net/http/httptest"
  6. "os"
  7. "regexp"
  8. "strings"
  9. "testing"
  10. "github.com/ollama/ollama/server/internal/cache/blob"
  11. "github.com/ollama/ollama/server/internal/client/ollama"
  12. "github.com/ollama/ollama/server/internal/testutil"
  13. )
  14. type panicTransport struct{}
  15. func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
  16. panic("unexpected RoundTrip call")
  17. }
  18. var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
  19. // bytesResetter is an interface for types that can be reset and return a byte
  20. // slice, only. This is to prevent inadvertent use of bytes.Buffer.Read/Write
  21. // etc for the purpose of checking logs.
  22. type bytesResetter interface {
  23. Bytes() []byte
  24. Reset()
  25. }
  26. func newTestServer(t *testing.T) *Local {
  27. t.Helper()
  28. dir := t.TempDir()
  29. err := os.CopyFS(dir, os.DirFS("testdata/models"))
  30. if err != nil {
  31. t.Fatal(err)
  32. }
  33. c, err := blob.Open(dir)
  34. if err != nil {
  35. t.Fatal(err)
  36. }
  37. rc := &ollama.Registry{
  38. Cache: c,
  39. HTTPClient: panicOnRoundTrip,
  40. }
  41. l := &Local{
  42. Client: rc,
  43. Logger: testutil.Slogger(t),
  44. }
  45. return l
  46. }
  47. func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
  48. t.Helper()
  49. req := httptest.NewRequestWithContext(t.Context(), method, path, strings.NewReader(body))
  50. return s.sendRequest(t, req)
  51. }
  52. func (s *Local) sendRequest(t *testing.T, req *http.Request) *httptest.ResponseRecorder {
  53. t.Helper()
  54. w := httptest.NewRecorder()
  55. s.ServeHTTP(w, req)
  56. return w
  57. }
  58. type invalidReader struct{}
  59. func (r *invalidReader) Read(p []byte) (int, error) {
  60. return 0, os.ErrInvalid
  61. }
  62. // captureLogs is a helper to capture logs from the server. It returns a
  63. // shallow copy of the server with a new logger and a bytesResetter for the
  64. // logs.
  65. func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
  66. t.Helper()
  67. log, logs := testutil.SlogBuffer()
  68. l := *s // shallow copy
  69. l.Logger = log
  70. return &l, logs
  71. }
  72. func TestServerDelete(t *testing.T) {
  73. check := testutil.Checker(t)
  74. s := newTestServer(t)
  75. _, err := s.Client.ResolveLocal("smol")
  76. check(err)
  77. got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
  78. if got.Code != 200 {
  79. t.Fatalf("Code = %d; want 200", got.Code)
  80. }
  81. _, err = s.Client.ResolveLocal("smol")
  82. if err == nil {
  83. t.Fatal("expected smol to have been deleted")
  84. }
  85. got = s.send(t, "DELETE", "/api/delete", `!`)
  86. checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
  87. got = s.send(t, "GET", "/api/delete", `{"model": "smol"}`)
  88. checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
  89. got = s.send(t, "DELETE", "/api/delete", ``)
  90. checkErrorResponse(t, got, 400, "bad_request", "empty request body")
  91. got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
  92. checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
  93. got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
  94. checkErrorResponse(t, got, 404, "not_found", "not found")
  95. s, logs := captureLogs(t, s)
  96. req := httptest.NewRequestWithContext(t.Context(), "DELETE", "/api/delete", &invalidReader{})
  97. got = s.sendRequest(t, req)
  98. checkErrorResponse(t, got, 500, "internal_error", "internal server error")
  99. ok, err := regexp.Match(`ERROR.*error="invalid argument"`, logs.Bytes())
  100. check(err)
  101. if !ok {
  102. t.Logf("logs:\n%s", logs)
  103. t.Fatalf("expected log to contain ERROR with invalid argument")
  104. }
  105. }
  106. func TestServerUnknownPath(t *testing.T) {
  107. s := newTestServer(t)
  108. got := s.send(t, "DELETE", "/api/unknown", `{}`)
  109. checkErrorResponse(t, got, 404, "not_found", "not found")
  110. }
  111. func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
  112. t.Helper()
  113. var printedBody bool
  114. errorf := func(format string, args ...any) {
  115. t.Helper()
  116. if !printedBody {
  117. t.Logf("BODY:\n%s", got.Body.String())
  118. printedBody = true
  119. }
  120. t.Errorf(format, args...)
  121. }
  122. if got.Code != status {
  123. errorf("Code = %d; want %d", got.Code, status)
  124. }
  125. // unmarshal the error as *ollama.Error (proving *serverError is an *ollama.Error)
  126. var e *ollama.Error
  127. if err := json.Unmarshal(got.Body.Bytes(), &e); err != nil {
  128. errorf("unmarshal error: %v", err)
  129. t.FailNow()
  130. }
  131. if e.Code != code {
  132. errorf("Code = %q; want %q", e.Code, code)
  133. }
  134. if !strings.Contains(e.Message, msg) {
  135. errorf("Message = %q; want to contain %q", e.Message, msg)
  136. }
  137. }