123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337 |
- // Package registry provides an http.Handler for handling local Ollama API
- // requests for performing tasks related to the ollama.com model registry and
- // the local disk cache.
- package registry
- import (
- "cmp"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "log/slog"
- "net/http"
- "sync"
- "time"
- "github.com/ollama/ollama/server/internal/cache/blob"
- "github.com/ollama/ollama/server/internal/client/ollama"
- )
- // Local is an http.Handler for handling local Ollama API requests for
- // performing tasks related to the ollama.com model registry combined with the
- // local disk cache.
- //
- // It is not concern of Local, or this package, to handle model creation, which
- // proceeds any registry operations for models it produces.
- //
- // NOTE: The package built for dealing with model creation should use
- // [DefaultCache] to access the blob store and not attempt to read or write
- // directly to the blob disk cache.
- type Local struct {
- Client *ollama.Registry // required
- Logger *slog.Logger // required
- // Fallback, if set, is used to handle requests that are not handled by
- // this handler.
- Fallback http.Handler
- // Prune, if set, is called to prune the local disk cache after a model
- // is deleted.
- Prune func() error // optional
- }
- // serverError is like ollama.Error, but with a Status field for the HTTP
- // response code. We want to avoid adding that field to ollama.Error because it
- // would always be 0 to clients (we don't want to leak the status code in
- // errors), and so it would be confusing to have a field that is always 0.
- type serverError struct {
- Status int `json:"-"`
- // TODO(bmizerany): Decide if we want to keep this and maybe
- // bring back later.
- Code string `json:"code"`
- Message string `json:"error"`
- }
- func (e serverError) Error() string {
- return e.Message
- }
- // Common API errors
- var (
- errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
- errNotFound = &serverError{404, "not_found", "not found"}
- errInternalError = &serverError{500, "internal_error", "internal server error"}
- )
- type statusCodeRecorder struct {
- _status int // use status() to get the status code
- http.ResponseWriter
- }
- func (r *statusCodeRecorder) WriteHeader(status int) {
- if r._status == 0 {
- r._status = status
- }
- r.ResponseWriter.WriteHeader(status)
- }
- var (
- _ http.ResponseWriter = (*statusCodeRecorder)(nil)
- _ http.CloseNotifier = (*statusCodeRecorder)(nil)
- _ http.Flusher = (*statusCodeRecorder)(nil)
- )
- // CloseNotify implements the http.CloseNotifier interface, for Gin. Remove with Gin.
- //
- // It panics if the underlying ResponseWriter is not a CloseNotifier.
- func (r *statusCodeRecorder) CloseNotify() <-chan bool {
- return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
- }
- // Flush implements the http.Flusher interface, for Gin. Remove with Gin.
- //
- // It panics if the underlying ResponseWriter is not a Flusher.
- func (r *statusCodeRecorder) Flush() {
- r.ResponseWriter.(http.Flusher).Flush()
- }
- func (r *statusCodeRecorder) status() int {
- return cmp.Or(r._status, 200)
- }
- func (s *Local) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- rec := &statusCodeRecorder{ResponseWriter: w}
- s.serveHTTP(rec, r)
- }
- func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
- var errattr slog.Attr
- proxied, err := func() (bool, error) {
- switch r.URL.Path {
- case "/api/delete":
- return false, s.handleDelete(rec, r)
- case "/api/pull":
- return false, s.handlePull(rec, r)
- default:
- if s.Fallback != nil {
- s.Fallback.ServeHTTP(rec, r)
- return true, nil
- }
- return false, errNotFound
- }
- }()
- if err != nil {
- // We always log the error, so fill in the error log attribute
- errattr = slog.String("error", err.Error())
- var e *serverError
- switch {
- case errors.As(err, &e):
- case errors.Is(err, ollama.ErrNameInvalid):
- e = &serverError{400, "bad_request", err.Error()}
- default:
- e = errInternalError
- }
- data, err := json.Marshal(e)
- if err != nil {
- // unreachable
- panic(err)
- }
- rec.Header().Set("Content-Type", "application/json")
- rec.WriteHeader(e.Status)
- rec.Write(data)
- // fallthrough to log
- }
- if !proxied {
- // we're only responsible for logging if we handled the request
- var level slog.Level
- if rec.status() >= 500 {
- level = slog.LevelError
- } else if rec.status() >= 400 {
- level = slog.LevelWarn
- }
- s.Logger.LogAttrs(r.Context(), level, "http",
- errattr, // report first in line to make it easy to find
- // TODO(bmizerany): Write a test to ensure that we are logging
- // all of this correctly. That also goes for the level+error
- // logic above.
- slog.Int("status", rec.status()),
- slog.String("method", r.Method),
- slog.String("path", r.URL.Path),
- slog.Int64("content-length", r.ContentLength),
- slog.String("remote", r.RemoteAddr),
- slog.String("proto", r.Proto),
- slog.String("query", r.URL.RawQuery),
- )
- }
- }
- type params struct {
- DeprecatedName string `json:"name"` // Use [params.model]
- Model string `json:"model"` // Use [params.model]
- // AllowNonTLS is a flag that indicates a client using HTTP
- // is doing so, deliberately.
- //
- // Deprecated: This field is ignored and only present for this
- // deprecation message. It should be removed in a future release.
- //
- // Users can just use http or https+insecure to show intent to
- // communicate they want to do insecure things, without awkward and
- // confusing flags such as this.
- AllowNonTLS bool `json:"insecure"`
- // ProgressStream is a flag that indicates the client is expecting a stream of
- // progress updates.
- ProgressStream bool `json:"stream"`
- }
- // model returns the model name for both old and new API requests.
- func (p params) model() string {
- return cmp.Or(p.Model, p.DeprecatedName)
- }
- func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
- if r.Method != "DELETE" {
- return errMethodNotAllowed
- }
- p, err := decodeUserJSON[*params](r.Body)
- if err != nil {
- return err
- }
- ok, err := s.Client.Unlink(p.model())
- if err != nil {
- return err
- }
- if !ok {
- return &serverError{404, "not_found", "model not found"}
- }
- if s.Prune == nil {
- return nil
- }
- return s.Prune()
- }
- type progressUpdateJSON struct {
- Status string `json:"status"`
- Digest blob.Digest `json:"digest,omitempty,omitzero"`
- Total int64 `json:"total,omitempty,omitzero"`
- Completed int64 `json:"completed,omitempty,omitzero"`
- }
- func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
- if r.Method != "POST" {
- return errMethodNotAllowed
- }
- p, err := decodeUserJSON[*params](r.Body)
- if err != nil {
- return err
- }
- maybeFlush := func() {
- fl, _ := w.(http.Flusher)
- if fl != nil {
- fl.Flush()
- }
- }
- defer maybeFlush()
- var mu sync.Mutex
- enc := json.NewEncoder(w)
- enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
- ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
- Update: func(l *ollama.Layer, n int64, err error) {
- mu.Lock()
- defer mu.Unlock()
- // TODO(bmizerany): coalesce these updates; writing per
- // update is expensive
- enc.Encode(progressUpdateJSON{
- Digest: l.Digest,
- Status: "pulling",
- Total: l.Size,
- Completed: n,
- })
- },
- })
- done := make(chan error, 1)
- go func() {
- // TODO(bmizerany): continue to support non-streaming responses
- done <- s.Client.Pull(ctx, p.model())
- }()
- func() {
- t := time.NewTicker(100 * time.Millisecond)
- defer t.Stop()
- for {
- select {
- case <-t.C:
- mu.Lock()
- maybeFlush()
- mu.Unlock()
- case err := <-done:
- if err != nil {
- var status string
- if errors.Is(err, ollama.ErrModelNotFound) {
- status = fmt.Sprintf("error: model %q not found", p.model())
- enc.Encode(progressUpdateJSON{Status: status})
- } else {
- status = fmt.Sprintf("error: %v", err)
- enc.Encode(progressUpdateJSON{Status: status})
- }
- return
- }
- // These final updates are not strictly necessary, because they have
- // already happened at this point. Our pull handler code used to do
- // these steps after, not during, the pull, and they were slow, so we
- // wanted to provide feedback to users what was happening. For now, we
- // keep them to not jar users who are used to seeing them. We can phase
- // them out with a new and nicer UX later. One without progress bars
- // and digests that no one cares about.
- enc.Encode(progressUpdateJSON{Status: "verifying layers"})
- enc.Encode(progressUpdateJSON{Status: "writing manifest"})
- enc.Encode(progressUpdateJSON{Status: "success"})
- return
- }
- }
- }()
- return nil
- }
- func decodeUserJSON[T any](r io.Reader) (T, error) {
- var v T
- err := json.NewDecoder(r).Decode(&v)
- if err == nil {
- return v, nil
- }
- var zero T
- // Not sure why, but I can't seem to be able to use:
- //
- // errors.As(err, &json.UnmarshalTypeError{})
- //
- // This is working fine in stdlib, so I'm not sure what rules changed
- // and why this no longer works here. So, we do it the verbose way.
- var a *json.UnmarshalTypeError
- var b *json.SyntaxError
- if errors.As(err, &a) || errors.As(err, &b) {
- err = &serverError{Status: 400, Message: err.Error(), Code: "bad_request"}
- }
- if errors.Is(err, io.EOF) {
- err = &serverError{Status: 400, Message: "empty request body", Code: "bad_request"}
- }
- return zero, err
- }
|