routes.go 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398
  1. package server
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "io/fs"
  9. "log/slog"
  10. "math"
  11. "net"
  12. "net/http"
  13. "net/netip"
  14. "os"
  15. "os/signal"
  16. "path/filepath"
  17. "reflect"
  18. "runtime"
  19. "strconv"
  20. "strings"
  21. "sync"
  22. "syscall"
  23. "time"
  24. "github.com/gin-contrib/cors"
  25. "github.com/gin-gonic/gin"
  26. "golang.org/x/exp/slices"
  27. "github.com/ollama/ollama/api"
  28. "github.com/ollama/ollama/gpu"
  29. "github.com/ollama/ollama/llm"
  30. "github.com/ollama/ollama/openai"
  31. "github.com/ollama/ollama/parser"
  32. "github.com/ollama/ollama/version"
  33. )
  34. var mode string = gin.DebugMode
  35. type Server struct {
  36. addr net.Addr
  37. }
  38. func init() {
  39. switch mode {
  40. case gin.DebugMode:
  41. case gin.ReleaseMode:
  42. case gin.TestMode:
  43. default:
  44. mode = gin.DebugMode
  45. }
  46. gin.SetMode(mode)
  47. }
  48. var loaded struct {
  49. mu sync.Mutex
  50. runner llm.LLM
  51. expireAt time.Time
  52. expireTimer *time.Timer
  53. *Model
  54. *api.Options
  55. }
  56. var defaultSessionDuration = 5 * time.Minute
  57. // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
  58. func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
  59. needLoad := loaded.runner == nil || // is there a model loaded?
  60. loaded.ModelPath != model.ModelPath || // has the base model changed?
  61. !reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
  62. !reflect.DeepEqual(loaded.Options.Runner, opts.Runner) // have the runner options changed?
  63. if needLoad {
  64. if loaded.runner != nil {
  65. slog.Info("changing loaded model")
  66. loaded.runner.Close()
  67. loaded.runner = nil
  68. loaded.Model = nil
  69. loaded.Options = nil
  70. }
  71. llmRunner, err := llm.New(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
  72. if err != nil {
  73. // some older models are not compatible with newer versions of llama.cpp
  74. // show a generalized compatibility error until there is a better way to
  75. // check for model compatibility
  76. if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
  77. err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
  78. }
  79. return err
  80. }
  81. loaded.Model = model
  82. loaded.runner = llmRunner
  83. loaded.Options = &opts
  84. }
  85. loaded.expireAt = time.Now().Add(sessionDuration)
  86. if loaded.expireTimer == nil {
  87. loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
  88. loaded.mu.Lock()
  89. defer loaded.mu.Unlock()
  90. if time.Now().Before(loaded.expireAt) {
  91. return
  92. }
  93. if loaded.runner != nil {
  94. loaded.runner.Close()
  95. }
  96. loaded.runner = nil
  97. loaded.Model = nil
  98. loaded.Options = nil
  99. })
  100. }
  101. loaded.expireTimer.Reset(sessionDuration)
  102. return nil
  103. }
  104. func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
  105. opts := api.DefaultOptions()
  106. if err := opts.FromMap(model.Options); err != nil {
  107. return api.Options{}, err
  108. }
  109. if err := opts.FromMap(requestOpts); err != nil {
  110. return api.Options{}, err
  111. }
  112. return opts, nil
  113. }
  114. func isSupportedImageType(image []byte) bool {
  115. contentType := http.DetectContentType(image)
  116. allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
  117. return slices.Contains(allowedTypes, contentType)
  118. }
  119. func GenerateHandler(c *gin.Context) {
  120. loaded.mu.Lock()
  121. defer loaded.mu.Unlock()
  122. checkpointStart := time.Now()
  123. var req api.GenerateRequest
  124. err := c.ShouldBindJSON(&req)
  125. switch {
  126. case errors.Is(err, io.EOF):
  127. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  128. return
  129. case err != nil:
  130. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  131. return
  132. }
  133. // validate the request
  134. switch {
  135. case req.Model == "":
  136. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  137. return
  138. case len(req.Format) > 0 && req.Format != "json":
  139. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
  140. return
  141. case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
  142. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
  143. return
  144. }
  145. for _, img := range req.Images {
  146. if !isSupportedImageType(img) {
  147. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
  148. return
  149. }
  150. }
  151. model, err := GetModel(req.Model)
  152. if err != nil {
  153. var pErr *fs.PathError
  154. if errors.As(err, &pErr) {
  155. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
  156. return
  157. }
  158. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  159. return
  160. }
  161. if model.IsEmbedding() {
  162. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"})
  163. return
  164. }
  165. opts, err := modelOptions(model, req.Options)
  166. if err != nil {
  167. if errors.Is(err, api.ErrInvalidOpts) {
  168. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  169. return
  170. }
  171. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  172. return
  173. }
  174. var sessionDuration time.Duration
  175. if req.KeepAlive == nil {
  176. sessionDuration = getDefaultSessionDuration()
  177. } else {
  178. sessionDuration = req.KeepAlive.Duration
  179. }
  180. if err := load(c, model, opts, sessionDuration); err != nil {
  181. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  182. return
  183. }
  184. // an empty request loads the model
  185. // note: for a short while template was used in lieu
  186. // of `raw` mode so we need to check for it too
  187. if req.Prompt == "" && req.Template == "" && req.System == "" {
  188. c.JSON(http.StatusOK, api.GenerateResponse{
  189. CreatedAt: time.Now().UTC(),
  190. Model: req.Model,
  191. Done: true,
  192. })
  193. return
  194. }
  195. checkpointLoaded := time.Now()
  196. var prompt string
  197. switch {
  198. case req.Raw:
  199. prompt = req.Prompt
  200. case req.Prompt != "":
  201. if req.Template == "" {
  202. req.Template = model.Template
  203. }
  204. if req.System == "" {
  205. req.System = model.System
  206. }
  207. slog.Debug("generate handler", "prompt", req.Prompt)
  208. slog.Debug("generate handler", "template", req.Template)
  209. slog.Debug("generate handler", "system", req.System)
  210. var sb strings.Builder
  211. for i := range req.Images {
  212. fmt.Fprintf(&sb, "[img-%d] ", i)
  213. }
  214. sb.WriteString(req.Prompt)
  215. p, err := Prompt(req.Template, req.System, sb.String(), "", true)
  216. if err != nil {
  217. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  218. return
  219. }
  220. sb.Reset()
  221. if req.Context != nil {
  222. prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
  223. if err != nil {
  224. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  225. return
  226. }
  227. sb.WriteString(prev)
  228. }
  229. sb.WriteString(p)
  230. prompt = sb.String()
  231. }
  232. slog.Debug("generate handler", "prompt", prompt)
  233. ch := make(chan any)
  234. var generated strings.Builder
  235. go func() {
  236. defer close(ch)
  237. fn := func(r llm.PredictResult) {
  238. // Update model expiration
  239. loaded.expireAt = time.Now().Add(sessionDuration)
  240. loaded.expireTimer.Reset(sessionDuration)
  241. // Build up the full response
  242. if _, err := generated.WriteString(r.Content); err != nil {
  243. ch <- gin.H{"error": err.Error()}
  244. return
  245. }
  246. resp := api.GenerateResponse{
  247. Model: req.Model,
  248. CreatedAt: time.Now().UTC(),
  249. Done: r.Done,
  250. Response: r.Content,
  251. Metrics: api.Metrics{
  252. PromptEvalCount: r.PromptEvalCount,
  253. PromptEvalDuration: r.PromptEvalDuration,
  254. EvalCount: r.EvalCount,
  255. EvalDuration: r.EvalDuration,
  256. },
  257. }
  258. if r.Done {
  259. resp.TotalDuration = time.Since(checkpointStart)
  260. resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
  261. if !req.Raw {
  262. p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
  263. if err != nil {
  264. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  265. return
  266. }
  267. // TODO (jmorganca): encode() should not strip special tokens
  268. tokens, err := loaded.runner.Encode(c.Request.Context(), p)
  269. if err != nil {
  270. ch <- gin.H{"error": err.Error()}
  271. return
  272. }
  273. resp.Context = append(req.Context, tokens...)
  274. }
  275. }
  276. ch <- resp
  277. }
  278. var images []llm.ImageData
  279. for i := range req.Images {
  280. images = append(images, llm.ImageData{
  281. ID: i,
  282. Data: req.Images[i],
  283. })
  284. }
  285. // Start prediction
  286. predictReq := llm.PredictOpts{
  287. Prompt: prompt,
  288. Format: req.Format,
  289. Images: images,
  290. Options: opts,
  291. }
  292. if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
  293. ch <- gin.H{"error": err.Error()}
  294. }
  295. }()
  296. if req.Stream != nil && !*req.Stream {
  297. // Accumulate responses into the final response
  298. var final api.GenerateResponse
  299. var sb strings.Builder
  300. for resp := range ch {
  301. switch r := resp.(type) {
  302. case api.GenerateResponse:
  303. sb.WriteString(r.Response)
  304. final = r
  305. case gin.H:
  306. if errorMsg, ok := r["error"].(string); ok {
  307. c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
  308. return
  309. } else {
  310. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
  311. return
  312. }
  313. default:
  314. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
  315. return
  316. }
  317. }
  318. final.Response = sb.String()
  319. c.JSON(http.StatusOK, final)
  320. return
  321. }
  322. streamResponse(c, ch)
  323. }
  324. func getDefaultSessionDuration() time.Duration {
  325. if t, exists := os.LookupEnv("OLLAMA_KEEP_ALIVE"); exists {
  326. v, err := strconv.Atoi(t)
  327. if err != nil {
  328. d, err := time.ParseDuration(t)
  329. if err != nil {
  330. return defaultSessionDuration
  331. }
  332. if d < 0 {
  333. return time.Duration(math.MaxInt64)
  334. }
  335. return d
  336. }
  337. d := time.Duration(v) * time.Second
  338. if d < 0 {
  339. return time.Duration(math.MaxInt64)
  340. }
  341. return d
  342. }
  343. return defaultSessionDuration
  344. }
  345. func EmbeddingsHandler(c *gin.Context) {
  346. loaded.mu.Lock()
  347. defer loaded.mu.Unlock()
  348. var req api.EmbeddingRequest
  349. err := c.ShouldBindJSON(&req)
  350. switch {
  351. case errors.Is(err, io.EOF):
  352. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  353. return
  354. case err != nil:
  355. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  356. return
  357. }
  358. if req.Model == "" {
  359. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  360. return
  361. }
  362. model, err := GetModel(req.Model)
  363. if err != nil {
  364. var pErr *fs.PathError
  365. if errors.As(err, &pErr) {
  366. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
  367. return
  368. }
  369. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  370. return
  371. }
  372. opts, err := modelOptions(model, req.Options)
  373. if err != nil {
  374. if errors.Is(err, api.ErrInvalidOpts) {
  375. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  376. return
  377. }
  378. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  379. return
  380. }
  381. var sessionDuration time.Duration
  382. if req.KeepAlive == nil {
  383. sessionDuration = getDefaultSessionDuration()
  384. } else {
  385. sessionDuration = req.KeepAlive.Duration
  386. }
  387. if err := load(c, model, opts, sessionDuration); err != nil {
  388. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  389. return
  390. }
  391. // an empty request loads the model
  392. if req.Prompt == "" {
  393. c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
  394. return
  395. }
  396. embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt)
  397. if err != nil {
  398. slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
  399. c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
  400. return
  401. }
  402. resp := api.EmbeddingResponse{
  403. Embedding: embedding,
  404. }
  405. c.JSON(http.StatusOK, resp)
  406. }
  407. func PullModelHandler(c *gin.Context) {
  408. var req api.PullRequest
  409. err := c.ShouldBindJSON(&req)
  410. switch {
  411. case errors.Is(err, io.EOF):
  412. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  413. return
  414. case err != nil:
  415. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  416. return
  417. }
  418. var model string
  419. if req.Model != "" {
  420. model = req.Model
  421. } else if req.Name != "" {
  422. model = req.Name
  423. } else {
  424. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  425. return
  426. }
  427. ch := make(chan any)
  428. go func() {
  429. defer close(ch)
  430. fn := func(r api.ProgressResponse) {
  431. ch <- r
  432. }
  433. regOpts := &registryOptions{
  434. Insecure: req.Insecure,
  435. }
  436. ctx, cancel := context.WithCancel(c.Request.Context())
  437. defer cancel()
  438. if err := PullModel(ctx, model, regOpts, fn); err != nil {
  439. ch <- gin.H{"error": err.Error()}
  440. }
  441. }()
  442. if req.Stream != nil && !*req.Stream {
  443. waitForStream(c, ch)
  444. return
  445. }
  446. streamResponse(c, ch)
  447. }
  448. func PushModelHandler(c *gin.Context) {
  449. var req api.PushRequest
  450. err := c.ShouldBindJSON(&req)
  451. switch {
  452. case errors.Is(err, io.EOF):
  453. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  454. return
  455. case err != nil:
  456. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  457. return
  458. }
  459. var model string
  460. if req.Model != "" {
  461. model = req.Model
  462. } else if req.Name != "" {
  463. model = req.Name
  464. } else {
  465. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  466. return
  467. }
  468. ch := make(chan any)
  469. go func() {
  470. defer close(ch)
  471. fn := func(r api.ProgressResponse) {
  472. ch <- r
  473. }
  474. regOpts := &registryOptions{
  475. Insecure: req.Insecure,
  476. }
  477. ctx, cancel := context.WithCancel(c.Request.Context())
  478. defer cancel()
  479. if err := PushModel(ctx, model, regOpts, fn); err != nil {
  480. ch <- gin.H{"error": err.Error()}
  481. }
  482. }()
  483. if req.Stream != nil && !*req.Stream {
  484. waitForStream(c, ch)
  485. return
  486. }
  487. streamResponse(c, ch)
  488. }
  489. func CreateModelHandler(c *gin.Context) {
  490. var req api.CreateRequest
  491. err := c.ShouldBindJSON(&req)
  492. switch {
  493. case errors.Is(err, io.EOF):
  494. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  495. return
  496. case err != nil:
  497. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  498. return
  499. }
  500. var model string
  501. if req.Model != "" {
  502. model = req.Model
  503. } else if req.Name != "" {
  504. model = req.Name
  505. } else {
  506. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  507. return
  508. }
  509. if err := ParseModelPath(model).Validate(); err != nil {
  510. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  511. return
  512. }
  513. if req.Path == "" && req.Modelfile == "" {
  514. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
  515. return
  516. }
  517. var modelfile io.Reader = strings.NewReader(req.Modelfile)
  518. if req.Path != "" && req.Modelfile == "" {
  519. mf, err := os.Open(req.Path)
  520. if err != nil {
  521. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
  522. return
  523. }
  524. defer mf.Close()
  525. modelfile = mf
  526. }
  527. commands, err := parser.Parse(modelfile)
  528. if err != nil {
  529. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  530. return
  531. }
  532. ch := make(chan any)
  533. go func() {
  534. defer close(ch)
  535. fn := func(resp api.ProgressResponse) {
  536. ch <- resp
  537. }
  538. ctx, cancel := context.WithCancel(c.Request.Context())
  539. defer cancel()
  540. if err := CreateModel(ctx, model, filepath.Dir(req.Path), commands, fn); err != nil {
  541. ch <- gin.H{"error": err.Error()}
  542. }
  543. }()
  544. if req.Stream != nil && !*req.Stream {
  545. waitForStream(c, ch)
  546. return
  547. }
  548. streamResponse(c, ch)
  549. }
  550. func DeleteModelHandler(c *gin.Context) {
  551. var req api.DeleteRequest
  552. err := c.ShouldBindJSON(&req)
  553. switch {
  554. case errors.Is(err, io.EOF):
  555. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  556. return
  557. case err != nil:
  558. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  559. return
  560. }
  561. var model string
  562. if req.Model != "" {
  563. model = req.Model
  564. } else if req.Name != "" {
  565. model = req.Name
  566. } else {
  567. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  568. return
  569. }
  570. if err := DeleteModel(model); err != nil {
  571. if os.IsNotExist(err) {
  572. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)})
  573. } else {
  574. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  575. }
  576. return
  577. }
  578. manifestsPath, err := GetManifestPath()
  579. if err != nil {
  580. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  581. return
  582. }
  583. if err := PruneDirectory(manifestsPath); err != nil {
  584. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  585. return
  586. }
  587. c.JSON(http.StatusOK, nil)
  588. }
  589. func ShowModelHandler(c *gin.Context) {
  590. var req api.ShowRequest
  591. err := c.ShouldBindJSON(&req)
  592. switch {
  593. case errors.Is(err, io.EOF):
  594. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  595. return
  596. case err != nil:
  597. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  598. return
  599. }
  600. if req.Model != "" {
  601. // noop
  602. } else if req.Name != "" {
  603. req.Model = req.Name
  604. } else {
  605. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  606. return
  607. }
  608. resp, err := GetModelInfo(req)
  609. if err != nil {
  610. if os.IsNotExist(err) {
  611. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
  612. } else {
  613. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  614. }
  615. return
  616. }
  617. c.JSON(http.StatusOK, resp)
  618. }
  619. func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
  620. model, err := GetModel(req.Model)
  621. if err != nil {
  622. return nil, err
  623. }
  624. modelDetails := api.ModelDetails{
  625. ParentModel: model.ParentModel,
  626. Format: model.Config.ModelFormat,
  627. Family: model.Config.ModelFamily,
  628. Families: model.Config.ModelFamilies,
  629. ParameterSize: model.Config.ModelType,
  630. QuantizationLevel: model.Config.FileType,
  631. }
  632. if req.System != "" {
  633. model.System = req.System
  634. }
  635. if req.Template != "" {
  636. model.Template = req.Template
  637. }
  638. msgs := make([]api.Message, 0)
  639. for _, msg := range model.Messages {
  640. msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
  641. }
  642. resp := &api.ShowResponse{
  643. License: strings.Join(model.License, "\n"),
  644. System: model.System,
  645. Template: model.Template,
  646. Details: modelDetails,
  647. Messages: msgs,
  648. }
  649. var params []string
  650. cs := 30
  651. for k, v := range model.Options {
  652. switch val := v.(type) {
  653. case []interface{}:
  654. for _, nv := range val {
  655. params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
  656. }
  657. default:
  658. params = append(params, fmt.Sprintf("%-*s %#v", cs, k, v))
  659. }
  660. }
  661. resp.Parameters = strings.Join(params, "\n")
  662. for k, v := range req.Options {
  663. if _, ok := req.Options[k]; ok {
  664. model.Options[k] = v
  665. }
  666. }
  667. mf, err := ShowModelfile(model)
  668. if err != nil {
  669. return nil, err
  670. }
  671. resp.Modelfile = mf
  672. return resp, nil
  673. }
  674. func ListModelsHandler(c *gin.Context) {
  675. models := make([]api.ModelResponse, 0)
  676. manifestsPath, err := GetManifestPath()
  677. if err != nil {
  678. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  679. return
  680. }
  681. modelResponse := func(modelName string) (api.ModelResponse, error) {
  682. model, err := GetModel(modelName)
  683. if err != nil {
  684. return api.ModelResponse{}, err
  685. }
  686. modelDetails := api.ModelDetails{
  687. Format: model.Config.ModelFormat,
  688. Family: model.Config.ModelFamily,
  689. Families: model.Config.ModelFamilies,
  690. ParameterSize: model.Config.ModelType,
  691. QuantizationLevel: model.Config.FileType,
  692. }
  693. return api.ModelResponse{
  694. Model: model.ShortName,
  695. Name: model.ShortName,
  696. Size: model.Size,
  697. Digest: model.Digest,
  698. Details: modelDetails,
  699. }, nil
  700. }
  701. walkFunc := func(path string, info os.FileInfo, _ error) error {
  702. if !info.IsDir() {
  703. path, tag := filepath.Split(path)
  704. model := strings.Trim(strings.TrimPrefix(path, manifestsPath), string(os.PathSeparator))
  705. modelPath := strings.Join([]string{model, tag}, ":")
  706. canonicalModelPath := strings.ReplaceAll(modelPath, string(os.PathSeparator), "/")
  707. resp, err := modelResponse(canonicalModelPath)
  708. if err != nil {
  709. slog.Info(fmt.Sprintf("skipping file: %s", canonicalModelPath))
  710. // nolint: nilerr
  711. return nil
  712. }
  713. resp.ModifiedAt = info.ModTime()
  714. models = append(models, resp)
  715. }
  716. return nil
  717. }
  718. if err := filepath.Walk(manifestsPath, walkFunc); err != nil {
  719. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  720. return
  721. }
  722. c.JSON(http.StatusOK, api.ListResponse{Models: models})
  723. }
  724. func CopyModelHandler(c *gin.Context) {
  725. var req api.CopyRequest
  726. err := c.ShouldBindJSON(&req)
  727. switch {
  728. case errors.Is(err, io.EOF):
  729. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  730. return
  731. case err != nil:
  732. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  733. return
  734. }
  735. if req.Source == "" || req.Destination == "" {
  736. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"})
  737. return
  738. }
  739. if err := ParseModelPath(req.Destination).Validate(); err != nil {
  740. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  741. return
  742. }
  743. if err := CopyModel(req.Source, req.Destination); err != nil {
  744. if os.IsNotExist(err) {
  745. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})
  746. } else {
  747. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  748. }
  749. return
  750. }
  751. }
  752. func HeadBlobHandler(c *gin.Context) {
  753. path, err := GetBlobsPath(c.Param("digest"))
  754. if err != nil {
  755. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  756. return
  757. }
  758. if _, err := os.Stat(path); err != nil {
  759. c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
  760. return
  761. }
  762. c.Status(http.StatusOK)
  763. }
  764. func CreateBlobHandler(c *gin.Context) {
  765. layer, err := NewLayer(c.Request.Body, "")
  766. if err != nil {
  767. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  768. return
  769. }
  770. if layer.Digest != c.Param("digest") {
  771. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
  772. return
  773. }
  774. if _, err := layer.Commit(); err != nil {
  775. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  776. return
  777. }
  778. c.Status(http.StatusCreated)
  779. }
  780. var defaultAllowOrigins = []string{
  781. "localhost",
  782. "127.0.0.1",
  783. "0.0.0.0",
  784. }
  785. func isLocalIP(ip netip.Addr) bool {
  786. if interfaces, err := net.Interfaces(); err == nil {
  787. for _, iface := range interfaces {
  788. addrs, err := iface.Addrs()
  789. if err != nil {
  790. continue
  791. }
  792. for _, a := range addrs {
  793. if parsed, _, err := net.ParseCIDR(a.String()); err == nil {
  794. if parsed.String() == ip.String() {
  795. return true
  796. }
  797. }
  798. }
  799. }
  800. }
  801. return false
  802. }
  803. func allowedHost(host string) bool {
  804. if host == "" || host == "localhost" {
  805. return true
  806. }
  807. if hostname, err := os.Hostname(); err == nil && host == hostname {
  808. return true
  809. }
  810. var tlds = []string{
  811. "localhost",
  812. "local",
  813. "internal",
  814. }
  815. // check if the host is a local TLD
  816. for _, tld := range tlds {
  817. if strings.HasSuffix(host, "."+tld) {
  818. return true
  819. }
  820. }
  821. return false
  822. }
  823. func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
  824. return func(c *gin.Context) {
  825. if addr == nil {
  826. c.Next()
  827. return
  828. }
  829. if addr, err := netip.ParseAddrPort(addr.String()); err == nil && !addr.Addr().IsLoopback() {
  830. c.Next()
  831. return
  832. }
  833. host, _, err := net.SplitHostPort(c.Request.Host)
  834. if err != nil {
  835. host = c.Request.Host
  836. }
  837. if addr, err := netip.ParseAddr(host); err == nil {
  838. if addr.IsLoopback() || addr.IsPrivate() || addr.IsUnspecified() || isLocalIP(addr) {
  839. c.Next()
  840. return
  841. }
  842. }
  843. if allowedHost(host) {
  844. c.Next()
  845. return
  846. }
  847. c.AbortWithStatus(http.StatusForbidden)
  848. }
  849. }
  850. func (s *Server) GenerateRoutes() http.Handler {
  851. var origins []string
  852. if o := os.Getenv("OLLAMA_ORIGINS"); o != "" {
  853. origins = strings.Split(o, ",")
  854. }
  855. config := cors.DefaultConfig()
  856. config.AllowWildcard = true
  857. config.AllowBrowserExtensions = true
  858. config.AllowOrigins = origins
  859. for _, allowOrigin := range defaultAllowOrigins {
  860. config.AllowOrigins = append(config.AllowOrigins,
  861. fmt.Sprintf("http://%s", allowOrigin),
  862. fmt.Sprintf("https://%s", allowOrigin),
  863. fmt.Sprintf("http://%s:*", allowOrigin),
  864. fmt.Sprintf("https://%s:*", allowOrigin),
  865. )
  866. }
  867. r := gin.Default()
  868. r.Use(
  869. cors.New(config),
  870. allowedHostsMiddleware(s.addr),
  871. )
  872. r.POST("/api/pull", PullModelHandler)
  873. r.POST("/api/generate", GenerateHandler)
  874. r.POST("/api/chat", ChatHandler)
  875. r.POST("/api/embeddings", EmbeddingsHandler)
  876. r.POST("/api/create", CreateModelHandler)
  877. r.POST("/api/push", PushModelHandler)
  878. r.POST("/api/copy", CopyModelHandler)
  879. r.DELETE("/api/delete", DeleteModelHandler)
  880. r.POST("/api/show", ShowModelHandler)
  881. r.POST("/api/blobs/:digest", CreateBlobHandler)
  882. r.HEAD("/api/blobs/:digest", HeadBlobHandler)
  883. // Compatibility endpoints
  884. r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler)
  885. for _, method := range []string{http.MethodGet, http.MethodHead} {
  886. r.Handle(method, "/", func(c *gin.Context) {
  887. c.String(http.StatusOK, "Ollama is running")
  888. })
  889. r.Handle(method, "/api/tags", ListModelsHandler)
  890. r.Handle(method, "/api/version", func(c *gin.Context) {
  891. c.JSON(http.StatusOK, gin.H{"version": version.Version})
  892. })
  893. }
  894. return r
  895. }
  896. func Serve(ln net.Listener) error {
  897. level := slog.LevelInfo
  898. if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
  899. level = slog.LevelDebug
  900. }
  901. handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
  902. Level: level,
  903. AddSource: true,
  904. ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
  905. if attr.Key == slog.SourceKey {
  906. source := attr.Value.Any().(*slog.Source)
  907. source.File = filepath.Base(source.File)
  908. }
  909. return attr
  910. },
  911. })
  912. slog.SetDefault(slog.New(handler))
  913. blobsDir, err := GetBlobsPath("")
  914. if err != nil {
  915. return err
  916. }
  917. if err := fixBlobs(blobsDir); err != nil {
  918. return err
  919. }
  920. if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
  921. // clean up unused layers and manifests
  922. if err := PruneLayers(); err != nil {
  923. return err
  924. }
  925. manifestsPath, err := GetManifestPath()
  926. if err != nil {
  927. return err
  928. }
  929. if err := PruneDirectory(manifestsPath); err != nil {
  930. return err
  931. }
  932. }
  933. s := &Server{addr: ln.Addr()}
  934. r := s.GenerateRoutes()
  935. slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
  936. srvr := &http.Server{
  937. Handler: r,
  938. }
  939. // listen for a ctrl+c and stop any loaded llm
  940. signals := make(chan os.Signal, 1)
  941. signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
  942. go func() {
  943. <-signals
  944. if loaded.runner != nil {
  945. loaded.runner.Close()
  946. }
  947. gpu.Cleanup()
  948. os.Exit(0)
  949. }()
  950. if err := llm.Init(); err != nil {
  951. return fmt.Errorf("unable to initialize llm library %w", err)
  952. }
  953. if runtime.GOOS == "linux" { // TODO - windows too
  954. // check compatibility to log warnings
  955. if _, err := gpu.CheckVRAM(); err != nil {
  956. slog.Info(err.Error())
  957. }
  958. }
  959. return srvr.Serve(ln)
  960. }
  961. func waitForStream(c *gin.Context, ch chan interface{}) {
  962. c.Header("Content-Type", "application/json")
  963. for resp := range ch {
  964. switch r := resp.(type) {
  965. case api.ProgressResponse:
  966. if r.Status == "success" {
  967. c.JSON(http.StatusOK, r)
  968. return
  969. }
  970. case gin.H:
  971. if errorMsg, ok := r["error"].(string); ok {
  972. c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
  973. return
  974. } else {
  975. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
  976. return
  977. }
  978. default:
  979. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"})
  980. return
  981. }
  982. }
  983. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"})
  984. }
  985. func streamResponse(c *gin.Context, ch chan any) {
  986. c.Header("Content-Type", "application/x-ndjson")
  987. c.Stream(func(w io.Writer) bool {
  988. val, ok := <-ch
  989. if !ok {
  990. return false
  991. }
  992. bts, err := json.Marshal(val)
  993. if err != nil {
  994. slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
  995. return false
  996. }
  997. // Delineate chunks with new-line delimiter
  998. bts = append(bts, '\n')
  999. if _, err := w.Write(bts); err != nil {
  1000. slog.Info(fmt.Sprintf("streamResponse: w.Write failed with %s", err))
  1001. return false
  1002. }
  1003. return true
  1004. })
  1005. }
  1006. // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
  1007. func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
  1008. encode := func(s string) ([]int, error) {
  1009. return loaded.runner.Encode(ctx, s)
  1010. }
  1011. prompt, err := ChatPrompt(template, messages, numCtx, encode)
  1012. if err != nil {
  1013. return "", err
  1014. }
  1015. return prompt, nil
  1016. }
  1017. func ChatHandler(c *gin.Context) {
  1018. loaded.mu.Lock()
  1019. defer loaded.mu.Unlock()
  1020. checkpointStart := time.Now()
  1021. var req api.ChatRequest
  1022. err := c.ShouldBindJSON(&req)
  1023. switch {
  1024. case errors.Is(err, io.EOF):
  1025. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  1026. return
  1027. case err != nil:
  1028. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  1029. return
  1030. }
  1031. // validate the request
  1032. switch {
  1033. case req.Model == "":
  1034. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  1035. return
  1036. case len(req.Format) > 0 && req.Format != "json":
  1037. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
  1038. return
  1039. }
  1040. model, err := GetModel(req.Model)
  1041. if err != nil {
  1042. var pErr *fs.PathError
  1043. if errors.As(err, &pErr) {
  1044. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
  1045. return
  1046. }
  1047. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  1048. return
  1049. }
  1050. if model.IsEmbedding() {
  1051. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"})
  1052. return
  1053. }
  1054. opts, err := modelOptions(model, req.Options)
  1055. if err != nil {
  1056. if errors.Is(err, api.ErrInvalidOpts) {
  1057. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  1058. return
  1059. }
  1060. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  1061. return
  1062. }
  1063. var sessionDuration time.Duration
  1064. if req.KeepAlive == nil {
  1065. sessionDuration = getDefaultSessionDuration()
  1066. } else {
  1067. sessionDuration = req.KeepAlive.Duration
  1068. }
  1069. if err := load(c, model, opts, sessionDuration); err != nil {
  1070. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  1071. return
  1072. }
  1073. checkpointLoaded := time.Now()
  1074. // if the first message is not a system message, then add the model's default system message
  1075. if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
  1076. req.Messages = append([]api.Message{
  1077. {
  1078. Role: "system",
  1079. Content: model.System,
  1080. },
  1081. }, req.Messages...)
  1082. }
  1083. prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
  1084. if err != nil {
  1085. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  1086. return
  1087. }
  1088. // an empty request loads the model
  1089. if len(req.Messages) == 0 || prompt == "" {
  1090. resp := api.ChatResponse{
  1091. CreatedAt: time.Now().UTC(),
  1092. Model: req.Model,
  1093. Done: true,
  1094. Message: api.Message{Role: "assistant"},
  1095. }
  1096. c.JSON(http.StatusOK, resp)
  1097. return
  1098. }
  1099. // only send images that are in the prompt
  1100. var i int
  1101. var images []llm.ImageData
  1102. for _, m := range req.Messages {
  1103. for _, img := range m.Images {
  1104. if !isSupportedImageType(img) {
  1105. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
  1106. return
  1107. }
  1108. if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
  1109. images = append(images, llm.ImageData{Data: img, ID: i})
  1110. }
  1111. i += 1
  1112. }
  1113. }
  1114. slog.Debug("chat handler", "prompt", prompt, "images", len(images))
  1115. ch := make(chan any)
  1116. go func() {
  1117. defer close(ch)
  1118. fn := func(r llm.PredictResult) {
  1119. // Update model expiration
  1120. loaded.expireAt = time.Now().Add(sessionDuration)
  1121. loaded.expireTimer.Reset(sessionDuration)
  1122. resp := api.ChatResponse{
  1123. Model: req.Model,
  1124. CreatedAt: time.Now().UTC(),
  1125. Message: api.Message{Role: "assistant", Content: r.Content},
  1126. Done: r.Done,
  1127. Metrics: api.Metrics{
  1128. PromptEvalCount: r.PromptEvalCount,
  1129. PromptEvalDuration: r.PromptEvalDuration,
  1130. EvalCount: r.EvalCount,
  1131. EvalDuration: r.EvalDuration,
  1132. },
  1133. }
  1134. if r.Done {
  1135. resp.TotalDuration = time.Since(checkpointStart)
  1136. resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
  1137. }
  1138. ch <- resp
  1139. }
  1140. // Start prediction
  1141. predictReq := llm.PredictOpts{
  1142. Prompt: prompt,
  1143. Format: req.Format,
  1144. Images: images,
  1145. Options: opts,
  1146. }
  1147. if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
  1148. ch <- gin.H{"error": err.Error()}
  1149. }
  1150. }()
  1151. if req.Stream != nil && !*req.Stream {
  1152. // Accumulate responses into the final response
  1153. var final api.ChatResponse
  1154. var sb strings.Builder
  1155. for resp := range ch {
  1156. switch r := resp.(type) {
  1157. case api.ChatResponse:
  1158. sb.WriteString(r.Message.Content)
  1159. final = r
  1160. case gin.H:
  1161. if errorMsg, ok := r["error"].(string); ok {
  1162. c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
  1163. return
  1164. } else {
  1165. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
  1166. return
  1167. }
  1168. default:
  1169. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
  1170. return
  1171. }
  1172. }
  1173. final.Message = api.Message{Role: "assistant", Content: sb.String()}
  1174. c.JSON(http.StatusOK, final)
  1175. return
  1176. }
  1177. streamResponse(c, ch)
  1178. }