routes.go 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608
  1. package server
  2. import (
  3. "bytes"
  4. "cmp"
  5. "context"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "math"
  12. "math/rand"
  13. "mime/multipart"
  14. "net"
  15. "net/http"
  16. "net/netip"
  17. "os"
  18. "os/exec"
  19. "os/signal"
  20. "path/filepath"
  21. "slices"
  22. "strconv"
  23. "strings"
  24. "syscall"
  25. "time"
  26. "github.com/gin-contrib/cors"
  27. "github.com/gin-gonic/gin"
  28. "github.com/ollama/ollama/api"
  29. "github.com/ollama/ollama/envconfig"
  30. "github.com/ollama/ollama/gpu"
  31. "github.com/ollama/ollama/llm"
  32. "github.com/ollama/ollama/openai"
  33. "github.com/ollama/ollama/parser"
  34. "github.com/ollama/ollama/template"
  35. "github.com/ollama/ollama/types/errtypes"
  36. "github.com/ollama/ollama/types/model"
  37. "github.com/ollama/ollama/version"
  38. )
  39. var mode string = gin.DebugMode
  40. type Server struct {
  41. addr net.Addr
  42. sched *Scheduler
  43. }
  44. func init() {
  45. switch mode {
  46. case gin.DebugMode:
  47. case gin.ReleaseMode:
  48. case gin.TestMode:
  49. default:
  50. mode = gin.DebugMode
  51. }
  52. gin.SetMode(mode)
  53. }
  54. var (
  55. errRequired = errors.New("is required")
  56. errBadTemplate = errors.New("template error")
  57. )
  58. func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
  59. opts := api.DefaultOptions()
  60. if err := opts.FromMap(model.Options); err != nil {
  61. return api.Options{}, err
  62. }
  63. if err := opts.FromMap(requestOpts); err != nil {
  64. return api.Options{}, err
  65. }
  66. return opts, nil
  67. }
  68. // scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
  69. // It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
  70. func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
  71. if name == "" {
  72. return nil, nil, nil, fmt.Errorf("model %w", errRequired)
  73. }
  74. model, err := GetModel(name)
  75. if err != nil {
  76. return nil, nil, nil, err
  77. }
  78. if err := model.CheckCapabilities(caps...); err != nil {
  79. return nil, nil, nil, fmt.Errorf("%s %w", name, err)
  80. }
  81. opts, err := modelOptions(model, requestOpts)
  82. if err != nil {
  83. return nil, nil, nil, err
  84. }
  85. runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
  86. var runner *runnerRef
  87. select {
  88. case runner = <-runnerCh:
  89. case err = <-errCh:
  90. return nil, nil, nil, err
  91. }
  92. return runner.llama, model, &opts, nil
  93. }
  94. func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, modelPath string) {
  95. s.sched.whisperMu.Lock()
  96. if s.sched.whisperLoaded[modelPath] != nil {
  97. slog.Info("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath])
  98. portCh <- *s.sched.whisperLoaded[modelPath]
  99. s.sched.whisperMu.Unlock()
  100. return
  101. }
  102. whisperServer := "/Users/royhan-ollama/ollama/llm/whisper.cpp/server"
  103. // Find an available port for whisper
  104. port := 0
  105. params := []string{}
  106. if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
  107. var l *net.TCPListener
  108. if l, err = net.ListenTCP("tcp", a); err == nil {
  109. port = l.Addr().(*net.TCPAddr).Port
  110. l.Close()
  111. }
  112. }
  113. if port == 0 {
  114. slog.Debug("ResolveTCPAddr failed")
  115. port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
  116. }
  117. finalParams := append(params, "--port", strconv.Itoa(port), "--model", modelPath)
  118. cmd := exec.Command(whisperServer, finalParams...)
  119. slog.Info("starting whisper server", "cmd", cmd.String())
  120. cmd.Stdout = os.Stdout
  121. cmd.Stderr = os.Stderr
  122. err := cmd.Start()
  123. if err != nil {
  124. slog.Error("failed to start whisper server", "error", err)
  125. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to start whisper server"})
  126. }
  127. // Wait for server connection
  128. retries := 10
  129. for range retries {
  130. time.Sleep(25 * time.Millisecond)
  131. conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), time.Second)
  132. if err == nil {
  133. conn.Close()
  134. break
  135. }
  136. }
  137. if err != nil {
  138. slog.Error("failed to connect to whisper server", "error", err)
  139. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to connect to whisper server"})
  140. }
  141. portCh <- port
  142. s.sched.whisperLoaded[modelPath] = &port
  143. s.sched.whisperMu.Unlock()
  144. // Wait for the whisper server to exit
  145. defer func() {
  146. err = cmd.Wait()
  147. if err != nil {
  148. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
  149. }
  150. s.sched.whisperMu.Lock()
  151. delete(s.sched.whisperLoaded, modelPath)
  152. s.sched.whisperMu.Unlock()
  153. }()
  154. }
  155. func whisperInference(c *gin.Context, filePath string, port int) (*api.WhisperCompletion, error) {
  156. // Open the file
  157. file, err := os.Open(filePath)
  158. if err != nil {
  159. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to open file"})
  160. return nil, err
  161. }
  162. defer file.Close()
  163. // Create a buffer to hold the multipart form data
  164. buffer := &bytes.Buffer{}
  165. writer := multipart.NewWriter(buffer)
  166. // Add the file to the multipart form
  167. part, err := writer.CreateFormFile("file", filepath.Base(filePath))
  168. if err != nil {
  169. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create form file"})
  170. return nil, err
  171. }
  172. if _, err := io.Copy(part, file); err != nil {
  173. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to copy file"})
  174. return nil, err
  175. }
  176. // Add other fields to the form
  177. if err := writer.WriteField("temperature", "0.0"); err != nil {
  178. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to write field"})
  179. return nil, err
  180. }
  181. // Close the writer to finalize the multipart form
  182. if err := writer.Close(); err != nil {
  183. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to close writer"})
  184. return nil, err
  185. }
  186. endpoint := fmt.Sprintf("http://localhost:%s/inference", strconv.Itoa(port))
  187. serverReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodPost, endpoint, buffer)
  188. if err != nil {
  189. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create request"})
  190. return nil, err
  191. }
  192. serverReq.Header.Set("Content-Type", writer.FormDataContentType())
  193. res, err := http.DefaultClient.Do(serverReq)
  194. if err != nil {
  195. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to send request"})
  196. return nil, err
  197. }
  198. defer res.Body.Close()
  199. body, err := io.ReadAll(res.Body)
  200. if err != nil {
  201. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to read response"})
  202. return nil, err
  203. }
  204. if res.StatusCode >= 400 {
  205. slog.Error("error response from whisper server", "status", res.Status, "body", string(body))
  206. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error response from whisper server"})
  207. }
  208. var w api.WhisperCompletion
  209. if err := json.Unmarshal(body, &w); err != nil {
  210. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to unmarshal response"})
  211. return nil, err
  212. }
  213. return &w, nil
  214. }
  215. func (s *Server) GenerateHandler(c *gin.Context) {
  216. checkpointStart := time.Now()
  217. var req api.GenerateRequest
  218. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  219. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  220. return
  221. } else if err != nil {
  222. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  223. return
  224. }
  225. if req.Format != "" && req.Format != "json" {
  226. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
  227. return
  228. } else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
  229. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
  230. return
  231. }
  232. caps := []Capability{CapabilityCompletion}
  233. if req.Suffix != "" {
  234. caps = append(caps, CapabilityInsert)
  235. }
  236. if req.Audio != "" {
  237. port := make(chan int, 1)
  238. go s.runWhisperServer(c, port, req.WhisperModel)
  239. w, err := whisperInference(c, req.Audio, <-port)
  240. if err != nil {
  241. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
  242. return
  243. }
  244. if req.Transcribe {
  245. c.JSON(http.StatusOK, api.GenerateResponse{
  246. Model: req.Model,
  247. CreatedAt: time.Now().UTC(),
  248. Response: w.Text,
  249. Done: true,
  250. DoneReason: "stop",
  251. })
  252. return
  253. }
  254. req.Prompt += w.Text
  255. }
  256. r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
  257. if errors.Is(err, errCapabilityCompletion) {
  258. c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
  259. return
  260. } else if err != nil {
  261. handleScheduleError(c, req.Model, err)
  262. return
  263. }
  264. checkpointLoaded := time.Now()
  265. if req.Prompt == "" {
  266. c.JSON(http.StatusOK, api.GenerateResponse{
  267. Model: req.Model,
  268. CreatedAt: time.Now().UTC(),
  269. Done: true,
  270. DoneReason: "load",
  271. })
  272. return
  273. }
  274. images := make([]llm.ImageData, len(req.Images))
  275. for i := range req.Images {
  276. images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
  277. }
  278. prompt := req.Prompt
  279. if !req.Raw {
  280. tmpl := m.Template
  281. if req.Template != "" {
  282. tmpl, err = template.Parse(req.Template)
  283. if err != nil {
  284. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  285. return
  286. }
  287. }
  288. var values template.Values
  289. if req.Suffix != "" {
  290. values.Prompt = prompt
  291. values.Suffix = req.Suffix
  292. } else {
  293. var msgs []api.Message
  294. if req.System != "" {
  295. msgs = append(msgs, api.Message{Role: "system", Content: req.System})
  296. } else if m.System != "" {
  297. msgs = append(msgs, api.Message{Role: "system", Content: m.System})
  298. }
  299. if req.Context == nil {
  300. msgs = append(msgs, m.Messages...)
  301. }
  302. for _, i := range images {
  303. msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
  304. }
  305. values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
  306. }
  307. var b bytes.Buffer
  308. if req.Context != nil {
  309. s, err := r.Detokenize(c.Request.Context(), req.Context)
  310. if err != nil {
  311. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  312. return
  313. }
  314. b.WriteString(s)
  315. }
  316. if err := tmpl.Execute(&b, values); err != nil {
  317. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  318. return
  319. }
  320. prompt = b.String()
  321. }
  322. slog.Debug("generate request", "prompt", prompt, "images", images)
  323. ch := make(chan any)
  324. go func() {
  325. // TODO (jmorganca): avoid building the response twice both here and below
  326. var sb strings.Builder
  327. defer close(ch)
  328. if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
  329. Prompt: prompt,
  330. Images: images,
  331. Format: req.Format,
  332. Options: opts,
  333. }, func(cr llm.CompletionResponse) {
  334. res := api.GenerateResponse{
  335. Model: req.Model,
  336. CreatedAt: time.Now().UTC(),
  337. Response: cr.Content,
  338. Done: cr.Done,
  339. DoneReason: cr.DoneReason,
  340. Metrics: api.Metrics{
  341. PromptEvalCount: cr.PromptEvalCount,
  342. PromptEvalDuration: cr.PromptEvalDuration,
  343. EvalCount: cr.EvalCount,
  344. EvalDuration: cr.EvalDuration,
  345. },
  346. }
  347. if _, err := sb.WriteString(cr.Content); err != nil {
  348. ch <- gin.H{"error": err.Error()}
  349. }
  350. if cr.Done {
  351. res.TotalDuration = time.Since(checkpointStart)
  352. res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
  353. if !req.Raw {
  354. tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
  355. if err != nil {
  356. ch <- gin.H{"error": err.Error()}
  357. return
  358. }
  359. res.Context = tokens
  360. }
  361. }
  362. ch <- res
  363. }); err != nil {
  364. ch <- gin.H{"error": err.Error()}
  365. }
  366. }()
  367. if req.Stream != nil && !*req.Stream {
  368. var r api.GenerateResponse
  369. var sb strings.Builder
  370. for rr := range ch {
  371. switch t := rr.(type) {
  372. case api.GenerateResponse:
  373. sb.WriteString(t.Response)
  374. r = t
  375. case gin.H:
  376. msg, ok := t["error"].(string)
  377. if !ok {
  378. msg = "unexpected error format in response"
  379. }
  380. c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
  381. return
  382. default:
  383. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
  384. return
  385. }
  386. }
  387. r.Response = sb.String()
  388. c.JSON(http.StatusOK, r)
  389. return
  390. }
  391. streamResponse(c, ch)
  392. }
  393. func (s *Server) EmbedHandler(c *gin.Context) {
  394. checkpointStart := time.Now()
  395. var req api.EmbedRequest
  396. err := c.ShouldBindJSON(&req)
  397. switch {
  398. case errors.Is(err, io.EOF):
  399. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  400. return
  401. case err != nil:
  402. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  403. return
  404. }
  405. truncate := true
  406. if req.Truncate != nil && !*req.Truncate {
  407. truncate = false
  408. }
  409. var input []string
  410. switch i := req.Input.(type) {
  411. case string:
  412. if len(i) > 0 {
  413. input = append(input, i)
  414. }
  415. case []any:
  416. for _, v := range i {
  417. if _, ok := v.(string); !ok {
  418. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
  419. return
  420. }
  421. input = append(input, v.(string))
  422. }
  423. default:
  424. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
  425. return
  426. }
  427. if len(input) == 0 {
  428. c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
  429. return
  430. }
  431. r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
  432. if err != nil {
  433. handleScheduleError(c, req.Model, err)
  434. return
  435. }
  436. checkpointLoaded := time.Now()
  437. kvData, err := getKVData(m.ModelPath, false)
  438. if err != nil {
  439. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  440. return
  441. }
  442. for i, s := range input {
  443. tokens, err := r.Tokenize(c.Request.Context(), s)
  444. if err != nil {
  445. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  446. return
  447. }
  448. ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
  449. if len(tokens) > ctxLen {
  450. if !truncate {
  451. c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
  452. return
  453. }
  454. tokens = tokens[:ctxLen]
  455. s, err = r.Detokenize(c.Request.Context(), tokens)
  456. if err != nil {
  457. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  458. return
  459. }
  460. }
  461. input[i] = s
  462. }
  463. embeddings, err := r.Embed(c.Request.Context(), input)
  464. if err != nil {
  465. slog.Error("embedding generation failed", "error", err)
  466. c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
  467. return
  468. }
  469. for i, e := range embeddings.Embedding {
  470. embeddings.Embedding[i] = normalize(e)
  471. }
  472. resp := api.EmbedResponse{
  473. Model: req.Model,
  474. Embeddings: embeddings.Embedding,
  475. TotalDuration: time.Since(checkpointStart),
  476. LoadDuration: checkpointLoaded.Sub(checkpointStart),
  477. PromptEvalCount: embeddings.PromptEvalCount,
  478. }
  479. c.JSON(http.StatusOK, resp)
  480. }
  481. func normalize(vec []float32) []float32 {
  482. var sum float32
  483. for _, v := range vec {
  484. sum += v * v
  485. }
  486. norm := float32(0.0)
  487. if sum > 0 {
  488. norm = float32(1.0 / math.Sqrt(float64(sum)))
  489. }
  490. for i := range vec {
  491. vec[i] *= norm
  492. }
  493. return vec
  494. }
  495. func (s *Server) EmbeddingsHandler(c *gin.Context) {
  496. var req api.EmbeddingRequest
  497. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  498. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  499. return
  500. } else if err != nil {
  501. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  502. return
  503. }
  504. r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
  505. if err != nil {
  506. handleScheduleError(c, req.Model, err)
  507. return
  508. }
  509. // an empty request loads the model
  510. if req.Prompt == "" {
  511. c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
  512. return
  513. }
  514. embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
  515. if err != nil {
  516. slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
  517. c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
  518. return
  519. }
  520. embedding := make([]float64, len(embeddings.Embedding[0]))
  521. for i, v := range embeddings.Embedding[0] {
  522. embedding[i] = float64(v)
  523. }
  524. resp := api.EmbeddingResponse{
  525. Embedding: embedding,
  526. }
  527. c.JSON(http.StatusOK, resp)
  528. }
  529. func (s *Server) PullModelHandler(c *gin.Context) {
  530. var req api.PullRequest
  531. err := c.ShouldBindJSON(&req)
  532. switch {
  533. case errors.Is(err, io.EOF):
  534. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  535. return
  536. case err != nil:
  537. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  538. return
  539. }
  540. name := model.ParseName(cmp.Or(req.Model, req.Name))
  541. if !name.IsValid() {
  542. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
  543. return
  544. }
  545. if err := checkNameExists(name); err != nil {
  546. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  547. return
  548. }
  549. ch := make(chan any)
  550. go func() {
  551. defer close(ch)
  552. fn := func(r api.ProgressResponse) {
  553. ch <- r
  554. }
  555. regOpts := &registryOptions{
  556. Insecure: req.Insecure,
  557. }
  558. ctx, cancel := context.WithCancel(c.Request.Context())
  559. defer cancel()
  560. if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
  561. ch <- gin.H{"error": err.Error()}
  562. }
  563. }()
  564. if req.Stream != nil && !*req.Stream {
  565. waitForStream(c, ch)
  566. return
  567. }
  568. streamResponse(c, ch)
  569. }
  570. func (s *Server) PushModelHandler(c *gin.Context) {
  571. var req api.PushRequest
  572. err := c.ShouldBindJSON(&req)
  573. switch {
  574. case errors.Is(err, io.EOF):
  575. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  576. return
  577. case err != nil:
  578. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  579. return
  580. }
  581. var model string
  582. if req.Model != "" {
  583. model = req.Model
  584. } else if req.Name != "" {
  585. model = req.Name
  586. } else {
  587. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  588. return
  589. }
  590. ch := make(chan any)
  591. go func() {
  592. defer close(ch)
  593. fn := func(r api.ProgressResponse) {
  594. ch <- r
  595. }
  596. regOpts := &registryOptions{
  597. Insecure: req.Insecure,
  598. }
  599. ctx, cancel := context.WithCancel(c.Request.Context())
  600. defer cancel()
  601. if err := PushModel(ctx, model, regOpts, fn); err != nil {
  602. ch <- gin.H{"error": err.Error()}
  603. }
  604. }()
  605. if req.Stream != nil && !*req.Stream {
  606. waitForStream(c, ch)
  607. return
  608. }
  609. streamResponse(c, ch)
  610. }
  611. func checkNameExists(name model.Name) error {
  612. names, err := Manifests()
  613. if err != nil {
  614. return err
  615. }
  616. for n := range names {
  617. if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
  618. return errors.New("a model with that name already exists")
  619. }
  620. }
  621. return nil
  622. }
  623. func (s *Server) CreateModelHandler(c *gin.Context) {
  624. var r api.CreateRequest
  625. if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
  626. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  627. return
  628. } else if err != nil {
  629. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  630. return
  631. }
  632. name := model.ParseName(cmp.Or(r.Model, r.Name))
  633. if !name.IsValid() {
  634. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
  635. return
  636. }
  637. if err := checkNameExists(name); err != nil {
  638. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  639. return
  640. }
  641. if r.Path == "" && r.Modelfile == "" {
  642. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
  643. return
  644. }
  645. var sr io.Reader = strings.NewReader(r.Modelfile)
  646. if r.Path != "" && r.Modelfile == "" {
  647. f, err := os.Open(r.Path)
  648. if err != nil {
  649. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
  650. return
  651. }
  652. defer f.Close()
  653. sr = f
  654. }
  655. f, err := parser.ParseFile(sr)
  656. if err != nil {
  657. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  658. return
  659. }
  660. ch := make(chan any)
  661. go func() {
  662. defer close(ch)
  663. fn := func(resp api.ProgressResponse) {
  664. ch <- resp
  665. }
  666. ctx, cancel := context.WithCancel(c.Request.Context())
  667. defer cancel()
  668. quantization := cmp.Or(r.Quantize, r.Quantization)
  669. if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); errors.Is(err, errBadTemplate) {
  670. ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
  671. } else if err != nil {
  672. ch <- gin.H{"error": err.Error()}
  673. }
  674. }()
  675. if r.Stream != nil && !*r.Stream {
  676. waitForStream(c, ch)
  677. return
  678. }
  679. streamResponse(c, ch)
  680. }
  681. func (s *Server) DeleteModelHandler(c *gin.Context) {
  682. var r api.DeleteRequest
  683. if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
  684. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  685. return
  686. } else if err != nil {
  687. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  688. return
  689. }
  690. n := model.ParseName(cmp.Or(r.Model, r.Name))
  691. if !n.IsValid() {
  692. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
  693. return
  694. }
  695. m, err := ParseNamedManifest(n)
  696. if err != nil {
  697. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  698. return
  699. }
  700. if err := m.Remove(); err != nil {
  701. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  702. return
  703. }
  704. if err := m.RemoveLayers(); err != nil {
  705. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  706. return
  707. }
  708. }
  709. func (s *Server) ShowModelHandler(c *gin.Context) {
  710. var req api.ShowRequest
  711. err := c.ShouldBindJSON(&req)
  712. switch {
  713. case errors.Is(err, io.EOF):
  714. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  715. return
  716. case err != nil:
  717. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  718. return
  719. }
  720. if req.Model != "" {
  721. // noop
  722. } else if req.Name != "" {
  723. req.Model = req.Name
  724. } else {
  725. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
  726. return
  727. }
  728. resp, err := GetModelInfo(req)
  729. if err != nil {
  730. switch {
  731. case os.IsNotExist(err):
  732. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
  733. case err.Error() == "invalid model name":
  734. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  735. default:
  736. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  737. }
  738. return
  739. }
  740. c.JSON(http.StatusOK, resp)
  741. }
  742. func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
  743. m, err := GetModel(req.Model)
  744. if err != nil {
  745. return nil, err
  746. }
  747. modelDetails := api.ModelDetails{
  748. ParentModel: m.ParentModel,
  749. Format: m.Config.ModelFormat,
  750. Family: m.Config.ModelFamily,
  751. Families: m.Config.ModelFamilies,
  752. ParameterSize: m.Config.ModelType,
  753. QuantizationLevel: m.Config.FileType,
  754. }
  755. if req.System != "" {
  756. m.System = req.System
  757. }
  758. msgs := make([]api.Message, len(m.Messages))
  759. for i, msg := range m.Messages {
  760. msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
  761. }
  762. n := model.ParseName(req.Model)
  763. if !n.IsValid() {
  764. return nil, errors.New("invalid model name")
  765. }
  766. manifest, err := ParseNamedManifest(n)
  767. if err != nil {
  768. return nil, err
  769. }
  770. resp := &api.ShowResponse{
  771. License: strings.Join(m.License, "\n"),
  772. System: m.System,
  773. Template: m.Template.String(),
  774. Details: modelDetails,
  775. Messages: msgs,
  776. ModifiedAt: manifest.fi.ModTime(),
  777. }
  778. var params []string
  779. cs := 30
  780. for k, v := range m.Options {
  781. switch val := v.(type) {
  782. case []interface{}:
  783. for _, nv := range val {
  784. params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
  785. }
  786. default:
  787. params = append(params, fmt.Sprintf("%-*s %#v", cs, k, v))
  788. }
  789. }
  790. resp.Parameters = strings.Join(params, "\n")
  791. for k, v := range req.Options {
  792. if _, ok := req.Options[k]; ok {
  793. m.Options[k] = v
  794. }
  795. }
  796. var sb strings.Builder
  797. fmt.Fprintln(&sb, "# Modelfile generated by \"ollama show\"")
  798. fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
  799. fmt.Fprintf(&sb, "# FROM %s\n\n", m.ShortName)
  800. fmt.Fprint(&sb, m.String())
  801. resp.Modelfile = sb.String()
  802. kvData, err := getKVData(m.ModelPath, req.Verbose)
  803. if err != nil {
  804. return nil, err
  805. }
  806. delete(kvData, "general.name")
  807. delete(kvData, "tokenizer.chat_template")
  808. resp.ModelInfo = kvData
  809. if len(m.ProjectorPaths) > 0 {
  810. projectorData, err := getKVData(m.ProjectorPaths[0], req.Verbose)
  811. if err != nil {
  812. return nil, err
  813. }
  814. resp.ProjectorInfo = projectorData
  815. }
  816. return resp, nil
  817. }
  818. func getKVData(digest string, verbose bool) (llm.KV, error) {
  819. maxArraySize := 0
  820. if verbose {
  821. maxArraySize = -1
  822. }
  823. kvData, err := llm.LoadModel(digest, maxArraySize)
  824. if err != nil {
  825. return nil, err
  826. }
  827. kv := kvData.KV()
  828. if !verbose {
  829. for k := range kv {
  830. if t, ok := kv[k].([]any); len(t) > 5 && ok {
  831. kv[k] = []any{}
  832. }
  833. }
  834. }
  835. return kv, nil
  836. }
  837. func (s *Server) ListModelsHandler(c *gin.Context) {
  838. ms, err := Manifests()
  839. if err != nil {
  840. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  841. return
  842. }
  843. models := []api.ListModelResponse{}
  844. for n, m := range ms {
  845. var cf ConfigV2
  846. if m.Config.Digest != "" {
  847. f, err := m.Config.Open()
  848. if err != nil {
  849. slog.Warn("bad manifest filepath", "name", n, "error", err)
  850. continue
  851. }
  852. defer f.Close()
  853. if err := json.NewDecoder(f).Decode(&cf); err != nil {
  854. slog.Warn("bad manifest config", "name", n, "error", err)
  855. continue
  856. }
  857. }
  858. // tag should never be masked
  859. models = append(models, api.ListModelResponse{
  860. Model: n.DisplayShortest(),
  861. Name: n.DisplayShortest(),
  862. Size: m.Size(),
  863. Digest: m.digest,
  864. ModifiedAt: m.fi.ModTime(),
  865. Details: api.ModelDetails{
  866. Format: cf.ModelFormat,
  867. Family: cf.ModelFamily,
  868. Families: cf.ModelFamilies,
  869. ParameterSize: cf.ModelType,
  870. QuantizationLevel: cf.FileType,
  871. },
  872. })
  873. }
  874. slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
  875. // most recently modified first
  876. return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
  877. })
  878. c.JSON(http.StatusOK, api.ListResponse{Models: models})
  879. }
  880. func (s *Server) CopyModelHandler(c *gin.Context) {
  881. var r api.CopyRequest
  882. if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
  883. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  884. return
  885. } else if err != nil {
  886. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  887. return
  888. }
  889. src := model.ParseName(r.Source)
  890. if !src.IsValid() {
  891. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
  892. return
  893. }
  894. dst := model.ParseName(r.Destination)
  895. if !dst.IsValid() {
  896. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)})
  897. return
  898. }
  899. if err := checkNameExists(dst); err != nil {
  900. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  901. return
  902. }
  903. if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
  904. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
  905. } else if err != nil {
  906. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  907. }
  908. }
  909. func (s *Server) HeadBlobHandler(c *gin.Context) {
  910. path, err := GetBlobsPath(c.Param("digest"))
  911. if err != nil {
  912. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  913. return
  914. }
  915. if _, err := os.Stat(path); err != nil {
  916. c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
  917. return
  918. }
  919. c.Status(http.StatusOK)
  920. }
  921. func (s *Server) CreateBlobHandler(c *gin.Context) {
  922. if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
  923. p, err := GetBlobsPath(ib)
  924. if err != nil {
  925. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  926. return
  927. }
  928. if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
  929. slog.Info("evicting intermediate blob which no longer exists", "digest", ib)
  930. delete(intermediateBlobs, c.Param("digest"))
  931. } else if err != nil {
  932. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  933. return
  934. } else {
  935. c.Status(http.StatusOK)
  936. return
  937. }
  938. }
  939. path, err := GetBlobsPath(c.Param("digest"))
  940. if err != nil {
  941. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  942. return
  943. }
  944. _, err = os.Stat(path)
  945. switch {
  946. case errors.Is(err, os.ErrNotExist):
  947. // noop
  948. case err != nil:
  949. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  950. return
  951. default:
  952. c.Status(http.StatusOK)
  953. return
  954. }
  955. layer, err := NewLayer(c.Request.Body, "")
  956. if err != nil {
  957. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  958. return
  959. }
  960. if layer.Digest != c.Param("digest") {
  961. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
  962. return
  963. }
  964. c.Status(http.StatusCreated)
  965. }
  966. func isLocalIP(ip netip.Addr) bool {
  967. if interfaces, err := net.Interfaces(); err == nil {
  968. for _, iface := range interfaces {
  969. addrs, err := iface.Addrs()
  970. if err != nil {
  971. continue
  972. }
  973. for _, a := range addrs {
  974. if parsed, _, err := net.ParseCIDR(a.String()); err == nil {
  975. if parsed.String() == ip.String() {
  976. return true
  977. }
  978. }
  979. }
  980. }
  981. }
  982. return false
  983. }
  984. func allowedHost(host string) bool {
  985. if host == "" || host == "localhost" {
  986. return true
  987. }
  988. if hostname, err := os.Hostname(); err == nil && host == hostname {
  989. return true
  990. }
  991. tlds := []string{
  992. "localhost",
  993. "local",
  994. "internal",
  995. }
  996. // check if the host is a local TLD
  997. for _, tld := range tlds {
  998. if strings.HasSuffix(host, "."+tld) {
  999. return true
  1000. }
  1001. }
  1002. return false
  1003. }
  1004. func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
  1005. return func(c *gin.Context) {
  1006. if addr == nil {
  1007. c.Next()
  1008. return
  1009. }
  1010. if addr, err := netip.ParseAddrPort(addr.String()); err == nil && !addr.Addr().IsLoopback() {
  1011. c.Next()
  1012. return
  1013. }
  1014. host, _, err := net.SplitHostPort(c.Request.Host)
  1015. if err != nil {
  1016. host = c.Request.Host
  1017. }
  1018. if addr, err := netip.ParseAddr(host); err == nil {
  1019. if addr.IsLoopback() || addr.IsPrivate() || addr.IsUnspecified() || isLocalIP(addr) {
  1020. c.Next()
  1021. return
  1022. }
  1023. }
  1024. if allowedHost(host) {
  1025. if c.Request.Method == http.MethodOptions {
  1026. c.AbortWithStatus(http.StatusNoContent)
  1027. return
  1028. }
  1029. c.Next()
  1030. return
  1031. }
  1032. c.AbortWithStatus(http.StatusForbidden)
  1033. }
  1034. }
  1035. func (s *Server) GenerateRoutes() http.Handler {
  1036. config := cors.DefaultConfig()
  1037. config.AllowWildcard = true
  1038. config.AllowBrowserExtensions = true
  1039. config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
  1040. openAIProperties := []string{"lang", "package-version", "os", "arch", "runtime", "runtime-version", "async"}
  1041. for _, prop := range openAIProperties {
  1042. config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
  1043. }
  1044. config.AllowOrigins = envconfig.Origins()
  1045. r := gin.Default()
  1046. r.Use(
  1047. cors.New(config),
  1048. allowedHostsMiddleware(s.addr),
  1049. )
  1050. r.POST("/api/pull", s.PullModelHandler)
  1051. r.POST("/api/generate", s.GenerateHandler)
  1052. r.POST("/api/chat", s.ChatHandler)
  1053. r.POST("/api/embed", s.EmbedHandler)
  1054. r.POST("/api/embeddings", s.EmbeddingsHandler)
  1055. r.POST("/api/create", s.CreateModelHandler)
  1056. r.POST("/api/push", s.PushModelHandler)
  1057. r.POST("/api/copy", s.CopyModelHandler)
  1058. r.DELETE("/api/delete", s.DeleteModelHandler)
  1059. r.POST("/api/show", s.ShowModelHandler)
  1060. r.POST("/api/blobs/:digest", s.CreateBlobHandler)
  1061. r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
  1062. r.GET("/api/ps", s.ProcessHandler)
  1063. // Compatibility endpoints
  1064. r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
  1065. r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
  1066. r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
  1067. r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
  1068. r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
  1069. for _, method := range []string{http.MethodGet, http.MethodHead} {
  1070. r.Handle(method, "/", func(c *gin.Context) {
  1071. c.String(http.StatusOK, "Ollama is running")
  1072. })
  1073. r.Handle(method, "/api/tags", s.ListModelsHandler)
  1074. r.Handle(method, "/api/version", func(c *gin.Context) {
  1075. c.JSON(http.StatusOK, gin.H{"version": version.Version})
  1076. })
  1077. }
  1078. return r
  1079. }
  1080. func Serve(ln net.Listener) error {
  1081. level := slog.LevelInfo
  1082. if envconfig.Debug() {
  1083. level = slog.LevelDebug
  1084. }
  1085. slog.Info("server config", "env", envconfig.Values())
  1086. handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
  1087. Level: level,
  1088. AddSource: true,
  1089. ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
  1090. if attr.Key == slog.SourceKey {
  1091. source := attr.Value.Any().(*slog.Source)
  1092. source.File = filepath.Base(source.File)
  1093. }
  1094. return attr
  1095. },
  1096. })
  1097. slog.SetDefault(slog.New(handler))
  1098. blobsDir, err := GetBlobsPath("")
  1099. if err != nil {
  1100. return err
  1101. }
  1102. if err := fixBlobs(blobsDir); err != nil {
  1103. return err
  1104. }
  1105. if !envconfig.NoPrune() {
  1106. // clean up unused layers and manifests
  1107. if err := PruneLayers(); err != nil {
  1108. return err
  1109. }
  1110. manifestsPath, err := GetManifestPath()
  1111. if err != nil {
  1112. return err
  1113. }
  1114. if err := PruneDirectory(manifestsPath); err != nil {
  1115. return err
  1116. }
  1117. }
  1118. ctx, done := context.WithCancel(context.Background())
  1119. schedCtx, schedDone := context.WithCancel(ctx)
  1120. sched := InitScheduler(schedCtx)
  1121. s := &Server{addr: ln.Addr(), sched: sched}
  1122. http.Handle("/", s.GenerateRoutes())
  1123. slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
  1124. srvr := &http.Server{
  1125. // Use http.DefaultServeMux so we get net/http/pprof for
  1126. // free.
  1127. //
  1128. // TODO(bmizerany): Decide if we want to make this
  1129. // configurable so it is not exposed by default, or allow
  1130. // users to bind it to a different port. This was a quick
  1131. // and easy way to get pprof, but it may not be the best
  1132. // way.
  1133. Handler: nil,
  1134. }
  1135. // listen for a ctrl+c and stop any loaded llm
  1136. signals := make(chan os.Signal, 1)
  1137. signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
  1138. go func() {
  1139. <-signals
  1140. srvr.Close()
  1141. schedDone()
  1142. sched.unloadAllRunners()
  1143. gpu.Cleanup()
  1144. done()
  1145. }()
  1146. if err := llm.Init(); err != nil {
  1147. return fmt.Errorf("unable to initialize llm library %w", err)
  1148. }
  1149. s.sched.Run(schedCtx)
  1150. // At startup we retrieve GPU information so we can get log messages before loading a model
  1151. // This will log warnings to the log in case we have problems with detected GPUs
  1152. gpus := gpu.GetGPUInfo()
  1153. gpus.LogDetails()
  1154. err = srvr.Serve(ln)
  1155. // If server is closed from the signal handler, wait for the ctx to be done
  1156. // otherwise error out quickly
  1157. if !errors.Is(err, http.ErrServerClosed) {
  1158. return err
  1159. }
  1160. <-ctx.Done()
  1161. return nil
  1162. }
  1163. func waitForStream(c *gin.Context, ch chan interface{}) {
  1164. c.Header("Content-Type", "application/json")
  1165. for resp := range ch {
  1166. switch r := resp.(type) {
  1167. case api.ProgressResponse:
  1168. if r.Status == "success" {
  1169. c.JSON(http.StatusOK, r)
  1170. return
  1171. }
  1172. case gin.H:
  1173. status, ok := r["status"].(int)
  1174. if !ok {
  1175. status = http.StatusInternalServerError
  1176. }
  1177. if errorMsg, ok := r["error"].(string); ok {
  1178. c.JSON(status, gin.H{"error": errorMsg})
  1179. return
  1180. } else {
  1181. c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
  1182. return
  1183. }
  1184. default:
  1185. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"})
  1186. return
  1187. }
  1188. }
  1189. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"})
  1190. }
  1191. func streamResponse(c *gin.Context, ch chan any) {
  1192. c.Header("Content-Type", "application/x-ndjson")
  1193. c.Stream(func(w io.Writer) bool {
  1194. val, ok := <-ch
  1195. if !ok {
  1196. return false
  1197. }
  1198. bts, err := json.Marshal(val)
  1199. if err != nil {
  1200. slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
  1201. return false
  1202. }
  1203. // Delineate chunks with new-line delimiter
  1204. bts = append(bts, '\n')
  1205. if _, err := w.Write(bts); err != nil {
  1206. slog.Info(fmt.Sprintf("streamResponse: w.Write failed with %s", err))
  1207. return false
  1208. }
  1209. return true
  1210. })
  1211. }
  1212. func (s *Server) ProcessHandler(c *gin.Context) {
  1213. models := []api.ProcessModelResponse{}
  1214. for _, v := range s.sched.loaded {
  1215. model := v.model
  1216. modelDetails := api.ModelDetails{
  1217. Format: model.Config.ModelFormat,
  1218. Family: model.Config.ModelFamily,
  1219. Families: model.Config.ModelFamilies,
  1220. ParameterSize: model.Config.ModelType,
  1221. QuantizationLevel: model.Config.FileType,
  1222. }
  1223. mr := api.ProcessModelResponse{
  1224. Model: model.ShortName,
  1225. Name: model.ShortName,
  1226. Size: int64(v.estimatedTotal),
  1227. SizeVRAM: int64(v.estimatedVRAM),
  1228. Digest: model.Digest,
  1229. Details: modelDetails,
  1230. ExpiresAt: v.expiresAt,
  1231. }
  1232. // The scheduler waits to set expiresAt, so if a model is loading it's
  1233. // possible that it will be set to the unix epoch. For those cases, just
  1234. // calculate the time w/ the sessionDuration instead.
  1235. var epoch time.Time
  1236. if v.expiresAt == epoch {
  1237. mr.ExpiresAt = time.Now().Add(v.sessionDuration)
  1238. }
  1239. models = append(models, mr)
  1240. }
  1241. slices.SortStableFunc(models, func(i, j api.ProcessModelResponse) int {
  1242. // longest duration remaining listed first
  1243. return cmp.Compare(j.ExpiresAt.Unix(), i.ExpiresAt.Unix())
  1244. })
  1245. c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
  1246. }
  1247. func (s *Server) ChatHandler(c *gin.Context) {
  1248. checkpointStart := time.Now()
  1249. var req api.ChatRequest
  1250. if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
  1251. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
  1252. return
  1253. } else if err != nil {
  1254. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  1255. return
  1256. }
  1257. caps := []Capability{CapabilityCompletion}
  1258. if len(req.Tools) > 0 {
  1259. caps = append(caps, CapabilityTools)
  1260. }
  1261. r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
  1262. if errors.Is(err, errCapabilityCompletion) {
  1263. c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
  1264. return
  1265. } else if err != nil {
  1266. handleScheduleError(c, req.Model, err)
  1267. return
  1268. }
  1269. checkpointLoaded := time.Now()
  1270. if len(req.Messages) == 0 {
  1271. c.JSON(http.StatusOK, api.ChatResponse{
  1272. Model: req.Model,
  1273. CreatedAt: time.Now().UTC(),
  1274. Message: api.Message{Role: "assistant"},
  1275. Done: true,
  1276. DoneReason: "load",
  1277. })
  1278. return
  1279. }
  1280. msgs := append(m.Messages, req.Messages...)
  1281. if req.Messages[0].Role != "system" && m.System != "" {
  1282. msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
  1283. }
  1284. prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
  1285. if err != nil {
  1286. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  1287. return
  1288. }
  1289. slog.Debug("chat request", "images", len(images), "prompt", prompt)
  1290. ch := make(chan any)
  1291. go func() {
  1292. defer close(ch)
  1293. if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
  1294. Prompt: prompt,
  1295. Images: images,
  1296. Format: req.Format,
  1297. Options: opts,
  1298. }, func(r llm.CompletionResponse) {
  1299. res := api.ChatResponse{
  1300. Model: req.Model,
  1301. CreatedAt: time.Now().UTC(),
  1302. Message: api.Message{Role: "assistant", Content: r.Content},
  1303. Done: r.Done,
  1304. DoneReason: r.DoneReason,
  1305. Metrics: api.Metrics{
  1306. PromptEvalCount: r.PromptEvalCount,
  1307. PromptEvalDuration: r.PromptEvalDuration,
  1308. EvalCount: r.EvalCount,
  1309. EvalDuration: r.EvalDuration,
  1310. },
  1311. }
  1312. if r.Done {
  1313. res.TotalDuration = time.Since(checkpointStart)
  1314. res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
  1315. }
  1316. ch <- res
  1317. }); err != nil {
  1318. ch <- gin.H{"error": err.Error()}
  1319. }
  1320. }()
  1321. if req.Stream != nil && !*req.Stream {
  1322. var resp api.ChatResponse
  1323. var sb strings.Builder
  1324. for rr := range ch {
  1325. switch t := rr.(type) {
  1326. case api.ChatResponse:
  1327. sb.WriteString(t.Message.Content)
  1328. resp = t
  1329. case gin.H:
  1330. msg, ok := t["error"].(string)
  1331. if !ok {
  1332. msg = "unexpected error format in response"
  1333. }
  1334. c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
  1335. return
  1336. default:
  1337. c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
  1338. return
  1339. }
  1340. }
  1341. resp.Message.Content = sb.String()
  1342. if len(req.Tools) > 0 {
  1343. if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
  1344. resp.Message.ToolCalls = toolCalls
  1345. resp.Message.Content = ""
  1346. }
  1347. }
  1348. c.JSON(http.StatusOK, resp)
  1349. return
  1350. }
  1351. streamResponse(c, ch)
  1352. }
  1353. func handleScheduleError(c *gin.Context, name string, err error) {
  1354. switch {
  1355. case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
  1356. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  1357. case errors.Is(err, context.Canceled):
  1358. c.JSON(499, gin.H{"error": "request canceled"})
  1359. case errors.Is(err, ErrMaxQueue):
  1360. c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
  1361. case errors.Is(err, os.ErrNotExist):
  1362. c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
  1363. default:
  1364. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  1365. }
  1366. }