server.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. // Package registry implements an http.Handler for handling local Ollama API
  2. // model management requests. See [Local] for details.
  3. package registry
  4. import (
  5. "cmp"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "maps"
  12. "net/http"
  13. "sync"
  14. "time"
  15. "github.com/ollama/ollama/server/internal/cache/blob"
  16. "github.com/ollama/ollama/server/internal/client/ollama"
  17. )
  18. // Local implements an http.Handler for handling local Ollama API model
  19. // management requests, such as pushing, pulling, and deleting models.
  20. //
  21. // It can be arranged for all unknown requests to be passed through to a
  22. // fallback handler, if one is provided.
  23. type Local struct {
  24. Client *ollama.Registry // required
  25. Logger *slog.Logger // required
  26. // Fallback, if set, is used to handle requests that are not handled by
  27. // this handler.
  28. Fallback http.Handler
  29. // Prune, if set, is called to prune the local disk cache after a model
  30. // is deleted.
  31. Prune func() error // optional
  32. }
  33. // serverError is like ollama.Error, but with a Status field for the HTTP
  34. // response code. We want to avoid adding that field to ollama.Error because it
  35. // would always be 0 to clients (we don't want to leak the status code in
  36. // errors), and so it would be confusing to have a field that is always 0.
  37. type serverError struct {
  38. Status int `json:"-"`
  39. // TODO(bmizerany): Decide if we want to keep this and maybe
  40. // bring back later.
  41. Code string `json:"code"`
  42. Message string `json:"error"`
  43. }
  44. func (e serverError) Error() string {
  45. return e.Message
  46. }
  47. // Common API errors
  48. var (
  49. errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
  50. errNotFound = &serverError{404, "not_found", "not found"}
  51. errModelNotFound = &serverError{404, "not_found", "model not found"}
  52. errInternalError = &serverError{500, "internal_error", "internal server error"}
  53. )
  54. type statusCodeRecorder struct {
  55. _status int // use status() to get the status code
  56. http.ResponseWriter
  57. }
  58. func (r *statusCodeRecorder) WriteHeader(status int) {
  59. if r._status == 0 {
  60. r._status = status
  61. }
  62. r.ResponseWriter.WriteHeader(status)
  63. }
  64. var (
  65. _ http.ResponseWriter = (*statusCodeRecorder)(nil)
  66. _ http.CloseNotifier = (*statusCodeRecorder)(nil)
  67. _ http.Flusher = (*statusCodeRecorder)(nil)
  68. )
  69. // CloseNotify implements the http.CloseNotifier interface, for Gin. Remove with Gin.
  70. //
  71. // It panics if the underlying ResponseWriter is not a CloseNotifier.
  72. func (r *statusCodeRecorder) CloseNotify() <-chan bool {
  73. return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
  74. }
  75. // Flush implements the http.Flusher interface, for Gin. Remove with Gin.
  76. //
  77. // It panics if the underlying ResponseWriter is not a Flusher.
  78. func (r *statusCodeRecorder) Flush() {
  79. r.ResponseWriter.(http.Flusher).Flush()
  80. }
  81. func (r *statusCodeRecorder) status() int {
  82. return cmp.Or(r._status, 200)
  83. }
  84. func (s *Local) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  85. rec := &statusCodeRecorder{ResponseWriter: w}
  86. s.serveHTTP(rec, r)
  87. }
  88. func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
  89. var errattr slog.Attr
  90. proxied, err := func() (bool, error) {
  91. switch r.URL.Path {
  92. case "/api/delete":
  93. return false, s.handleDelete(rec, r)
  94. case "/api/pull":
  95. return false, s.handlePull(rec, r)
  96. default:
  97. if s.Fallback != nil {
  98. s.Fallback.ServeHTTP(rec, r)
  99. return true, nil
  100. }
  101. return false, errNotFound
  102. }
  103. }()
  104. if err != nil {
  105. // We always log the error, so fill in the error log attribute
  106. errattr = slog.String("error", err.Error())
  107. var e *serverError
  108. switch {
  109. case errors.As(err, &e):
  110. case errors.Is(err, ollama.ErrNameInvalid):
  111. e = &serverError{400, "bad_request", err.Error()}
  112. default:
  113. e = errInternalError
  114. }
  115. data, err := json.Marshal(e)
  116. if err != nil {
  117. // unreachable
  118. panic(err)
  119. }
  120. rec.Header().Set("Content-Type", "application/json")
  121. rec.WriteHeader(e.Status)
  122. rec.Write(data)
  123. // fallthrough to log
  124. }
  125. if !proxied {
  126. // we're only responsible for logging if we handled the request
  127. var level slog.Level
  128. if rec.status() >= 500 {
  129. level = slog.LevelError
  130. } else if rec.status() >= 400 {
  131. level = slog.LevelWarn
  132. }
  133. s.Logger.LogAttrs(r.Context(), level, "http",
  134. errattr, // report first in line to make it easy to find
  135. // TODO(bmizerany): Write a test to ensure that we are logging
  136. // all of this correctly. That also goes for the level+error
  137. // logic above.
  138. slog.Int("status", rec.status()),
  139. slog.String("method", r.Method),
  140. slog.String("path", r.URL.Path),
  141. slog.Int64("content-length", r.ContentLength),
  142. slog.String("remote", r.RemoteAddr),
  143. slog.String("proto", r.Proto),
  144. slog.String("query", r.URL.RawQuery),
  145. )
  146. }
  147. }
  148. type params struct {
  149. // DeprecatedName is the name of the model to push, pull, or delete,
  150. // but is deprecated. New clients should use [Model] instead.
  151. //
  152. // Use [model()] to get the model name for both old and new API requests.
  153. DeprecatedName string `json:"name"`
  154. // Model is the name of the model to push, pull, or delete.
  155. //
  156. // Use [model()] to get the model name for both old and new API requests.
  157. Model string `json:"model"`
  158. // AllowNonTLS is a flag that indicates a client using HTTP
  159. // is doing so, deliberately.
  160. //
  161. // Deprecated: This field is ignored and only present for this
  162. // deprecation message. It should be removed in a future release.
  163. //
  164. // Users can just use http or https+insecure to show intent to
  165. // communicate they want to do insecure things, without awkward and
  166. // confusing flags such as this.
  167. AllowNonTLS bool `json:"insecure"`
  168. // Stream, if true, will make the server send progress updates in a
  169. // streaming of JSON objects. If false, the server will send a single
  170. // JSON object with the final status as "success", or an error object
  171. // if an error occurred.
  172. //
  173. // Unfortunately, this API was designed to be a bit awkward. Stream is
  174. // defined to default to true if not present, so we need a way to check
  175. // if the client decisively set it to false. So, we use a pointer to a
  176. // bool. Gross.
  177. //
  178. // Use [stream()] to get the correct value for this field.
  179. Stream *bool `json:"stream"`
  180. }
  181. // model returns the model name for both old and new API requests.
  182. func (p params) model() string {
  183. return cmp.Or(p.Model, p.DeprecatedName)
  184. }
  185. func (p params) stream() bool {
  186. if p.Stream == nil {
  187. return true
  188. }
  189. return *p.Stream
  190. }
  191. func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
  192. if r.Method != "DELETE" {
  193. return errMethodNotAllowed
  194. }
  195. p, err := decodeUserJSON[*params](r.Body)
  196. if err != nil {
  197. return err
  198. }
  199. ok, err := s.Client.Unlink(p.model())
  200. if err != nil {
  201. return err
  202. }
  203. if !ok {
  204. return errModelNotFound
  205. }
  206. if s.Prune != nil {
  207. return s.Prune()
  208. }
  209. return nil
  210. }
  211. type progressUpdateJSON struct {
  212. Status string `json:"status,omitempty,omitzero"`
  213. Digest blob.Digest `json:"digest,omitempty,omitzero"`
  214. Total int64 `json:"total,omitempty,omitzero"`
  215. Completed int64 `json:"completed,omitempty,omitzero"`
  216. }
  217. func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
  218. if r.Method != "POST" {
  219. return errMethodNotAllowed
  220. }
  221. p, err := decodeUserJSON[*params](r.Body)
  222. if err != nil {
  223. return err
  224. }
  225. enc := json.NewEncoder(w)
  226. if !p.stream() {
  227. if err := s.Client.Pull(r.Context(), p.model()); err != nil {
  228. if errors.Is(err, ollama.ErrModelNotFound) {
  229. return errModelNotFound
  230. }
  231. return err
  232. }
  233. return enc.Encode(progressUpdateJSON{Status: "success"})
  234. }
  235. maybeFlush := func() {
  236. fl, _ := w.(http.Flusher)
  237. if fl != nil {
  238. fl.Flush()
  239. }
  240. }
  241. defer maybeFlush()
  242. var mu sync.Mutex
  243. progress := make(map[*ollama.Layer]int64)
  244. progressCopy := make(map[*ollama.Layer]int64, len(progress))
  245. flushProgress := func() {
  246. defer maybeFlush()
  247. // TODO(bmizerany): Flushing every layer in one update doesn't
  248. // scale well. We could flush only the modified layers or track
  249. // the full download. Needs further consideration, though it's
  250. // fine for now.
  251. mu.Lock()
  252. maps.Copy(progressCopy, progress)
  253. mu.Unlock()
  254. for l, n := range progressCopy {
  255. enc.Encode(progressUpdateJSON{
  256. Digest: l.Digest,
  257. Total: l.Size,
  258. Completed: n,
  259. })
  260. }
  261. }
  262. defer flushProgress()
  263. t := time.NewTicker(1000 * time.Hour) // "unstarted" timer
  264. start := sync.OnceFunc(func() {
  265. flushProgress() // flush initial state
  266. t.Reset(100 * time.Millisecond)
  267. })
  268. ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
  269. Update: func(l *ollama.Layer, n int64, err error) {
  270. if n > 0 {
  271. // Block flushing progress updates until every
  272. // layer is accounted for. Clients depend on a
  273. // complete model size to calculate progress
  274. // correctly; if they use an incomplete total,
  275. // progress indicators would erratically jump
  276. // as new layers are registered.
  277. start()
  278. }
  279. mu.Lock()
  280. progress[l] += n
  281. mu.Unlock()
  282. },
  283. })
  284. done := make(chan error, 1)
  285. go func() {
  286. done <- s.Client.Pull(ctx, p.model())
  287. }()
  288. for {
  289. select {
  290. case <-t.C:
  291. flushProgress()
  292. case err := <-done:
  293. flushProgress()
  294. if err != nil {
  295. var status string
  296. if errors.Is(err, ollama.ErrModelNotFound) {
  297. status = fmt.Sprintf("error: model %q not found", p.model())
  298. } else {
  299. status = fmt.Sprintf("error: %v", err)
  300. }
  301. enc.Encode(progressUpdateJSON{Status: status})
  302. }
  303. return nil
  304. }
  305. }
  306. }
  307. func decodeUserJSON[T any](r io.Reader) (T, error) {
  308. var v T
  309. err := json.NewDecoder(r).Decode(&v)
  310. if err == nil {
  311. return v, nil
  312. }
  313. var zero T
  314. // Not sure why, but I can't seem to be able to use:
  315. //
  316. // errors.As(err, &json.UnmarshalTypeError{})
  317. //
  318. // This is working fine in stdlib, so I'm not sure what rules changed
  319. // and why this no longer works here. So, we do it the verbose way.
  320. var a *json.UnmarshalTypeError
  321. var b *json.SyntaxError
  322. if errors.As(err, &a) || errors.As(err, &b) {
  323. err = &serverError{Status: 400, Message: err.Error(), Code: "bad_request"}
  324. }
  325. if errors.Is(err, io.EOF) {
  326. err = &serverError{Status: 400, Message: "empty request body", Code: "bad_request"}
  327. }
  328. return zero, err
  329. }