server.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. // Package registry provides an http.Handler for handling local Ollama API
  2. // requests for performing tasks related to the ollama.com model registry and
  3. // the local disk cache.
  4. package registry
  5. import (
  6. "cmp"
  7. "encoding/json"
  8. "errors"
  9. "io"
  10. "log/slog"
  11. "net/http"
  12. "github.com/ollama/ollama/server/internal/client/ollama"
  13. )
  14. // Local is an http.Handler for handling local Ollama API requests for
  15. // performing tasks related to the ollama.com model registry combined with the
  16. // local disk cache.
  17. //
  18. // It is not concern of Local, or this package, to handle model creation, which
  19. // proceeds any registry operations for models it produces.
  20. //
  21. // NOTE: The package built for dealing with model creation should use
  22. // [DefaultCache] to access the blob store and not attempt to read or write
  23. // directly to the blob disk cache.
  24. type Local struct {
  25. Client *ollama.Registry // required
  26. Logger *slog.Logger // required
  27. // Fallback, if set, is used to handle requests that are not handled by
  28. // this handler.
  29. Fallback http.Handler
  30. // Prune, if set, is called to prune the local disk cache after a model
  31. // is deleted.
  32. Prune func() error // optional
  33. }
  34. // serverError is like ollama.Error, but with a Status field for the HTTP
  35. // response code. We want to avoid adding that field to ollama.Error because it
  36. // would always be 0 to clients (we don't want to leak the status code in
  37. // errors), and so it would be confusing to have a field that is always 0.
  38. type serverError struct {
  39. Status int `json:"-"`
  40. // TODO(bmizerany): Decide if we want to keep this and maybe
  41. // bring back later.
  42. Code string `json:"code"`
  43. Message string `json:"error"`
  44. }
  45. func (e serverError) Error() string {
  46. return e.Message
  47. }
  48. // Common API errors
  49. var (
  50. errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
  51. errNotFound = &serverError{404, "not_found", "not found"}
  52. errInternalError = &serverError{500, "internal_error", "internal server error"}
  53. )
  54. type statusCodeRecorder struct {
  55. _status int // use status() to get the status code
  56. http.ResponseWriter
  57. }
  58. func (r *statusCodeRecorder) WriteHeader(status int) {
  59. if r._status == 0 {
  60. r._status = status
  61. }
  62. r.ResponseWriter.WriteHeader(status)
  63. }
  64. var (
  65. _ http.ResponseWriter = (*statusCodeRecorder)(nil)
  66. _ http.CloseNotifier = (*statusCodeRecorder)(nil)
  67. _ http.Flusher = (*statusCodeRecorder)(nil)
  68. )
  69. // CloseNotify implements the http.CloseNotifier interface, for Gin. Remove with Gin.
  70. //
  71. // It panics if the underlying ResponseWriter is not a CloseNotifier.
  72. func (r *statusCodeRecorder) CloseNotify() <-chan bool {
  73. return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
  74. }
  75. // Flush implements the http.Flusher interface, for Gin. Remove with Gin.
  76. //
  77. // It panics if the underlying ResponseWriter is not a Flusher.
  78. func (r *statusCodeRecorder) Flush() {
  79. r.ResponseWriter.(http.Flusher).Flush()
  80. }
  81. func (r *statusCodeRecorder) status() int {
  82. return cmp.Or(r._status, 200)
  83. }
  84. func (s *Local) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  85. rec := &statusCodeRecorder{ResponseWriter: w}
  86. s.serveHTTP(rec, r)
  87. }
  88. func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
  89. var errattr slog.Attr
  90. proxied, err := func() (bool, error) {
  91. switch r.URL.Path {
  92. case "/api/delete":
  93. return false, s.handleDelete(rec, r)
  94. default:
  95. if s.Fallback != nil {
  96. s.Fallback.ServeHTTP(rec, r)
  97. return true, nil
  98. }
  99. return false, errNotFound
  100. }
  101. }()
  102. if err != nil {
  103. // We always log the error, so fill in the error log attribute
  104. errattr = slog.String("error", err.Error())
  105. var e *serverError
  106. switch {
  107. case errors.As(err, &e):
  108. case errors.Is(err, ollama.ErrNameInvalid):
  109. e = &serverError{400, "bad_request", err.Error()}
  110. default:
  111. e = errInternalError
  112. }
  113. data, err := json.Marshal(e)
  114. if err != nil {
  115. // unreachable
  116. panic(err)
  117. }
  118. rec.Header().Set("Content-Type", "application/json")
  119. rec.WriteHeader(e.Status)
  120. rec.Write(data)
  121. // fallthrough to log
  122. }
  123. if !proxied {
  124. // we're only responsible for logging if we handled the request
  125. var level slog.Level
  126. if rec.status() >= 500 {
  127. level = slog.LevelError
  128. } else if rec.status() >= 400 {
  129. level = slog.LevelWarn
  130. }
  131. s.Logger.LogAttrs(r.Context(), level, "http",
  132. errattr, // report first in line to make it easy to find
  133. // TODO(bmizerany): Write a test to ensure that we are logging
  134. // all of this correctly. That also goes for the level+error
  135. // logic above.
  136. slog.Int("status", rec.status()),
  137. slog.String("method", r.Method),
  138. slog.String("path", r.URL.Path),
  139. slog.Int64("content-length", r.ContentLength),
  140. slog.String("remote", r.RemoteAddr),
  141. slog.String("proto", r.Proto),
  142. slog.String("query", r.URL.RawQuery),
  143. )
  144. }
  145. }
  146. type params struct {
  147. DeprecatedName string `json:"name"` // Use [params.model]
  148. Model string `json:"model"` // Use [params.model]
  149. // AllowNonTLS is a flag that indicates a client using HTTP
  150. // is doing so, deliberately.
  151. //
  152. // Deprecated: This field is ignored and only present for this
  153. // deprecation message. It should be removed in a future release.
  154. //
  155. // Users can just use http or https+insecure to show intent to
  156. // communicate they want to do insecure things, without awkward and
  157. // confusing flags such as this.
  158. AllowNonTLS bool `json:"insecure"`
  159. // ProgressStream is a flag that indicates the client is expecting a stream of
  160. // progress updates.
  161. ProgressStream bool `json:"stream"`
  162. }
  163. // model returns the model name for both old and new API requests.
  164. func (p params) model() string {
  165. return cmp.Or(p.Model, p.DeprecatedName)
  166. }
  167. func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
  168. if r.Method != "DELETE" {
  169. return errMethodNotAllowed
  170. }
  171. p, err := decodeUserJSON[*params](r.Body)
  172. if err != nil {
  173. return err
  174. }
  175. ok, err := s.Client.Unlink(p.model())
  176. if err != nil {
  177. return err
  178. }
  179. if !ok {
  180. return &serverError{404, "not_found", "model not found"}
  181. }
  182. if s.Prune == nil {
  183. return nil
  184. }
  185. return s.Prune()
  186. }
  187. func decodeUserJSON[T any](r io.Reader) (T, error) {
  188. var v T
  189. err := json.NewDecoder(r).Decode(&v)
  190. if err == nil {
  191. return v, nil
  192. }
  193. var zero T
  194. // Not sure why, but I can't seem to be able to use:
  195. //
  196. // errors.As(err, &json.UnmarshalTypeError{})
  197. //
  198. // This is working fine in stdlib, so I'm not sure what rules changed
  199. // and why this no longer works here. So, we do it the verbose way.
  200. var a *json.UnmarshalTypeError
  201. var b *json.SyntaxError
  202. if errors.As(err, &a) || errors.As(err, &b) {
  203. err = &serverError{Status: 400, Message: err.Error(), Code: "bad_request"}
  204. }
  205. if errors.Is(err, io.EOF) {
  206. err = &serverError{Status: 400, Message: "empty request body", Code: "bad_request"}
  207. }
  208. return zero, err
  209. }