123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- package registry
- import (
- "encoding/json"
- "net/http"
- "net/http/httptest"
- "os"
- "regexp"
- "strings"
- "testing"
- "github.com/ollama/ollama/server/internal/cache/blob"
- "github.com/ollama/ollama/server/internal/client/ollama"
- "github.com/ollama/ollama/server/internal/testutil"
- )
- type panicTransport struct{}
- func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
- panic("unexpected RoundTrip call")
- }
- var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
- // bytesResetter is an interface for types that can be reset and return a byte
- // slice, only. This is to prevent inadvertent use of bytes.Buffer.Read/Write
- // etc for the purpose of checking logs.
- type bytesResetter interface {
- Bytes() []byte
- Reset()
- }
- func newTestServer(t *testing.T) *Local {
- t.Helper()
- dir := t.TempDir()
- err := os.CopyFS(dir, os.DirFS("testdata/models"))
- if err != nil {
- t.Fatal(err)
- }
- c, err := blob.Open(dir)
- if err != nil {
- t.Fatal(err)
- }
- rc := &ollama.Registry{
- Cache: c,
- HTTPClient: panicOnRoundTrip,
- }
- l := &Local{
- Client: rc,
- Logger: testutil.Slogger(t),
- }
- return l
- }
- func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
- t.Helper()
- req := httptest.NewRequestWithContext(t.Context(), method, path, strings.NewReader(body))
- return s.sendRequest(t, req)
- }
- func (s *Local) sendRequest(t *testing.T, req *http.Request) *httptest.ResponseRecorder {
- t.Helper()
- w := httptest.NewRecorder()
- s.ServeHTTP(w, req)
- return w
- }
- type invalidReader struct{}
- func (r *invalidReader) Read(p []byte) (int, error) {
- return 0, os.ErrInvalid
- }
- // captureLogs is a helper to capture logs from the server. It returns a
- // shallow copy of the server with a new logger and a bytesResetter for the
- // logs.
- func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
- t.Helper()
- log, logs := testutil.SlogBuffer()
- l := *s // shallow copy
- l.Logger = log
- return &l, logs
- }
- func TestServerDelete(t *testing.T) {
- check := testutil.Checker(t)
- s := newTestServer(t)
- _, err := s.Client.ResolveLocal("smol")
- check(err)
- got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
- if got.Code != 200 {
- t.Fatalf("Code = %d; want 200", got.Code)
- }
- _, err = s.Client.ResolveLocal("smol")
- if err == nil {
- t.Fatal("expected smol to have been deleted")
- }
- got = s.send(t, "DELETE", "/api/delete", `!`)
- checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
- got = s.send(t, "GET", "/api/delete", `{"model": "smol"}`)
- checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
- got = s.send(t, "DELETE", "/api/delete", ``)
- checkErrorResponse(t, got, 400, "bad_request", "empty request body")
- got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
- checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
- got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
- checkErrorResponse(t, got, 404, "not_found", "not found")
- s, logs := captureLogs(t, s)
- req := httptest.NewRequestWithContext(t.Context(), "DELETE", "/api/delete", &invalidReader{})
- got = s.sendRequest(t, req)
- checkErrorResponse(t, got, 500, "internal_error", "internal server error")
- ok, err := regexp.Match(`ERROR.*error="invalid argument"`, logs.Bytes())
- check(err)
- if !ok {
- t.Logf("logs:\n%s", logs)
- t.Fatalf("expected log to contain ERROR with invalid argument")
- }
- }
- func TestServerUnknownPath(t *testing.T) {
- s := newTestServer(t)
- got := s.send(t, "DELETE", "/api/unknown", `{}`)
- checkErrorResponse(t, got, 404, "not_found", "not found")
- }
- func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
- t.Helper()
- var printedBody bool
- errorf := func(format string, args ...any) {
- t.Helper()
- if !printedBody {
- t.Logf("BODY:\n%s", got.Body.String())
- printedBody = true
- }
- t.Errorf(format, args...)
- }
- if got.Code != status {
- errorf("Code = %d; want %d", got.Code, status)
- }
- // unmarshal the error as *ollama.Error (proving *serverError is an *ollama.Error)
- var e *ollama.Error
- if err := json.Unmarshal(got.Body.Bytes(), &e); err != nil {
- errorf("unmarshal error: %v", err)
- t.FailNow()
- }
- if e.Code != code {
- errorf("Code = %q; want %q", e.Code, code)
- }
- if !strings.Contains(e.Message, msg) {
- errorf("Message = %q; want to contain %q", e.Message, msg)
- }
- }
|