routes.go 39 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583
  1. package server
  2. import (
  3. "bytes"
  4. "cmp"
  5. "context"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "math"
  12. "net"
  13. "net/http"
  14. "net/netip"
  15. "os"
  16. "os/signal"
  17. "path/filepath"
  18. "slices"
  19. "strings"
  20. "syscall"
  21. "time"
  22. "github.com/gin-contrib/cors"
  23. "github.com/gin-gonic/gin"
  24. "github.com/ollama/ollama/api"
  25. "github.com/ollama/ollama/envconfig"
  26. "github.com/ollama/ollama/gpu"
  27. "github.com/ollama/ollama/llm"
  28. "github.com/ollama/ollama/openai"
  29. "github.com/ollama/ollama/parser"
  30. "github.com/ollama/ollama/template"
  31. "github.com/ollama/ollama/types/errtypes"
  32. "github.com/ollama/ollama/types/model"
  33. "github.com/ollama/ollama/version"
  34. )
  35. var mode string = gin.DebugMode
  36. type Server struct {
  37. addr net.Addr
  38. sched *Scheduler
  39. }
  40. func init() {
  41. switch mode {
  42. case gin.DebugMode:
  43. case gin.ReleaseMode:
  44. case gin.TestMode:
  45. default:
  46. mode = gin.DebugMode
  47. }
  48. gin.SetMode(mode)
  49. }
  50. var (
  51. errRequired = errors.New("is required")
  52. errBadTemplate = errors.New("template error")
  53. )
  54. func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
  55. opts := api.DefaultOptions()
  56. if err := opts.FromMap(model.Options); err != nil {
  57. return api.Options{}, err
  58. }
  59. if err := opts.FromMap(requestOpts); err != nil {
  60. return api.Options{}, err
  61. }
  62. return opts, nil
  63. }
  64. // scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
  65. // It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
  66. func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
  67. if name == "" {
  68. return nil, nil, nil, fmt.Errorf("model %w", errRequired)
  69. }
  70. model, err := GetModel(name)
  71. if err != nil {
  72. return nil, nil, nil, err
  73. }
  74. if err := model.CheckCapabilities(caps...); err != nil {
  75. return nil, nil, nil, fmt.Errorf("%s %w", name, err)
  76. }
  77. opts, err := modelOptions(model, requestOpts)
  78. if err != nil {
  79. return nil, nil, nil, err
  80. }
  81. runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
  82. var runner *runnerRef
  83. select {
  84. case runner = <-runnerCh:
  85. case err = <-errCh:
  86. return nil, nil, nil, err
  87. }
  88. return runner.llama, model, &opts, nil
  89. }
  90. func (s *Server) GenerateHandler(c *gin.Context) {
  91. checkpointStart := time.Now()
  92. var req api.GenerateRequest
  93. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  94. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  95. return
  96. } else if err != nil {
  97. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  98. return
  99. }
  100. if req.Format != "" && req.Format != "json" {
  101. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
  102. return
  103. } else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
  104. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
  105. return
  106. }
  107. caps := []Capability{CapabilityCompletion}
  108. if req.Suffix != "" {
  109. caps = append(caps, CapabilityInsert)
  110. }
  111. r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
  112. if errors.Is(err, errCapabilityCompletion) {
  113. c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
  114. return
  115. } else if err != nil {
  116. handleScheduleError(c, req.Model, err)
  117. return
  118. }
  119. checkpointLoaded := time.Now()
  120. if req.Prompt == "" {
  121. c.JSON(http.StatusOK, api.GenerateResponse{
  122. Model: req.Model,
  123. CreatedAt: time.Now().UTC(),
  124. Done: true,
  125. DoneReason: "load",
  126. })
  127. return
  128. }
  129. images := make([]llm.ImageData, len(req.Images))
  130. for i := range req.Images {
  131. images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
  132. }
  133. prompt := req.Prompt
  134. if !req.Raw {
  135. tmpl := m.Template
  136. if req.Template != "" {
  137. tmpl, err = template.Parse(req.Template)
  138. if err != nil {
  139. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  140. return
  141. }
  142. }
  143. var values template.Values
  144. if req.Suffix != "" {
  145. values.Prompt = prompt
  146. values.Suffix = req.Suffix
  147. } else {
  148. var msgs []api.Message
  149. if req.System != "" {
  150. msgs = append(msgs, api.Message{Role: "system", Content: req.System})
  151. } else if m.System != "" {
  152. msgs = append(msgs, api.Message{Role: "system", Content: m.System})
  153. }
  154. if req.Context == nil {
  155. msgs = append(msgs, m.Messages...)
  156. }
  157. for _, i := range images {
  158. msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
  159. }
  160. values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
  161. }
  162. var b bytes.Buffer
  163. if req.Context != nil {
  164. s, err := r.Detokenize(c.Request.Context(), req.Context)
  165. if err != nil {
  166. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  167. return
  168. }
  169. b.WriteString(s)
  170. }
  171. if err := tmpl.Execute(&b, values); err != nil {
  172. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  173. return
  174. }
  175. prompt = b.String()
  176. }
  177. slog.Debug("generate request", "prompt", prompt, "images", images)
  178. ch := make(chan any)
  179. go func() {
  180. // TODO (jmorganca): avoid building the response twice both here and below
  181. var sb strings.Builder
  182. defer close(ch)
  183. if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
  184. Prompt: prompt,
  185. Images: images,
  186. Format: req.Format,
  187. Options: opts,
  188. }, func(cr llm.CompletionResponse) {
  189. res := api.GenerateResponse{
  190. Model: req.Model,
  191. CreatedAt: time.Now().UTC(),
  192. Response: cr.Content,
  193. Done: cr.Done,
  194. DoneReason: cr.DoneReason,
  195. Metrics: api.Metrics{
  196. PromptEvalCount: cr.PromptEvalCount,
  197. PromptEvalDuration: cr.PromptEvalDuration,
  198. EvalCount: cr.EvalCount,
  199. EvalDuration: cr.EvalDuration,
  200. },
  201. }
  202. if _, err := sb.WriteString(cr.Content); err != nil {
  203. ch <- gin.H{"error": err.Error()}
  204. }
  205. if cr.Done {
  206. res.TotalDuration = time.Since(checkpointStart)
  207. res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
  208. if !req.Raw {
  209. tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
  210. if err != nil {
  211. ch <- gin.H{"error": err.Error()}
  212. return
  213. }
  214. res.Context = tokens
  215. }
  216. }
  217. ch <- res
  218. }); err != nil {
  219. ch <- gin.H{"error": err.Error()}
  220. }
  221. }()
  222. if req.Stream != nil && !*req.Stream {
  223. var r api.GenerateResponse
  224. var sb strings.Builder
  225. for rr := range ch {
  226. switch t := rr.(type) {
  227. case api.GenerateResponse:
  228. sb.WriteString(t.Response)
  229. r = t
  230. case gin.H:
  231. msg, ok := t["error"].(string)
  232. if !ok {
  233. msg = "unexpected error format in response"
  234. }
  235. c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
  236. return
  237. default:
  238. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
  239. return
  240. }
  241. }
  242. r.Response = sb.String()
  243. c.JSON(http.StatusOK, r)
  244. return
  245. }
  246. streamResponse(c, ch)
  247. }
  248. func (s *Server) EmbedHandler(c *gin.Context) {
  249. checkpointStart := time.Now()
  250. var req api.EmbedRequest
  251. err := c.ShouldBindJSON(&req)
  252. switch {
  253. case errors.Is(err, io.EOF):
  254. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  255. return
  256. case err != nil:
  257. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  258. return
  259. }
  260. truncate := true
  261. if req.Truncate != nil && !*req.Truncate {
  262. truncate = false
  263. }
  264. var input []string
  265. switch i := req.Input.(type) {
  266. case string:
  267. if len(i) > 0 {
  268. input = append(input, i)
  269. }
  270. case []any:
  271. for _, v := range i {
  272. if _, ok := v.(string); !ok {
  273. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
  274. return
  275. }
  276. input = append(input, v.(string))
  277. }
  278. default:
  279. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
  280. return
  281. }
  282. if len(input) == 0 {
  283. c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
  284. return
  285. }
  286. r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
  287. if err != nil {
  288. handleScheduleError(c, req.Model, err)
  289. return
  290. }
  291. checkpointLoaded := time.Now()
  292. kvData, err := getKVData(m.ModelPath, false)
  293. if err != nil {
  294. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  295. return
  296. }
  297. for i, s := range input {
  298. tokens, err := r.Tokenize(c.Request.Context(), s)
  299. if err != nil {
  300. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  301. return
  302. }
  303. ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
  304. if len(tokens) > ctxLen {
  305. if !truncate {
  306. c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
  307. return
  308. }
  309. tokens = tokens[:ctxLen]
  310. s, err = r.Detokenize(c.Request.Context(), tokens)
  311. if err != nil {
  312. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  313. return
  314. }
  315. }
  316. input[i] = s
  317. }
  318. embeddings, err := r.Embed(c.Request.Context(), input)
  319. if err != nil {
  320. slog.Error("embedding generation failed", "error", err)
  321. c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
  322. return
  323. }
  324. for i, e := range embeddings.Embedding {
  325. embeddings.Embedding[i] = normalize(e)
  326. }
  327. resp := api.EmbedResponse{
  328. Model: req.Model,
  329. Embeddings: embeddings.Embedding,
  330. TotalDuration: time.Since(checkpointStart),
  331. LoadDuration: checkpointLoaded.Sub(checkpointStart),
  332. PromptEvalCount: embeddings.PromptEvalCount,
  333. }
  334. c.JSON(http.StatusOK, resp)
  335. }
  336. func normalize(vec []float32) []float32 {
  337. var sum float32
  338. for _, v := range vec {
  339. sum += v * v
  340. }
  341. norm := float32(0.0)
  342. if sum > 0 {
  343. norm = float32(1.0 / math.Sqrt(float64(sum)))
  344. }
  345. for i := range vec {
  346. vec[i] *= norm
  347. }
  348. return vec
  349. }
  350. func (s *Server) EmbeddingsHandler(c *gin.Context) {
  351. var req api.EmbeddingRequest
  352. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  353. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  354. return
  355. } else if err != nil {
  356. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  357. return
  358. }
  359. r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
  360. if err != nil {
  361. handleScheduleError(c, req.Model, err)
  362. return
  363. }
  364. // an empty request loads the model
  365. if req.Prompt == "" {
  366. c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
  367. return
  368. }
  369. embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
  370. if err != nil {
  371. slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
  372. c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
  373. return
  374. }
  375. embedding := make([]float64, len(embeddings.Embedding[0]))
  376. for i, v := range embeddings.Embedding[0] {
  377. embedding[i] = float64(v)
  378. }
  379. resp := api.EmbeddingResponse{
  380. Embedding: embedding,
  381. }
  382. c.JSON(http.StatusOK, resp)
  383. }
  384. func (s *Server) PullModelHandler(c *gin.Context) {
  385. var req api.PullRequest
  386. err := c.ShouldBindJSON(&req)
  387. switch {
  388. case errors.Is(err, io.EOF):
  389. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  390. return
  391. case err != nil:
  392. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  393. return
  394. }
  395. name := model.ParseName(cmp.Or(req.Model, req.Name))
  396. if !name.IsValid() {
  397. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
  398. return
  399. }
  400. if err := checkNameExists(name); err != nil {
  401. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  402. return
  403. }
  404. ch := make(chan any)
  405. go func() {
  406. defer close(ch)
  407. fn := func(r api.ProgressResponse) {
  408. ch <- r
  409. }
  410. regOpts := &registryOptions{
  411. Insecure: req.Insecure,
  412. }
  413. ctx, cancel := context.WithCancel(c.Request.Context())
  414. defer cancel()
  415. if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
  416. ch <- gin.H{"error": err.Error()}
  417. }
  418. }()
  419. if req.Stream != nil && !*req.Stream {
  420. waitForStream(c, ch)
  421. return
  422. }
  423. streamResponse(c, ch)
  424. }
  425. func (s *Server) PushModelHandler(c *gin.Context) {
  426. var req api.PushRequest
  427. err := c.ShouldBindJSON(&req)
  428. switch {
  429. case errors.Is(err, io.EOF):
  430. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  431. return
  432. case err != nil:
  433. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  434. return
  435. }
  436. var model string
  437. if req.Model != "" {
  438. model = req.Model
  439. } else if req.Name != "" {
  440. model = req.Name
  441. } else {
  442. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  443. return
  444. }
  445. ch := make(chan any)
  446. go func() {
  447. defer close(ch)
  448. fn := func(r api.ProgressResponse) {
  449. ch <- r
  450. }
  451. regOpts := &registryOptions{
  452. Insecure: req.Insecure,
  453. }
  454. ctx, cancel := context.WithCancel(c.Request.Context())
  455. defer cancel()
  456. if err := PushModel(ctx, model, regOpts, fn); err != nil {
  457. ch <- gin.H{"error": err.Error()}
  458. }
  459. }()
  460. if req.Stream != nil && !*req.Stream {
  461. waitForStream(c, ch)
  462. return
  463. }
  464. streamResponse(c, ch)
  465. }
  466. func checkNameExists(name model.Name) error {
  467. names, err := Manifests()
  468. if err != nil {
  469. return err
  470. }
  471. for n := range names {
  472. if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
  473. return errors.New("a model with that name already exists")
  474. }
  475. }
  476. return nil
  477. }
  478. func (s *Server) CreateModelHandler(c *gin.Context) {
  479. var r api.CreateRequest
  480. if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
  481. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  482. return
  483. } else if err != nil {
  484. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  485. return
  486. }
  487. name := model.ParseName(cmp.Or(r.Model, r.Name))
  488. if !name.IsValid() {
  489. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
  490. return
  491. }
  492. if err := checkNameExists(name); err != nil {
  493. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  494. return
  495. }
  496. if r.Path == "" && r.Modelfile == "" {
  497. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
  498. return
  499. }
  500. var sr io.Reader = strings.NewReader(r.Modelfile)
  501. if r.Path != "" && r.Modelfile == "" {
  502. f, err := os.Open(r.Path)
  503. if err != nil {
  504. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
  505. return
  506. }
  507. defer f.Close()
  508. sr = f
  509. }
  510. f, err := parser.ParseFile(sr)
  511. if err != nil {
  512. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  513. return
  514. }
  515. ch := make(chan any)
  516. go func() {
  517. defer close(ch)
  518. fn := func(resp api.ProgressResponse) {
  519. ch <- resp
  520. }
  521. ctx, cancel := context.WithCancel(c.Request.Context())
  522. defer cancel()
  523. quantization := cmp.Or(r.Quantize, r.Quantization)
  524. if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); errors.Is(err, errBadTemplate) {
  525. ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
  526. } else if err != nil {
  527. ch <- gin.H{"error": err.Error()}
  528. }
  529. }()
  530. if r.Stream != nil && !*r.Stream {
  531. waitForStream(c, ch)
  532. return
  533. }
  534. streamResponse(c, ch)
  535. }
  536. func (s *Server) DeleteModelHandler(c *gin.Context) {
  537. var r api.DeleteRequest
  538. if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
  539. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  540. return
  541. } else if err != nil {
  542. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  543. return
  544. }
  545. n := model.ParseName(cmp.Or(r.Model, r.Name))
  546. if !n.IsValid() {
  547. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
  548. return
  549. }
  550. m, err := ParseNamedManifest(n)
  551. if err != nil {
  552. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  553. return
  554. }
  555. if err := m.Remove(); err != nil {
  556. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  557. return
  558. }
  559. if err := m.RemoveLayers(); err != nil {
  560. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  561. return
  562. }
  563. }
  564. func (s *Server) ShowModelHandler(c *gin.Context) {
  565. var req api.ShowRequest
  566. err := c.ShouldBindJSON(&req)
  567. switch {
  568. case errors.Is(err, io.EOF):
  569. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  570. return
  571. case err != nil:
  572. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  573. return
  574. }
  575. if req.Model != "" {
  576. // noop
  577. } else if req.Name != "" {
  578. req.Model = req.Name
  579. } else {
  580. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  581. return
  582. }
  583. resp, err := GetModelInfo(req)
  584. if err != nil {
  585. switch {
  586. case os.IsNotExist(err):
  587. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
  588. case err.Error() == "invalid model name":
  589. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  590. default:
  591. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  592. }
  593. return
  594. }
  595. c.JSON(http.StatusOK, resp)
  596. }
  597. func manifestLayers(m *Manifest, exclude []string) (map[string]any, error) {
  598. r := map[string]any{
  599. "name": m.name.DisplayShortest(),
  600. "digest": m.digest,
  601. "size": m.Size(),
  602. "modified_at": m.fi.ModTime(),
  603. }
  604. excludeAll := slices.Contains(exclude, "all")
  605. excludeDetails := slices.Contains(exclude, "details")
  606. for _, layer := range m.Layers {
  607. var errExcludeKey = errors.New("exclude key")
  608. key, content, err := func() (string, any, error) {
  609. key := strings.TrimPrefix(layer.MediaType, "application/vnd.ollama.image.")
  610. if slices.Contains(exclude, key) || excludeAll {
  611. return "", nil, errExcludeKey
  612. }
  613. f, err := layer.Open()
  614. if err != nil {
  615. return "", nil, err
  616. }
  617. defer f.Close()
  618. switch key {
  619. case "model", "projector", "adapter":
  620. ggml, _, err := llm.DecodeGGML(f, 0)
  621. if err != nil {
  622. return "", nil, err
  623. }
  624. content := map[string]any{
  625. "architecture": ggml.KV().Architecture(),
  626. "file_type": ggml.KV().FileType().String(),
  627. "parameter_count": ggml.KV().ParameterCount(),
  628. }
  629. if !slices.Contains(exclude, key+".details") && !excludeAll && !excludeDetails {
  630. // exclude any extraneous or redundant fields
  631. delete(ggml.KV(), "general.basename")
  632. delete(ggml.KV(), "general.description")
  633. delete(ggml.KV(), "general.filename")
  634. delete(ggml.KV(), "general.finetune")
  635. delete(ggml.KV(), "general.languages")
  636. delete(ggml.KV(), "general.license")
  637. delete(ggml.KV(), "general.license.link")
  638. delete(ggml.KV(), "general.name")
  639. delete(ggml.KV(), "general.paramter_count")
  640. delete(ggml.KV(), "general.size_label")
  641. delete(ggml.KV(), "general.tags")
  642. delete(ggml.KV(), "general.type")
  643. delete(ggml.KV(), "general.quantization_version")
  644. delete(ggml.KV(), "tokenizer.chat_template")
  645. content["details"] = ggml.KV()
  646. }
  647. return key, content, nil
  648. case "params", "messages":
  649. var content any
  650. if err := json.NewDecoder(f).Decode(&content); err != nil {
  651. return "", nil, err
  652. }
  653. return key, content, nil
  654. case "template", "system", "license":
  655. bts, err := io.ReadAll(f)
  656. if err != nil {
  657. return "", nil, err
  658. }
  659. if key == "license" {
  660. return key, []any{string(bts)}, nil
  661. }
  662. return key, string(bts), nil
  663. }
  664. return layer.MediaType, nil, nil
  665. }()
  666. if errors.Is(err, errExcludeKey) {
  667. continue
  668. } else if err != nil {
  669. return nil, err
  670. }
  671. if s, ok := r[key].([]any); ok {
  672. r[key] = append(s, content)
  673. } else {
  674. r[key] = content
  675. }
  676. }
  677. return r, nil
  678. }
  679. func (s *Server) GetModelsHandler(c *gin.Context) {
  680. ms, err := Manifests()
  681. if err != nil {
  682. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  683. return
  684. }
  685. var rs []map[string]any
  686. for _, m := range ms {
  687. r, err := manifestLayers(m, c.QueryArray("exclude"))
  688. if err != nil {
  689. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  690. return
  691. }
  692. rs = append(rs, r)
  693. }
  694. slices.SortStableFunc(rs, func(i, j map[string]any) int {
  695. // most recently modified first
  696. return cmp.Compare(
  697. j["modified_at"].(time.Time).Unix(),
  698. i["modified_at"].(time.Time).Unix(),
  699. )
  700. })
  701. c.JSON(http.StatusOK, rs)
  702. }
  703. func (s *Server) GetModelHandler(c *gin.Context) {
  704. n := model.ParseName(strings.TrimPrefix(c.Param("model"), "/"))
  705. if !n.IsValid() {
  706. c.JSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
  707. return
  708. }
  709. m, err := ParseNamedManifest(n)
  710. if err != nil {
  711. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  712. return
  713. }
  714. r, err := manifestLayers(m, c.QueryArray("exclude"))
  715. if err != nil {
  716. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  717. return
  718. }
  719. c.JSON(http.StatusOK, r)
  720. }
  721. func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
  722. m, err := GetModel(req.Model)
  723. if err != nil {
  724. return nil, err
  725. }
  726. modelDetails := api.ModelDetails{
  727. ParentModel: m.ParentModel,
  728. Format: m.Config.ModelFormat,
  729. Family: m.Config.ModelFamily,
  730. Families: m.Config.ModelFamilies,
  731. ParameterSize: m.Config.ModelType,
  732. QuantizationLevel: m.Config.FileType,
  733. }
  734. if req.System != "" {
  735. m.System = req.System
  736. }
  737. msgs := make([]api.Message, len(m.Messages))
  738. for i, msg := range m.Messages {
  739. msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
  740. }
  741. n := model.ParseName(req.Model)
  742. if !n.IsValid() {
  743. return nil, errors.New("invalid model name")
  744. }
  745. manifest, err := ParseNamedManifest(n)
  746. if err != nil {
  747. return nil, err
  748. }
  749. resp := &api.ShowResponse{
  750. License: strings.Join(m.License, "\n"),
  751. System: m.System,
  752. Template: m.Template.String(),
  753. Details: modelDetails,
  754. Messages: msgs,
  755. ModifiedAt: manifest.fi.ModTime(),
  756. }
  757. var params []string
  758. cs := 30
  759. for k, v := range m.Options {
  760. switch val := v.(type) {
  761. case []interface{}:
  762. for _, nv := range val {
  763. params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
  764. }
  765. default:
  766. params = append(params, fmt.Sprintf("%-*s %#v", cs, k, v))
  767. }
  768. }
  769. resp.Parameters = strings.Join(params, "\n")
  770. for k, v := range req.Options {
  771. if _, ok := req.Options[k]; ok {
  772. m.Options[k] = v
  773. }
  774. }
  775. var sb strings.Builder
  776. fmt.Fprintln(&sb, "# Modelfile generated by \"ollama show\"")
  777. fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
  778. fmt.Fprintf(&sb, "# FROM %s\n\n", m.ShortName)
  779. fmt.Fprint(&sb, m.String())
  780. resp.Modelfile = sb.String()
  781. kvData, err := getKVData(m.ModelPath, req.Verbose)
  782. if err != nil {
  783. return nil, err
  784. }
  785. delete(kvData, "general.name")
  786. delete(kvData, "tokenizer.chat_template")
  787. resp.ModelInfo = kvData
  788. if len(m.ProjectorPaths) > 0 {
  789. projectorData, err := getKVData(m.ProjectorPaths[0], req.Verbose)
  790. if err != nil {
  791. return nil, err
  792. }
  793. resp.ProjectorInfo = projectorData
  794. }
  795. return resp, nil
  796. }
  797. func getKVData(digest string, verbose bool) (llm.KV, error) {
  798. maxArraySize := 0
  799. if verbose {
  800. maxArraySize = -1
  801. }
  802. kvData, err := llm.LoadModel(digest, maxArraySize)
  803. if err != nil {
  804. return nil, err
  805. }
  806. kv := kvData.KV()
  807. if !verbose {
  808. for k := range kv {
  809. if t, ok := kv[k].([]any); len(t) > 5 && ok {
  810. kv[k] = []any{}
  811. }
  812. }
  813. }
  814. return kv, nil
  815. }
  816. func (s *Server) ListModelsHandler(c *gin.Context) {
  817. ms, err := Manifests()
  818. if err != nil {
  819. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  820. return
  821. }
  822. models := []api.ListModelResponse{}
  823. for n, m := range ms {
  824. f, err := m.Config.Open()
  825. if err != nil {
  826. slog.Warn("bad manifest filepath", "name", n, "error", err)
  827. continue
  828. }
  829. defer f.Close()
  830. var cf ConfigV2
  831. if err := json.NewDecoder(f).Decode(&cf); err != nil {
  832. slog.Warn("bad manifest config", "name", n, "error", err)
  833. continue
  834. }
  835. // tag should never be masked
  836. models = append(models, api.ListModelResponse{
  837. Model: n.DisplayShortest(),
  838. Name: n.DisplayShortest(),
  839. Size: m.Size(),
  840. Digest: m.digest,
  841. ModifiedAt: m.fi.ModTime(),
  842. Details: api.ModelDetails{
  843. Format: cf.ModelFormat,
  844. Family: cf.ModelFamily,
  845. Families: cf.ModelFamilies,
  846. ParameterSize: cf.ModelType,
  847. QuantizationLevel: cf.FileType,
  848. },
  849. })
  850. }
  851. slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
  852. // most recently modified first
  853. return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
  854. })
  855. c.JSON(http.StatusOK, api.ListResponse{Models: models})
  856. }
  857. func (s *Server) CopyModelHandler(c *gin.Context) {
  858. var r api.CopyRequest
  859. if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
  860. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  861. return
  862. } else if err != nil {
  863. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  864. return
  865. }
  866. src := model.ParseName(r.Source)
  867. if !src.IsValid() {
  868. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
  869. return
  870. }
  871. dst := model.ParseName(r.Destination)
  872. if !dst.IsValid() {
  873. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)})
  874. return
  875. }
  876. if err := checkNameExists(dst); err != nil {
  877. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  878. return
  879. }
  880. if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
  881. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
  882. } else if err != nil {
  883. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  884. }
  885. }
  886. func (s *Server) HeadBlobHandler(c *gin.Context) {
  887. path, err := GetBlobsPath(c.Param("digest"))
  888. if err != nil {
  889. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  890. return
  891. }
  892. if _, err := os.Stat(path); err != nil {
  893. c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
  894. return
  895. }
  896. c.Status(http.StatusOK)
  897. }
  898. func (s *Server) CreateBlobHandler(c *gin.Context) {
  899. if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
  900. p, err := GetBlobsPath(ib)
  901. if err != nil {
  902. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  903. return
  904. }
  905. if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
  906. slog.Info("evicting intermediate blob which no longer exists", "digest", ib)
  907. delete(intermediateBlobs, c.Param("digest"))
  908. } else if err != nil {
  909. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  910. return
  911. } else {
  912. c.Status(http.StatusOK)
  913. return
  914. }
  915. }
  916. path, err := GetBlobsPath(c.Param("digest"))
  917. if err != nil {
  918. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  919. return
  920. }
  921. _, err = os.Stat(path)
  922. switch {
  923. case errors.Is(err, os.ErrNotExist):
  924. // noop
  925. case err != nil:
  926. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  927. return
  928. default:
  929. c.Status(http.StatusOK)
  930. return
  931. }
  932. layer, err := NewLayer(c.Request.Body, "")
  933. if err != nil {
  934. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  935. return
  936. }
  937. if layer.Digest != c.Param("digest") {
  938. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
  939. return
  940. }
  941. c.Status(http.StatusCreated)
  942. }
  943. func isLocalIP(ip netip.Addr) bool {
  944. if interfaces, err := net.Interfaces(); err == nil {
  945. for _, iface := range interfaces {
  946. addrs, err := iface.Addrs()
  947. if err != nil {
  948. continue
  949. }
  950. for _, a := range addrs {
  951. if parsed, _, err := net.ParseCIDR(a.String()); err == nil {
  952. if parsed.String() == ip.String() {
  953. return true
  954. }
  955. }
  956. }
  957. }
  958. }
  959. return false
  960. }
  961. func allowedHost(host string) bool {
  962. if host == "" || host == "localhost" {
  963. return true
  964. }
  965. if hostname, err := os.Hostname(); err == nil && host == hostname {
  966. return true
  967. }
  968. tlds := []string{
  969. "localhost",
  970. "local",
  971. "internal",
  972. }
  973. // check if the host is a local TLD
  974. for _, tld := range tlds {
  975. if strings.HasSuffix(host, "."+tld) {
  976. return true
  977. }
  978. }
  979. return false
  980. }
  981. func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
  982. return func(c *gin.Context) {
  983. if addr == nil {
  984. c.Next()
  985. return
  986. }
  987. if addr, err := netip.ParseAddrPort(addr.String()); err == nil && !addr.Addr().IsLoopback() {
  988. c.Next()
  989. return
  990. }
  991. host, _, err := net.SplitHostPort(c.Request.Host)
  992. if err != nil {
  993. host = c.Request.Host
  994. }
  995. if addr, err := netip.ParseAddr(host); err == nil {
  996. if addr.IsLoopback() || addr.IsPrivate() || addr.IsUnspecified() || isLocalIP(addr) {
  997. c.Next()
  998. return
  999. }
  1000. }
  1001. if allowedHost(host) {
  1002. if c.Request.Method == http.MethodOptions {
  1003. c.AbortWithStatus(http.StatusNoContent)
  1004. return
  1005. }
  1006. c.Next()
  1007. return
  1008. }
  1009. c.AbortWithStatus(http.StatusForbidden)
  1010. }
  1011. }
  1012. func (s *Server) GenerateRoutes() http.Handler {
  1013. config := cors.DefaultConfig()
  1014. config.AllowWildcard = true
  1015. config.AllowBrowserExtensions = true
  1016. config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
  1017. openAIProperties := []string{"lang", "package-version", "os", "arch", "runtime", "runtime-version", "async"}
  1018. for _, prop := range openAIProperties {
  1019. config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
  1020. }
  1021. config.AllowOrigins = envconfig.Origins()
  1022. r := gin.Default()
  1023. r.Use(
  1024. cors.New(config),
  1025. allowedHostsMiddleware(s.addr),
  1026. )
  1027. r.POST("/api/pull", s.PullModelHandler)
  1028. r.POST("/api/generate", s.GenerateHandler)
  1029. r.POST("/api/chat", s.ChatHandler)
  1030. r.POST("/api/embed", s.EmbedHandler)
  1031. r.POST("/api/embeddings", s.EmbeddingsHandler)
  1032. r.POST("/api/create", s.CreateModelHandler)
  1033. r.POST("/api/push", s.PushModelHandler)
  1034. r.POST("/api/copy", s.CopyModelHandler)
  1035. r.DELETE("/api/delete", s.DeleteModelHandler)
  1036. r.POST("/api/show", s.ShowModelHandler)
  1037. r.POST("/api/blobs/:digest", s.CreateBlobHandler)
  1038. r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
  1039. r.GET("/api/ps", s.ProcessHandler)
  1040. // Compatibility endpoints
  1041. r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
  1042. r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
  1043. r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
  1044. r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
  1045. r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
  1046. for _, method := range []string{http.MethodGet, http.MethodHead} {
  1047. r.Handle(method, "/", func(c *gin.Context) {
  1048. c.String(http.StatusOK, "Ollama is running")
  1049. })
  1050. r.Handle(method, "/api/models", s.GetModelsHandler)
  1051. r.Handle(method, "/api/models/*model", s.GetModelHandler)
  1052. r.Handle(method, "/api/tags", s.ListModelsHandler)
  1053. r.Handle(method, "/api/version", func(c *gin.Context) {
  1054. c.JSON(http.StatusOK, gin.H{"version": version.Version})
  1055. })
  1056. }
  1057. return r
  1058. }
  1059. func Serve(ln net.Listener) error {
  1060. level := slog.LevelInfo
  1061. if envconfig.Debug() {
  1062. level = slog.LevelDebug
  1063. }
  1064. slog.Info("server config", "env", envconfig.Values())
  1065. handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
  1066. Level: level,
  1067. AddSource: true,
  1068. ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
  1069. if attr.Key == slog.SourceKey {
  1070. source := attr.Value.Any().(*slog.Source)
  1071. source.File = filepath.Base(source.File)
  1072. }
  1073. return attr
  1074. },
  1075. })
  1076. slog.SetDefault(slog.New(handler))
  1077. blobsDir, err := GetBlobsPath("")
  1078. if err != nil {
  1079. return err
  1080. }
  1081. if err := fixBlobs(blobsDir); err != nil {
  1082. return err
  1083. }
  1084. if !envconfig.NoPrune() {
  1085. // clean up unused layers and manifests
  1086. if err := PruneLayers(); err != nil {
  1087. return err
  1088. }
  1089. manifestsPath, err := GetManifestPath()
  1090. if err != nil {
  1091. return err
  1092. }
  1093. if err := PruneDirectory(manifestsPath); err != nil {
  1094. return err
  1095. }
  1096. }
  1097. ctx, done := context.WithCancel(context.Background())
  1098. schedCtx, schedDone := context.WithCancel(ctx)
  1099. sched := InitScheduler(schedCtx)
  1100. s := &Server{addr: ln.Addr(), sched: sched}
  1101. http.Handle("/", s.GenerateRoutes())
  1102. slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
  1103. srvr := &http.Server{
  1104. // Use http.DefaultServeMux so we get net/http/pprof for
  1105. // free.
  1106. //
  1107. // TODO(bmizerany): Decide if we want to make this
  1108. // configurable so it is not exposed by default, or allow
  1109. // users to bind it to a different port. This was a quick
  1110. // and easy way to get pprof, but it may not be the best
  1111. // way.
  1112. Handler: nil,
  1113. }
  1114. // listen for a ctrl+c and stop any loaded llm
  1115. signals := make(chan os.Signal, 1)
  1116. signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
  1117. go func() {
  1118. <-signals
  1119. srvr.Close()
  1120. schedDone()
  1121. sched.unloadAllRunners()
  1122. gpu.Cleanup()
  1123. done()
  1124. }()
  1125. if err := llm.Init(); err != nil {
  1126. return fmt.Errorf("unable to initialize llm library %w", err)
  1127. }
  1128. s.sched.Run(schedCtx)
  1129. // At startup we retrieve GPU information so we can get log messages before loading a model
  1130. // This will log warnings to the log in case we have problems with detected GPUs
  1131. gpus := gpu.GetGPUInfo()
  1132. gpus.LogDetails()
  1133. err = srvr.Serve(ln)
  1134. // If server is closed from the signal handler, wait for the ctx to be done
  1135. // otherwise error out quickly
  1136. if !errors.Is(err, http.ErrServerClosed) {
  1137. return err
  1138. }
  1139. <-ctx.Done()
  1140. return nil
  1141. }
  1142. func waitForStream(c *gin.Context, ch chan interface{}) {
  1143. c.Header("Content-Type", "application/json")
  1144. for resp := range ch {
  1145. switch r := resp.(type) {
  1146. case api.ProgressResponse:
  1147. if r.Status == "success" {
  1148. c.JSON(http.StatusOK, r)
  1149. return
  1150. }
  1151. case gin.H:
  1152. status, ok := r["status"].(int)
  1153. if !ok {
  1154. status = http.StatusInternalServerError
  1155. }
  1156. if errorMsg, ok := r["error"].(string); ok {
  1157. c.JSON(status, gin.H{"error": errorMsg})
  1158. return
  1159. } else {
  1160. c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
  1161. return
  1162. }
  1163. default:
  1164. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"})
  1165. return
  1166. }
  1167. }
  1168. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"})
  1169. }
  1170. func streamResponse(c *gin.Context, ch chan any) {
  1171. c.Header("Content-Type", "application/x-ndjson")
  1172. c.Stream(func(w io.Writer) bool {
  1173. val, ok := <-ch
  1174. if !ok {
  1175. return false
  1176. }
  1177. bts, err := json.Marshal(val)
  1178. if err != nil {
  1179. slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
  1180. return false
  1181. }
  1182. // Delineate chunks with new-line delimiter
  1183. bts = append(bts, '\n')
  1184. if _, err := w.Write(bts); err != nil {
  1185. slog.Info(fmt.Sprintf("streamResponse: w.Write failed with %s", err))
  1186. return false
  1187. }
  1188. return true
  1189. })
  1190. }
  1191. func (s *Server) ProcessHandler(c *gin.Context) {
  1192. models := []api.ProcessModelResponse{}
  1193. for _, v := range s.sched.loaded {
  1194. model := v.model
  1195. modelDetails := api.ModelDetails{
  1196. Format: model.Config.ModelFormat,
  1197. Family: model.Config.ModelFamily,
  1198. Families: model.Config.ModelFamilies,
  1199. ParameterSize: model.Config.ModelType,
  1200. QuantizationLevel: model.Config.FileType,
  1201. }
  1202. mr := api.ProcessModelResponse{
  1203. Model: model.ShortName,
  1204. Name: model.ShortName,
  1205. Size: int64(v.estimatedTotal),
  1206. SizeVRAM: int64(v.estimatedVRAM),
  1207. Digest: model.Digest,
  1208. Details: modelDetails,
  1209. ExpiresAt: v.expiresAt,
  1210. }
  1211. // The scheduler waits to set expiresAt, so if a model is loading it's
  1212. // possible that it will be set to the unix epoch. For those cases, just
  1213. // calculate the time w/ the sessionDuration instead.
  1214. var epoch time.Time
  1215. if v.expiresAt == epoch {
  1216. mr.ExpiresAt = time.Now().Add(v.sessionDuration)
  1217. }
  1218. models = append(models, mr)
  1219. }
  1220. slices.SortStableFunc(models, func(i, j api.ProcessModelResponse) int {
  1221. // longest duration remaining listed first
  1222. return cmp.Compare(j.ExpiresAt.Unix(), i.ExpiresAt.Unix())
  1223. })
  1224. c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
  1225. }
  1226. func (s *Server) ChatHandler(c *gin.Context) {
  1227. checkpointStart := time.Now()
  1228. var req api.ChatRequest
  1229. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  1230. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  1231. return
  1232. } else if err != nil {
  1233. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  1234. return
  1235. }
  1236. caps := []Capability{CapabilityCompletion}
  1237. if len(req.Tools) > 0 {
  1238. caps = append(caps, CapabilityTools)
  1239. }
  1240. r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
  1241. if errors.Is(err, errCapabilityCompletion) {
  1242. c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
  1243. return
  1244. } else if err != nil {
  1245. handleScheduleError(c, req.Model, err)
  1246. return
  1247. }
  1248. checkpointLoaded := time.Now()
  1249. if len(req.Messages) == 0 {
  1250. c.JSON(http.StatusOK, api.ChatResponse{
  1251. Model: req.Model,
  1252. CreatedAt: time.Now().UTC(),
  1253. Message: api.Message{Role: "assistant"},
  1254. Done: true,
  1255. DoneReason: "load",
  1256. })
  1257. return
  1258. }
  1259. msgs := append(m.Messages, req.Messages...)
  1260. if req.Messages[0].Role != "system" && m.System != "" {
  1261. msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
  1262. }
  1263. prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
  1264. if err != nil {
  1265. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  1266. return
  1267. }
  1268. slog.Debug("chat request", "images", len(images), "prompt", prompt)
  1269. ch := make(chan any)
  1270. go func() {
  1271. defer close(ch)
  1272. if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
  1273. Prompt: prompt,
  1274. Images: images,
  1275. Format: req.Format,
  1276. Options: opts,
  1277. }, func(r llm.CompletionResponse) {
  1278. res := api.ChatResponse{
  1279. Model: req.Model,
  1280. CreatedAt: time.Now().UTC(),
  1281. Message: api.Message{Role: "assistant", Content: r.Content},
  1282. Done: r.Done,
  1283. DoneReason: r.DoneReason,
  1284. Metrics: api.Metrics{
  1285. PromptEvalCount: r.PromptEvalCount,
  1286. PromptEvalDuration: r.PromptEvalDuration,
  1287. EvalCount: r.EvalCount,
  1288. EvalDuration: r.EvalDuration,
  1289. },
  1290. }
  1291. if r.Done {
  1292. res.TotalDuration = time.Since(checkpointStart)
  1293. res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
  1294. }
  1295. ch <- res
  1296. }); err != nil {
  1297. ch <- gin.H{"error": err.Error()}
  1298. }
  1299. }()
  1300. if req.Stream != nil && !*req.Stream {
  1301. var resp api.ChatResponse
  1302. var sb strings.Builder
  1303. for rr := range ch {
  1304. switch t := rr.(type) {
  1305. case api.ChatResponse:
  1306. sb.WriteString(t.Message.Content)
  1307. resp = t
  1308. case gin.H:
  1309. msg, ok := t["error"].(string)
  1310. if !ok {
  1311. msg = "unexpected error format in response"
  1312. }
  1313. c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
  1314. return
  1315. default:
  1316. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
  1317. return
  1318. }
  1319. }
  1320. resp.Message.Content = sb.String()
  1321. if len(req.Tools) > 0 {
  1322. if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
  1323. resp.Message.ToolCalls = toolCalls
  1324. resp.Message.Content = ""
  1325. }
  1326. }
  1327. c.JSON(http.StatusOK, resp)
  1328. return
  1329. }
  1330. streamResponse(c, ch)
  1331. }
  1332. func handleScheduleError(c *gin.Context, name string, err error) {
  1333. switch {
  1334. case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
  1335. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  1336. case errors.Is(err, context.Canceled):
  1337. c.JSON(499, gin.H{"error": "request canceled"})
  1338. case errors.Is(err, ErrMaxQueue):
  1339. c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
  1340. case errors.Is(err, os.ErrNotExist):
  1341. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
  1342. default:
  1343. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  1344. }
  1345. }