cmd.go 35 KB


  1. package cmd
  2. import (
  3. "archive/zip"
  4. "bufio"
  5. "bytes"
  6. "context"
  7. "crypto/ed25519"
  8. "crypto/rand"
  9. "crypto/sha256"
  10. "encoding/pem"
  11. "errors"
  12. "fmt"
  13. "io"
  14. "log"
  15. "math"
  16. "net"
  17. "net/http"
  18. "os"
  19. "os/signal"
  20. "path/filepath"
  21. "regexp"
  22. "runtime"
  23. "strconv"
  24. "strings"
  25. "sync/atomic"
  26. "syscall"
  27. "time"
  28. "github.com/containerd/console"
  29. "github.com/mattn/go-runewidth"
  30. "github.com/olekukonko/tablewriter"
  31. "github.com/spf13/cobra"
  32. "golang.org/x/crypto/ssh"
  33. "golang.org/x/term"
  34. "github.com/ollama/ollama/api"
  35. "github.com/ollama/ollama/auth"
  36. "github.com/ollama/ollama/envconfig"
  37. "github.com/ollama/ollama/format"
  38. "github.com/ollama/ollama/parser"
  39. "github.com/ollama/ollama/progress"
  40. "github.com/ollama/ollama/server"
  41. "github.com/ollama/ollama/types/errtypes"
  42. "github.com/ollama/ollama/types/model"
  43. "github.com/ollama/ollama/version"
  44. )
  45. var (
  46. errModelNotFound = errors.New("no Modelfile or safetensors files found")
  47. errModelfileNotFound = errors.New("specified Modelfile wasn't found")
  48. )
  49. func getModelfileName(cmd *cobra.Command) (string, error) {
  50. fn, _ := cmd.Flags().GetString("file")
  51. filename := fn
  52. if filename == "" {
  53. filename = "Modelfile"
  54. }
  55. absName, err := filepath.Abs(filename)
  56. if err != nil {
  57. return "", err
  58. }
  59. _, err = os.Stat(absName)
  60. if err != nil {
  61. return fn, err
  62. }
  63. return absName, nil
  64. }
  65. func CreateHandler(cmd *cobra.Command, args []string) error {
  66. p := progress.NewProgress(os.Stderr)
  67. defer p.Stop()
  68. var reader io.Reader
  69. filename, err := getModelfileName(cmd)
  70. if os.IsNotExist(err) {
  71. if filename == "" {
  72. reader = strings.NewReader("FROM .\n")
  73. } else {
  74. return errModelfileNotFound
  75. }
  76. } else if err != nil {
  77. return err
  78. } else {
  79. f, err := os.Open(filename)
  80. if err != nil {
  81. return err
  82. }
  83. reader = f
  84. defer f.Close()
  85. }
  86. modelfile, err := parser.ParseFile(reader)
  87. if err != nil {
  88. return err
  89. }
  90. home, err := os.UserHomeDir()
  91. if err != nil {
  92. return err
  93. }
  94. status := "transferring model data"
  95. spinner := progress.NewSpinner(status)
  96. p.Add(status, spinner)
  97. defer p.Stop()
  98. client, err := api.ClientFromEnvironment()
  99. if err != nil {
  100. return err
  101. }
  102. for i := range modelfile.Commands {
  103. switch modelfile.Commands[i].Name {
  104. case "model", "adapter":
  105. path := modelfile.Commands[i].Args
  106. if path == "~" {
  107. path = home
  108. } else if strings.HasPrefix(path, "~/") {
  109. path = filepath.Join(home, path[2:])
  110. }
  111. if !filepath.IsAbs(path) {
  112. path = filepath.Join(filepath.Dir(filename), path)
  113. }
  114. fi, err := os.Stat(path)
  115. if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
  116. continue
  117. } else if err != nil {
  118. return err
  119. }
  120. if fi.IsDir() {
  121. // this is likely a safetensors or pytorch directory
  122. // TODO make this work w/ adapters
  123. tempfile, err := tempZipFiles(path)
  124. if err != nil {
  125. return err
  126. }
  127. defer os.RemoveAll(tempfile)
  128. path = tempfile
  129. }
  130. digest, err := createBlob(cmd, client, path, spinner)
  131. if err != nil {
  132. return err
  133. }
  134. modelfile.Commands[i].Args = "@" + digest
  135. }
  136. }
  137. bars := make(map[string]*progress.Bar)
  138. fn := func(resp api.ProgressResponse) error {
  139. if resp.Digest != "" {
  140. spinner.Stop()
  141. bar, ok := bars[resp.Digest]
  142. if !ok {
  143. bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
  144. bars[resp.Digest] = bar
  145. p.Add(resp.Digest, bar)
  146. }
  147. bar.Set(resp.Completed)
  148. } else if status != resp.Status {
  149. spinner.Stop()
  150. status = resp.Status
  151. spinner = progress.NewSpinner(status)
  152. p.Add(status, spinner)
  153. }
  154. return nil
  155. }
  156. quantize, _ := cmd.Flags().GetString("quantize")
  157. request := api.CreateRequest{Name: args[0], Modelfile: modelfile.String(), Quantize: quantize}
  158. if err := client.Create(cmd.Context(), &request, fn); err != nil {
  159. return err
  160. }
  161. return nil
  162. }
  163. func tempZipFiles(path string) (string, error) {
  164. tempfile, err := os.CreateTemp("", "ollama-tf")
  165. if err != nil {
  166. return "", err
  167. }
  168. defer tempfile.Close()
  169. detectContentType := func(path string) (string, error) {
  170. f, err := os.Open(path)
  171. if err != nil {
  172. return "", err
  173. }
  174. defer f.Close()
  175. var b bytes.Buffer
  176. b.Grow(512)
  177. if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
  178. return "", err
  179. }
  180. contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
  181. return contentType, nil
  182. }
  183. glob := func(pattern, contentType string) ([]string, error) {
  184. matches, err := filepath.Glob(pattern)
  185. if err != nil {
  186. return nil, err
  187. }
  188. for _, safetensor := range matches {
  189. if ct, err := detectContentType(safetensor); err != nil {
  190. return nil, err
  191. } else if ct != contentType {
  192. return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
  193. }
  194. }
  195. return matches, nil
  196. }
  197. var files []string
  198. if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
  199. // safetensors files might be unresolved git lfs references; skip if they are
  200. // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
  201. files = append(files, st...)
  202. } else if st, _ := glob(filepath.Join(path, "adapters.safetensors"), "application/octet-stream"); len(st) > 0 {
  203. // covers adapters.safetensors
  204. files = append(files, st...)
  205. } else if st, _ := glob(filepath.Join(path, "adapter_model.safetensors"), "application/octet-stream"); len(st) > 0 {
  206. // covers adapter_model.safetensors
  207. files = append(files, st...)
  208. } else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
  209. // pytorch files might also be unresolved git lfs references; skip if they are
  210. // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
  211. files = append(files, pt...)
  212. } else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 {
  213. // pytorch files might also be unresolved git lfs references; skip if they are
  214. // covers consolidated.x.pth, consolidated.pth
  215. files = append(files, pt...)
  216. } else {
  217. return "", errModelNotFound
  218. }
  219. // add configuration files, json files are detected as text/plain
  220. js, err := glob(filepath.Join(path, "*.json"), "text/plain")
  221. if err != nil {
  222. return "", err
  223. }
  224. files = append(files, js...)
  225. // bert models require a nested config.json
  226. // TODO(mxyng): merge this with the glob above
  227. js, err = glob(filepath.Join(path, "**/*.json"), "text/plain")
  228. if err != nil {
  229. return "", err
  230. }
  231. files = append(files, js...)
  232. if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
  233. // add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
  234. // tokenizer.model might be a unresolved git lfs reference; error if it is
  235. files = append(files, tks...)
  236. } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
  237. // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
  238. files = append(files, tks...)
  239. }
  240. zipfile := zip.NewWriter(tempfile)
  241. defer zipfile.Close()
  242. for _, file := range files {
  243. f, err := os.Open(file)
  244. if err != nil {
  245. return "", err
  246. }
  247. defer f.Close()
  248. fi, err := f.Stat()
  249. if err != nil {
  250. return "", err
  251. }
  252. zfi, err := zip.FileInfoHeader(fi)
  253. if err != nil {
  254. return "", err
  255. }
  256. zfi.Name, err = filepath.Rel(path, file)
  257. if err != nil {
  258. return "", err
  259. }
  260. zf, err := zipfile.CreateHeader(zfi)
  261. if err != nil {
  262. return "", err
  263. }
  264. if _, err := io.Copy(zf, f); err != nil {
  265. return "", err
  266. }
  267. }
  268. return tempfile.Name(), nil
  269. }
  270. func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
  271. bin, err := os.Open(path)
  272. if err != nil {
  273. return "", err
  274. }
  275. defer bin.Close()
  276. // Get file info to retrieve the size
  277. fileInfo, err := bin.Stat()
  278. if err != nil {
  279. return "", err
  280. }
  281. fileSize := fileInfo.Size()
  282. hash := sha256.New()
  283. if _, err := io.Copy(hash, bin); err != nil {
  284. return "", err
  285. }
  286. if _, err := bin.Seek(0, io.SeekStart); err != nil {
  287. return "", err
  288. }
  289. var pw progressWriter
  290. status := "transferring model data 0%"
  291. spinner.SetMessage(status)
  292. done := make(chan struct{})
  293. defer close(done)
  294. go func() {
  295. ticker := time.NewTicker(60 * time.Millisecond)
  296. defer ticker.Stop()
  297. for {
  298. select {
  299. case <-ticker.C:
  300. spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n.Load()/fileSize)))
  301. case <-done:
  302. spinner.SetMessage("transferring model data 100%")
  303. return
  304. }
  305. }
  306. }()
  307. digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
  308. if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
  309. return "", err
  310. }
  311. return digest, nil
  312. }
  313. type progressWriter struct {
  314. n atomic.Int64
  315. }
  316. func (w *progressWriter) Write(p []byte) (n int, err error) {
  317. w.n.Add(int64(len(p)))
  318. return len(p), nil
  319. }
  320. func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
  321. p := progress.NewProgress(os.Stderr)
  322. defer p.StopAndClear()
  323. spinner := progress.NewSpinner("")
  324. p.Add("", spinner)
  325. client, err := api.ClientFromEnvironment()
  326. if err != nil {
  327. return err
  328. }
  329. req := &api.GenerateRequest{
  330. Model: opts.Model,
  331. KeepAlive: opts.KeepAlive,
  332. }
  333. return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
  334. }
  335. func StopHandler(cmd *cobra.Command, args []string) error {
  336. opts := &runOptions{
  337. Model: args[0],
  338. KeepAlive: &api.Duration{Duration: 0},
  339. }
  340. if err := loadOrUnloadModel(cmd, opts); err != nil {
  341. if strings.Contains(err.Error(), "not found") {
  342. return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
  343. }
  344. }
  345. return nil
  346. }
  347. func RunHandler(cmd *cobra.Command, args []string) error {
  348. interactive := true
  349. opts := runOptions{
  350. Model: args[0],
  351. WordWrap: os.Getenv("TERM") == "xterm-256color",
  352. Options: map[string]interface{}{},
  353. }
  354. format, err := cmd.Flags().GetString("format")
  355. if err != nil {
  356. return err
  357. }
  358. opts.Format = format
  359. keepAlive, err := cmd.Flags().GetString("keepalive")
  360. if err != nil {
  361. return err
  362. }
  363. if keepAlive != "" {
  364. d, err := time.ParseDuration(keepAlive)
  365. if err != nil {
  366. return err
  367. }
  368. opts.KeepAlive = &api.Duration{Duration: d}
  369. }
  370. prompts := args[1:]
  371. // prepend stdin to the prompt if provided
  372. if !term.IsTerminal(int(os.Stdin.Fd())) {
  373. in, err := io.ReadAll(os.Stdin)
  374. if err != nil {
  375. return err
  376. }
  377. prompts = append([]string{string(in)}, prompts...)
  378. opts.WordWrap = false
  379. interactive = false
  380. }
  381. opts.Prompt = strings.Join(prompts, " ")
  382. if len(prompts) > 0 {
  383. interactive = false
  384. }
  385. nowrap, err := cmd.Flags().GetBool("nowordwrap")
  386. if err != nil {
  387. return err
  388. }
  389. opts.WordWrap = !nowrap
  390. // Fill out the rest of the options based on information about the
  391. // model.
  392. client, err := api.ClientFromEnvironment()
  393. if err != nil {
  394. return err
  395. }
  396. name := args[0]
  397. info, err := func() (*api.ShowResponse, error) {
  398. showReq := &api.ShowRequest{Name: name}
  399. info, err := client.Show(cmd.Context(), showReq)
  400. var se api.StatusError
  401. if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
  402. if err := PullHandler(cmd, []string{name}); err != nil {
  403. return nil, err
  404. }
  405. return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
  406. }
  407. return info, err
  408. }()
  409. if err != nil {
  410. return err
  411. }
  412. opts.MultiModal = len(info.ProjectorInfo) != 0
  413. opts.ParentModel = info.Details.ParentModel
  414. if interactive {
  415. if err := loadOrUnloadModel(cmd, &opts); err != nil {
  416. return err
  417. }
  418. for _, msg := range info.Messages {
  419. switch msg.Role {
  420. case "user":
  421. fmt.Printf(">>> %s\n", msg.Content)
  422. case "assistant":
  423. state := &displayResponseState{}
  424. displayResponse(msg.Content, opts.WordWrap, state)
  425. fmt.Println()
  426. fmt.Println()
  427. }
  428. }
  429. return generateInteractive(cmd, opts)
  430. }
  431. return generate(cmd, opts)
  432. }
  433. func errFromUnknownKey(unknownKeyErr error) error {
  434. // find SSH public key in the error message
  435. sshKeyPattern := `ssh-\w+ [^\s"]+`
  436. re := regexp.MustCompile(sshKeyPattern)
  437. matches := re.FindStringSubmatch(unknownKeyErr.Error())
  438. if len(matches) > 0 {
  439. serverPubKey := matches[0]
  440. localPubKey, err := auth.GetPublicKey()
  441. if err != nil {
  442. return unknownKeyErr
  443. }
  444. if runtime.GOOS == "linux" && serverPubKey != localPubKey {
  445. // try the ollama service public key
  446. svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
  447. if err != nil {
  448. return unknownKeyErr
  449. }
  450. localPubKey = strings.TrimSpace(string(svcPubKey))
  451. }
  452. // check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
  453. if serverPubKey != localPubKey {
  454. return unknownKeyErr
  455. }
  456. var msg strings.Builder
  457. msg.WriteString(unknownKeyErr.Error())
  458. msg.WriteString("\n\nYour ollama key is:\n")
  459. msg.WriteString(localPubKey)
  460. msg.WriteString("\nAdd your key at:\n")
  461. msg.WriteString("https://ollama.com/settings/keys")
  462. return errors.New(msg.String())
  463. }
  464. return unknownKeyErr
  465. }
  466. func PushHandler(cmd *cobra.Command, args []string) error {
  467. client, err := api.ClientFromEnvironment()
  468. if err != nil {
  469. return err
  470. }
  471. insecure, err := cmd.Flags().GetBool("insecure")
  472. if err != nil {
  473. return err
  474. }
  475. p := progress.NewProgress(os.Stderr)
  476. defer p.Stop()
  477. bars := make(map[string]*progress.Bar)
  478. var status string
  479. var spinner *progress.Spinner
  480. fn := func(resp api.ProgressResponse) error {
  481. if resp.Digest != "" {
  482. if spinner != nil {
  483. spinner.Stop()
  484. }
  485. bar, ok := bars[resp.Digest]
  486. if !ok {
  487. bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
  488. bars[resp.Digest] = bar
  489. p.Add(resp.Digest, bar)
  490. }
  491. bar.Set(resp.Completed)
  492. } else if status != resp.Status {
  493. if spinner != nil {
  494. spinner.Stop()
  495. }
  496. status = resp.Status
  497. spinner = progress.NewSpinner(status)
  498. p.Add(status, spinner)
  499. }
  500. return nil
  501. }
  502. request := api.PushRequest{Name: args[0], Insecure: insecure}
  503. if err := client.Push(cmd.Context(), &request, fn); err != nil {
  504. if spinner != nil {
  505. spinner.Stop()
  506. }
  507. if strings.Contains(err.Error(), "access denied") {
  508. return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
  509. }
  510. host := model.ParseName(args[0]).Host
  511. isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
  512. if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
  513. // the user has not added their ollama key to ollama.com
  514. // re-throw an error with a more user-friendly message
  515. return errFromUnknownKey(err)
  516. }
  517. return err
  518. }
  519. spinner.Stop()
  520. return nil
  521. }
  522. func ListHandler(cmd *cobra.Command, args []string) error {
  523. client, err := api.ClientFromEnvironment()
  524. if err != nil {
  525. return err
  526. }
  527. models, err := client.List(cmd.Context())
  528. if err != nil {
  529. return err
  530. }
  531. var data [][]string
  532. for _, m := range models.Models {
  533. if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
  534. data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
  535. }
  536. }
  537. table := tablewriter.NewWriter(os.Stdout)
  538. table.SetHeader([]string{"NAME", "ID", "SIZE", "MODIFIED"})
  539. table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
  540. table.SetAlignment(tablewriter.ALIGN_LEFT)
  541. table.SetHeaderLine(false)
  542. table.SetBorder(false)
  543. table.SetNoWhiteSpace(true)
  544. table.SetTablePadding(" ")
  545. table.AppendBulk(data)
  546. table.Render()
  547. return nil
  548. }
  549. func ListRunningHandler(cmd *cobra.Command, args []string) error {
  550. client, err := api.ClientFromEnvironment()
  551. if err != nil {
  552. return err
  553. }
  554. models, err := client.ListRunning(cmd.Context())
  555. if err != nil {
  556. return err
  557. }
  558. var data [][]string
  559. for _, m := range models.Models {
  560. if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
  561. var procStr string
  562. switch {
  563. case m.SizeVRAM == 0:
  564. procStr = "100% CPU"
  565. case m.SizeVRAM == m.Size:
  566. procStr = "100% GPU"
  567. case m.SizeVRAM > m.Size || m.Size == 0:
  568. procStr = "Unknown"
  569. default:
  570. sizeCPU := m.Size - m.SizeVRAM
  571. cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
  572. procStr = fmt.Sprintf("%d%%/%d%% CPU/GPU", int(cpuPercent), int(100-cpuPercent))
  573. }
  574. var until string
  575. delta := time.Since(m.ExpiresAt)
  576. if delta > 0 {
  577. until = "Stopping..."
  578. } else {
  579. until = format.HumanTime(m.ExpiresAt, "Never")
  580. }
  581. data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, until})
  582. }
  583. }
  584. table := tablewriter.NewWriter(os.Stdout)
  585. table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "UNTIL"})
  586. table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
  587. table.SetAlignment(tablewriter.ALIGN_LEFT)
  588. table.SetHeaderLine(false)
  589. table.SetBorder(false)
  590. table.SetNoWhiteSpace(true)
  591. table.SetTablePadding(" ")
  592. table.AppendBulk(data)
  593. table.Render()
  594. return nil
  595. }
  596. func DeleteHandler(cmd *cobra.Command, args []string) error {
  597. client, err := api.ClientFromEnvironment()
  598. if err != nil {
  599. return err
  600. }
  601. // Unload the model if it's running before deletion
  602. opts := &runOptions{
  603. Model: args[0],
  604. KeepAlive: &api.Duration{Duration: 0},
  605. }
  606. if err := loadOrUnloadModel(cmd, opts); err != nil {
  607. if !strings.Contains(err.Error(), "not found") {
  608. return fmt.Errorf("unable to stop existing running model \"%s\": %s", args[0], err)
  609. }
  610. }
  611. for _, name := range args {
  612. req := api.DeleteRequest{Name: name}
  613. if err := client.Delete(cmd.Context(), &req); err != nil {
  614. return err
  615. }
  616. fmt.Printf("deleted '%s'\n", name)
  617. }
  618. return nil
  619. }
  620. func ShowHandler(cmd *cobra.Command, args []string) error {
  621. client, err := api.ClientFromEnvironment()
  622. if err != nil {
  623. return err
  624. }
  625. license, errLicense := cmd.Flags().GetBool("license")
  626. modelfile, errModelfile := cmd.Flags().GetBool("modelfile")
  627. parameters, errParams := cmd.Flags().GetBool("parameters")
  628. system, errSystem := cmd.Flags().GetBool("system")
  629. template, errTemplate := cmd.Flags().GetBool("template")
  630. for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate} {
  631. if boolErr != nil {
  632. return errors.New("error retrieving flags")
  633. }
  634. }
  635. flagsSet := 0
  636. showType := ""
  637. if license {
  638. flagsSet++
  639. showType = "license"
  640. }
  641. if modelfile {
  642. flagsSet++
  643. showType = "modelfile"
  644. }
  645. if parameters {
  646. flagsSet++
  647. showType = "parameters"
  648. }
  649. if system {
  650. flagsSet++
  651. showType = "system"
  652. }
  653. if template {
  654. flagsSet++
  655. showType = "template"
  656. }
  657. if flagsSet > 1 {
  658. return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
  659. }
  660. req := api.ShowRequest{Name: args[0]}
  661. resp, err := client.Show(cmd.Context(), &req)
  662. if err != nil {
  663. return err
  664. }
  665. if flagsSet == 1 {
  666. switch showType {
  667. case "license":
  668. fmt.Println(resp.License)
  669. case "modelfile":
  670. fmt.Println(resp.Modelfile)
  671. case "parameters":
  672. fmt.Println(resp.Parameters)
  673. case "system":
  674. fmt.Println(resp.System)
  675. case "template":
  676. fmt.Println(resp.Template)
  677. }
  678. return nil
  679. }
  680. return showInfo(resp, os.Stdout)
  681. }
  682. func showInfo(resp *api.ShowResponse, w io.Writer) error {
  683. tableRender := func(header string, rows func() [][]string) {
  684. fmt.Fprintln(w, " ", header)
  685. table := tablewriter.NewWriter(w)
  686. table.SetAlignment(tablewriter.ALIGN_LEFT)
  687. table.SetBorder(false)
  688. table.SetNoWhiteSpace(true)
  689. table.SetTablePadding(" ")
  690. switch header {
  691. case "Template", "System", "License":
  692. table.SetColWidth(100)
  693. }
  694. table.AppendBulk(rows())
  695. table.Render()
  696. fmt.Fprintln(w)
  697. }
  698. tableRender("Model", func() (rows [][]string) {
  699. if resp.ModelInfo != nil {
  700. arch := resp.ModelInfo["general.architecture"].(string)
  701. rows = append(rows, []string{"", "architecture", arch})
  702. rows = append(rows, []string{"", "parameters", format.HumanNumber(uint64(resp.ModelInfo["general.parameter_count"].(float64)))})
  703. rows = append(rows, []string{"", "context length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64), 'f', -1, 64)})
  704. rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64), 'f', -1, 64)})
  705. } else {
  706. rows = append(rows, []string{"", "architecture", resp.Details.Family})
  707. rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
  708. }
  709. rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
  710. return
  711. })
  712. if resp.ProjectorInfo != nil {
  713. tableRender("Projector", func() (rows [][]string) {
  714. arch := resp.ProjectorInfo["general.architecture"].(string)
  715. rows = append(rows, []string{"", "architecture", arch})
  716. rows = append(rows, []string{"", "parameters", format.HumanNumber(uint64(resp.ProjectorInfo["general.parameter_count"].(float64)))})
  717. rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(resp.ProjectorInfo[fmt.Sprintf("%s.vision.embedding_length", arch)].(float64), 'f', -1, 64)})
  718. rows = append(rows, []string{"", "dimensions", strconv.FormatFloat(resp.ProjectorInfo[fmt.Sprintf("%s.vision.projection_dim", arch)].(float64), 'f', -1, 64)})
  719. return
  720. })
  721. }
  722. if resp.Parameters != "" {
  723. tableRender("Parameters", func() (rows [][]string) {
  724. scanner := bufio.NewScanner(strings.NewReader(resp.Parameters))
  725. for scanner.Scan() {
  726. if text := scanner.Text(); text != "" {
  727. rows = append(rows, append([]string{""}, strings.Fields(text)...))
  728. }
  729. }
  730. return
  731. })
  732. }
  733. head := func(s string, n int) (rows [][]string) {
  734. scanner := bufio.NewScanner(strings.NewReader(s))
  735. for scanner.Scan() && (len(rows) < n || n < 0) {
  736. if text := scanner.Text(); text != "" {
  737. rows = append(rows, []string{"", strings.TrimSpace(text)})
  738. }
  739. }
  740. return
  741. }
  742. if resp.System != "" {
  743. tableRender("System", func() [][]string {
  744. return head(resp.System, 2)
  745. })
  746. }
  747. if resp.License != "" {
  748. tableRender("License", func() [][]string {
  749. return head(resp.License, 2)
  750. })
  751. }
  752. return nil
  753. }
  754. func CopyHandler(cmd *cobra.Command, args []string) error {
  755. client, err := api.ClientFromEnvironment()
  756. if err != nil {
  757. return err
  758. }
  759. req := api.CopyRequest{Source: args[0], Destination: args[1]}
  760. if err := client.Copy(cmd.Context(), &req); err != nil {
  761. return err
  762. }
  763. fmt.Printf("copied '%s' to '%s'\n", args[0], args[1])
  764. return nil
  765. }
  766. func PullHandler(cmd *cobra.Command, args []string) error {
  767. insecure, err := cmd.Flags().GetBool("insecure")
  768. if err != nil {
  769. return err
  770. }
  771. client, err := api.ClientFromEnvironment()
  772. if err != nil {
  773. return err
  774. }
  775. p := progress.NewProgress(os.Stderr)
  776. defer p.Stop()
  777. bars := make(map[string]*progress.Bar)
  778. var status string
  779. var spinner *progress.Spinner
  780. fn := func(resp api.ProgressResponse) error {
  781. if resp.Digest != "" {
  782. if spinner != nil {
  783. spinner.Stop()
  784. }
  785. bar, ok := bars[resp.Digest]
  786. if !ok {
  787. bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
  788. bars[resp.Digest] = bar
  789. p.Add(resp.Digest, bar)
  790. }
  791. bar.Set(resp.Completed)
  792. } else if status != resp.Status {
  793. if spinner != nil {
  794. spinner.Stop()
  795. }
  796. status = resp.Status
  797. spinner = progress.NewSpinner(status)
  798. p.Add(status, spinner)
  799. }
  800. return nil
  801. }
  802. request := api.PullRequest{Name: args[0], Insecure: insecure}
  803. if err := client.Pull(cmd.Context(), &request, fn); err != nil {
  804. return err
  805. }
  806. return nil
  807. }
  808. type generateContextKey string
  809. type runOptions struct {
  810. Model string
  811. ParentModel string
  812. Prompt string
  813. Messages []api.Message
  814. WordWrap bool
  815. Format string
  816. System string
  817. Images []api.ImageData
  818. Options map[string]interface{}
  819. MultiModal bool
  820. KeepAlive *api.Duration
  821. }
  822. type displayResponseState struct {
  823. lineLength int
  824. wordBuffer string
  825. }
  826. func displayResponse(content string, wordWrap bool, state *displayResponseState) {
  827. termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
  828. if wordWrap && termWidth >= 10 {
  829. for _, ch := range content {
  830. if state.lineLength+1 > termWidth-5 {
  831. if runewidth.StringWidth(state.wordBuffer) > termWidth-10 {
  832. fmt.Printf("%s%c", state.wordBuffer, ch)
  833. state.wordBuffer = ""
  834. state.lineLength = 0
  835. continue
  836. }
  837. // backtrack the length of the last word and clear to the end of the line
  838. a := runewidth.StringWidth(state.wordBuffer)
  839. if a > 0 {
  840. fmt.Printf("\x1b[%dD", a)
  841. }
  842. fmt.Printf("\x1b[K\n")
  843. fmt.Printf("%s%c", state.wordBuffer, ch)
  844. chWidth := runewidth.RuneWidth(ch)
  845. state.lineLength = runewidth.StringWidth(state.wordBuffer) + chWidth
  846. } else {
  847. fmt.Print(string(ch))
  848. state.lineLength += runewidth.RuneWidth(ch)
  849. if runewidth.RuneWidth(ch) >= 2 {
  850. state.wordBuffer = ""
  851. continue
  852. }
  853. switch ch {
  854. case ' ':
  855. state.wordBuffer = ""
  856. case '\n':
  857. state.lineLength = 0
  858. default:
  859. state.wordBuffer += string(ch)
  860. }
  861. }
  862. }
  863. } else {
  864. fmt.Printf("%s%s", state.wordBuffer, content)
  865. if len(state.wordBuffer) > 0 {
  866. state.wordBuffer = ""
  867. }
  868. }
  869. }
  870. func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
  871. client, err := api.ClientFromEnvironment()
  872. if err != nil {
  873. return nil, err
  874. }
  875. p := progress.NewProgress(os.Stderr)
  876. defer p.StopAndClear()
  877. spinner := progress.NewSpinner("")
  878. p.Add("", spinner)
  879. cancelCtx, cancel := context.WithCancel(cmd.Context())
  880. defer cancel()
  881. sigChan := make(chan os.Signal, 1)
  882. signal.Notify(sigChan, syscall.SIGINT)
  883. go func() {
  884. <-sigChan
  885. cancel()
  886. }()
  887. var state *displayResponseState = &displayResponseState{}
  888. var latest api.ChatResponse
  889. var fullResponse strings.Builder
  890. var role string
  891. fn := func(response api.ChatResponse) error {
  892. p.StopAndClear()
  893. latest = response
  894. role = response.Message.Role
  895. content := response.Message.Content
  896. fullResponse.WriteString(content)
  897. displayResponse(content, opts.WordWrap, state)
  898. return nil
  899. }
  900. req := &api.ChatRequest{
  901. Model: opts.Model,
  902. Messages: opts.Messages,
  903. Format: opts.Format,
  904. Options: opts.Options,
  905. }
  906. if opts.KeepAlive != nil {
  907. req.KeepAlive = opts.KeepAlive
  908. }
  909. if err := client.Chat(cancelCtx, req, fn); err != nil {
  910. if errors.Is(err, context.Canceled) {
  911. return nil, nil
  912. }
  913. return nil, err
  914. }
  915. if len(opts.Messages) > 0 {
  916. fmt.Println()
  917. fmt.Println()
  918. }
  919. verbose, err := cmd.Flags().GetBool("verbose")
  920. if err != nil {
  921. return nil, err
  922. }
  923. if verbose {
  924. latest.Summary()
  925. }
  926. return &api.Message{Role: role, Content: fullResponse.String()}, nil
  927. }
  928. func generate(cmd *cobra.Command, opts runOptions) error {
  929. client, err := api.ClientFromEnvironment()
  930. if err != nil {
  931. return err
  932. }
  933. p := progress.NewProgress(os.Stderr)
  934. defer p.StopAndClear()
  935. spinner := progress.NewSpinner("")
  936. p.Add("", spinner)
  937. var latest api.GenerateResponse
  938. generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
  939. if !ok {
  940. generateContext = []int{}
  941. }
  942. ctx, cancel := context.WithCancel(cmd.Context())
  943. defer cancel()
  944. sigChan := make(chan os.Signal, 1)
  945. signal.Notify(sigChan, syscall.SIGINT)
  946. go func() {
  947. <-sigChan
  948. cancel()
  949. }()
  950. var state *displayResponseState = &displayResponseState{}
  951. fn := func(response api.GenerateResponse) error {
  952. p.StopAndClear()
  953. latest = response
  954. content := response.Response
  955. displayResponse(content, opts.WordWrap, state)
  956. return nil
  957. }
  958. if opts.MultiModal {
  959. opts.Prompt, opts.Images, err = extractFileData(opts.Prompt)
  960. if err != nil {
  961. return err
  962. }
  963. }
  964. request := api.GenerateRequest{
  965. Model: opts.Model,
  966. Prompt: opts.Prompt,
  967. Context: generateContext,
  968. Images: opts.Images,
  969. Format: opts.Format,
  970. System: opts.System,
  971. Options: opts.Options,
  972. KeepAlive: opts.KeepAlive,
  973. }
  974. if err := client.Generate(ctx, &request, fn); err != nil {
  975. if errors.Is(err, context.Canceled) {
  976. return nil
  977. }
  978. return err
  979. }
  980. if opts.Prompt != "" {
  981. fmt.Println()
  982. fmt.Println()
  983. }
  984. if !latest.Done {
  985. return nil
  986. }
  987. verbose, err := cmd.Flags().GetBool("verbose")
  988. if err != nil {
  989. return err
  990. }
  991. if verbose {
  992. latest.Summary()
  993. }
  994. ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
  995. cmd.SetContext(ctx)
  996. return nil
  997. }
  998. func RunServer(_ *cobra.Command, _ []string) error {
  999. if err := initializeKeypair(); err != nil {
  1000. return err
  1001. }
  1002. ln, err := net.Listen("tcp", envconfig.Host().Host)
  1003. if err != nil {
  1004. return err
  1005. }
  1006. err = server.Serve(ln)
  1007. if errors.Is(err, http.ErrServerClosed) {
  1008. return nil
  1009. }
  1010. return err
  1011. }
  1012. func initializeKeypair() error {
  1013. home, err := os.UserHomeDir()
  1014. if err != nil {
  1015. return err
  1016. }
  1017. privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
  1018. pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
  1019. _, err = os.Stat(privKeyPath)
  1020. if os.IsNotExist(err) {
  1021. fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
  1022. cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
  1023. if err != nil {
  1024. return err
  1025. }
  1026. privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
  1027. if err != nil {
  1028. return err
  1029. }
  1030. if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
  1031. return fmt.Errorf("could not create directory %w", err)
  1032. }
  1033. if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
  1034. return err
  1035. }
  1036. sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
  1037. if err != nil {
  1038. return err
  1039. }
  1040. publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
  1041. if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
  1042. return err
  1043. }
  1044. fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
  1045. }
  1046. return nil
  1047. }
  1048. func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
  1049. client, err := api.ClientFromEnvironment()
  1050. if err != nil {
  1051. return err
  1052. }
  1053. if err := client.Heartbeat(cmd.Context()); err != nil {
  1054. if !strings.Contains(err.Error(), " refused") {
  1055. return err
  1056. }
  1057. if err := startApp(cmd.Context(), client); err != nil {
  1058. return errors.New("could not connect to ollama app, is it running?")
  1059. }
  1060. }
  1061. return nil
  1062. }
  1063. func versionHandler(cmd *cobra.Command, _ []string) {
  1064. client, err := api.ClientFromEnvironment()
  1065. if err != nil {
  1066. return
  1067. }
  1068. serverVersion, err := client.Version(cmd.Context())
  1069. if err != nil {
  1070. fmt.Println("Warning: could not connect to a running Ollama instance")
  1071. }
  1072. if serverVersion != "" {
  1073. fmt.Printf("ollama version is %s\n", serverVersion)
  1074. }
  1075. if serverVersion != version.Version {
  1076. fmt.Printf("Warning: client version is %s\n", version.Version)
  1077. }
  1078. }
  1079. func appendEnvDocs(cmd *cobra.Command, envs []envconfig.EnvVar) {
  1080. if len(envs) == 0 {
  1081. return
  1082. }
  1083. envUsage := `
  1084. Environment Variables:
  1085. `
  1086. for _, e := range envs {
  1087. envUsage += fmt.Sprintf(" %-24s %s\n", e.Name, e.Description)
  1088. }
  1089. cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
  1090. }
  1091. func NewCLI() *cobra.Command {
  1092. log.SetFlags(log.LstdFlags | log.Lshortfile)
  1093. cobra.EnableCommandSorting = false
  1094. if runtime.GOOS == "windows" && term.IsTerminal(int(os.Stdout.Fd())) {
  1095. console.ConsoleFromFile(os.Stdin) //nolint:errcheck
  1096. }
  1097. rootCmd := &cobra.Command{
  1098. Use: "ollama",
  1099. Short: "Large language model runner",
  1100. SilenceUsage: true,
  1101. SilenceErrors: true,
  1102. CompletionOptions: cobra.CompletionOptions{
  1103. DisableDefaultCmd: true,
  1104. },
  1105. Run: func(cmd *cobra.Command, args []string) {
  1106. if version, _ := cmd.Flags().GetBool("version"); version {
  1107. versionHandler(cmd, args)
  1108. return
  1109. }
  1110. cmd.Print(cmd.UsageString())
  1111. },
  1112. }
  1113. rootCmd.Flags().BoolP("version", "v", false, "Show version information")
  1114. createCmd := &cobra.Command{
  1115. Use: "create MODEL",
  1116. Short: "Create a model from a Modelfile",
  1117. Args: cobra.ExactArgs(1),
  1118. PreRunE: checkServerHeartbeat,
  1119. RunE: CreateHandler,
  1120. }
  1121. createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\"")
  1122. createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)")
  1123. showCmd := &cobra.Command{
  1124. Use: "show MODEL",
  1125. Short: "Show information for a model",
  1126. Args: cobra.ExactArgs(1),
  1127. PreRunE: checkServerHeartbeat,
  1128. RunE: ShowHandler,
  1129. }
  1130. showCmd.Flags().Bool("license", false, "Show license of a model")
  1131. showCmd.Flags().Bool("modelfile", false, "Show Modelfile of a model")
  1132. showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
  1133. showCmd.Flags().Bool("template", false, "Show template of a model")
  1134. showCmd.Flags().Bool("system", false, "Show system message of a model")
  1135. runCmd := &cobra.Command{
  1136. Use: "run MODEL [PROMPT]",
  1137. Short: "Run a model",
  1138. Args: cobra.MinimumNArgs(1),
  1139. PreRunE: checkServerHeartbeat,
  1140. RunE: RunHandler,
  1141. }
  1142. runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
  1143. runCmd.Flags().Bool("verbose", false, "Show timings for response")
  1144. runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
  1145. runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
  1146. runCmd.Flags().String("format", "", "Response format (e.g. json)")
  1147. stopCmd := &cobra.Command{
  1148. Use: "stop MODEL",
  1149. Short: "Stop a running model",
  1150. Args: cobra.ExactArgs(1),
  1151. PreRunE: checkServerHeartbeat,
  1152. RunE: StopHandler,
  1153. }
  1154. serveCmd := &cobra.Command{
  1155. Use: "serve",
  1156. Aliases: []string{"start"},
  1157. Short: "Start ollama",
  1158. Args: cobra.ExactArgs(0),
  1159. RunE: RunServer,
  1160. }
  1161. pullCmd := &cobra.Command{
  1162. Use: "pull MODEL",
  1163. Short: "Pull a model from a registry",
  1164. Args: cobra.ExactArgs(1),
  1165. PreRunE: checkServerHeartbeat,
  1166. RunE: PullHandler,
  1167. }
  1168. pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
  1169. pushCmd := &cobra.Command{
  1170. Use: "push MODEL",
  1171. Short: "Push a model to a registry",
  1172. Args: cobra.ExactArgs(1),
  1173. PreRunE: checkServerHeartbeat,
  1174. RunE: PushHandler,
  1175. }
  1176. pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
  1177. listCmd := &cobra.Command{
  1178. Use: "list",
  1179. Aliases: []string{"ls"},
  1180. Short: "List models",
  1181. PreRunE: checkServerHeartbeat,
  1182. RunE: ListHandler,
  1183. }
  1184. psCmd := &cobra.Command{
  1185. Use: "ps",
  1186. Short: "List running models",
  1187. PreRunE: checkServerHeartbeat,
  1188. RunE: ListRunningHandler,
  1189. }
  1190. copyCmd := &cobra.Command{
  1191. Use: "cp SOURCE DESTINATION",
  1192. Short: "Copy a model",
  1193. Args: cobra.ExactArgs(2),
  1194. PreRunE: checkServerHeartbeat,
  1195. RunE: CopyHandler,
  1196. }
  1197. deleteCmd := &cobra.Command{
  1198. Use: "rm MODEL [MODEL...]",
  1199. Short: "Remove a model",
  1200. Args: cobra.MinimumNArgs(1),
  1201. PreRunE: checkServerHeartbeat,
  1202. RunE: DeleteHandler,
  1203. }
  1204. envVars := envconfig.AsMap()
  1205. envs := []envconfig.EnvVar{envVars["OLLAMA_HOST"]}
  1206. for _, cmd := range []*cobra.Command{
  1207. createCmd,
  1208. showCmd,
  1209. runCmd,
  1210. stopCmd,
  1211. pullCmd,
  1212. pushCmd,
  1213. listCmd,
  1214. psCmd,
  1215. copyCmd,
  1216. deleteCmd,
  1217. serveCmd,
  1218. } {
  1219. switch cmd {
  1220. case runCmd:
  1221. appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
  1222. case serveCmd:
  1223. appendEnvDocs(cmd, []envconfig.EnvVar{
  1224. envVars["OLLAMA_DEBUG"],
  1225. envVars["OLLAMA_HOST"],
  1226. envVars["OLLAMA_KEEP_ALIVE"],
  1227. envVars["OLLAMA_MAX_LOADED_MODELS"],
  1228. envVars["OLLAMA_MAX_QUEUE"],
  1229. envVars["OLLAMA_MODELS"],
  1230. envVars["OLLAMA_NUM_PARALLEL"],
  1231. envVars["OLLAMA_NOPRUNE"],
  1232. envVars["OLLAMA_ORIGINS"],
  1233. envVars["OLLAMA_SCHED_SPREAD"],
  1234. envVars["OLLAMA_TMPDIR"],
  1235. envVars["OLLAMA_FLASH_ATTENTION"],
  1236. envVars["OLLAMA_LLM_LIBRARY"],
  1237. envVars["OLLAMA_GPU_OVERHEAD"],
  1238. envVars["OLLAMA_LOAD_TIMEOUT"],
  1239. })
  1240. default:
  1241. appendEnvDocs(cmd, envs)
  1242. }
  1243. }
  1244. rootCmd.AddCommand(
  1245. serveCmd,
  1246. createCmd,
  1247. showCmd,
  1248. runCmd,
  1249. stopCmd,
  1250. pullCmd,
  1251. pushCmd,
  1252. listCmd,
  1253. psCmd,
  1254. copyCmd,
  1255. deleteCmd,
  1256. )
  1257. return rootCmd
  1258. }