routes.go 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597
  1. package server
  2. import (
  3. "bytes"
  4. "cmp"
  5. "context"
  6. "encoding/binary"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/fs"
  12. "log/slog"
  13. "math"
  14. "net"
  15. "net/http"
  16. "net/netip"
  17. "os"
  18. "os/signal"
  19. "path/filepath"
  20. "slices"
  21. "strings"
  22. "syscall"
  23. "time"
  24. "github.com/gin-contrib/cors"
  25. "github.com/gin-gonic/gin"
  26. "golang.org/x/sync/errgroup"
  27. "github.com/ollama/ollama/api"
  28. "github.com/ollama/ollama/discover"
  29. "github.com/ollama/ollama/envconfig"
  30. "github.com/ollama/ollama/llm"
  31. "github.com/ollama/ollama/model/mllama"
  32. "github.com/ollama/ollama/openai"
  33. "github.com/ollama/ollama/template"
  34. "github.com/ollama/ollama/types/errtypes"
  35. "github.com/ollama/ollama/types/model"
  36. "github.com/ollama/ollama/version"
  37. )
  38. var mode string = gin.DebugMode
  39. type Server struct {
  40. addr net.Addr
  41. sched *Scheduler
  42. }
  43. func init() {
  44. switch mode {
  45. case gin.DebugMode:
  46. case gin.ReleaseMode:
  47. case gin.TestMode:
  48. default:
  49. mode = gin.DebugMode
  50. }
  51. gin.SetMode(mode)
  52. }
  53. var (
  54. errRequired = errors.New("is required")
  55. errBadTemplate = errors.New("template error")
  56. )
  57. func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
  58. opts := api.DefaultOptions()
  59. if err := opts.FromMap(model.Options); err != nil {
  60. return api.Options{}, err
  61. }
  62. if err := opts.FromMap(requestOpts); err != nil {
  63. return api.Options{}, err
  64. }
  65. return opts, nil
  66. }
  67. // scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
  68. // It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
  69. func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
  70. if name == "" {
  71. return nil, nil, nil, fmt.Errorf("model %w", errRequired)
  72. }
  73. model, err := GetModel(name)
  74. if err != nil {
  75. return nil, nil, nil, err
  76. }
  77. if err := model.CheckCapabilities(caps...); err != nil {
  78. return nil, nil, nil, fmt.Errorf("%s %w", name, err)
  79. }
  80. opts, err := modelOptions(model, requestOpts)
  81. if err != nil {
  82. return nil, nil, nil, err
  83. }
  84. runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
  85. var runner *runnerRef
  86. select {
  87. case runner = <-runnerCh:
  88. case err = <-errCh:
  89. return nil, nil, nil, err
  90. }
  91. return runner.llama, model, &opts, nil
  92. }
  93. func (s *Server) GenerateHandler(c *gin.Context) {
  94. checkpointStart := time.Now()
  95. var req api.GenerateRequest
  96. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  97. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  98. return
  99. } else if err != nil {
  100. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  101. return
  102. }
  103. name := model.ParseName(req.Model)
  104. if !name.IsValid() {
  105. // Ideally this is "invalid model name" but we're keeping with
  106. // what the API currently returns until we can change it.
  107. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
  108. return
  109. }
  110. // We cannot currently consolidate this into GetModel because all we'll
  111. // induce infinite recursion given the current code structure.
  112. name, err := getExistingName(name)
  113. if err != nil {
  114. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
  115. return
  116. }
  117. model, err := GetModel(name.String())
  118. if err != nil {
  119. switch {
  120. case errors.Is(err, fs.ErrNotExist):
  121. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
  122. case err.Error() == errtypes.InvalidModelNameErrMsg:
  123. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  124. default:
  125. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  126. }
  127. return
  128. }
  129. // expire the runner
  130. if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
  131. s.sched.expireRunner(model)
  132. c.JSON(http.StatusOK, api.GenerateResponse{
  133. Model: req.Model,
  134. CreatedAt: time.Now().UTC(),
  135. Response: "",
  136. Done: true,
  137. DoneReason: "unload",
  138. })
  139. return
  140. }
  141. if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
  142. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
  143. return
  144. }
  145. caps := []Capability{CapabilityCompletion}
  146. if req.Suffix != "" {
  147. caps = append(caps, CapabilityInsert)
  148. }
  149. r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
  150. if errors.Is(err, errCapabilityCompletion) {
  151. c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
  152. return
  153. } else if err != nil {
  154. handleScheduleError(c, req.Model, err)
  155. return
  156. }
  157. checkpointLoaded := time.Now()
  158. // load the model
  159. if req.Prompt == "" {
  160. c.JSON(http.StatusOK, api.GenerateResponse{
  161. Model: req.Model,
  162. CreatedAt: time.Now().UTC(),
  163. Done: true,
  164. DoneReason: "load",
  165. })
  166. return
  167. }
  168. isMllama := checkMllamaModelFamily(model)
  169. if isMllama && len(req.Images) > 1 {
  170. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
  171. return
  172. }
  173. images := make([]llm.ImageData, len(req.Images))
  174. for i := range req.Images {
  175. if isMllama {
  176. data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
  177. if err != nil {
  178. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
  179. return
  180. }
  181. ar, ok := opts["aspectRatioIndex"].(int)
  182. if !ok {
  183. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
  184. return
  185. }
  186. buf := new(bytes.Buffer)
  187. err = binary.Write(buf, binary.LittleEndian, data)
  188. if err != nil {
  189. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
  190. return
  191. }
  192. images[i] = llm.ImageData{ID: i, Data: buf.Bytes(), AspectRatioID: ar}
  193. } else {
  194. images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
  195. }
  196. }
  197. prompt := req.Prompt
  198. if !req.Raw {
  199. tmpl := m.Template
  200. if req.Template != "" {
  201. tmpl, err = template.Parse(req.Template)
  202. if err != nil {
  203. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  204. return
  205. }
  206. }
  207. var values template.Values
  208. if req.Suffix != "" {
  209. values.Prompt = prompt
  210. values.Suffix = req.Suffix
  211. } else {
  212. var msgs []api.Message
  213. if req.System != "" {
  214. msgs = append(msgs, api.Message{Role: "system", Content: req.System})
  215. } else if m.System != "" {
  216. msgs = append(msgs, api.Message{Role: "system", Content: m.System})
  217. }
  218. if req.Context == nil {
  219. msgs = append(msgs, m.Messages...)
  220. }
  221. for _, i := range images {
  222. imgPrompt := ""
  223. if isMllama {
  224. imgPrompt = "<|image|>"
  225. }
  226. msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]"+imgPrompt, i.ID)})
  227. }
  228. values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
  229. }
  230. var b bytes.Buffer
  231. if req.Context != nil {
  232. slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
  233. s, err := r.Detokenize(c.Request.Context(), req.Context)
  234. if err != nil {
  235. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  236. return
  237. }
  238. b.WriteString(s)
  239. }
  240. if err := tmpl.Execute(&b, values); err != nil {
  241. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  242. return
  243. }
  244. prompt = b.String()
  245. }
  246. slog.Debug("generate request", "images", len(images), "prompt", prompt)
  247. ch := make(chan any)
  248. go func() {
  249. // TODO (jmorganca): avoid building the response twice both here and below
  250. var sb strings.Builder
  251. defer close(ch)
  252. if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
  253. Prompt: prompt,
  254. Images: images,
  255. Format: req.Format,
  256. LogProbs: req.LogProbs,
  257. Options: opts,
  258. }, func(cr llm.CompletionResponse) {
  259. fmt.Printf("banana: %#v\n", cr)
  260. res := api.GenerateResponse{
  261. Model: req.Model,
  262. CreatedAt: time.Now().UTC(),
  263. Response: cr.Content,
  264. Done: cr.Done,
  265. DoneReason: cr.DoneReason,
  266. Metrics: api.Metrics{
  267. PromptEvalCount: cr.PromptEvalCount,
  268. PromptEvalDuration: cr.PromptEvalDuration,
  269. EvalCount: cr.EvalCount,
  270. EvalDuration: cr.EvalDuration,
  271. },
  272. }
  273. for _, p := range cr.LogProbs {
  274. res.LogProbs = append(res.LogProbs, api.TokenProbs{
  275. TokenID: p.TokenID,
  276. LogProb: p.LogProb,
  277. Token: p.Token,
  278. })
  279. }
  280. if _, err := sb.WriteString(cr.Content); err != nil {
  281. ch <- gin.H{"error": err.Error()}
  282. }
  283. if cr.Done {
  284. res.TotalDuration = time.Since(checkpointStart)
  285. res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
  286. if !req.Raw {
  287. tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
  288. if err != nil {
  289. ch <- gin.H{"error": err.Error()}
  290. return
  291. }
  292. res.Context = tokens
  293. }
  294. }
  295. ch <- res
  296. }); err != nil {
  297. ch <- gin.H{"error": err.Error()}
  298. }
  299. }()
  300. if req.Stream != nil && !*req.Stream {
  301. var r api.GenerateResponse
  302. var sb strings.Builder
  303. for rr := range ch {
  304. switch t := rr.(type) {
  305. case api.GenerateResponse:
  306. sb.WriteString(t.Response)
  307. r = t
  308. case gin.H:
  309. msg, ok := t["error"].(string)
  310. if !ok {
  311. msg = "unexpected error format in response"
  312. }
  313. c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
  314. return
  315. default:
  316. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
  317. return
  318. }
  319. }
  320. r.Response = sb.String()
  321. c.JSON(http.StatusOK, r)
  322. return
  323. }
  324. streamResponse(c, ch)
  325. }
  326. func (s *Server) EmbedHandler(c *gin.Context) {
  327. checkpointStart := time.Now()
  328. var req api.EmbedRequest
  329. err := c.ShouldBindJSON(&req)
  330. switch {
  331. case errors.Is(err, io.EOF):
  332. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  333. return
  334. case err != nil:
  335. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  336. return
  337. }
  338. truncate := true
  339. if req.Truncate != nil && !*req.Truncate {
  340. truncate = false
  341. }
  342. var input []string
  343. switch i := req.Input.(type) {
  344. case string:
  345. if len(i) > 0 {
  346. input = append(input, i)
  347. }
  348. case []any:
  349. for _, v := range i {
  350. if _, ok := v.(string); !ok {
  351. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
  352. return
  353. }
  354. input = append(input, v.(string))
  355. }
  356. default:
  357. if req.Input != nil {
  358. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
  359. return
  360. }
  361. }
  362. name, err := getExistingName(model.ParseName(req.Model))
  363. if err != nil {
  364. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
  365. return
  366. }
  367. r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
  368. if err != nil {
  369. handleScheduleError(c, req.Model, err)
  370. return
  371. }
  372. checkpointLoaded := time.Now()
  373. if len(input) == 0 {
  374. c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
  375. return
  376. }
  377. kvData, err := getKVData(m.ModelPath, false)
  378. if err != nil {
  379. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  380. return
  381. }
  382. var count int
  383. for i, s := range input {
  384. tokens, err := r.Tokenize(c.Request.Context(), s)
  385. if err != nil {
  386. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  387. return
  388. }
  389. ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
  390. if len(tokens) > ctxLen {
  391. if !truncate {
  392. c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
  393. return
  394. }
  395. tokens = tokens[:ctxLen]
  396. s, err = r.Detokenize(c.Request.Context(), tokens)
  397. if err != nil {
  398. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  399. return
  400. }
  401. }
  402. count += len(tokens)
  403. input[i] = s
  404. }
  405. var g errgroup.Group
  406. embeddings := make([][]float32, len(input))
  407. for i, text := range input {
  408. g.Go(func() error {
  409. embedding, err := r.Embedding(c.Request.Context(), text)
  410. if err != nil {
  411. return err
  412. }
  413. embeddings[i] = normalize(embedding)
  414. return nil
  415. })
  416. }
  417. if err := g.Wait(); err != nil {
  418. slog.Error("embedding generation failed", "error", err)
  419. c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)})
  420. return
  421. }
  422. resp := api.EmbedResponse{
  423. Model: req.Model,
  424. Embeddings: embeddings,
  425. TotalDuration: time.Since(checkpointStart),
  426. LoadDuration: checkpointLoaded.Sub(checkpointStart),
  427. PromptEvalCount: count,
  428. }
  429. c.JSON(http.StatusOK, resp)
  430. }
  431. func normalize(vec []float32) []float32 {
  432. var sum float32
  433. for _, v := range vec {
  434. sum += v * v
  435. }
  436. norm := float32(0.0)
  437. if sum > 0 {
  438. norm = float32(1.0 / math.Sqrt(float64(sum)))
  439. }
  440. for i := range vec {
  441. vec[i] *= norm
  442. }
  443. return vec
  444. }
  445. func (s *Server) EmbeddingsHandler(c *gin.Context) {
  446. var req api.EmbeddingRequest
  447. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  448. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  449. return
  450. } else if err != nil {
  451. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  452. return
  453. }
  454. name := model.ParseName(req.Model)
  455. if !name.IsValid() {
  456. c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  457. return
  458. }
  459. r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
  460. if err != nil {
  461. handleScheduleError(c, req.Model, err)
  462. return
  463. }
  464. // an empty request loads the model
  465. if req.Prompt == "" {
  466. c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
  467. return
  468. }
  469. embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
  470. if err != nil {
  471. slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
  472. c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embedding: %v", err)})
  473. return
  474. }
  475. var e []float64
  476. for _, v := range embedding {
  477. e = append(e, float64(v))
  478. }
  479. resp := api.EmbeddingResponse{
  480. Embedding: e,
  481. }
  482. c.JSON(http.StatusOK, resp)
  483. }
  484. func (s *Server) PullHandler(c *gin.Context) {
  485. var req api.PullRequest
  486. err := c.ShouldBindJSON(&req)
  487. switch {
  488. case errors.Is(err, io.EOF):
  489. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  490. return
  491. case err != nil:
  492. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  493. return
  494. }
  495. name := model.ParseName(cmp.Or(req.Model, req.Name))
  496. if !name.IsValid() {
  497. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
  498. return
  499. }
  500. name, err = getExistingName(name)
  501. if err != nil {
  502. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  503. return
  504. }
  505. ch := make(chan any)
  506. go func() {
  507. defer close(ch)
  508. fn := func(r api.ProgressResponse) {
  509. ch <- r
  510. }
  511. regOpts := &registryOptions{
  512. Insecure: req.Insecure,
  513. }
  514. ctx, cancel := context.WithCancel(c.Request.Context())
  515. defer cancel()
  516. if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
  517. ch <- gin.H{"error": err.Error()}
  518. }
  519. }()
  520. if req.Stream != nil && !*req.Stream {
  521. waitForStream(c, ch)
  522. return
  523. }
  524. streamResponse(c, ch)
  525. }
  526. func (s *Server) PushHandler(c *gin.Context) {
  527. var req api.PushRequest
  528. err := c.ShouldBindJSON(&req)
  529. switch {
  530. case errors.Is(err, io.EOF):
  531. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  532. return
  533. case err != nil:
  534. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  535. return
  536. }
  537. var mname string
  538. if req.Model != "" {
  539. mname = req.Model
  540. } else if req.Name != "" {
  541. mname = req.Name
  542. } else {
  543. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  544. return
  545. }
  546. ch := make(chan any)
  547. go func() {
  548. defer close(ch)
  549. fn := func(r api.ProgressResponse) {
  550. ch <- r
  551. }
  552. regOpts := &registryOptions{
  553. Insecure: req.Insecure,
  554. }
  555. ctx, cancel := context.WithCancel(c.Request.Context())
  556. defer cancel()
  557. name, err := getExistingName(model.ParseName(mname))
  558. if err != nil {
  559. ch <- gin.H{"error": err.Error()}
  560. return
  561. }
  562. if err := PushModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
  563. ch <- gin.H{"error": err.Error()}
  564. }
  565. }()
  566. if req.Stream != nil && !*req.Stream {
  567. waitForStream(c, ch)
  568. return
  569. }
  570. streamResponse(c, ch)
  571. }
  572. // getExistingName searches the models directory for the longest prefix match of
  573. // the input name and returns the input name with all existing parts replaced
  574. // with each part found. If no parts are found, the input name is returned as
  575. // is.
  576. func getExistingName(n model.Name) (model.Name, error) {
  577. var zero model.Name
  578. existing, err := Manifests(true)
  579. if err != nil {
  580. return zero, err
  581. }
  582. var set model.Name // tracks parts already canonicalized
  583. for e := range existing {
  584. if set.Host == "" && strings.EqualFold(e.Host, n.Host) {
  585. n.Host = e.Host
  586. }
  587. if set.Namespace == "" && strings.EqualFold(e.Namespace, n.Namespace) {
  588. n.Namespace = e.Namespace
  589. }
  590. if set.Model == "" && strings.EqualFold(e.Model, n.Model) {
  591. n.Model = e.Model
  592. }
  593. if set.Tag == "" && strings.EqualFold(e.Tag, n.Tag) {
  594. n.Tag = e.Tag
  595. }
  596. }
  597. return n, nil
  598. }
  599. func (s *Server) DeleteHandler(c *gin.Context) {
  600. var r api.DeleteRequest
  601. if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
  602. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  603. return
  604. } else if err != nil {
  605. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  606. return
  607. }
  608. n := model.ParseName(cmp.Or(r.Model, r.Name))
  609. if !n.IsValid() {
  610. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
  611. return
  612. }
  613. n, err := getExistingName(n)
  614. if err != nil {
  615. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
  616. return
  617. }
  618. m, err := ParseNamedManifest(n)
  619. if err != nil {
  620. switch {
  621. case os.IsNotExist(err):
  622. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
  623. default:
  624. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  625. }
  626. return
  627. }
  628. if err := m.Remove(); err != nil {
  629. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  630. return
  631. }
  632. if err := m.RemoveLayers(); err != nil {
  633. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  634. return
  635. }
  636. }
  637. func (s *Server) ShowHandler(c *gin.Context) {
  638. var req api.ShowRequest
  639. err := c.ShouldBindJSON(&req)
  640. switch {
  641. case errors.Is(err, io.EOF):
  642. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  643. return
  644. case err != nil:
  645. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  646. return
  647. }
  648. if req.Model != "" {
  649. // noop
  650. } else if req.Name != "" {
  651. req.Model = req.Name
  652. } else {
  653. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  654. return
  655. }
  656. resp, err := GetModelInfo(req)
  657. if err != nil {
  658. switch {
  659. case os.IsNotExist(err):
  660. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
  661. case err.Error() == errtypes.InvalidModelNameErrMsg:
  662. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  663. default:
  664. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  665. }
  666. return
  667. }
  668. c.JSON(http.StatusOK, resp)
  669. }
  670. func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
  671. name := model.ParseName(req.Model)
  672. if !name.IsValid() {
  673. return nil, errModelPathInvalid
  674. }
  675. name, err := getExistingName(name)
  676. if err != nil {
  677. return nil, err
  678. }
  679. m, err := GetModel(name.String())
  680. if err != nil {
  681. return nil, err
  682. }
  683. modelDetails := api.ModelDetails{
  684. ParentModel: m.ParentModel,
  685. Format: m.Config.ModelFormat,
  686. Family: m.Config.ModelFamily,
  687. Families: m.Config.ModelFamilies,
  688. ParameterSize: m.Config.ModelType,
  689. QuantizationLevel: m.Config.FileType,
  690. }
  691. if req.System != "" {
  692. m.System = req.System
  693. }
  694. msgs := make([]api.Message, len(m.Messages))
  695. for i, msg := range m.Messages {
  696. msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
  697. }
  698. manifest, err := ParseNamedManifest(name)
  699. if err != nil {
  700. return nil, err
  701. }
  702. resp := &api.ShowResponse{
  703. License: strings.Join(m.License, "\n"),
  704. System: m.System,
  705. Template: m.Template.String(),
  706. Details: modelDetails,
  707. Messages: msgs,
  708. ModifiedAt: manifest.fi.ModTime(),
  709. }
  710. var params []string
  711. cs := 30
  712. for k, v := range m.Options {
  713. switch val := v.(type) {
  714. case []interface{}:
  715. for _, nv := range val {
  716. params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
  717. }
  718. default:
  719. params = append(params, fmt.Sprintf("%-*s %#v", cs, k, v))
  720. }
  721. }
  722. resp.Parameters = strings.Join(params, "\n")
  723. for k, v := range req.Options {
  724. if _, ok := req.Options[k]; ok {
  725. m.Options[k] = v
  726. }
  727. }
  728. var sb strings.Builder
  729. fmt.Fprintln(&sb, "# Modelfile generated by \"ollama show\"")
  730. fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
  731. fmt.Fprintf(&sb, "# FROM %s\n\n", m.ShortName)
  732. fmt.Fprint(&sb, m.String())
  733. resp.Modelfile = sb.String()
  734. kvData, err := getKVData(m.ModelPath, req.Verbose)
  735. if err != nil {
  736. return nil, err
  737. }
  738. delete(kvData, "general.name")
  739. delete(kvData, "tokenizer.chat_template")
  740. resp.ModelInfo = kvData
  741. if len(m.ProjectorPaths) > 0 {
  742. projectorData, err := getKVData(m.ProjectorPaths[0], req.Verbose)
  743. if err != nil {
  744. return nil, err
  745. }
  746. resp.ProjectorInfo = projectorData
  747. }
  748. return resp, nil
  749. }
  750. func getKVData(digest string, verbose bool) (llm.KV, error) {
  751. maxArraySize := 0
  752. if verbose {
  753. maxArraySize = -1
  754. }
  755. kvData, err := llm.LoadModel(digest, maxArraySize)
  756. if err != nil {
  757. return nil, err
  758. }
  759. kv := kvData.KV()
  760. if !verbose {
  761. for k := range kv {
  762. if t, ok := kv[k].([]any); len(t) > 5 && ok {
  763. kv[k] = []any{}
  764. }
  765. }
  766. }
  767. return kv, nil
  768. }
  769. func (s *Server) ListHandler(c *gin.Context) {
  770. ms, err := Manifests(true)
  771. if err != nil {
  772. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  773. return
  774. }
  775. models := []api.ListModelResponse{}
  776. for n, m := range ms {
  777. var cf ConfigV2
  778. if m.Config.Digest != "" {
  779. f, err := m.Config.Open()
  780. if err != nil {
  781. slog.Warn("bad manifest filepath", "name", n, "error", err)
  782. continue
  783. }
  784. defer f.Close()
  785. if err := json.NewDecoder(f).Decode(&cf); err != nil {
  786. slog.Warn("bad manifest config", "name", n, "error", err)
  787. continue
  788. }
  789. }
  790. // tag should never be masked
  791. models = append(models, api.ListModelResponse{
  792. Model: n.DisplayShortest(),
  793. Name: n.DisplayShortest(),
  794. Size: m.Size(),
  795. Digest: m.digest,
  796. ModifiedAt: m.fi.ModTime(),
  797. Details: api.ModelDetails{
  798. Format: cf.ModelFormat,
  799. Family: cf.ModelFamily,
  800. Families: cf.ModelFamilies,
  801. ParameterSize: cf.ModelType,
  802. QuantizationLevel: cf.FileType,
  803. },
  804. })
  805. }
  806. slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
  807. // most recently modified first
  808. return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
  809. })
  810. c.JSON(http.StatusOK, api.ListResponse{Models: models})
  811. }
  812. func (s *Server) CopyHandler(c *gin.Context) {
  813. var r api.CopyRequest
  814. if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
  815. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  816. return
  817. } else if err != nil {
  818. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  819. return
  820. }
  821. src := model.ParseName(r.Source)
  822. if !src.IsValid() {
  823. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
  824. return
  825. }
  826. src, err := getExistingName(src)
  827. if err != nil {
  828. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  829. return
  830. }
  831. dst := model.ParseName(r.Destination)
  832. if !dst.IsValid() {
  833. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)})
  834. return
  835. }
  836. dst, err = getExistingName(dst)
  837. if err != nil {
  838. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  839. return
  840. }
  841. if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
  842. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
  843. } else if err != nil {
  844. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  845. }
  846. }
  847. func (s *Server) HeadBlobHandler(c *gin.Context) {
  848. path, err := GetBlobsPath(c.Param("digest"))
  849. if err != nil {
  850. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  851. return
  852. }
  853. if _, err := os.Stat(path); err != nil {
  854. c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
  855. return
  856. }
  857. c.Status(http.StatusOK)
  858. }
  859. func (s *Server) CreateBlobHandler(c *gin.Context) {
  860. if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
  861. p, err := GetBlobsPath(ib)
  862. if err != nil {
  863. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  864. return
  865. }
  866. if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
  867. slog.Info("evicting intermediate blob which no longer exists", "digest", ib)
  868. delete(intermediateBlobs, c.Param("digest"))
  869. } else if err != nil {
  870. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  871. return
  872. } else {
  873. c.Status(http.StatusOK)
  874. return
  875. }
  876. }
  877. path, err := GetBlobsPath(c.Param("digest"))
  878. if err != nil {
  879. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  880. return
  881. }
  882. _, err = os.Stat(path)
  883. switch {
  884. case errors.Is(err, os.ErrNotExist):
  885. // noop
  886. case err != nil:
  887. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  888. return
  889. default:
  890. c.Status(http.StatusOK)
  891. return
  892. }
  893. layer, err := NewLayer(c.Request.Body, "")
  894. if err != nil {
  895. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  896. return
  897. }
  898. if layer.Digest != c.Param("digest") {
  899. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
  900. return
  901. }
  902. c.Status(http.StatusCreated)
  903. }
  904. func isLocalIP(ip netip.Addr) bool {
  905. if interfaces, err := net.Interfaces(); err == nil {
  906. for _, iface := range interfaces {
  907. addrs, err := iface.Addrs()
  908. if err != nil {
  909. continue
  910. }
  911. for _, a := range addrs {
  912. if parsed, _, err := net.ParseCIDR(a.String()); err == nil {
  913. if parsed.String() == ip.String() {
  914. return true
  915. }
  916. }
  917. }
  918. }
  919. }
  920. return false
  921. }
  922. func allowedHost(host string) bool {
  923. host = strings.ToLower(host)
  924. if host == "" || host == "localhost" {
  925. return true
  926. }
  927. if hostname, err := os.Hostname(); err == nil && host == strings.ToLower(hostname) {
  928. return true
  929. }
  930. tlds := []string{
  931. "localhost",
  932. "local",
  933. "internal",
  934. }
  935. // check if the host is a local TLD
  936. for _, tld := range tlds {
  937. if strings.HasSuffix(host, "."+tld) {
  938. return true
  939. }
  940. }
  941. return false
  942. }
  943. func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
  944. return func(c *gin.Context) {
  945. if addr == nil {
  946. c.Next()
  947. return
  948. }
  949. if addr, err := netip.ParseAddrPort(addr.String()); err == nil && !addr.Addr().IsLoopback() {
  950. c.Next()
  951. return
  952. }
  953. host, _, err := net.SplitHostPort(c.Request.Host)
  954. if err != nil {
  955. host = c.Request.Host
  956. }
  957. if addr, err := netip.ParseAddr(host); err == nil {
  958. if addr.IsLoopback() || addr.IsPrivate() || addr.IsUnspecified() || isLocalIP(addr) {
  959. c.Next()
  960. return
  961. }
  962. }
  963. if allowedHost(host) {
  964. if c.Request.Method == http.MethodOptions {
  965. c.AbortWithStatus(http.StatusNoContent)
  966. return
  967. }
  968. c.Next()
  969. return
  970. }
  971. c.AbortWithStatus(http.StatusForbidden)
  972. }
  973. }
  974. func (s *Server) GenerateRoutes() http.Handler {
  975. config := cors.DefaultConfig()
  976. config.AllowWildcard = true
  977. config.AllowBrowserExtensions = true
  978. config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
  979. openAIProperties := []string{"lang", "package-version", "os", "arch", "retry-count", "runtime", "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval"}
  980. for _, prop := range openAIProperties {
  981. config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
  982. }
  983. config.AllowOrigins = envconfig.Origins()
  984. r := gin.Default()
  985. r.Use(
  986. cors.New(config),
  987. allowedHostsMiddleware(s.addr),
  988. )
  989. r.POST("/api/pull", s.PullHandler)
  990. r.POST("/api/generate", s.GenerateHandler)
  991. r.POST("/api/chat", s.ChatHandler)
  992. r.POST("/api/embed", s.EmbedHandler)
  993. r.POST("/api/embeddings", s.EmbeddingsHandler)
  994. r.POST("/api/create", s.CreateHandler)
  995. r.POST("/api/push", s.PushHandler)
  996. r.POST("/api/copy", s.CopyHandler)
  997. r.DELETE("/api/delete", s.DeleteHandler)
  998. r.POST("/api/show", s.ShowHandler)
  999. r.POST("/api/blobs/:digest", s.CreateBlobHandler)
  1000. r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
  1001. r.GET("/api/ps", s.PsHandler)
  1002. // Compatibility endpoints
  1003. r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
  1004. r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
  1005. r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
  1006. r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
  1007. r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
  1008. for _, method := range []string{http.MethodGet, http.MethodHead} {
  1009. r.Handle(method, "/", func(c *gin.Context) {
  1010. c.String(http.StatusOK, "Ollama is running")
  1011. })
  1012. r.Handle(method, "/api/tags", s.ListHandler)
  1013. r.Handle(method, "/api/version", func(c *gin.Context) {
  1014. c.JSON(http.StatusOK, gin.H{"version": version.Version})
  1015. })
  1016. }
  1017. return r
  1018. }
  1019. func Serve(ln net.Listener) error {
  1020. level := slog.LevelInfo
  1021. if envconfig.Debug() {
  1022. level = slog.LevelDebug
  1023. }
  1024. slog.Info("server config", "env", envconfig.Values())
  1025. handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
  1026. Level: level,
  1027. AddSource: true,
  1028. ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
  1029. if attr.Key == slog.SourceKey {
  1030. source := attr.Value.Any().(*slog.Source)
  1031. source.File = filepath.Base(source.File)
  1032. }
  1033. return attr
  1034. },
  1035. })
  1036. slog.SetDefault(slog.New(handler))
  1037. blobsDir, err := GetBlobsPath("")
  1038. if err != nil {
  1039. return err
  1040. }
  1041. if err := fixBlobs(blobsDir); err != nil {
  1042. return err
  1043. }
  1044. if !envconfig.NoPrune() {
  1045. if _, err := Manifests(false); err != nil {
  1046. slog.Warn("corrupt manifests detected, skipping prune operation. Re-pull or delete to clear", "error", err)
  1047. } else {
  1048. // clean up unused layers and manifests
  1049. if err := PruneLayers(); err != nil {
  1050. return err
  1051. }
  1052. manifestsPath, err := GetManifestPath()
  1053. if err != nil {
  1054. return err
  1055. }
  1056. if err := PruneDirectory(manifestsPath); err != nil {
  1057. return err
  1058. }
  1059. }
  1060. }
  1061. ctx, done := context.WithCancel(context.Background())
  1062. schedCtx, schedDone := context.WithCancel(ctx)
  1063. sched := InitScheduler(schedCtx)
  1064. s := &Server{addr: ln.Addr(), sched: sched}
  1065. http.Handle("/", s.GenerateRoutes())
  1066. slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
  1067. srvr := &http.Server{
  1068. // Use http.DefaultServeMux so we get net/http/pprof for
  1069. // free.
  1070. //
  1071. // TODO(bmizerany): Decide if we want to make this
  1072. // configurable so it is not exposed by default, or allow
  1073. // users to bind it to a different port. This was a quick
  1074. // and easy way to get pprof, but it may not be the best
  1075. // way.
  1076. Handler: nil,
  1077. }
  1078. // listen for a ctrl+c and stop any loaded llm
  1079. signals := make(chan os.Signal, 1)
  1080. signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
  1081. go func() {
  1082. <-signals
  1083. srvr.Close()
  1084. schedDone()
  1085. sched.unloadAllRunners()
  1086. done()
  1087. }()
  1088. s.sched.Run(schedCtx)
  1089. // At startup we retrieve GPU information so we can get log messages before loading a model
  1090. // This will log warnings to the log in case we have problems with detected GPUs
  1091. gpus := discover.GetGPUInfo()
  1092. gpus.LogDetails()
  1093. err = srvr.Serve(ln)
  1094. // If server is closed from the signal handler, wait for the ctx to be done
  1095. // otherwise error out quickly
  1096. if !errors.Is(err, http.ErrServerClosed) {
  1097. return err
  1098. }
  1099. <-ctx.Done()
  1100. return nil
  1101. }
  1102. func waitForStream(c *gin.Context, ch chan interface{}) {
  1103. c.Header("Content-Type", "application/json")
  1104. for resp := range ch {
  1105. switch r := resp.(type) {
  1106. case api.ProgressResponse:
  1107. if r.Status == "success" {
  1108. c.JSON(http.StatusOK, r)
  1109. return
  1110. }
  1111. case gin.H:
  1112. status, ok := r["status"].(int)
  1113. if !ok {
  1114. status = http.StatusInternalServerError
  1115. }
  1116. if errorMsg, ok := r["error"].(string); ok {
  1117. c.JSON(status, gin.H{"error": errorMsg})
  1118. return
  1119. } else {
  1120. c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
  1121. return
  1122. }
  1123. default:
  1124. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"})
  1125. return
  1126. }
  1127. }
  1128. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"})
  1129. }
  1130. func streamResponse(c *gin.Context, ch chan any) {
  1131. c.Header("Content-Type", "application/x-ndjson")
  1132. c.Stream(func(w io.Writer) bool {
  1133. val, ok := <-ch
  1134. if !ok {
  1135. return false
  1136. }
  1137. bts, err := json.Marshal(val)
  1138. if err != nil {
  1139. slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
  1140. return false
  1141. }
  1142. // Delineate chunks with new-line delimiter
  1143. bts = append(bts, '\n')
  1144. if _, err := w.Write(bts); err != nil {
  1145. slog.Info(fmt.Sprintf("streamResponse: w.Write failed with %s", err))
  1146. return false
  1147. }
  1148. return true
  1149. })
  1150. }
  1151. func (s *Server) PsHandler(c *gin.Context) {
  1152. models := []api.ProcessModelResponse{}
  1153. for _, v := range s.sched.loaded {
  1154. model := v.model
  1155. modelDetails := api.ModelDetails{
  1156. Format: model.Config.ModelFormat,
  1157. Family: model.Config.ModelFamily,
  1158. Families: model.Config.ModelFamilies,
  1159. ParameterSize: model.Config.ModelType,
  1160. QuantizationLevel: model.Config.FileType,
  1161. }
  1162. mr := api.ProcessModelResponse{
  1163. Model: model.ShortName,
  1164. Name: model.ShortName,
  1165. Size: int64(v.estimatedTotal),
  1166. SizeVRAM: int64(v.estimatedVRAM),
  1167. Digest: model.Digest,
  1168. Details: modelDetails,
  1169. ExpiresAt: v.expiresAt,
  1170. }
  1171. // The scheduler waits to set expiresAt, so if a model is loading it's
  1172. // possible that it will be set to the unix epoch. For those cases, just
  1173. // calculate the time w/ the sessionDuration instead.
  1174. var epoch time.Time
  1175. if v.expiresAt == epoch {
  1176. mr.ExpiresAt = time.Now().Add(v.sessionDuration)
  1177. }
  1178. models = append(models, mr)
  1179. }
  1180. slices.SortStableFunc(models, func(i, j api.ProcessModelResponse) int {
  1181. // longest duration remaining listed first
  1182. return cmp.Compare(j.ExpiresAt.Unix(), i.ExpiresAt.Unix())
  1183. })
  1184. c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
  1185. }
  1186. func (s *Server) ChatHandler(c *gin.Context) {
  1187. checkpointStart := time.Now()
  1188. var req api.ChatRequest
  1189. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  1190. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  1191. return
  1192. } else if err != nil {
  1193. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  1194. return
  1195. }
  1196. // expire the runner
  1197. if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
  1198. model, err := GetModel(req.Model)
  1199. if err != nil {
  1200. switch {
  1201. case os.IsNotExist(err):
  1202. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
  1203. case err.Error() == errtypes.InvalidModelNameErrMsg:
  1204. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  1205. default:
  1206. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  1207. }
  1208. return
  1209. }
  1210. s.sched.expireRunner(model)
  1211. c.JSON(http.StatusOK, api.ChatResponse{
  1212. Model: req.Model,
  1213. CreatedAt: time.Now().UTC(),
  1214. Message: api.Message{Role: "assistant"},
  1215. Done: true,
  1216. DoneReason: "unload",
  1217. })
  1218. return
  1219. }
  1220. caps := []Capability{CapabilityCompletion}
  1221. if len(req.Tools) > 0 {
  1222. caps = append(caps, CapabilityTools)
  1223. }
  1224. name := model.ParseName(req.Model)
  1225. if !name.IsValid() {
  1226. c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  1227. return
  1228. }
  1229. name, err := getExistingName(name)
  1230. if err != nil {
  1231. c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  1232. return
  1233. }
  1234. r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
  1235. if errors.Is(err, errCapabilityCompletion) {
  1236. c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
  1237. return
  1238. } else if err != nil {
  1239. handleScheduleError(c, req.Model, err)
  1240. return
  1241. }
  1242. checkpointLoaded := time.Now()
  1243. if len(req.Messages) == 0 {
  1244. c.JSON(http.StatusOK, api.ChatResponse{
  1245. Model: req.Model,
  1246. CreatedAt: time.Now().UTC(),
  1247. Message: api.Message{Role: "assistant"},
  1248. Done: true,
  1249. DoneReason: "load",
  1250. })
  1251. return
  1252. }
  1253. msgs := append(m.Messages, req.Messages...)
  1254. if req.Messages[0].Role != "system" && m.System != "" {
  1255. msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
  1256. }
  1257. prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
  1258. if err != nil {
  1259. slog.Error("chat prompt error", "error", err)
  1260. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  1261. return
  1262. }
  1263. slog.Debug("chat request", "images", len(images), "prompt", prompt)
  1264. ch := make(chan any)
  1265. go func() {
  1266. defer close(ch)
  1267. var sb strings.Builder
  1268. var toolCallIndex int = 0
  1269. if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
  1270. Prompt: prompt,
  1271. Images: images,
  1272. Format: req.Format,
  1273. LogProbs: req.LogProbs,
  1274. Options: opts,
  1275. }, func(r llm.CompletionResponse) {
  1276. res := api.ChatResponse{
  1277. Model: req.Model,
  1278. CreatedAt: time.Now().UTC(),
  1279. Message: api.Message{Role: "assistant", Content: r.Content},
  1280. Done: r.Done,
  1281. DoneReason: r.DoneReason,
  1282. Metrics: api.Metrics{
  1283. PromptEvalCount: r.PromptEvalCount,
  1284. PromptEvalDuration: r.PromptEvalDuration,
  1285. EvalCount: r.EvalCount,
  1286. EvalDuration: r.EvalDuration,
  1287. },
  1288. }
  1289. for _, p := range r.LogProbs {
  1290. res.LogProbs = append(res.LogProbs, api.TokenProbs{
  1291. TokenID: p.TokenID,
  1292. LogProb: p.LogProb,
  1293. Token: p.Token,
  1294. })
  1295. }
  1296. if r.Done {
  1297. res.TotalDuration = time.Since(checkpointStart)
  1298. res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
  1299. }
  1300. // TODO: tool call checking and filtering should be moved outside of this callback once streaming
  1301. // however this was a simple change for now without reworking streaming logic of this (and other)
  1302. // handlers
  1303. if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
  1304. ch <- res
  1305. return
  1306. }
  1307. // Streaming tool calls:
  1308. // If tools are recognized, use a flag to track the sending of a tool downstream
  1309. // This ensures that content is cleared from the message on the last chunk sent
  1310. sb.WriteString(r.Content)
  1311. if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
  1312. res.Message.ToolCalls = toolCalls
  1313. for i := range toolCalls {
  1314. toolCalls[i].Function.Index = toolCallIndex
  1315. toolCallIndex++
  1316. }
  1317. res.Message.Content = ""
  1318. sb.Reset()
  1319. ch <- res
  1320. return
  1321. }
  1322. if r.Done {
  1323. // Send any remaining content if no tool calls were detected
  1324. if toolCallIndex == 0 {
  1325. res.Message.Content = sb.String()
  1326. }
  1327. ch <- res
  1328. }
  1329. }); err != nil {
  1330. ch <- gin.H{"error": err.Error()}
  1331. }
  1332. }()
  1333. if req.Stream != nil && !*req.Stream {
  1334. var resp api.ChatResponse
  1335. var sb strings.Builder
  1336. for rr := range ch {
  1337. switch t := rr.(type) {
  1338. case api.ChatResponse:
  1339. sb.WriteString(t.Message.Content)
  1340. resp = t
  1341. case gin.H:
  1342. msg, ok := t["error"].(string)
  1343. if !ok {
  1344. msg = "unexpected error format in response"
  1345. }
  1346. c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
  1347. return
  1348. default:
  1349. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
  1350. return
  1351. }
  1352. }
  1353. resp.Message.Content = sb.String()
  1354. if len(req.Tools) > 0 {
  1355. if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
  1356. resp.Message.ToolCalls = toolCalls
  1357. resp.Message.Content = ""
  1358. }
  1359. }
  1360. c.JSON(http.StatusOK, resp)
  1361. return
  1362. }
  1363. streamResponse(c, ch)
  1364. }
  1365. func handleScheduleError(c *gin.Context, name string, err error) {
  1366. switch {
  1367. case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
  1368. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  1369. case errors.Is(err, context.Canceled):
  1370. c.JSON(499, gin.H{"error": "request canceled"})
  1371. case errors.Is(err, ErrMaxQueue):
  1372. c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
  1373. case errors.Is(err, os.ErrNotExist):
  1374. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
  1375. default:
  1376. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  1377. }
  1378. }