server.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. // Package registry provides an http.Handler for handling local Ollama API
  2. // requests for performing tasks related to the ollama.com model registry and
  3. // the local disk cache.
  4. package registry
  5. import (
  6. "cmp"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "log/slog"
  12. "net/http"
  13. "sync"
  14. "time"
  15. "github.com/ollama/ollama/server/internal/cache/blob"
  16. "github.com/ollama/ollama/server/internal/client/ollama"
  17. )
  18. // Local is an http.Handler for handling local Ollama API requests for
  19. // performing tasks related to the ollama.com model registry combined with the
  20. // local disk cache.
  21. //
  22. // It is not concern of Local, or this package, to handle model creation, which
  23. // proceeds any registry operations for models it produces.
  24. //
  25. // NOTE: The package built for dealing with model creation should use
  26. // [DefaultCache] to access the blob store and not attempt to read or write
  27. // directly to the blob disk cache.
  28. type Local struct {
  29. Client *ollama.Registry // required
  30. Logger *slog.Logger // required
  31. // Fallback, if set, is used to handle requests that are not handled by
  32. // this handler.
  33. Fallback http.Handler
  34. // Prune, if set, is called to prune the local disk cache after a model
  35. // is deleted.
  36. Prune func() error // optional
  37. }
  38. // serverError is like ollama.Error, but with a Status field for the HTTP
  39. // response code. We want to avoid adding that field to ollama.Error because it
  40. // would always be 0 to clients (we don't want to leak the status code in
  41. // errors), and so it would be confusing to have a field that is always 0.
  42. type serverError struct {
  43. Status int `json:"-"`
  44. // TODO(bmizerany): Decide if we want to keep this and maybe
  45. // bring back later.
  46. Code string `json:"code"`
  47. Message string `json:"error"`
  48. }
  49. func (e serverError) Error() string {
  50. return e.Message
  51. }
  52. // Common API errors
  53. var (
  54. errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
  55. errNotFound = &serverError{404, "not_found", "not found"}
  56. errInternalError = &serverError{500, "internal_error", "internal server error"}
  57. )
  58. type statusCodeRecorder struct {
  59. _status int // use status() to get the status code
  60. http.ResponseWriter
  61. }
  62. func (r *statusCodeRecorder) WriteHeader(status int) {
  63. if r._status == 0 {
  64. r._status = status
  65. }
  66. r.ResponseWriter.WriteHeader(status)
  67. }
  68. var (
  69. _ http.ResponseWriter = (*statusCodeRecorder)(nil)
  70. _ http.CloseNotifier = (*statusCodeRecorder)(nil)
  71. _ http.Flusher = (*statusCodeRecorder)(nil)
  72. )
  73. // CloseNotify implements the http.CloseNotifier interface, for Gin. Remove with Gin.
  74. //
  75. // It panics if the underlying ResponseWriter is not a CloseNotifier.
  76. func (r *statusCodeRecorder) CloseNotify() <-chan bool {
  77. return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
  78. }
  79. // Flush implements the http.Flusher interface, for Gin. Remove with Gin.
  80. //
  81. // It panics if the underlying ResponseWriter is not a Flusher.
  82. func (r *statusCodeRecorder) Flush() {
  83. r.ResponseWriter.(http.Flusher).Flush()
  84. }
  85. func (r *statusCodeRecorder) status() int {
  86. return cmp.Or(r._status, 200)
  87. }
  88. func (s *Local) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  89. rec := &statusCodeRecorder{ResponseWriter: w}
  90. s.serveHTTP(rec, r)
  91. }
  92. func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
  93. var errattr slog.Attr
  94. proxied, err := func() (bool, error) {
  95. switch r.URL.Path {
  96. case "/api/delete":
  97. return false, s.handleDelete(rec, r)
  98. case "/api/pull":
  99. return false, s.handlePull(rec, r)
  100. default:
  101. if s.Fallback != nil {
  102. s.Fallback.ServeHTTP(rec, r)
  103. return true, nil
  104. }
  105. return false, errNotFound
  106. }
  107. }()
  108. if err != nil {
  109. // We always log the error, so fill in the error log attribute
  110. errattr = slog.String("error", err.Error())
  111. var e *serverError
  112. switch {
  113. case errors.As(err, &e):
  114. case errors.Is(err, ollama.ErrNameInvalid):
  115. e = &serverError{400, "bad_request", err.Error()}
  116. default:
  117. e = errInternalError
  118. }
  119. data, err := json.Marshal(e)
  120. if err != nil {
  121. // unreachable
  122. panic(err)
  123. }
  124. rec.Header().Set("Content-Type", "application/json")
  125. rec.WriteHeader(e.Status)
  126. rec.Write(data)
  127. // fallthrough to log
  128. }
  129. if !proxied {
  130. // we're only responsible for logging if we handled the request
  131. var level slog.Level
  132. if rec.status() >= 500 {
  133. level = slog.LevelError
  134. } else if rec.status() >= 400 {
  135. level = slog.LevelWarn
  136. }
  137. s.Logger.LogAttrs(r.Context(), level, "http",
  138. errattr, // report first in line to make it easy to find
  139. // TODO(bmizerany): Write a test to ensure that we are logging
  140. // all of this correctly. That also goes for the level+error
  141. // logic above.
  142. slog.Int("status", rec.status()),
  143. slog.String("method", r.Method),
  144. slog.String("path", r.URL.Path),
  145. slog.Int64("content-length", r.ContentLength),
  146. slog.String("remote", r.RemoteAddr),
  147. slog.String("proto", r.Proto),
  148. slog.String("query", r.URL.RawQuery),
  149. )
  150. }
  151. }
  152. type params struct {
  153. DeprecatedName string `json:"name"` // Use [params.model]
  154. Model string `json:"model"` // Use [params.model]
  155. // AllowNonTLS is a flag that indicates a client using HTTP
  156. // is doing so, deliberately.
  157. //
  158. // Deprecated: This field is ignored and only present for this
  159. // deprecation message. It should be removed in a future release.
  160. //
  161. // Users can just use http or https+insecure to show intent to
  162. // communicate they want to do insecure things, without awkward and
  163. // confusing flags such as this.
  164. AllowNonTLS bool `json:"insecure"`
  165. // ProgressStream is a flag that indicates the client is expecting a stream of
  166. // progress updates.
  167. ProgressStream bool `json:"stream"`
  168. }
  169. // model returns the model name for both old and new API requests.
  170. func (p params) model() string {
  171. return cmp.Or(p.Model, p.DeprecatedName)
  172. }
  173. func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
  174. if r.Method != "DELETE" {
  175. return errMethodNotAllowed
  176. }
  177. p, err := decodeUserJSON[*params](r.Body)
  178. if err != nil {
  179. return err
  180. }
  181. ok, err := s.Client.Unlink(p.model())
  182. if err != nil {
  183. return err
  184. }
  185. if !ok {
  186. return &serverError{404, "not_found", "model not found"}
  187. }
  188. if s.Prune == nil {
  189. return nil
  190. }
  191. return s.Prune()
  192. }
  193. type progressUpdateJSON struct {
  194. Status string `json:"status"`
  195. Digest blob.Digest `json:"digest,omitempty,omitzero"`
  196. Total int64 `json:"total,omitempty,omitzero"`
  197. Completed int64 `json:"completed,omitempty,omitzero"`
  198. }
  199. func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
  200. if r.Method != "POST" {
  201. return errMethodNotAllowed
  202. }
  203. p, err := decodeUserJSON[*params](r.Body)
  204. if err != nil {
  205. return err
  206. }
  207. maybeFlush := func() {
  208. fl, _ := w.(http.Flusher)
  209. if fl != nil {
  210. fl.Flush()
  211. }
  212. }
  213. defer maybeFlush()
  214. var mu sync.Mutex
  215. enc := json.NewEncoder(w)
  216. enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
  217. ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
  218. Update: func(l *ollama.Layer, n int64, err error) {
  219. mu.Lock()
  220. defer mu.Unlock()
  221. // TODO(bmizerany): coalesce these updates; writing per
  222. // update is expensive
  223. enc.Encode(progressUpdateJSON{
  224. Digest: l.Digest,
  225. Status: "pulling",
  226. Total: l.Size,
  227. Completed: n,
  228. })
  229. },
  230. })
  231. done := make(chan error, 1)
  232. go func() {
  233. // TODO(bmizerany): continue to support non-streaming responses
  234. done <- s.Client.Pull(ctx, p.model())
  235. }()
  236. func() {
  237. t := time.NewTicker(100 * time.Millisecond)
  238. defer t.Stop()
  239. for {
  240. select {
  241. case <-t.C:
  242. mu.Lock()
  243. maybeFlush()
  244. mu.Unlock()
  245. case err := <-done:
  246. if err != nil {
  247. var status string
  248. if errors.Is(err, ollama.ErrModelNotFound) {
  249. status = fmt.Sprintf("error: model %q not found", p.model())
  250. enc.Encode(progressUpdateJSON{Status: status})
  251. } else {
  252. status = fmt.Sprintf("error: %v", err)
  253. enc.Encode(progressUpdateJSON{Status: status})
  254. }
  255. return
  256. }
  257. // These final updates are not strictly necessary, because they have
  258. // already happened at this point. Our pull handler code used to do
  259. // these steps after, not during, the pull, and they were slow, so we
  260. // wanted to provide feedback to users what was happening. For now, we
  261. // keep them to not jar users who are used to seeing them. We can phase
  262. // them out with a new and nicer UX later. One without progress bars
  263. // and digests that no one cares about.
  264. enc.Encode(progressUpdateJSON{Status: "verifying layers"})
  265. enc.Encode(progressUpdateJSON{Status: "writing manifest"})
  266. enc.Encode(progressUpdateJSON{Status: "success"})
  267. return
  268. }
  269. }
  270. }()
  271. return nil
  272. }
  273. func decodeUserJSON[T any](r io.Reader) (T, error) {
  274. var v T
  275. err := json.NewDecoder(r).Decode(&v)
  276. if err == nil {
  277. return v, nil
  278. }
  279. var zero T
  280. // Not sure why, but I can't seem to be able to use:
  281. //
  282. // errors.As(err, &json.UnmarshalTypeError{})
  283. //
  284. // This is working fine in stdlib, so I'm not sure what rules changed
  285. // and why this no longer works here. So, we do it the verbose way.
  286. var a *json.UnmarshalTypeError
  287. var b *json.SyntaxError
  288. if errors.As(err, &a) || errors.As(err, &b) {
  289. err = &serverError{Status: 400, Message: err.Error(), Code: "bad_request"}
  290. }
  291. if errors.Is(err, io.EOF) {
  292. err = &serverError{Status: 400, Message: "empty request body", Code: "bad_request"}
  293. }
  294. return zero, err
  295. }