routes.go 27 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099
  1. package server
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "io/fs"
  9. "log"
  10. "net"
  11. "net/http"
  12. "os"
  13. "os/signal"
  14. "path/filepath"
  15. "reflect"
  16. "runtime"
  17. "strconv"
  18. "strings"
  19. "sync"
  20. "syscall"
  21. "time"
  22. "github.com/gin-contrib/cors"
  23. "github.com/gin-gonic/gin"
  24. "github.com/jmorganca/ollama/api"
  25. "github.com/jmorganca/ollama/llm"
  26. "github.com/jmorganca/ollama/parser"
  27. "github.com/jmorganca/ollama/version"
  28. )
  29. var mode string = gin.DebugMode
  30. type Server struct {
  31. WorkDir string
  32. }
  33. func init() {
  34. switch mode {
  35. case gin.DebugMode:
  36. case gin.ReleaseMode:
  37. case gin.TestMode:
  38. default:
  39. mode = gin.DebugMode
  40. }
  41. gin.SetMode(mode)
  42. }
  43. var loaded struct {
  44. mu sync.Mutex
  45. runner llm.LLM
  46. expireAt time.Time
  47. expireTimer *time.Timer
  48. *Model
  49. *api.Options
  50. }
  51. var defaultSessionDuration = 5 * time.Minute
  52. // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
  53. func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) {
  54. model, err := GetModel(modelName)
  55. if err != nil {
  56. return nil, err
  57. }
  58. workDir := c.GetString("workDir")
  59. opts := api.DefaultOptions()
  60. if err := opts.FromMap(model.Options); err != nil {
  61. log.Printf("could not load model options: %v", err)
  62. return nil, err
  63. }
  64. if err := opts.FromMap(reqOpts); err != nil {
  65. return nil, err
  66. }
  67. ctx := c.Request.Context()
  68. // check if the loaded model is still running in a subprocess, in case something unexpected happened
  69. if loaded.runner != nil {
  70. if err := loaded.runner.Ping(ctx); err != nil {
  71. log.Print("loaded llm process not responding, closing now")
  72. // the subprocess is no longer running, so close it
  73. loaded.runner.Close()
  74. loaded.runner = nil
  75. loaded.Model = nil
  76. loaded.Options = nil
  77. }
  78. }
  79. needLoad := loaded.runner == nil || // is there a model loaded?
  80. loaded.ModelPath != model.ModelPath || // has the base model changed?
  81. !reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
  82. !reflect.DeepEqual(loaded.Options.Runner, opts.Runner) // have the runner options changed?
  83. if needLoad {
  84. if loaded.runner != nil {
  85. log.Println("changing loaded model")
  86. loaded.runner.Close()
  87. loaded.runner = nil
  88. loaded.Model = nil
  89. loaded.Options = nil
  90. }
  91. llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
  92. if err != nil {
  93. // some older models are not compatible with newer versions of llama.cpp
  94. // show a generalized compatibility error until there is a better way to
  95. // check for model compatibility
  96. if strings.Contains(err.Error(), "failed to load model") {
  97. err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
  98. }
  99. return nil, err
  100. }
  101. loaded.Model = model
  102. loaded.runner = llmRunner
  103. loaded.Options = &opts
  104. }
  105. // update options for the loaded llm
  106. // TODO(mxyng): this isn't thread safe, but it should be fine for now
  107. loaded.runner.SetOptions(opts)
  108. loaded.expireAt = time.Now().Add(sessionDuration)
  109. if loaded.expireTimer == nil {
  110. loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
  111. loaded.mu.Lock()
  112. defer loaded.mu.Unlock()
  113. if time.Now().Before(loaded.expireAt) {
  114. return
  115. }
  116. if loaded.runner != nil {
  117. loaded.runner.Close()
  118. }
  119. loaded.runner = nil
  120. loaded.Model = nil
  121. loaded.Options = nil
  122. })
  123. }
  124. loaded.expireTimer.Reset(sessionDuration)
  125. return model, nil
  126. }
  127. func GenerateHandler(c *gin.Context) {
  128. loaded.mu.Lock()
  129. defer loaded.mu.Unlock()
  130. checkpointStart := time.Now()
  131. var req api.GenerateRequest
  132. err := c.ShouldBindJSON(&req)
  133. switch {
  134. case errors.Is(err, io.EOF):
  135. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  136. return
  137. case err != nil:
  138. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  139. return
  140. }
  141. // validate the request
  142. switch {
  143. case req.Model == "":
  144. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  145. return
  146. case len(req.Format) > 0 && req.Format != "json":
  147. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
  148. return
  149. case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
  150. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
  151. return
  152. }
  153. sessionDuration := defaultSessionDuration
  154. model, err := load(c, req.Model, req.Options, sessionDuration)
  155. if err != nil {
  156. var pErr *fs.PathError
  157. switch {
  158. case errors.As(err, &pErr):
  159. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
  160. case errors.Is(err, api.ErrInvalidOpts):
  161. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  162. default:
  163. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  164. }
  165. return
  166. }
  167. // an empty request loads the model
  168. if req.Prompt == "" && req.Template == "" && req.System == "" {
  169. c.JSON(http.StatusOK, api.GenerateResponse{
  170. CreatedAt: time.Now().UTC(),
  171. Model: req.Model,
  172. Done: true})
  173. return
  174. }
  175. checkpointLoaded := time.Now()
  176. var prompt string
  177. switch {
  178. case req.Raw:
  179. prompt = req.Prompt
  180. case req.Prompt != "":
  181. if req.Template != "" {
  182. // override the default model template
  183. model.Template = req.Template
  184. }
  185. var rebuild strings.Builder
  186. if req.Context != nil {
  187. // TODO: context is deprecated, at some point the context logic within this conditional should be removed
  188. prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
  189. if err != nil {
  190. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  191. return
  192. }
  193. // Remove leading spaces from prevCtx if present
  194. prevCtx = strings.TrimPrefix(prevCtx, " ")
  195. rebuild.WriteString(prevCtx)
  196. }
  197. p, err := model.Prompt(PromptVars{
  198. System: req.System,
  199. Prompt: req.Prompt,
  200. First: len(req.Context) == 0,
  201. })
  202. if err != nil {
  203. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  204. return
  205. }
  206. rebuild.WriteString(p)
  207. prompt = rebuild.String()
  208. }
  209. ch := make(chan any)
  210. var generated strings.Builder
  211. go func() {
  212. defer close(ch)
  213. fn := func(r llm.PredictResult) {
  214. // Update model expiration
  215. loaded.expireAt = time.Now().Add(sessionDuration)
  216. loaded.expireTimer.Reset(sessionDuration)
  217. // Build up the full response
  218. if _, err := generated.WriteString(r.Content); err != nil {
  219. ch <- gin.H{"error": err.Error()}
  220. return
  221. }
  222. resp := api.GenerateResponse{
  223. Model: req.Model,
  224. CreatedAt: time.Now().UTC(),
  225. Done: r.Done,
  226. Response: r.Content,
  227. Metrics: api.Metrics{
  228. PromptEvalCount: r.PromptEvalCount,
  229. PromptEvalDuration: r.PromptEvalDuration,
  230. EvalCount: r.EvalCount,
  231. EvalDuration: r.EvalDuration,
  232. },
  233. }
  234. if r.Done {
  235. resp.TotalDuration = time.Since(checkpointStart)
  236. resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
  237. if !req.Raw {
  238. embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
  239. if err != nil {
  240. ch <- gin.H{"error": err.Error()}
  241. return
  242. }
  243. resp.Context = embd
  244. }
  245. }
  246. ch <- resp
  247. }
  248. // Start prediction
  249. predictReq := llm.PredictOpts{
  250. Prompt: prompt,
  251. Format: req.Format,
  252. Images: req.Images,
  253. }
  254. if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
  255. ch <- gin.H{"error": err.Error()}
  256. }
  257. }()
  258. if req.Stream != nil && !*req.Stream {
  259. // Accumulate responses into the final response
  260. var final api.GenerateResponse
  261. var sb strings.Builder
  262. for resp := range ch {
  263. switch r := resp.(type) {
  264. case api.GenerateResponse:
  265. sb.WriteString(r.Response)
  266. final = r
  267. case gin.H:
  268. if errorMsg, ok := r["error"].(string); ok {
  269. c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
  270. return
  271. } else {
  272. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
  273. return
  274. }
  275. default:
  276. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
  277. return
  278. }
  279. }
  280. final.Response = sb.String()
  281. c.JSON(http.StatusOK, final)
  282. return
  283. }
  284. streamResponse(c, ch)
  285. }
  286. func EmbeddingHandler(c *gin.Context) {
  287. loaded.mu.Lock()
  288. defer loaded.mu.Unlock()
  289. var req api.EmbeddingRequest
  290. err := c.ShouldBindJSON(&req)
  291. switch {
  292. case errors.Is(err, io.EOF):
  293. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  294. return
  295. case err != nil:
  296. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  297. return
  298. }
  299. if req.Model == "" {
  300. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  301. return
  302. }
  303. sessionDuration := defaultSessionDuration
  304. _, err = load(c, req.Model, req.Options, sessionDuration)
  305. if err != nil {
  306. var pErr *fs.PathError
  307. switch {
  308. case errors.As(err, &pErr):
  309. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
  310. case errors.Is(err, api.ErrInvalidOpts):
  311. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  312. default:
  313. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  314. }
  315. return
  316. }
  317. if !loaded.Options.EmbeddingOnly {
  318. c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"})
  319. return
  320. }
  321. embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt)
  322. if err != nil {
  323. log.Printf("embedding generation failed: %v", err)
  324. c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
  325. return
  326. }
  327. resp := api.EmbeddingResponse{
  328. Embedding: embedding,
  329. }
  330. c.JSON(http.StatusOK, resp)
  331. }
  332. func PullModelHandler(c *gin.Context) {
  333. var req api.PullRequest
  334. err := c.ShouldBindJSON(&req)
  335. switch {
  336. case errors.Is(err, io.EOF):
  337. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  338. return
  339. case err != nil:
  340. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  341. return
  342. }
  343. if req.Name == "" {
  344. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
  345. return
  346. }
  347. ch := make(chan any)
  348. go func() {
  349. defer close(ch)
  350. fn := func(r api.ProgressResponse) {
  351. ch <- r
  352. }
  353. regOpts := &RegistryOptions{
  354. Insecure: req.Insecure,
  355. }
  356. ctx, cancel := context.WithCancel(c.Request.Context())
  357. defer cancel()
  358. if err := PullModel(ctx, req.Name, regOpts, fn); err != nil {
  359. ch <- gin.H{"error": err.Error()}
  360. }
  361. }()
  362. if req.Stream != nil && !*req.Stream {
  363. waitForStream(c, ch)
  364. return
  365. }
  366. streamResponse(c, ch)
  367. }
  368. func PushModelHandler(c *gin.Context) {
  369. var req api.PushRequest
  370. err := c.ShouldBindJSON(&req)
  371. switch {
  372. case errors.Is(err, io.EOF):
  373. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  374. return
  375. case err != nil:
  376. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  377. return
  378. }
  379. if req.Name == "" {
  380. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
  381. return
  382. }
  383. ch := make(chan any)
  384. go func() {
  385. defer close(ch)
  386. fn := func(r api.ProgressResponse) {
  387. ch <- r
  388. }
  389. regOpts := &RegistryOptions{
  390. Insecure: req.Insecure,
  391. }
  392. ctx, cancel := context.WithCancel(c.Request.Context())
  393. defer cancel()
  394. if err := PushModel(ctx, req.Name, regOpts, fn); err != nil {
  395. ch <- gin.H{"error": err.Error()}
  396. }
  397. }()
  398. if req.Stream != nil && !*req.Stream {
  399. waitForStream(c, ch)
  400. return
  401. }
  402. streamResponse(c, ch)
  403. }
  404. func CreateModelHandler(c *gin.Context) {
  405. var req api.CreateRequest
  406. err := c.ShouldBindJSON(&req)
  407. switch {
  408. case errors.Is(err, io.EOF):
  409. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  410. return
  411. case err != nil:
  412. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  413. return
  414. }
  415. if req.Name == "" {
  416. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
  417. return
  418. }
  419. if err := ParseModelPath(req.Name).Validate(); err != nil {
  420. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  421. return
  422. }
  423. if req.Path == "" && req.Modelfile == "" {
  424. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
  425. return
  426. }
  427. var modelfile io.Reader = strings.NewReader(req.Modelfile)
  428. if req.Path != "" && req.Modelfile == "" {
  429. mf, err := os.Open(req.Path)
  430. if err != nil {
  431. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
  432. return
  433. }
  434. defer mf.Close()
  435. modelfile = mf
  436. }
  437. commands, err := parser.Parse(modelfile)
  438. if err != nil {
  439. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  440. return
  441. }
  442. ch := make(chan any)
  443. go func() {
  444. defer close(ch)
  445. fn := func(resp api.ProgressResponse) {
  446. ch <- resp
  447. }
  448. ctx, cancel := context.WithCancel(c.Request.Context())
  449. defer cancel()
  450. if err := CreateModel(ctx, req.Name, filepath.Dir(req.Path), commands, fn); err != nil {
  451. ch <- gin.H{"error": err.Error()}
  452. }
  453. }()
  454. if req.Stream != nil && !*req.Stream {
  455. waitForStream(c, ch)
  456. return
  457. }
  458. streamResponse(c, ch)
  459. }
  460. func DeleteModelHandler(c *gin.Context) {
  461. var req api.DeleteRequest
  462. err := c.ShouldBindJSON(&req)
  463. switch {
  464. case errors.Is(err, io.EOF):
  465. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  466. return
  467. case err != nil:
  468. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  469. return
  470. }
  471. if req.Name == "" {
  472. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
  473. return
  474. }
  475. if err := DeleteModel(req.Name); err != nil {
  476. if os.IsNotExist(err) {
  477. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)})
  478. } else {
  479. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  480. }
  481. return
  482. }
  483. manifestsPath, err := GetManifestPath()
  484. if err != nil {
  485. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  486. return
  487. }
  488. if err := PruneDirectory(manifestsPath); err != nil {
  489. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  490. return
  491. }
  492. c.JSON(http.StatusOK, nil)
  493. }
  494. func ShowModelHandler(c *gin.Context) {
  495. var req api.ShowRequest
  496. err := c.ShouldBindJSON(&req)
  497. switch {
  498. case errors.Is(err, io.EOF):
  499. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  500. return
  501. case err != nil:
  502. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  503. return
  504. }
  505. if req.Name == "" {
  506. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
  507. return
  508. }
  509. resp, err := GetModelInfo(req.Name)
  510. if err != nil {
  511. if os.IsNotExist(err) {
  512. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)})
  513. } else {
  514. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  515. }
  516. return
  517. }
  518. c.JSON(http.StatusOK, resp)
  519. }
  520. func GetModelInfo(name string) (*api.ShowResponse, error) {
  521. model, err := GetModel(name)
  522. if err != nil {
  523. return nil, err
  524. }
  525. modelDetails := api.ModelDetails{
  526. Format: model.Config.ModelFormat,
  527. Family: model.Config.ModelFamily,
  528. Families: model.Config.ModelFamilies,
  529. ParameterSize: model.Config.ModelType,
  530. QuantizationLevel: model.Config.FileType,
  531. }
  532. resp := &api.ShowResponse{
  533. License: strings.Join(model.License, "\n"),
  534. System: model.System,
  535. Template: model.Template,
  536. Details: modelDetails,
  537. }
  538. mf, err := ShowModelfile(model)
  539. if err != nil {
  540. return nil, err
  541. }
  542. resp.Modelfile = mf
  543. var params []string
  544. cs := 30
  545. for k, v := range model.Options {
  546. switch val := v.(type) {
  547. case string:
  548. params = append(params, fmt.Sprintf("%-*s %s", cs, k, val))
  549. case int:
  550. params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.Itoa(val)))
  551. case float64:
  552. params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatFloat(val, 'f', 0, 64)))
  553. case bool:
  554. params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatBool(val)))
  555. case []interface{}:
  556. for _, nv := range val {
  557. switch nval := nv.(type) {
  558. case string:
  559. params = append(params, fmt.Sprintf("%-*s %s", cs, k, nval))
  560. case int:
  561. params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.Itoa(nval)))
  562. case float64:
  563. params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatFloat(nval, 'f', 0, 64)))
  564. case bool:
  565. params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatBool(nval)))
  566. }
  567. }
  568. }
  569. }
  570. resp.Parameters = strings.Join(params, "\n")
  571. return resp, nil
  572. }
  573. func ListModelsHandler(c *gin.Context) {
  574. models := make([]api.ModelResponse, 0)
  575. fp, err := GetManifestPath()
  576. if err != nil {
  577. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  578. return
  579. }
  580. modelResponse := func(modelName string) (api.ModelResponse, error) {
  581. model, err := GetModel(modelName)
  582. if err != nil {
  583. return api.ModelResponse{}, err
  584. }
  585. modelDetails := api.ModelDetails{
  586. Format: model.Config.ModelFormat,
  587. Family: model.Config.ModelFamily,
  588. Families: model.Config.ModelFamilies,
  589. ParameterSize: model.Config.ModelType,
  590. QuantizationLevel: model.Config.FileType,
  591. }
  592. return api.ModelResponse{
  593. Name: model.ShortName,
  594. Size: model.Size,
  595. Digest: model.Digest,
  596. Details: modelDetails,
  597. }, nil
  598. }
  599. walkFunc := func(path string, info os.FileInfo, _ error) error {
  600. if !info.IsDir() {
  601. dir, file := filepath.Split(path)
  602. dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
  603. tag := strings.Join([]string{dir, file}, ":")
  604. resp, err := modelResponse(tag)
  605. if err != nil {
  606. log.Printf("skipping file: %s", fp)
  607. return nil
  608. }
  609. resp.ModifiedAt = info.ModTime()
  610. models = append(models, resp)
  611. }
  612. return nil
  613. }
  614. if err := filepath.Walk(fp, walkFunc); err != nil {
  615. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  616. return
  617. }
  618. c.JSON(http.StatusOK, api.ListResponse{Models: models})
  619. }
  620. func CopyModelHandler(c *gin.Context) {
  621. var req api.CopyRequest
  622. err := c.ShouldBindJSON(&req)
  623. switch {
  624. case errors.Is(err, io.EOF):
  625. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  626. return
  627. case err != nil:
  628. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  629. return
  630. }
  631. if req.Source == "" || req.Destination == "" {
  632. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"})
  633. return
  634. }
  635. if err := ParseModelPath(req.Destination).Validate(); err != nil {
  636. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  637. return
  638. }
  639. if err := CopyModel(req.Source, req.Destination); err != nil {
  640. if os.IsNotExist(err) {
  641. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})
  642. } else {
  643. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  644. }
  645. return
  646. }
  647. }
  648. func HeadBlobHandler(c *gin.Context) {
  649. path, err := GetBlobsPath(c.Param("digest"))
  650. if err != nil {
  651. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  652. return
  653. }
  654. if _, err := os.Stat(path); err != nil {
  655. c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
  656. return
  657. }
  658. c.Status(http.StatusOK)
  659. }
  660. func CreateBlobHandler(c *gin.Context) {
  661. layer, err := NewLayer(c.Request.Body, "")
  662. if err != nil {
  663. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  664. return
  665. }
  666. if layer.Digest != c.Param("digest") {
  667. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
  668. return
  669. }
  670. if _, err := layer.Commit(); err != nil {
  671. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  672. return
  673. }
  674. c.Status(http.StatusCreated)
  675. }
  676. var defaultAllowOrigins = []string{
  677. "localhost",
  678. "127.0.0.1",
  679. "0.0.0.0",
  680. }
  681. func NewServer() (*Server, error) {
  682. workDir, err := os.MkdirTemp("", "ollama")
  683. if err != nil {
  684. return nil, err
  685. }
  686. return &Server{
  687. WorkDir: workDir,
  688. }, nil
  689. }
  690. func (s *Server) GenerateRoutes() http.Handler {
  691. var origins []string
  692. if o := os.Getenv("OLLAMA_ORIGINS"); o != "" {
  693. origins = strings.Split(o, ",")
  694. }
  695. config := cors.DefaultConfig()
  696. config.AllowWildcard = true
  697. config.AllowOrigins = origins
  698. for _, allowOrigin := range defaultAllowOrigins {
  699. config.AllowOrigins = append(config.AllowOrigins,
  700. fmt.Sprintf("http://%s", allowOrigin),
  701. fmt.Sprintf("https://%s", allowOrigin),
  702. fmt.Sprintf("http://%s:*", allowOrigin),
  703. fmt.Sprintf("https://%s:*", allowOrigin),
  704. )
  705. }
  706. r := gin.Default()
  707. r.Use(
  708. cors.New(config),
  709. func(c *gin.Context) {
  710. c.Set("workDir", s.WorkDir)
  711. c.Next()
  712. },
  713. )
  714. r.POST("/api/pull", PullModelHandler)
  715. r.POST("/api/generate", GenerateHandler)
  716. r.POST("/api/chat", ChatHandler)
  717. r.POST("/api/embeddings", EmbeddingHandler)
  718. r.POST("/api/create", CreateModelHandler)
  719. r.POST("/api/push", PushModelHandler)
  720. r.POST("/api/copy", CopyModelHandler)
  721. r.DELETE("/api/delete", DeleteModelHandler)
  722. r.POST("/api/show", ShowModelHandler)
  723. r.POST("/api/blobs/:digest", CreateBlobHandler)
  724. r.HEAD("/api/blobs/:digest", HeadBlobHandler)
  725. for _, method := range []string{http.MethodGet, http.MethodHead} {
  726. r.Handle(method, "/", func(c *gin.Context) {
  727. c.String(http.StatusOK, "Ollama is running")
  728. })
  729. r.Handle(method, "/api/tags", ListModelsHandler)
  730. r.Handle(method, "/api/version", func(c *gin.Context) {
  731. c.JSON(http.StatusOK, gin.H{"version": version.Version})
  732. })
  733. }
  734. return r
  735. }
  736. func Serve(ln net.Listener) error {
  737. if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
  738. // clean up unused layers and manifests
  739. if err := PruneLayers(); err != nil {
  740. return err
  741. }
  742. manifestsPath, err := GetManifestPath()
  743. if err != nil {
  744. return err
  745. }
  746. if err := PruneDirectory(manifestsPath); err != nil {
  747. return err
  748. }
  749. }
  750. s, err := NewServer()
  751. if err != nil {
  752. return err
  753. }
  754. r := s.GenerateRoutes()
  755. log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
  756. srvr := &http.Server{
  757. Handler: r,
  758. }
  759. // listen for a ctrl+c and stop any loaded llm
  760. signals := make(chan os.Signal, 1)
  761. signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
  762. go func() {
  763. <-signals
  764. if loaded.runner != nil {
  765. loaded.runner.Close()
  766. }
  767. os.RemoveAll(s.WorkDir)
  768. os.Exit(0)
  769. }()
  770. if runtime.GOOS == "linux" {
  771. // check compatibility to log warnings
  772. if _, err := llm.CheckVRAM(); err != nil {
  773. log.Print(err.Error())
  774. }
  775. }
  776. return srvr.Serve(ln)
  777. }
  778. func waitForStream(c *gin.Context, ch chan interface{}) {
  779. c.Header("Content-Type", "application/json")
  780. for resp := range ch {
  781. switch r := resp.(type) {
  782. case api.ProgressResponse:
  783. if r.Status == "success" {
  784. c.JSON(http.StatusOK, r)
  785. return
  786. }
  787. case gin.H:
  788. if errorMsg, ok := r["error"].(string); ok {
  789. c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
  790. return
  791. } else {
  792. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
  793. return
  794. }
  795. default:
  796. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"})
  797. return
  798. }
  799. }
  800. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"})
  801. }
  802. func streamResponse(c *gin.Context, ch chan any) {
  803. c.Header("Content-Type", "application/x-ndjson")
  804. c.Stream(func(w io.Writer) bool {
  805. val, ok := <-ch
  806. if !ok {
  807. return false
  808. }
  809. bts, err := json.Marshal(val)
  810. if err != nil {
  811. log.Printf("streamResponse: json.Marshal failed with %s", err)
  812. return false
  813. }
  814. // Delineate chunks with new-line delimiter
  815. bts = append(bts, '\n')
  816. if _, err := w.Write(bts); err != nil {
  817. log.Printf("streamResponse: w.Write failed with %s", err)
  818. return false
  819. }
  820. return true
  821. })
  822. }
  823. func ChatHandler(c *gin.Context) {
  824. loaded.mu.Lock()
  825. defer loaded.mu.Unlock()
  826. checkpointStart := time.Now()
  827. var req api.ChatRequest
  828. err := c.ShouldBindJSON(&req)
  829. switch {
  830. case errors.Is(err, io.EOF):
  831. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  832. return
  833. case err != nil:
  834. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  835. return
  836. }
  837. // validate the request
  838. switch {
  839. case req.Model == "":
  840. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  841. return
  842. case len(req.Format) > 0 && req.Format != "json":
  843. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
  844. return
  845. }
  846. sessionDuration := defaultSessionDuration
  847. model, err := load(c, req.Model, req.Options, sessionDuration)
  848. if err != nil {
  849. var pErr *fs.PathError
  850. switch {
  851. case errors.As(err, &pErr):
  852. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
  853. case errors.Is(err, api.ErrInvalidOpts):
  854. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  855. default:
  856. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  857. }
  858. return
  859. }
  860. // an empty request loads the model
  861. if len(req.Messages) == 0 {
  862. c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true, Message: api.Message{Role: "assistant"}})
  863. return
  864. }
  865. checkpointLoaded := time.Now()
  866. prompt, images, err := model.ChatPrompt(req.Messages)
  867. if err != nil {
  868. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  869. return
  870. }
  871. ch := make(chan any)
  872. go func() {
  873. defer close(ch)
  874. fn := func(r llm.PredictResult) {
  875. // Update model expiration
  876. loaded.expireAt = time.Now().Add(sessionDuration)
  877. loaded.expireTimer.Reset(sessionDuration)
  878. resp := api.ChatResponse{
  879. Model: req.Model,
  880. CreatedAt: time.Now().UTC(),
  881. Message: api.Message{Role: "assistant", Content: r.Content},
  882. Done: r.Done,
  883. Metrics: api.Metrics{
  884. PromptEvalCount: r.PromptEvalCount,
  885. PromptEvalDuration: r.PromptEvalDuration,
  886. EvalCount: r.EvalCount,
  887. EvalDuration: r.EvalDuration,
  888. },
  889. }
  890. if r.Done {
  891. resp.TotalDuration = time.Since(checkpointStart)
  892. resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
  893. }
  894. ch <- resp
  895. }
  896. // Start prediction
  897. predictReq := llm.PredictOpts{
  898. Prompt: prompt,
  899. Format: req.Format,
  900. Images: images,
  901. }
  902. if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
  903. ch <- gin.H{"error": err.Error()}
  904. }
  905. }()
  906. if req.Stream != nil && !*req.Stream {
  907. // Accumulate responses into the final response
  908. var final api.ChatResponse
  909. var sb strings.Builder
  910. for resp := range ch {
  911. switch r := resp.(type) {
  912. case api.ChatResponse:
  913. sb.WriteString(r.Message.Content)
  914. final = r
  915. case gin.H:
  916. if errorMsg, ok := r["error"].(string); ok {
  917. c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
  918. return
  919. } else {
  920. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
  921. return
  922. }
  923. default:
  924. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
  925. return
  926. }
  927. }
  928. final.Message = api.Message{Role: "assistant", Content: sb.String()}
  929. c.JSON(http.StatusOK, final)
  930. return
  931. }
  932. streamResponse(c, ch)
  933. }