cmd.go 28 KB


  1. package cmd
  2. import (
  3. "archive/zip"
  4. "bytes"
  5. "context"
  6. "crypto/ed25519"
  7. "crypto/rand"
  8. "crypto/sha256"
  9. "encoding/pem"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "log"
  14. "math"
  15. "net"
  16. "net/http"
  17. "os"
  18. "os/signal"
  19. "path/filepath"
  20. "regexp"
  21. "runtime"
  22. "strings"
  23. "syscall"
  24. "time"
  25. "github.com/containerd/console"
  26. "github.com/olekukonko/tablewriter"
  27. "github.com/spf13/cobra"
  28. "golang.org/x/crypto/ssh"
  29. "golang.org/x/exp/slices"
  30. "golang.org/x/term"
  31. "github.com/ollama/ollama/api"
  32. "github.com/ollama/ollama/auth"
  33. "github.com/ollama/ollama/format"
  34. "github.com/ollama/ollama/progress"
  35. "github.com/ollama/ollama/server"
  36. "github.com/ollama/ollama/types/errtypes"
  37. "github.com/ollama/ollama/types/model"
  38. "github.com/ollama/ollama/version"
  39. )
  40. func CreateHandler(cmd *cobra.Command, args []string) error {
  41. filename, _ := cmd.Flags().GetString("file")
  42. filename, err := filepath.Abs(filename)
  43. if err != nil {
  44. return err
  45. }
  46. client, err := api.ClientFromEnvironment()
  47. if err != nil {
  48. return err
  49. }
  50. p := progress.NewProgress(os.Stderr)
  51. defer p.Stop()
  52. f, err := os.Open(filename)
  53. if err != nil {
  54. return err
  55. }
  56. defer f.Close()
  57. modelfile, err := model.ParseFile(f)
  58. if err != nil {
  59. return err
  60. }
  61. home, err := os.UserHomeDir()
  62. if err != nil {
  63. return err
  64. }
  65. status := "transferring model data"
  66. spinner := progress.NewSpinner(status)
  67. p.Add(status, spinner)
  68. for i := range modelfile.Commands {
  69. switch modelfile.Commands[i].Name {
  70. case "model", "adapter":
  71. path := modelfile.Commands[i].Args
  72. if path == "~" {
  73. path = home
  74. } else if strings.HasPrefix(path, "~/") {
  75. path = filepath.Join(home, path[2:])
  76. }
  77. if !filepath.IsAbs(path) {
  78. path = filepath.Join(filepath.Dir(filename), path)
  79. }
  80. fi, err := os.Stat(path)
  81. if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
  82. continue
  83. } else if err != nil {
  84. return err
  85. }
  86. if fi.IsDir() {
  87. // this is likely a safetensors or pytorch directory
  88. // TODO make this work w/ adapters
  89. tempfile, err := tempZipFiles(path)
  90. if err != nil {
  91. return err
  92. }
  93. defer os.RemoveAll(tempfile)
  94. path = tempfile
  95. }
  96. digest, err := createBlob(cmd, client, path)
  97. if err != nil {
  98. return err
  99. }
  100. modelfile.Commands[i].Args = "@" + digest
  101. }
  102. }
  103. bars := make(map[string]*progress.Bar)
  104. fn := func(resp api.ProgressResponse) error {
  105. if resp.Digest != "" {
  106. spinner.Stop()
  107. bar, ok := bars[resp.Digest]
  108. if !ok {
  109. bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
  110. bars[resp.Digest] = bar
  111. p.Add(resp.Digest, bar)
  112. }
  113. bar.Set(resp.Completed)
  114. } else if status != resp.Status {
  115. spinner.Stop()
  116. status = resp.Status
  117. spinner = progress.NewSpinner(status)
  118. p.Add(status, spinner)
  119. }
  120. return nil
  121. }
  122. quantize, _ := cmd.Flags().GetString("quantize")
  123. request := api.CreateRequest{Name: args[0], Modelfile: modelfile.String(), Quantize: quantize}
  124. if err := client.Create(cmd.Context(), &request, fn); err != nil {
  125. return err
  126. }
  127. return nil
  128. }
  129. func tempZipFiles(path string) (string, error) {
  130. tempfile, err := os.CreateTemp("", "ollama-tf")
  131. if err != nil {
  132. return "", err
  133. }
  134. defer tempfile.Close()
  135. zipfile := zip.NewWriter(tempfile)
  136. defer zipfile.Close()
  137. detectContentType := func(path string) (string, error) {
  138. f, err := os.Open(path)
  139. if err != nil {
  140. return "", err
  141. }
  142. defer f.Close()
  143. var b bytes.Buffer
  144. b.Grow(512)
  145. if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
  146. return "", err
  147. }
  148. contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
  149. return contentType, nil
  150. }
  151. glob := func(pattern, contentType string) ([]string, error) {
  152. matches, err := filepath.Glob(pattern)
  153. if err != nil {
  154. return nil, err
  155. }
  156. for _, safetensor := range matches {
  157. if ct, err := detectContentType(safetensor); err != nil {
  158. return nil, err
  159. } else if ct != contentType {
  160. return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
  161. }
  162. }
  163. return matches, nil
  164. }
  165. var files []string
  166. if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
  167. // safetensors files might be unresolved git lfs references; skip if they are
  168. // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
  169. files = append(files, st...)
  170. } else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
  171. // pytorch files might also be unresolved git lfs references; skip if they are
  172. // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
  173. files = append(files, pt...)
  174. } else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/octet-stream"); len(pt) > 0 {
  175. // pytorch files might also be unresolved git lfs references; skip if they are
  176. // covers consolidated.x.pth, consolidated.pth
  177. files = append(files, pt...)
  178. } else {
  179. return "", errors.New("no safetensors or torch files found")
  180. }
  181. // add configuration files, json files are detected as text/plain
  182. js, err := glob(filepath.Join(path, "*.json"), "text/plain")
  183. if err != nil {
  184. return "", err
  185. }
  186. files = append(files, js...)
  187. if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
  188. // add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
  189. // tokenizer.model might be a unresolved git lfs reference; error if it is
  190. files = append(files, tks...)
  191. } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
  192. // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
  193. files = append(files, tks...)
  194. }
  195. for _, file := range files {
  196. f, err := os.Open(file)
  197. if err != nil {
  198. return "", err
  199. }
  200. defer f.Close()
  201. fi, err := f.Stat()
  202. if err != nil {
  203. return "", err
  204. }
  205. zfi, err := zip.FileInfoHeader(fi)
  206. if err != nil {
  207. return "", err
  208. }
  209. zf, err := zipfile.CreateHeader(zfi)
  210. if err != nil {
  211. return "", err
  212. }
  213. if _, err := io.Copy(zf, f); err != nil {
  214. return "", err
  215. }
  216. }
  217. return tempfile.Name(), nil
  218. }
  219. func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
  220. bin, err := os.Open(path)
  221. if err != nil {
  222. return "", err
  223. }
  224. defer bin.Close()
  225. hash := sha256.New()
  226. if _, err := io.Copy(hash, bin); err != nil {
  227. return "", err
  228. }
  229. if _, err := bin.Seek(0, io.SeekStart); err != nil {
  230. return "", err
  231. }
  232. digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
  233. if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
  234. return "", err
  235. }
  236. return digest, nil
  237. }
  238. func RunHandler(cmd *cobra.Command, args []string) error {
  239. client, err := api.ClientFromEnvironment()
  240. if err != nil {
  241. return err
  242. }
  243. name := args[0]
  244. // check if the model exists on the server
  245. show, err := client.Show(cmd.Context(), &api.ShowRequest{Name: name})
  246. var statusError api.StatusError
  247. switch {
  248. case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
  249. if err := PullHandler(cmd, []string{name}); err != nil {
  250. return err
  251. }
  252. show, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
  253. if err != nil {
  254. return err
  255. }
  256. case err != nil:
  257. return err
  258. }
  259. interactive := true
  260. opts := runOptions{
  261. Model: args[0],
  262. WordWrap: os.Getenv("TERM") == "xterm-256color",
  263. Options: map[string]interface{}{},
  264. MultiModal: slices.Contains(show.Details.Families, "clip"),
  265. ParentModel: show.Details.ParentModel,
  266. }
  267. format, err := cmd.Flags().GetString("format")
  268. if err != nil {
  269. return err
  270. }
  271. opts.Format = format
  272. keepAlive, err := cmd.Flags().GetString("keepalive")
  273. if err != nil {
  274. return err
  275. }
  276. if keepAlive != "" {
  277. d, err := time.ParseDuration(keepAlive)
  278. if err != nil {
  279. return err
  280. }
  281. opts.KeepAlive = &api.Duration{Duration: d}
  282. }
  283. prompts := args[1:]
  284. // prepend stdin to the prompt if provided
  285. if !term.IsTerminal(int(os.Stdin.Fd())) {
  286. in, err := io.ReadAll(os.Stdin)
  287. if err != nil {
  288. return err
  289. }
  290. prompts = append([]string{string(in)}, prompts...)
  291. opts.WordWrap = false
  292. interactive = false
  293. }
  294. opts.Prompt = strings.Join(prompts, " ")
  295. if len(prompts) > 0 {
  296. interactive = false
  297. }
  298. nowrap, err := cmd.Flags().GetBool("nowordwrap")
  299. if err != nil {
  300. return err
  301. }
  302. opts.WordWrap = !nowrap
  303. if !interactive {
  304. return generate(cmd, opts)
  305. }
  306. return generateInteractive(cmd, opts)
  307. }
  308. func errFromUnknownKey(unknownKeyErr error) error {
  309. // find SSH public key in the error message
  310. sshKeyPattern := `ssh-\w+ [^\s"]+`
  311. re := regexp.MustCompile(sshKeyPattern)
  312. matches := re.FindStringSubmatch(unknownKeyErr.Error())
  313. if len(matches) > 0 {
  314. serverPubKey := matches[0]
  315. localPubKey, err := auth.GetPublicKey()
  316. if err != nil {
  317. return unknownKeyErr
  318. }
  319. if runtime.GOOS == "linux" && serverPubKey != localPubKey {
  320. // try the ollama service public key
  321. svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
  322. if err != nil {
  323. return unknownKeyErr
  324. }
  325. localPubKey = strings.TrimSpace(string(svcPubKey))
  326. }
  327. // check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
  328. if serverPubKey != localPubKey {
  329. return unknownKeyErr
  330. }
  331. var msg strings.Builder
  332. msg.WriteString(unknownKeyErr.Error())
  333. msg.WriteString("\n\nYour ollama key is:\n")
  334. msg.WriteString(localPubKey)
  335. msg.WriteString("\nAdd your key at:\n")
  336. msg.WriteString("https://ollama.com/settings/keys")
  337. return errors.New(msg.String())
  338. }
  339. return unknownKeyErr
  340. }
  341. func PushHandler(cmd *cobra.Command, args []string) error {
  342. client, err := api.ClientFromEnvironment()
  343. if err != nil {
  344. return err
  345. }
  346. insecure, err := cmd.Flags().GetBool("insecure")
  347. if err != nil {
  348. return err
  349. }
  350. p := progress.NewProgress(os.Stderr)
  351. defer p.Stop()
  352. bars := make(map[string]*progress.Bar)
  353. var status string
  354. var spinner *progress.Spinner
  355. fn := func(resp api.ProgressResponse) error {
  356. if resp.Digest != "" {
  357. if spinner != nil {
  358. spinner.Stop()
  359. }
  360. bar, ok := bars[resp.Digest]
  361. if !ok {
  362. bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
  363. bars[resp.Digest] = bar
  364. p.Add(resp.Digest, bar)
  365. }
  366. bar.Set(resp.Completed)
  367. } else if status != resp.Status {
  368. if spinner != nil {
  369. spinner.Stop()
  370. }
  371. status = resp.Status
  372. spinner = progress.NewSpinner(status)
  373. p.Add(status, spinner)
  374. }
  375. return nil
  376. }
  377. request := api.PushRequest{Name: args[0], Insecure: insecure}
  378. if err := client.Push(cmd.Context(), &request, fn); err != nil {
  379. if spinner != nil {
  380. spinner.Stop()
  381. }
  382. if strings.Contains(err.Error(), "access denied") {
  383. return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
  384. }
  385. host := model.ParseName(args[0]).Host
  386. isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
  387. if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
  388. // the user has not added their ollama key to ollama.com
  389. // re-throw an error with a more user-friendly message
  390. return errFromUnknownKey(err)
  391. }
  392. return err
  393. }
  394. spinner.Stop()
  395. return nil
  396. }
  397. func ListHandler(cmd *cobra.Command, args []string) error {
  398. client, err := api.ClientFromEnvironment()
  399. if err != nil {
  400. return err
  401. }
  402. models, err := client.List(cmd.Context())
  403. if err != nil {
  404. return err
  405. }
  406. var data [][]string
  407. for _, m := range models.Models {
  408. if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
  409. data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
  410. }
  411. }
  412. table := tablewriter.NewWriter(os.Stdout)
  413. table.SetHeader([]string{"NAME", "ID", "SIZE", "MODIFIED"})
  414. table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
  415. table.SetAlignment(tablewriter.ALIGN_LEFT)
  416. table.SetHeaderLine(false)
  417. table.SetBorder(false)
  418. table.SetNoWhiteSpace(true)
  419. table.SetTablePadding("\t")
  420. table.AppendBulk(data)
  421. table.Render()
  422. return nil
  423. }
  424. func ListRunningHandler(cmd *cobra.Command, args []string) error {
  425. client, err := api.ClientFromEnvironment()
  426. if err != nil {
  427. return err
  428. }
  429. models, err := client.ListRunning(cmd.Context())
  430. if err != nil {
  431. return err
  432. }
  433. var data [][]string
  434. for _, m := range models.Models {
  435. if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
  436. var procStr string
  437. switch {
  438. case m.SizeVRAM == 0:
  439. procStr = "100% CPU"
  440. case m.SizeVRAM == m.Size:
  441. procStr = "100% GPU"
  442. case m.SizeVRAM > m.Size || m.Size == 0:
  443. procStr = "Unknown"
  444. default:
  445. sizeCPU := m.Size - m.SizeVRAM
  446. cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
  447. procStr = fmt.Sprintf("%d%%/%d%% CPU/GPU", int(cpuPercent), int(100-cpuPercent))
  448. }
  449. data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, format.HumanTime(m.ExpiresAt, "Never")})
  450. }
  451. }
  452. table := tablewriter.NewWriter(os.Stdout)
  453. table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "UNTIL"})
  454. table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
  455. table.SetAlignment(tablewriter.ALIGN_LEFT)
  456. table.SetHeaderLine(false)
  457. table.SetBorder(false)
  458. table.SetNoWhiteSpace(true)
  459. table.SetTablePadding("\t")
  460. table.AppendBulk(data)
  461. table.Render()
  462. return nil
  463. }
  464. func DeleteHandler(cmd *cobra.Command, args []string) error {
  465. client, err := api.ClientFromEnvironment()
  466. if err != nil {
  467. return err
  468. }
  469. for _, name := range args {
  470. req := api.DeleteRequest{Name: name}
  471. if err := client.Delete(cmd.Context(), &req); err != nil {
  472. return err
  473. }
  474. fmt.Printf("deleted '%s'\n", name)
  475. }
  476. return nil
  477. }
  478. func ShowHandler(cmd *cobra.Command, args []string) error {
  479. client, err := api.ClientFromEnvironment()
  480. if err != nil {
  481. return err
  482. }
  483. if len(args) != 1 {
  484. return errors.New("missing model name")
  485. }
  486. license, errLicense := cmd.Flags().GetBool("license")
  487. modelfile, errModelfile := cmd.Flags().GetBool("modelfile")
  488. parameters, errParams := cmd.Flags().GetBool("parameters")
  489. system, errSystem := cmd.Flags().GetBool("system")
  490. template, errTemplate := cmd.Flags().GetBool("template")
  491. for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate} {
  492. if boolErr != nil {
  493. return errors.New("error retrieving flags")
  494. }
  495. }
  496. flagsSet := 0
  497. showType := ""
  498. if license {
  499. flagsSet++
  500. showType = "license"
  501. }
  502. if modelfile {
  503. flagsSet++
  504. showType = "modelfile"
  505. }
  506. if parameters {
  507. flagsSet++
  508. showType = "parameters"
  509. }
  510. if system {
  511. flagsSet++
  512. showType = "system"
  513. }
  514. if template {
  515. flagsSet++
  516. showType = "template"
  517. }
  518. if flagsSet > 1 {
  519. return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
  520. } else if flagsSet == 0 {
  521. return errors.New("one of '--license', '--modelfile', '--parameters', '--system', or '--template' must be specified")
  522. }
  523. req := api.ShowRequest{Name: args[0]}
  524. resp, err := client.Show(cmd.Context(), &req)
  525. if err != nil {
  526. return err
  527. }
  528. switch showType {
  529. case "license":
  530. fmt.Println(resp.License)
  531. case "modelfile":
  532. fmt.Println(resp.Modelfile)
  533. case "parameters":
  534. fmt.Println(resp.Parameters)
  535. case "system":
  536. fmt.Println(resp.System)
  537. case "template":
  538. fmt.Println(resp.Template)
  539. }
  540. return nil
  541. }
  542. func CopyHandler(cmd *cobra.Command, args []string) error {
  543. client, err := api.ClientFromEnvironment()
  544. if err != nil {
  545. return err
  546. }
  547. req := api.CopyRequest{Source: args[0], Destination: args[1]}
  548. if err := client.Copy(cmd.Context(), &req); err != nil {
  549. return err
  550. }
  551. fmt.Printf("copied '%s' to '%s'\n", args[0], args[1])
  552. return nil
  553. }
  554. func PullHandler(cmd *cobra.Command, args []string) error {
  555. insecure, err := cmd.Flags().GetBool("insecure")
  556. if err != nil {
  557. return err
  558. }
  559. client, err := api.ClientFromEnvironment()
  560. if err != nil {
  561. return err
  562. }
  563. p := progress.NewProgress(os.Stderr)
  564. defer p.Stop()
  565. bars := make(map[string]*progress.Bar)
  566. var status string
  567. var spinner *progress.Spinner
  568. fn := func(resp api.ProgressResponse) error {
  569. if resp.Digest != "" {
  570. if spinner != nil {
  571. spinner.Stop()
  572. }
  573. bar, ok := bars[resp.Digest]
  574. if !ok {
  575. bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
  576. bars[resp.Digest] = bar
  577. p.Add(resp.Digest, bar)
  578. }
  579. bar.Set(resp.Completed)
  580. } else if status != resp.Status {
  581. if spinner != nil {
  582. spinner.Stop()
  583. }
  584. status = resp.Status
  585. spinner = progress.NewSpinner(status)
  586. p.Add(status, spinner)
  587. }
  588. return nil
  589. }
  590. request := api.PullRequest{Name: args[0], Insecure: insecure}
  591. if err := client.Pull(cmd.Context(), &request, fn); err != nil {
  592. return err
  593. }
  594. return nil
  595. }
  596. type generateContextKey string
  597. type runOptions struct {
  598. Model string
  599. ParentModel string
  600. Prompt string
  601. Messages []api.Message
  602. WordWrap bool
  603. Format string
  604. System string
  605. Template string
  606. Images []api.ImageData
  607. Options map[string]interface{}
  608. MultiModal bool
  609. KeepAlive *api.Duration
  610. }
  611. type displayResponseState struct {
  612. lineLength int
  613. wordBuffer string
  614. }
  615. func displayResponse(content string, wordWrap bool, state *displayResponseState) {
  616. termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
  617. if wordWrap && termWidth >= 10 {
  618. for _, ch := range content {
  619. if state.lineLength+1 > termWidth-5 {
  620. if len(state.wordBuffer) > termWidth-10 {
  621. fmt.Printf("%s%c", state.wordBuffer, ch)
  622. state.wordBuffer = ""
  623. state.lineLength = 0
  624. continue
  625. }
  626. // backtrack the length of the last word and clear to the end of the line
  627. fmt.Printf("\x1b[%dD\x1b[K\n", len(state.wordBuffer))
  628. fmt.Printf("%s%c", state.wordBuffer, ch)
  629. state.lineLength = len(state.wordBuffer) + 1
  630. } else {
  631. fmt.Print(string(ch))
  632. state.lineLength += 1
  633. switch ch {
  634. case ' ':
  635. state.wordBuffer = ""
  636. case '\n':
  637. state.lineLength = 0
  638. default:
  639. state.wordBuffer += string(ch)
  640. }
  641. }
  642. }
  643. } else {
  644. fmt.Printf("%s%s", state.wordBuffer, content)
  645. if len(state.wordBuffer) > 0 {
  646. state.wordBuffer = ""
  647. }
  648. }
  649. }
  650. func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
  651. client, err := api.ClientFromEnvironment()
  652. if err != nil {
  653. return nil, err
  654. }
  655. p := progress.NewProgress(os.Stderr)
  656. defer p.StopAndClear()
  657. spinner := progress.NewSpinner("")
  658. p.Add("", spinner)
  659. cancelCtx, cancel := context.WithCancel(cmd.Context())
  660. defer cancel()
  661. sigChan := make(chan os.Signal, 1)
  662. signal.Notify(sigChan, syscall.SIGINT)
  663. go func() {
  664. <-sigChan
  665. cancel()
  666. }()
  667. var state *displayResponseState = &displayResponseState{}
  668. var latest api.ChatResponse
  669. var fullResponse strings.Builder
  670. var role string
  671. fn := func(response api.ChatResponse) error {
  672. p.StopAndClear()
  673. latest = response
  674. role = response.Message.Role
  675. content := response.Message.Content
  676. fullResponse.WriteString(content)
  677. displayResponse(content, opts.WordWrap, state)
  678. return nil
  679. }
  680. req := &api.ChatRequest{
  681. Model: opts.Model,
  682. Messages: opts.Messages,
  683. Format: opts.Format,
  684. Options: opts.Options,
  685. }
  686. if opts.KeepAlive != nil {
  687. req.KeepAlive = opts.KeepAlive
  688. }
  689. if err := client.Chat(cancelCtx, req, fn); err != nil {
  690. if errors.Is(err, context.Canceled) {
  691. return nil, nil
  692. }
  693. return nil, err
  694. }
  695. if len(opts.Messages) > 0 {
  696. fmt.Println()
  697. fmt.Println()
  698. }
  699. verbose, err := cmd.Flags().GetBool("verbose")
  700. if err != nil {
  701. return nil, err
  702. }
  703. if verbose {
  704. latest.Summary()
  705. }
  706. return &api.Message{Role: role, Content: fullResponse.String()}, nil
  707. }
  708. func generate(cmd *cobra.Command, opts runOptions) error {
  709. client, err := api.ClientFromEnvironment()
  710. if err != nil {
  711. return err
  712. }
  713. p := progress.NewProgress(os.Stderr)
  714. defer p.StopAndClear()
  715. spinner := progress.NewSpinner("")
  716. p.Add("", spinner)
  717. var latest api.GenerateResponse
  718. generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
  719. if !ok {
  720. generateContext = []int{}
  721. }
  722. ctx, cancel := context.WithCancel(cmd.Context())
  723. defer cancel()
  724. sigChan := make(chan os.Signal, 1)
  725. signal.Notify(sigChan, syscall.SIGINT)
  726. go func() {
  727. <-sigChan
  728. cancel()
  729. }()
  730. var state *displayResponseState = &displayResponseState{}
  731. fn := func(response api.GenerateResponse) error {
  732. p.StopAndClear()
  733. latest = response
  734. content := response.Response
  735. displayResponse(content, opts.WordWrap, state)
  736. return nil
  737. }
  738. if opts.MultiModal {
  739. opts.Prompt, opts.Images, err = extractFileData(opts.Prompt)
  740. if err != nil {
  741. return err
  742. }
  743. }
  744. request := api.GenerateRequest{
  745. Model: opts.Model,
  746. Prompt: opts.Prompt,
  747. Context: generateContext,
  748. Images: opts.Images,
  749. Format: opts.Format,
  750. System: opts.System,
  751. Template: opts.Template,
  752. Options: opts.Options,
  753. }
  754. if err := client.Generate(ctx, &request, fn); err != nil {
  755. if errors.Is(err, context.Canceled) {
  756. return nil
  757. }
  758. return err
  759. }
  760. if opts.Prompt != "" {
  761. fmt.Println()
  762. fmt.Println()
  763. }
  764. if !latest.Done {
  765. return nil
  766. }
  767. verbose, err := cmd.Flags().GetBool("verbose")
  768. if err != nil {
  769. return err
  770. }
  771. if verbose {
  772. latest.Summary()
  773. }
  774. ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
  775. cmd.SetContext(ctx)
  776. return nil
  777. }
  778. func RunServer(cmd *cobra.Command, _ []string) error {
  779. // retrieve the OLLAMA_HOST environment variable
  780. ollamaHost, err := api.GetOllamaHost()
  781. if err != nil {
  782. return err
  783. }
  784. if err := initializeKeypair(); err != nil {
  785. return err
  786. }
  787. ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
  788. if err != nil {
  789. return err
  790. }
  791. err = server.Serve(ln)
  792. if errors.Is(err, http.ErrServerClosed) {
  793. return nil
  794. }
  795. return err
  796. }
  797. func initializeKeypair() error {
  798. home, err := os.UserHomeDir()
  799. if err != nil {
  800. return err
  801. }
  802. privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
  803. pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
  804. _, err = os.Stat(privKeyPath)
  805. if os.IsNotExist(err) {
  806. fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
  807. cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
  808. if err != nil {
  809. return err
  810. }
  811. privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
  812. if err != nil {
  813. return err
  814. }
  815. if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
  816. return fmt.Errorf("could not create directory %w", err)
  817. }
  818. if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
  819. return err
  820. }
  821. sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
  822. if err != nil {
  823. return err
  824. }
  825. publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
  826. if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
  827. return err
  828. }
  829. fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
  830. }
  831. return nil
  832. }
  833. //nolint:unused
  834. func waitForServer(ctx context.Context, client *api.Client) error {
  835. // wait for the server to start
  836. timeout := time.After(5 * time.Second)
  837. tick := time.Tick(500 * time.Millisecond)
  838. for {
  839. select {
  840. case <-timeout:
  841. return errors.New("timed out waiting for server to start")
  842. case <-tick:
  843. if err := client.Heartbeat(ctx); err == nil {
  844. return nil // server has started
  845. }
  846. }
  847. }
  848. }
  849. func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
  850. client, err := api.ClientFromEnvironment()
  851. if err != nil {
  852. return err
  853. }
  854. if err := client.Heartbeat(cmd.Context()); err != nil {
  855. if !strings.Contains(err.Error(), " refused") {
  856. return err
  857. }
  858. if err := startApp(cmd.Context(), client); err != nil {
  859. return fmt.Errorf("could not connect to ollama app, is it running?")
  860. }
  861. }
  862. return nil
  863. }
  864. func versionHandler(cmd *cobra.Command, _ []string) {
  865. client, err := api.ClientFromEnvironment()
  866. if err != nil {
  867. return
  868. }
  869. serverVersion, err := client.Version(cmd.Context())
  870. if err != nil {
  871. fmt.Println("Warning: could not connect to a running Ollama instance")
  872. }
  873. if serverVersion != "" {
  874. fmt.Printf("ollama version is %s\n", serverVersion)
  875. }
  876. if serverVersion != version.Version {
  877. fmt.Printf("Warning: client version is %s\n", version.Version)
  878. }
  879. }
  880. func appendHostEnvDocs(cmd *cobra.Command) {
  881. const hostEnvDocs = `
  882. Environment Variables:
  883. OLLAMA_HOST The host:port or base URL of the Ollama server (e.g. http://localhost:11434)
  884. `
  885. cmd.SetUsageTemplate(cmd.UsageTemplate() + hostEnvDocs)
  886. }
  887. func NewCLI() *cobra.Command {
  888. log.SetFlags(log.LstdFlags | log.Lshortfile)
  889. cobra.EnableCommandSorting = false
  890. if runtime.GOOS == "windows" {
  891. console.ConsoleFromFile(os.Stdin) //nolint:errcheck
  892. }
  893. rootCmd := &cobra.Command{
  894. Use: "ollama",
  895. Short: "Large language model runner",
  896. SilenceUsage: true,
  897. SilenceErrors: true,
  898. CompletionOptions: cobra.CompletionOptions{
  899. DisableDefaultCmd: true,
  900. },
  901. Run: func(cmd *cobra.Command, args []string) {
  902. if version, _ := cmd.Flags().GetBool("version"); version {
  903. versionHandler(cmd, args)
  904. return
  905. }
  906. cmd.Print(cmd.UsageString())
  907. },
  908. }
  909. rootCmd.Flags().BoolP("version", "v", false, "Show version information")
  910. createCmd := &cobra.Command{
  911. Use: "create MODEL",
  912. Short: "Create a model from a Modelfile",
  913. Args: cobra.ExactArgs(1),
  914. PreRunE: checkServerHeartbeat,
  915. RunE: CreateHandler,
  916. }
  917. createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile")
  918. createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)")
  919. showCmd := &cobra.Command{
  920. Use: "show MODEL",
  921. Short: "Show information for a model",
  922. Args: cobra.ExactArgs(1),
  923. PreRunE: checkServerHeartbeat,
  924. RunE: ShowHandler,
  925. }
  926. showCmd.Flags().Bool("license", false, "Show license of a model")
  927. showCmd.Flags().Bool("modelfile", false, "Show Modelfile of a model")
  928. showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
  929. showCmd.Flags().Bool("template", false, "Show template of a model")
  930. showCmd.Flags().Bool("system", false, "Show system message of a model")
  931. runCmd := &cobra.Command{
  932. Use: "run MODEL [PROMPT]",
  933. Short: "Run a model",
  934. Args: cobra.MinimumNArgs(1),
  935. PreRunE: checkServerHeartbeat,
  936. RunE: RunHandler,
  937. }
  938. runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
  939. runCmd.Flags().Bool("verbose", false, "Show timings for response")
  940. runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
  941. runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
  942. runCmd.Flags().String("format", "", "Response format (e.g. json)")
  943. serveCmd := &cobra.Command{
  944. Use: "serve",
  945. Aliases: []string{"start"},
  946. Short: "Start ollama",
  947. Args: cobra.ExactArgs(0),
  948. RunE: RunServer,
  949. }
  950. serveCmd.SetUsageTemplate(serveCmd.UsageTemplate() + `
  951. Environment Variables:
  952. OLLAMA_HOST The host:port to bind to (default "127.0.0.1:11434")
  953. OLLAMA_ORIGINS A comma separated list of allowed origins
  954. OLLAMA_MODELS The path to the models directory (default "~/.ollama/models")
  955. OLLAMA_KEEP_ALIVE The duration that models stay loaded in memory (default "5m")
  956. OLLAMA_DEBUG Set to 1 to enable additional debug logging
  957. `)
  958. pullCmd := &cobra.Command{
  959. Use: "pull MODEL",
  960. Short: "Pull a model from a registry",
  961. Args: cobra.ExactArgs(1),
  962. PreRunE: checkServerHeartbeat,
  963. RunE: PullHandler,
  964. }
  965. pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
  966. pushCmd := &cobra.Command{
  967. Use: "push MODEL",
  968. Short: "Push a model to a registry",
  969. Args: cobra.ExactArgs(1),
  970. PreRunE: checkServerHeartbeat,
  971. RunE: PushHandler,
  972. }
  973. pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
  974. listCmd := &cobra.Command{
  975. Use: "list",
  976. Aliases: []string{"ls"},
  977. Short: "List models",
  978. PreRunE: checkServerHeartbeat,
  979. RunE: ListHandler,
  980. }
  981. psCmd := &cobra.Command{
  982. Use: "ps",
  983. Short: "List running models",
  984. PreRunE: checkServerHeartbeat,
  985. RunE: ListRunningHandler,
  986. }
  987. copyCmd := &cobra.Command{
  988. Use: "cp SOURCE DESTINATION",
  989. Short: "Copy a model",
  990. Args: cobra.ExactArgs(2),
  991. PreRunE: checkServerHeartbeat,
  992. RunE: CopyHandler,
  993. }
  994. deleteCmd := &cobra.Command{
  995. Use: "rm MODEL [MODEL...]",
  996. Short: "Remove a model",
  997. Args: cobra.MinimumNArgs(1),
  998. PreRunE: checkServerHeartbeat,
  999. RunE: DeleteHandler,
  1000. }
  1001. for _, cmd := range []*cobra.Command{
  1002. createCmd,
  1003. showCmd,
  1004. runCmd,
  1005. pullCmd,
  1006. pushCmd,
  1007. listCmd,
  1008. psCmd,
  1009. copyCmd,
  1010. deleteCmd,
  1011. } {
  1012. appendHostEnvDocs(cmd)
  1013. }
  1014. rootCmd.AddCommand(
  1015. serveCmd,
  1016. createCmd,
  1017. showCmd,
  1018. runCmd,
  1019. pullCmd,
  1020. pushCmd,
  1021. listCmd,
  1022. psCmd,
  1023. copyCmd,
  1024. deleteCmd,
  1025. )
  1026. return rootCmd
  1027. }