cmd.go 35 KB


  1. package cmd
  2. import (
  3. "archive/zip"
  4. "bytes"
  5. "context"
  6. "crypto/ed25519"
  7. "crypto/rand"
  8. "crypto/sha256"
  9. "encoding/json"
  10. "encoding/pem"
  11. "errors"
  12. "fmt"
  13. "io"
  14. "log"
  15. "math"
  16. "net"
  17. "net/http"
  18. "net/url"
  19. "os"
  20. "os/signal"
  21. "path/filepath"
  22. "regexp"
  23. "runtime"
  24. "slices"
  25. "strings"
  26. "syscall"
  27. "time"
  28. "github.com/containerd/console"
  29. "github.com/mattn/go-runewidth"
  30. "github.com/olekukonko/tablewriter"
  31. "github.com/spf13/cobra"
  32. "golang.org/x/crypto/ssh"
  33. "golang.org/x/term"
  34. "github.com/ollama/ollama/api"
  35. "github.com/ollama/ollama/auth"
  36. "github.com/ollama/ollama/envconfig"
  37. "github.com/ollama/ollama/format"
  38. "github.com/ollama/ollama/parser"
  39. "github.com/ollama/ollama/progress"
  40. "github.com/ollama/ollama/server"
  41. "github.com/ollama/ollama/types/errtypes"
  42. "github.com/ollama/ollama/types/model"
  43. "github.com/ollama/ollama/version"
  44. )
  45. func CreateHandler(cmd *cobra.Command, args []string) error {
  46. filename, _ := cmd.Flags().GetString("file")
  47. filename, err := filepath.Abs(filename)
  48. if err != nil {
  49. return err
  50. }
  51. client, err := api.ClientFromEnvironment()
  52. if err != nil {
  53. return err
  54. }
  55. p := progress.NewProgress(os.Stderr)
  56. defer p.Stop()
  57. f, err := os.Open(filename)
  58. if err != nil {
  59. return err
  60. }
  61. defer f.Close()
  62. modelfile, err := parser.ParseFile(f)
  63. if err != nil {
  64. return err
  65. }
  66. home, err := os.UserHomeDir()
  67. if err != nil {
  68. return err
  69. }
  70. status := "transferring model data"
  71. spinner := progress.NewSpinner(status)
  72. p.Add(status, spinner)
  73. defer p.Stop()
  74. for i := range modelfile.Commands {
  75. switch modelfile.Commands[i].Name {
  76. case "model", "adapter":
  77. path := modelfile.Commands[i].Args
  78. if path == "~" {
  79. path = home
  80. } else if strings.HasPrefix(path, "~/") {
  81. path = filepath.Join(home, path[2:])
  82. }
  83. if !filepath.IsAbs(path) {
  84. path = filepath.Join(filepath.Dir(filename), path)
  85. }
  86. fi, err := os.Stat(path)
  87. if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
  88. continue
  89. } else if err != nil {
  90. return err
  91. }
  92. if fi.IsDir() {
  93. // this is likely a safetensors or pytorch directory
  94. // TODO make this work w/ adapters
  95. tempfile, err := tempZipFiles(path)
  96. if err != nil {
  97. return err
  98. }
  99. defer os.RemoveAll(tempfile)
  100. path = tempfile
  101. }
  102. digest, err := createBlob(cmd, client, path, spinner)
  103. if err != nil {
  104. return err
  105. }
  106. modelfile.Commands[i].Args = "@" + digest
  107. }
  108. }
  109. bars := make(map[string]*progress.Bar)
  110. fn := func(resp api.ProgressResponse) error {
  111. if resp.Digest != "" {
  112. spinner.Stop()
  113. bar, ok := bars[resp.Digest]
  114. if !ok {
  115. bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
  116. bars[resp.Digest] = bar
  117. p.Add(resp.Digest, bar)
  118. }
  119. bar.Set(resp.Completed)
  120. } else if status != resp.Status {
  121. spinner.Stop()
  122. status = resp.Status
  123. spinner := progress.NewSpinner(status)
  124. p.Add(status, spinner)
  125. }
  126. return nil
  127. }
  128. quantize, _ := cmd.Flags().GetString("quantize")
  129. request := api.CreateRequest{Name: args[0], Modelfile: modelfile.String(), Quantize: quantize}
  130. if err := client.Create(cmd.Context(), &request, fn); err != nil {
  131. return err
  132. }
  133. return nil
  134. }
  135. func tempZipFiles(path string) (string, error) {
  136. tempfile, err := os.CreateTemp("", "ollama-tf")
  137. if err != nil {
  138. return "", err
  139. }
  140. defer tempfile.Close()
  141. detectContentType := func(path string) (string, error) {
  142. f, err := os.Open(path)
  143. if err != nil {
  144. return "", err
  145. }
  146. defer f.Close()
  147. var b bytes.Buffer
  148. b.Grow(512)
  149. if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
  150. return "", err
  151. }
  152. contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
  153. return contentType, nil
  154. }
  155. glob := func(pattern, contentType string) ([]string, error) {
  156. matches, err := filepath.Glob(pattern)
  157. if err != nil {
  158. return nil, err
  159. }
  160. for _, safetensor := range matches {
  161. if ct, err := detectContentType(safetensor); err != nil {
  162. return nil, err
  163. } else if ct != contentType {
  164. return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
  165. }
  166. }
  167. return matches, nil
  168. }
  169. var files []string
  170. if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
  171. // safetensors files might be unresolved git lfs references; skip if they are
  172. // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
  173. files = append(files, st...)
  174. } else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
  175. // pytorch files might also be unresolved git lfs references; skip if they are
  176. // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
  177. files = append(files, pt...)
  178. } else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 {
  179. // pytorch files might also be unresolved git lfs references; skip if they are
  180. // covers consolidated.x.pth, consolidated.pth
  181. files = append(files, pt...)
  182. } else {
  183. return "", errors.New("no safetensors or torch files found")
  184. }
  185. // add configuration files, json files are detected as text/plain
  186. js, err := glob(filepath.Join(path, "*.json"), "text/plain")
  187. if err != nil {
  188. return "", err
  189. }
  190. files = append(files, js...)
  191. if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
  192. // add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
  193. // tokenizer.model might be a unresolved git lfs reference; error if it is
  194. files = append(files, tks...)
  195. } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
  196. // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
  197. files = append(files, tks...)
  198. }
  199. zipfile := zip.NewWriter(tempfile)
  200. defer zipfile.Close()
  201. for _, file := range files {
  202. f, err := os.Open(file)
  203. if err != nil {
  204. return "", err
  205. }
  206. defer f.Close()
  207. fi, err := f.Stat()
  208. if err != nil {
  209. return "", err
  210. }
  211. zfi, err := zip.FileInfoHeader(fi)
  212. if err != nil {
  213. return "", err
  214. }
  215. zf, err := zipfile.CreateHeader(zfi)
  216. if err != nil {
  217. return "", err
  218. }
  219. if _, err := io.Copy(zf, f); err != nil {
  220. return "", err
  221. }
  222. }
  223. return tempfile.Name(), nil
  224. }
  225. var ErrBlobExists = errors.New("blob exists")
  226. func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
  227. bin, err := os.Open(path)
  228. if err != nil {
  229. return "", err
  230. }
  231. defer bin.Close()
  232. // Get file info to retrieve the size
  233. fileInfo, err := bin.Stat()
  234. if err != nil {
  235. return "", err
  236. }
  237. fileSize := fileInfo.Size()
  238. hash := sha256.New()
  239. if _, err := io.Copy(hash, bin); err != nil {
  240. return "", err
  241. }
  242. if _, err := bin.Seek(0, io.SeekStart); err != nil {
  243. return "", err
  244. }
  245. var pw progressWriter
  246. status := "transferring model data 0%"
  247. spinner.SetMessage(status)
  248. ticker := time.NewTicker(60 * time.Millisecond)
  249. done := make(chan struct{})
  250. defer close(done)
  251. go func() {
  252. defer ticker.Stop()
  253. for {
  254. select {
  255. case <-ticker.C:
  256. spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n/fileSize)))
  257. case <-done:
  258. spinner.SetMessage("transferring model data 100%")
  259. return
  260. }
  261. }
  262. }()
  263. digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
  264. // We check if we can find the models directory locally
  265. // If we can, we return the path to the directory
  266. // If we can't, we return an error
  267. // If the blob exists already, we return the digest
  268. dest, err := getLocalPath(cmd.Context(), digest)
  269. if errors.Is(err, ErrBlobExists) {
  270. return digest, nil
  271. }
  272. // Successfully found the model directory
  273. if err == nil {
  274. // Copy blob in via OS specific copy
  275. // Linux errors out to use io.copy
  276. err = localCopy(path, dest)
  277. if err == nil {
  278. return digest, nil
  279. }
  280. // Default copy using io.copy
  281. err = defaultCopy(path, dest)
  282. if err == nil {
  283. return digest, nil
  284. }
  285. }
  286. // If at any point copying the blob over locally fails, we default to the copy through the server
  287. if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
  288. return "", err
  289. }
  290. return digest, nil
  291. }
  292. type progressWriter struct {
  293. n int64
  294. }
  295. func (w *progressWriter) Write(p []byte) (n int, err error) {
  296. w.n += int64(len(p))
  297. return len(p), nil
  298. }
  299. func getLocalPath(ctx context.Context, digest string) (string, error) {
  300. ollamaHost := envconfig.Host
  301. client := http.DefaultClient
  302. base := &url.URL{
  303. Scheme: ollamaHost.Scheme,
  304. Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
  305. }
  306. data, err := json.Marshal(digest)
  307. if err != nil {
  308. return "", err
  309. }
  310. reqBody := bytes.NewReader(data)
  311. path := fmt.Sprintf("/api/blobs/%s", digest)
  312. requestURL := base.JoinPath(path)
  313. request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody)
  314. if err != nil {
  315. return "", err
  316. }
  317. authz, err := api.Authorization(ctx, request)
  318. if err != nil {
  319. return "", err
  320. }
  321. request.Header.Set("Authorization", authz)
  322. request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
  323. request.Header.Set("X-Redirect-Create", "1")
  324. resp, err := client.Do(request)
  325. if err != nil {
  326. return "", err
  327. }
  328. defer resp.Body.Close()
  329. if resp.StatusCode == http.StatusTemporaryRedirect {
  330. dest := resp.Header.Get("LocalLocation")
  331. return dest, nil
  332. }
  333. return "", ErrBlobExists
  334. }
  335. func defaultCopy(path string, dest string) error {
  336. // This function should be called if the server is local
  337. // It should find the model directory, copy the blob over, and return the digest
  338. dirPath := filepath.Dir(dest)
  339. if err := os.MkdirAll(dirPath, 0o755); err != nil {
  340. return err
  341. }
  342. // Copy blob over
  343. sourceFile, err := os.Open(path)
  344. if err != nil {
  345. return fmt.Errorf("could not open source file: %v", err)
  346. }
  347. defer sourceFile.Close()
  348. destFile, err := os.Create(dest)
  349. if err != nil {
  350. return fmt.Errorf("could not create destination file: %v", err)
  351. }
  352. defer destFile.Close()
  353. _, err = io.CopyBuffer(destFile, sourceFile, make([]byte, 4*1024*1024))
  354. if err != nil {
  355. return fmt.Errorf("error copying file: %v", err)
  356. }
  357. err = destFile.Sync()
  358. if err != nil {
  359. return fmt.Errorf("error flushing file: %v", err)
  360. }
  361. return nil
  362. }
  363. func RunHandler(cmd *cobra.Command, args []string) error {
  364. interactive := true
  365. opts := runOptions{
  366. Model: args[0],
  367. WordWrap: os.Getenv("TERM") == "xterm-256color",
  368. Options: map[string]interface{}{},
  369. }
  370. format, err := cmd.Flags().GetString("format")
  371. if err != nil {
  372. return err
  373. }
  374. opts.Format = format
  375. keepAlive, err := cmd.Flags().GetString("keepalive")
  376. if err != nil {
  377. return err
  378. }
  379. if keepAlive != "" {
  380. d, err := time.ParseDuration(keepAlive)
  381. if err != nil {
  382. return err
  383. }
  384. opts.KeepAlive = &api.Duration{Duration: d}
  385. }
  386. prompts := args[1:]
  387. // prepend stdin to the prompt if provided
  388. if !term.IsTerminal(int(os.Stdin.Fd())) {
  389. in, err := io.ReadAll(os.Stdin)
  390. if err != nil {
  391. return err
  392. }
  393. prompts = append([]string{string(in)}, prompts...)
  394. opts.WordWrap = false
  395. interactive = false
  396. }
  397. opts.Prompt = strings.Join(prompts, " ")
  398. if len(prompts) > 0 {
  399. interactive = false
  400. }
  401. nowrap, err := cmd.Flags().GetBool("nowordwrap")
  402. if err != nil {
  403. return err
  404. }
  405. opts.WordWrap = !nowrap
  406. // Fill out the rest of the options based on information about the
  407. // model.
  408. client, err := api.ClientFromEnvironment()
  409. if err != nil {
  410. return err
  411. }
  412. name := args[0]
  413. info, err := func() (*api.ShowResponse, error) {
  414. showReq := &api.ShowRequest{Name: name}
  415. info, err := client.Show(cmd.Context(), showReq)
  416. var se api.StatusError
  417. if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
  418. if err := PullHandler(cmd, []string{name}); err != nil {
  419. return nil, err
  420. }
  421. return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
  422. }
  423. return info, err
  424. }()
  425. if err != nil {
  426. return err
  427. }
  428. opts.MultiModal = slices.Contains(info.Details.Families, "clip")
  429. opts.ParentModel = info.Details.ParentModel
  430. opts.Messages = append(opts.Messages, info.Messages...)
  431. if interactive {
  432. return generateInteractive(cmd, opts)
  433. }
  434. return generate(cmd, opts)
  435. }
  436. func errFromUnknownKey(unknownKeyErr error) error {
  437. // find SSH public key in the error message
  438. sshKeyPattern := `ssh-\w+ [^\s"]+`
  439. re := regexp.MustCompile(sshKeyPattern)
  440. matches := re.FindStringSubmatch(unknownKeyErr.Error())
  441. if len(matches) > 0 {
  442. serverPubKey := matches[0]
  443. publicKey, err := auth.GetPublicKey()
  444. if err != nil {
  445. return unknownKeyErr
  446. }
  447. localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey)))
  448. if runtime.GOOS == "linux" && serverPubKey != localPubKey {
  449. // try the ollama service public key
  450. svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
  451. if err != nil {
  452. return unknownKeyErr
  453. }
  454. localPubKey = strings.TrimSpace(string(svcPubKey))
  455. }
  456. // check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
  457. if serverPubKey != localPubKey {
  458. return unknownKeyErr
  459. }
  460. var msg strings.Builder
  461. msg.WriteString(unknownKeyErr.Error())
  462. msg.WriteString("\n\nYour ollama key is:\n")
  463. msg.WriteString(localPubKey)
  464. msg.WriteString("\nAdd your key at:\n")
  465. msg.WriteString("https://ollama.com/settings/keys")
  466. return errors.New(msg.String())
  467. }
  468. return unknownKeyErr
  469. }
  470. func PushHandler(cmd *cobra.Command, args []string) error {
  471. client, err := api.ClientFromEnvironment()
  472. if err != nil {
  473. return err
  474. }
  475. insecure, err := cmd.Flags().GetBool("insecure")
  476. if err != nil {
  477. return err
  478. }
  479. p := progress.NewProgress(os.Stderr)
  480. defer p.Stop()
  481. bars := make(map[string]*progress.Bar)
  482. var status string
  483. var spinner *progress.Spinner
  484. fn := func(resp api.ProgressResponse) error {
  485. if resp.Digest != "" {
  486. if spinner != nil {
  487. spinner.Stop()
  488. }
  489. bar, ok := bars[resp.Digest]
  490. if !ok {
  491. bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
  492. bars[resp.Digest] = bar
  493. p.Add(resp.Digest, bar)
  494. }
  495. bar.Set(resp.Completed)
  496. } else if status != resp.Status {
  497. if spinner != nil {
  498. spinner.Stop()
  499. }
  500. status = resp.Status
  501. spinner = progress.NewSpinner(status)
  502. p.Add(status, spinner)
  503. }
  504. return nil
  505. }
  506. request := api.PushRequest{Name: args[0], Insecure: insecure}
  507. if err := client.Push(cmd.Context(), &request, fn); err != nil {
  508. if spinner != nil {
  509. spinner.Stop()
  510. }
  511. if strings.Contains(err.Error(), "access denied") {
  512. return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
  513. }
  514. host := model.ParseName(args[0]).Host
  515. isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
  516. if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
  517. // the user has not added their ollama key to ollama.com
  518. // re-throw an error with a more user-friendly message
  519. return errFromUnknownKey(err)
  520. }
  521. return err
  522. }
  523. spinner.Stop()
  524. return nil
  525. }
  526. func ListHandler(cmd *cobra.Command, args []string) error {
  527. client, err := api.ClientFromEnvironment()
  528. if err != nil {
  529. return err
  530. }
  531. models, err := client.List(cmd.Context())
  532. if err != nil {
  533. return err
  534. }
  535. var data [][]string
  536. for _, m := range models.Models {
  537. if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
  538. data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
  539. }
  540. }
  541. table := tablewriter.NewWriter(os.Stdout)
  542. table.SetHeader([]string{"NAME", "ID", "SIZE", "MODIFIED"})
  543. table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
  544. table.SetAlignment(tablewriter.ALIGN_LEFT)
  545. table.SetHeaderLine(false)
  546. table.SetBorder(false)
  547. table.SetNoWhiteSpace(true)
  548. table.SetTablePadding("\t")
  549. table.AppendBulk(data)
  550. table.Render()
  551. return nil
  552. }
  553. func ListRunningHandler(cmd *cobra.Command, args []string) error {
  554. client, err := api.ClientFromEnvironment()
  555. if err != nil {
  556. return err
  557. }
  558. models, err := client.ListRunning(cmd.Context())
  559. if err != nil {
  560. return err
  561. }
  562. var data [][]string
  563. for _, m := range models.Models {
  564. if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
  565. var procStr string
  566. switch {
  567. case m.SizeVRAM == 0:
  568. procStr = "100% CPU"
  569. case m.SizeVRAM == m.Size:
  570. procStr = "100% GPU"
  571. case m.SizeVRAM > m.Size || m.Size == 0:
  572. procStr = "Unknown"
  573. default:
  574. sizeCPU := m.Size - m.SizeVRAM
  575. cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
  576. procStr = fmt.Sprintf("%d%%/%d%% CPU/GPU", int(cpuPercent), int(100-cpuPercent))
  577. }
  578. data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, format.HumanTime(m.ExpiresAt, "Never")})
  579. }
  580. }
  581. table := tablewriter.NewWriter(os.Stdout)
  582. table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "UNTIL"})
  583. table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
  584. table.SetAlignment(tablewriter.ALIGN_LEFT)
  585. table.SetHeaderLine(false)
  586. table.SetBorder(false)
  587. table.SetNoWhiteSpace(true)
  588. table.SetTablePadding("\t")
  589. table.AppendBulk(data)
  590. table.Render()
  591. return nil
  592. }
  593. func DeleteHandler(cmd *cobra.Command, args []string) error {
  594. client, err := api.ClientFromEnvironment()
  595. if err != nil {
  596. return err
  597. }
  598. for _, name := range args {
  599. req := api.DeleteRequest{Name: name}
  600. if err := client.Delete(cmd.Context(), &req); err != nil {
  601. return err
  602. }
  603. fmt.Printf("deleted '%s'\n", name)
  604. }
  605. return nil
  606. }
  607. func ShowHandler(cmd *cobra.Command, args []string) error {
  608. client, err := api.ClientFromEnvironment()
  609. if err != nil {
  610. return err
  611. }
  612. license, errLicense := cmd.Flags().GetBool("license")
  613. modelfile, errModelfile := cmd.Flags().GetBool("modelfile")
  614. parameters, errParams := cmd.Flags().GetBool("parameters")
  615. system, errSystem := cmd.Flags().GetBool("system")
  616. template, errTemplate := cmd.Flags().GetBool("template")
  617. for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate} {
  618. if boolErr != nil {
  619. return errors.New("error retrieving flags")
  620. }
  621. }
  622. flagsSet := 0
  623. showType := ""
  624. if license {
  625. flagsSet++
  626. showType = "license"
  627. }
  628. if modelfile {
  629. flagsSet++
  630. showType = "modelfile"
  631. }
  632. if parameters {
  633. flagsSet++
  634. showType = "parameters"
  635. }
  636. if system {
  637. flagsSet++
  638. showType = "system"
  639. }
  640. if template {
  641. flagsSet++
  642. showType = "template"
  643. }
  644. if flagsSet > 1 {
  645. return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
  646. }
  647. req := api.ShowRequest{Name: args[0]}
  648. resp, err := client.Show(cmd.Context(), &req)
  649. if err != nil {
  650. return err
  651. }
  652. if flagsSet == 1 {
  653. switch showType {
  654. case "license":
  655. fmt.Println(resp.License)
  656. case "modelfile":
  657. fmt.Println(resp.Modelfile)
  658. case "parameters":
  659. fmt.Println(resp.Parameters)
  660. case "system":
  661. fmt.Println(resp.System)
  662. case "template":
  663. fmt.Println(resp.Template)
  664. }
  665. return nil
  666. }
  667. showInfo(resp)
  668. return nil
  669. }
  670. func showInfo(resp *api.ShowResponse) {
  671. arch := resp.ModelInfo["general.architecture"].(string)
  672. modelData := [][]string{
  673. {"arch", arch},
  674. {"parameters", resp.Details.ParameterSize},
  675. {"quantization", resp.Details.QuantizationLevel},
  676. {"context length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))},
  677. {"embedding length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64))},
  678. }
  679. mainTableData := [][]string{
  680. {"Model"},
  681. {renderSubTable(modelData, false)},
  682. }
  683. if resp.ProjectorInfo != nil {
  684. projectorData := [][]string{
  685. {"arch", "clip"},
  686. {"parameters", format.HumanNumber(uint64(resp.ProjectorInfo["general.parameter_count"].(float64)))},
  687. }
  688. if projectorType, ok := resp.ProjectorInfo["clip.projector_type"]; ok {
  689. projectorData = append(projectorData, []string{"projector type", projectorType.(string)})
  690. }
  691. projectorData = append(projectorData,
  692. []string{"embedding length", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.embedding_length"].(float64))},
  693. []string{"projection dimensionality", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.projection_dim"].(float64))},
  694. )
  695. mainTableData = append(mainTableData,
  696. []string{"Projector"},
  697. []string{renderSubTable(projectorData, false)},
  698. )
  699. }
  700. if resp.Parameters != "" {
  701. mainTableData = append(mainTableData, []string{"Parameters"}, []string{formatParams(resp.Parameters)})
  702. }
  703. if resp.System != "" {
  704. mainTableData = append(mainTableData, []string{"System"}, []string{renderSubTable(twoLines(resp.System), true)})
  705. }
  706. if resp.License != "" {
  707. mainTableData = append(mainTableData, []string{"License"}, []string{renderSubTable(twoLines(resp.License), true)})
  708. }
  709. table := tablewriter.NewWriter(os.Stdout)
  710. table.SetAutoWrapText(false)
  711. table.SetBorder(false)
  712. table.SetAlignment(tablewriter.ALIGN_LEFT)
  713. for _, v := range mainTableData {
  714. table.Append(v)
  715. }
  716. table.Render()
  717. }
  718. func renderSubTable(data [][]string, file bool) string {
  719. var buf bytes.Buffer
  720. table := tablewriter.NewWriter(&buf)
  721. table.SetAutoWrapText(!file)
  722. table.SetBorder(false)
  723. table.SetNoWhiteSpace(true)
  724. table.SetTablePadding("\t")
  725. table.SetAlignment(tablewriter.ALIGN_LEFT)
  726. for _, v := range data {
  727. table.Append(v)
  728. }
  729. table.Render()
  730. renderedTable := buf.String()
  731. lines := strings.Split(renderedTable, "\n")
  732. for i, line := range lines {
  733. lines[i] = "\t" + line
  734. }
  735. return strings.Join(lines, "\n")
  736. }
  737. func twoLines(s string) [][]string {
  738. lines := strings.Split(s, "\n")
  739. res := [][]string{}
  740. count := 0
  741. for _, line := range lines {
  742. line = strings.TrimSpace(line)
  743. if line != "" {
  744. count++
  745. res = append(res, []string{line})
  746. if count == 2 {
  747. return res
  748. }
  749. }
  750. }
  751. return res
  752. }
  753. func formatParams(s string) string {
  754. lines := strings.Split(s, "\n")
  755. table := [][]string{}
  756. for _, line := range lines {
  757. table = append(table, strings.Fields(line))
  758. }
  759. return renderSubTable(table, false)
  760. }
  761. func CopyHandler(cmd *cobra.Command, args []string) error {
  762. client, err := api.ClientFromEnvironment()
  763. if err != nil {
  764. return err
  765. }
  766. req := api.CopyRequest{Source: args[0], Destination: args[1]}
  767. if err := client.Copy(cmd.Context(), &req); err != nil {
  768. return err
  769. }
  770. fmt.Printf("copied '%s' to '%s'\n", args[0], args[1])
  771. return nil
  772. }
  773. func PullHandler(cmd *cobra.Command, args []string) error {
  774. insecure, err := cmd.Flags().GetBool("insecure")
  775. if err != nil {
  776. return err
  777. }
  778. client, err := api.ClientFromEnvironment()
  779. if err != nil {
  780. return err
  781. }
  782. p := progress.NewProgress(os.Stderr)
  783. defer p.Stop()
  784. bars := make(map[string]*progress.Bar)
  785. var status string
  786. var spinner *progress.Spinner
  787. fn := func(resp api.ProgressResponse) error {
  788. if resp.Digest != "" {
  789. if spinner != nil {
  790. spinner.Stop()
  791. }
  792. bar, ok := bars[resp.Digest]
  793. if !ok {
  794. bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
  795. bars[resp.Digest] = bar
  796. p.Add(resp.Digest, bar)
  797. }
  798. bar.Set(resp.Completed)
  799. } else if status != resp.Status {
  800. if spinner != nil {
  801. spinner.Stop()
  802. }
  803. status = resp.Status
  804. spinner = progress.NewSpinner(status)
  805. p.Add(status, spinner)
  806. }
  807. return nil
  808. }
  809. request := api.PullRequest{Name: args[0], Insecure: insecure}
  810. if err := client.Pull(cmd.Context(), &request, fn); err != nil {
  811. return err
  812. }
  813. return nil
  814. }
  815. type generateContextKey string
  816. type runOptions struct {
  817. Model string
  818. ParentModel string
  819. Prompt string
  820. Messages []api.Message
  821. WordWrap bool
  822. Format string
  823. System string
  824. Template string
  825. Images []api.ImageData
  826. Options map[string]interface{}
  827. MultiModal bool
  828. KeepAlive *api.Duration
  829. }
  830. type displayResponseState struct {
  831. lineLength int
  832. wordBuffer string
  833. }
  834. func displayResponse(content string, wordWrap bool, state *displayResponseState) {
  835. termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
  836. if wordWrap && termWidth >= 10 {
  837. for _, ch := range content {
  838. if state.lineLength+1 > termWidth-5 {
  839. if runewidth.StringWidth(state.wordBuffer) > termWidth-10 {
  840. fmt.Printf("%s%c", state.wordBuffer, ch)
  841. state.wordBuffer = ""
  842. state.lineLength = 0
  843. continue
  844. }
  845. // backtrack the length of the last word and clear to the end of the line
  846. a := runewidth.StringWidth(state.wordBuffer)
  847. if a > 0 {
  848. fmt.Printf("\x1b[%dD", a)
  849. }
  850. fmt.Printf("\x1b[K\n")
  851. fmt.Printf("%s%c", state.wordBuffer, ch)
  852. chWidth := runewidth.RuneWidth(ch)
  853. state.lineLength = runewidth.StringWidth(state.wordBuffer) + chWidth
  854. } else {
  855. fmt.Print(string(ch))
  856. state.lineLength += runewidth.RuneWidth(ch)
  857. if runewidth.RuneWidth(ch) >= 2 {
  858. state.wordBuffer = ""
  859. continue
  860. }
  861. switch ch {
  862. case ' ':
  863. state.wordBuffer = ""
  864. case '\n':
  865. state.lineLength = 0
  866. default:
  867. state.wordBuffer += string(ch)
  868. }
  869. }
  870. }
  871. } else {
  872. fmt.Printf("%s%s", state.wordBuffer, content)
  873. if len(state.wordBuffer) > 0 {
  874. state.wordBuffer = ""
  875. }
  876. }
  877. }
  878. func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
  879. client, err := api.ClientFromEnvironment()
  880. if err != nil {
  881. return nil, err
  882. }
  883. p := progress.NewProgress(os.Stderr)
  884. defer p.StopAndClear()
  885. spinner := progress.NewSpinner("")
  886. p.Add("", spinner)
  887. cancelCtx, cancel := context.WithCancel(cmd.Context())
  888. defer cancel()
  889. sigChan := make(chan os.Signal, 1)
  890. signal.Notify(sigChan, syscall.SIGINT)
  891. go func() {
  892. <-sigChan
  893. cancel()
  894. }()
  895. var state *displayResponseState = &displayResponseState{}
  896. var latest api.ChatResponse
  897. var fullResponse strings.Builder
  898. var role string
  899. fn := func(response api.ChatResponse) error {
  900. p.StopAndClear()
  901. latest = response
  902. role = response.Message.Role
  903. content := response.Message.Content
  904. fullResponse.WriteString(content)
  905. displayResponse(content, opts.WordWrap, state)
  906. return nil
  907. }
  908. req := &api.ChatRequest{
  909. Model: opts.Model,
  910. Messages: opts.Messages,
  911. Format: opts.Format,
  912. Options: opts.Options,
  913. }
  914. if opts.KeepAlive != nil {
  915. req.KeepAlive = opts.KeepAlive
  916. }
  917. if err := client.Chat(cancelCtx, req, fn); err != nil {
  918. if errors.Is(err, context.Canceled) {
  919. return nil, nil
  920. }
  921. return nil, err
  922. }
  923. if len(opts.Messages) > 0 {
  924. fmt.Println()
  925. fmt.Println()
  926. }
  927. verbose, err := cmd.Flags().GetBool("verbose")
  928. if err != nil {
  929. return nil, err
  930. }
  931. if verbose {
  932. latest.Summary()
  933. }
  934. return &api.Message{Role: role, Content: fullResponse.String()}, nil
  935. }
  936. func generate(cmd *cobra.Command, opts runOptions) error {
  937. client, err := api.ClientFromEnvironment()
  938. if err != nil {
  939. return err
  940. }
  941. p := progress.NewProgress(os.Stderr)
  942. defer p.StopAndClear()
  943. spinner := progress.NewSpinner("")
  944. p.Add("", spinner)
  945. var latest api.GenerateResponse
  946. generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
  947. if !ok {
  948. generateContext = []int{}
  949. }
  950. ctx, cancel := context.WithCancel(cmd.Context())
  951. defer cancel()
  952. sigChan := make(chan os.Signal, 1)
  953. signal.Notify(sigChan, syscall.SIGINT)
  954. go func() {
  955. <-sigChan
  956. cancel()
  957. }()
  958. var state *displayResponseState = &displayResponseState{}
  959. fn := func(response api.GenerateResponse) error {
  960. p.StopAndClear()
  961. latest = response
  962. content := response.Response
  963. displayResponse(content, opts.WordWrap, state)
  964. return nil
  965. }
  966. if opts.MultiModal {
  967. opts.Prompt, opts.Images, err = extractFileData(opts.Prompt)
  968. if err != nil {
  969. return err
  970. }
  971. }
  972. request := api.GenerateRequest{
  973. Model: opts.Model,
  974. Prompt: opts.Prompt,
  975. Context: generateContext,
  976. Images: opts.Images,
  977. Format: opts.Format,
  978. System: opts.System,
  979. Template: opts.Template,
  980. Options: opts.Options,
  981. KeepAlive: opts.KeepAlive,
  982. }
  983. if err := client.Generate(ctx, &request, fn); err != nil {
  984. if errors.Is(err, context.Canceled) {
  985. return nil
  986. }
  987. return err
  988. }
  989. if opts.Prompt != "" {
  990. fmt.Println()
  991. fmt.Println()
  992. }
  993. if !latest.Done {
  994. return nil
  995. }
  996. verbose, err := cmd.Flags().GetBool("verbose")
  997. if err != nil {
  998. return err
  999. }
  1000. if verbose {
  1001. latest.Summary()
  1002. }
  1003. ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
  1004. cmd.SetContext(ctx)
  1005. return nil
  1006. }
  1007. func RunServer(cmd *cobra.Command, _ []string) error {
  1008. if err := initializeKeypair(); err != nil {
  1009. return err
  1010. }
  1011. ln, err := net.Listen("tcp", net.JoinHostPort(envconfig.Host.Host, envconfig.Host.Port))
  1012. if err != nil {
  1013. return err
  1014. }
  1015. err = server.Serve(ln)
  1016. if errors.Is(err, http.ErrServerClosed) {
  1017. return nil
  1018. }
  1019. return err
  1020. }
  1021. func initializeKeypair() error {
  1022. home, err := os.UserHomeDir()
  1023. if err != nil {
  1024. return err
  1025. }
  1026. privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
  1027. pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
  1028. _, err = os.Stat(privKeyPath)
  1029. if os.IsNotExist(err) {
  1030. fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
  1031. cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
  1032. if err != nil {
  1033. return err
  1034. }
  1035. privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
  1036. if err != nil {
  1037. return err
  1038. }
  1039. if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
  1040. return fmt.Errorf("could not create directory %w", err)
  1041. }
  1042. if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
  1043. return err
  1044. }
  1045. sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
  1046. if err != nil {
  1047. return err
  1048. }
  1049. publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
  1050. if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
  1051. return err
  1052. }
  1053. fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
  1054. }
  1055. return nil
  1056. }
  1057. func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
  1058. client, err := api.ClientFromEnvironment()
  1059. if err != nil {
  1060. return err
  1061. }
  1062. if err := client.Heartbeat(cmd.Context()); err != nil {
  1063. if !strings.Contains(err.Error(), " refused") {
  1064. return err
  1065. }
  1066. if err := startApp(cmd.Context(), client); err != nil {
  1067. return fmt.Errorf("could not connect to ollama app, is it running?")
  1068. }
  1069. }
  1070. return nil
  1071. }
  1072. func versionHandler(cmd *cobra.Command, _ []string) {
  1073. client, err := api.ClientFromEnvironment()
  1074. if err != nil {
  1075. return
  1076. }
  1077. serverVersion, err := client.Version(cmd.Context())
  1078. if err != nil {
  1079. fmt.Println("Warning: could not connect to a running Ollama instance")
  1080. }
  1081. if serverVersion != "" {
  1082. fmt.Printf("ollama version is %s\n", serverVersion)
  1083. }
  1084. if serverVersion != version.Version {
  1085. fmt.Printf("Warning: client version is %s\n", version.Version)
  1086. }
  1087. }
  1088. func appendEnvDocs(cmd *cobra.Command, envs []envconfig.EnvVar) {
  1089. if len(envs) == 0 {
  1090. return
  1091. }
  1092. envUsage := `
  1093. Environment Variables:
  1094. `
  1095. for _, e := range envs {
  1096. envUsage += fmt.Sprintf(" %-24s %s\n", e.Name, e.Description)
  1097. }
  1098. cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
  1099. }
  1100. func NewCLI() *cobra.Command {
  1101. log.SetFlags(log.LstdFlags | log.Lshortfile)
  1102. cobra.EnableCommandSorting = false
  1103. if runtime.GOOS == "windows" {
  1104. console.ConsoleFromFile(os.Stdin) //nolint:errcheck
  1105. }
  1106. rootCmd := &cobra.Command{
  1107. Use: "ollama",
  1108. Short: "Large language model runner",
  1109. SilenceUsage: true,
  1110. SilenceErrors: true,
  1111. CompletionOptions: cobra.CompletionOptions{
  1112. DisableDefaultCmd: true,
  1113. },
  1114. Run: func(cmd *cobra.Command, args []string) {
  1115. if version, _ := cmd.Flags().GetBool("version"); version {
  1116. versionHandler(cmd, args)
  1117. return
  1118. }
  1119. cmd.Print(cmd.UsageString())
  1120. },
  1121. }
  1122. rootCmd.Flags().BoolP("version", "v", false, "Show version information")
  1123. createCmd := &cobra.Command{
  1124. Use: "create MODEL",
  1125. Short: "Create a model from a Modelfile",
  1126. Args: cobra.ExactArgs(1),
  1127. PreRunE: checkServerHeartbeat,
  1128. RunE: CreateHandler,
  1129. }
  1130. createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile")
  1131. createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)")
  1132. showCmd := &cobra.Command{
  1133. Use: "show MODEL",
  1134. Short: "Show information for a model",
  1135. Args: cobra.ExactArgs(1),
  1136. PreRunE: checkServerHeartbeat,
  1137. RunE: ShowHandler,
  1138. }
  1139. showCmd.Flags().Bool("license", false, "Show license of a model")
  1140. showCmd.Flags().Bool("modelfile", false, "Show Modelfile of a model")
  1141. showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
  1142. showCmd.Flags().Bool("template", false, "Show template of a model")
  1143. showCmd.Flags().Bool("system", false, "Show system message of a model")
  1144. runCmd := &cobra.Command{
  1145. Use: "run MODEL [PROMPT]",
  1146. Short: "Run a model",
  1147. Args: cobra.MinimumNArgs(1),
  1148. PreRunE: checkServerHeartbeat,
  1149. RunE: RunHandler,
  1150. }
  1151. runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
  1152. runCmd.Flags().Bool("verbose", false, "Show timings for response")
  1153. runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
  1154. runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
  1155. runCmd.Flags().String("format", "", "Response format (e.g. json)")
  1156. serveCmd := &cobra.Command{
  1157. Use: "serve",
  1158. Aliases: []string{"start"},
  1159. Short: "Start ollama",
  1160. Args: cobra.ExactArgs(0),
  1161. RunE: RunServer,
  1162. }
  1163. pullCmd := &cobra.Command{
  1164. Use: "pull MODEL",
  1165. Short: "Pull a model from a registry",
  1166. Args: cobra.ExactArgs(1),
  1167. PreRunE: checkServerHeartbeat,
  1168. RunE: PullHandler,
  1169. }
  1170. pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
  1171. pushCmd := &cobra.Command{
  1172. Use: "push MODEL",
  1173. Short: "Push a model to a registry",
  1174. Args: cobra.ExactArgs(1),
  1175. PreRunE: checkServerHeartbeat,
  1176. RunE: PushHandler,
  1177. }
  1178. pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
  1179. listCmd := &cobra.Command{
  1180. Use: "list",
  1181. Aliases: []string{"ls"},
  1182. Short: "List models",
  1183. PreRunE: checkServerHeartbeat,
  1184. RunE: ListHandler,
  1185. }
  1186. psCmd := &cobra.Command{
  1187. Use: "ps",
  1188. Short: "List running models",
  1189. PreRunE: checkServerHeartbeat,
  1190. RunE: ListRunningHandler,
  1191. }
  1192. copyCmd := &cobra.Command{
  1193. Use: "cp SOURCE DESTINATION",
  1194. Short: "Copy a model",
  1195. Args: cobra.ExactArgs(2),
  1196. PreRunE: checkServerHeartbeat,
  1197. RunE: CopyHandler,
  1198. }
  1199. deleteCmd := &cobra.Command{
  1200. Use: "rm MODEL [MODEL...]",
  1201. Short: "Remove a model",
  1202. Args: cobra.MinimumNArgs(1),
  1203. PreRunE: checkServerHeartbeat,
  1204. RunE: DeleteHandler,
  1205. }
  1206. envVars := envconfig.AsMap()
  1207. envs := []envconfig.EnvVar{envVars["OLLAMA_HOST"]}
  1208. for _, cmd := range []*cobra.Command{
  1209. createCmd,
  1210. showCmd,
  1211. runCmd,
  1212. pullCmd,
  1213. pushCmd,
  1214. listCmd,
  1215. psCmd,
  1216. copyCmd,
  1217. deleteCmd,
  1218. serveCmd,
  1219. } {
  1220. switch cmd {
  1221. case runCmd:
  1222. appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
  1223. case serveCmd:
  1224. appendEnvDocs(cmd, []envconfig.EnvVar{
  1225. envVars["OLLAMA_DEBUG"],
  1226. envVars["OLLAMA_HOST"],
  1227. envVars["OLLAMA_KEEP_ALIVE"],
  1228. envVars["OLLAMA_MAX_LOADED_MODELS"],
  1229. envVars["OLLAMA_MAX_QUEUE"],
  1230. envVars["OLLAMA_MODELS"],
  1231. envVars["OLLAMA_NUM_PARALLEL"],
  1232. envVars["OLLAMA_NOPRUNE"],
  1233. envVars["OLLAMA_ORIGINS"],
  1234. envVars["OLLAMA_TMPDIR"],
  1235. envVars["OLLAMA_FLASH_ATTENTION"],
  1236. envVars["OLLAMA_LLM_LIBRARY"],
  1237. envVars["OLLAMA_MAX_VRAM"],
  1238. })
  1239. default:
  1240. appendEnvDocs(cmd, envs)
  1241. }
  1242. }
  1243. rootCmd.AddCommand(
  1244. serveCmd,
  1245. createCmd,
  1246. showCmd,
  1247. runCmd,
  1248. pullCmd,
  1249. pushCmd,
  1250. listCmd,
  1251. psCmd,
  1252. copyCmd,
  1253. deleteCmd,
  1254. )
  1255. return rootCmd
  1256. }