cmd.go 29 KB


  1. package cmd
  2. import (
  3. "archive/zip"
  4. "bytes"
  5. "context"
  6. "crypto/ed25519"
  7. "crypto/rand"
  8. "crypto/sha256"
  9. "encoding/pem"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "log"
  14. "math"
  15. "net"
  16. "net/http"
  17. "os"
  18. "os/signal"
  19. "path/filepath"
  20. "regexp"
  21. "runtime"
  22. "strings"
  23. "syscall"
  24. "time"
  25. "github.com/containerd/console"
  26. "github.com/mattn/go-runewidth"
  27. "github.com/olekukonko/tablewriter"
  28. "github.com/spf13/cobra"
  29. "golang.org/x/crypto/ssh"
  30. "golang.org/x/exp/slices"
  31. "golang.org/x/term"
  32. "github.com/ollama/ollama/api"
  33. "github.com/ollama/ollama/auth"
  34. "github.com/ollama/ollama/format"
  35. "github.com/ollama/ollama/parser"
  36. "github.com/ollama/ollama/progress"
  37. "github.com/ollama/ollama/server"
  38. "github.com/ollama/ollama/types/errtypes"
  39. "github.com/ollama/ollama/types/model"
  40. "github.com/ollama/ollama/version"
  41. )
  42. func CreateHandler(cmd *cobra.Command, args []string) error {
  43. filename, _ := cmd.Flags().GetString("file")
  44. filename, err := filepath.Abs(filename)
  45. if err != nil {
  46. return err
  47. }
  48. client, err := api.ClientFromEnvironment()
  49. if err != nil {
  50. return err
  51. }
  52. p := progress.NewProgress(os.Stderr)
  53. defer p.Stop()
  54. f, err := os.Open(filename)
  55. if err != nil {
  56. return err
  57. }
  58. defer f.Close()
  59. modelfile, err := parser.ParseFile(f)
  60. if err != nil {
  61. return err
  62. }
  63. home, err := os.UserHomeDir()
  64. if err != nil {
  65. return err
  66. }
  67. status := "transferring model data"
  68. spinner := progress.NewSpinner(status)
  69. p.Add(status, spinner)
  70. for i := range modelfile.Commands {
  71. switch modelfile.Commands[i].Name {
  72. case "model", "adapter":
  73. path := modelfile.Commands[i].Args
  74. if path == "~" {
  75. path = home
  76. } else if strings.HasPrefix(path, "~/") {
  77. path = filepath.Join(home, path[2:])
  78. }
  79. if !filepath.IsAbs(path) {
  80. path = filepath.Join(filepath.Dir(filename), path)
  81. }
  82. fi, err := os.Stat(path)
  83. if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
  84. continue
  85. } else if err != nil {
  86. return err
  87. }
  88. if fi.IsDir() {
  89. // this is likely a safetensors or pytorch directory
  90. // TODO make this work w/ adapters
  91. tempfile, err := tempZipFiles(path)
  92. if err != nil {
  93. return err
  94. }
  95. defer os.RemoveAll(tempfile)
  96. path = tempfile
  97. }
  98. digest, err := createBlob(cmd, client, path)
  99. if err != nil {
  100. return err
  101. }
  102. modelfile.Commands[i].Args = "@" + digest
  103. }
  104. }
  105. bars := make(map[string]*progress.Bar)
  106. fn := func(resp api.ProgressResponse) error {
  107. if resp.Digest != "" {
  108. spinner.Stop()
  109. bar, ok := bars[resp.Digest]
  110. if !ok {
  111. bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
  112. bars[resp.Digest] = bar
  113. p.Add(resp.Digest, bar)
  114. }
  115. bar.Set(resp.Completed)
  116. } else if status != resp.Status {
  117. spinner.Stop()
  118. status = resp.Status
  119. spinner = progress.NewSpinner(status)
  120. p.Add(status, spinner)
  121. }
  122. return nil
  123. }
  124. quantize, _ := cmd.Flags().GetString("quantize")
  125. request := api.CreateRequest{Name: args[0], Modelfile: modelfile.String(), Quantize: quantize}
  126. if err := client.Create(cmd.Context(), &request, fn); err != nil {
  127. return err
  128. }
  129. return nil
  130. }
  131. func tempZipFiles(path string) (string, error) {
  132. tempfile, err := os.CreateTemp("", "ollama-tf")
  133. if err != nil {
  134. return "", err
  135. }
  136. defer tempfile.Close()
  137. zipfile := zip.NewWriter(tempfile)
  138. defer zipfile.Close()
  139. detectContentType := func(path string) (string, error) {
  140. f, err := os.Open(path)
  141. if err != nil {
  142. return "", err
  143. }
  144. defer f.Close()
  145. var b bytes.Buffer
  146. b.Grow(512)
  147. if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
  148. return "", err
  149. }
  150. contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
  151. return contentType, nil
  152. }
  153. glob := func(pattern, contentType string) ([]string, error) {
  154. matches, err := filepath.Glob(pattern)
  155. if err != nil {
  156. return nil, err
  157. }
  158. for _, safetensor := range matches {
  159. if ct, err := detectContentType(safetensor); err != nil {
  160. return nil, err
  161. } else if ct != contentType {
  162. return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
  163. }
  164. }
  165. return matches, nil
  166. }
  167. var files []string
  168. if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
  169. // safetensors files might be unresolved git lfs references; skip if they are
  170. // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
  171. files = append(files, st...)
  172. } else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
  173. // pytorch files might also be unresolved git lfs references; skip if they are
  174. // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
  175. files = append(files, pt...)
  176. } else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/octet-stream"); len(pt) > 0 {
  177. // pytorch files might also be unresolved git lfs references; skip if they are
  178. // covers consolidated.x.pth, consolidated.pth
  179. files = append(files, pt...)
  180. } else {
  181. return "", errors.New("no safetensors or torch files found")
  182. }
  183. // add configuration files, json files are detected as text/plain
  184. js, err := glob(filepath.Join(path, "*.json"), "text/plain")
  185. if err != nil {
  186. return "", err
  187. }
  188. files = append(files, js...)
  189. if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
  190. // add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
  191. // tokenizer.model might be a unresolved git lfs reference; error if it is
  192. files = append(files, tks...)
  193. } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
  194. // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
  195. files = append(files, tks...)
  196. }
  197. for _, file := range files {
  198. f, err := os.Open(file)
  199. if err != nil {
  200. return "", err
  201. }
  202. defer f.Close()
  203. fi, err := f.Stat()
  204. if err != nil {
  205. return "", err
  206. }
  207. zfi, err := zip.FileInfoHeader(fi)
  208. if err != nil {
  209. return "", err
  210. }
  211. zf, err := zipfile.CreateHeader(zfi)
  212. if err != nil {
  213. return "", err
  214. }
  215. if _, err := io.Copy(zf, f); err != nil {
  216. return "", err
  217. }
  218. }
  219. return tempfile.Name(), nil
  220. }
  221. func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
  222. bin, err := os.Open(path)
  223. if err != nil {
  224. return "", err
  225. }
  226. defer bin.Close()
  227. hash := sha256.New()
  228. if _, err := io.Copy(hash, bin); err != nil {
  229. return "", err
  230. }
  231. if _, err := bin.Seek(0, io.SeekStart); err != nil {
  232. return "", err
  233. }
  234. digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
  235. if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
  236. return "", err
  237. }
  238. return digest, nil
  239. }
  240. func RunHandler(cmd *cobra.Command, args []string) error {
  241. client, err := api.ClientFromEnvironment()
  242. if err != nil {
  243. return err
  244. }
  245. name := args[0]
  246. // check if the model exists on the server
  247. show, err := client.Show(cmd.Context(), &api.ShowRequest{Name: name})
  248. var statusError api.StatusError
  249. switch {
  250. case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
  251. if err := PullHandler(cmd, []string{name}); err != nil {
  252. return err
  253. }
  254. show, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
  255. if err != nil {
  256. return err
  257. }
  258. case err != nil:
  259. return err
  260. }
  261. interactive := true
  262. opts := runOptions{
  263. Model: args[0],
  264. WordWrap: os.Getenv("TERM") == "xterm-256color",
  265. Options: map[string]interface{}{},
  266. MultiModal: slices.Contains(show.Details.Families, "clip"),
  267. ParentModel: show.Details.ParentModel,
  268. }
  269. format, err := cmd.Flags().GetString("format")
  270. if err != nil {
  271. return err
  272. }
  273. opts.Format = format
  274. keepAlive, err := cmd.Flags().GetString("keepalive")
  275. if err != nil {
  276. return err
  277. }
  278. if keepAlive != "" {
  279. d, err := time.ParseDuration(keepAlive)
  280. if err != nil {
  281. return err
  282. }
  283. opts.KeepAlive = &api.Duration{Duration: d}
  284. }
  285. prompts := args[1:]
  286. // prepend stdin to the prompt if provided
  287. if !term.IsTerminal(int(os.Stdin.Fd())) {
  288. in, err := io.ReadAll(os.Stdin)
  289. if err != nil {
  290. return err
  291. }
  292. prompts = append([]string{string(in)}, prompts...)
  293. opts.WordWrap = false
  294. interactive = false
  295. }
  296. opts.Prompt = strings.Join(prompts, " ")
  297. if len(prompts) > 0 {
  298. interactive = false
  299. }
  300. nowrap, err := cmd.Flags().GetBool("nowordwrap")
  301. if err != nil {
  302. return err
  303. }
  304. opts.WordWrap = !nowrap
  305. if !interactive {
  306. return generate(cmd, opts)
  307. }
  308. return generateInteractive(cmd, opts)
  309. }
  310. func errFromUnknownKey(unknownKeyErr error) error {
  311. // find SSH public key in the error message
  312. sshKeyPattern := `ssh-\w+ [^\s"]+`
  313. re := regexp.MustCompile(sshKeyPattern)
  314. matches := re.FindStringSubmatch(unknownKeyErr.Error())
  315. if len(matches) > 0 {
  316. serverPubKey := matches[0]
  317. localPubKey, err := auth.GetPublicKey()
  318. if err != nil {
  319. return unknownKeyErr
  320. }
  321. if runtime.GOOS == "linux" && serverPubKey != localPubKey {
  322. // try the ollama service public key
  323. svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
  324. if err != nil {
  325. return unknownKeyErr
  326. }
  327. localPubKey = strings.TrimSpace(string(svcPubKey))
  328. }
  329. // check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
  330. if serverPubKey != localPubKey {
  331. return unknownKeyErr
  332. }
  333. var msg strings.Builder
  334. msg.WriteString(unknownKeyErr.Error())
  335. msg.WriteString("\n\nYour ollama key is:\n")
  336. msg.WriteString(localPubKey)
  337. msg.WriteString("\nAdd your key at:\n")
  338. msg.WriteString("https://ollama.com/settings/keys")
  339. return errors.New(msg.String())
  340. }
  341. return unknownKeyErr
  342. }
  343. func PushHandler(cmd *cobra.Command, args []string) error {
  344. client, err := api.ClientFromEnvironment()
  345. if err != nil {
  346. return err
  347. }
  348. insecure, err := cmd.Flags().GetBool("insecure")
  349. if err != nil {
  350. return err
  351. }
  352. p := progress.NewProgress(os.Stderr)
  353. defer p.Stop()
  354. bars := make(map[string]*progress.Bar)
  355. var status string
  356. var spinner *progress.Spinner
  357. fn := func(resp api.ProgressResponse) error {
  358. if resp.Digest != "" {
  359. if spinner != nil {
  360. spinner.Stop()
  361. }
  362. bar, ok := bars[resp.Digest]
  363. if !ok {
  364. bar = progress.NewBar(fmt.Sprintf("pushing %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
  365. bars[resp.Digest] = bar
  366. p.Add(resp.Digest, bar)
  367. }
  368. bar.Set(resp.Completed)
  369. } else if status != resp.Status {
  370. if spinner != nil {
  371. spinner.Stop()
  372. }
  373. status = resp.Status
  374. spinner = progress.NewSpinner(status)
  375. p.Add(status, spinner)
  376. }
  377. return nil
  378. }
  379. request := api.PushRequest{Name: args[0], Insecure: insecure}
  380. if err := client.Push(cmd.Context(), &request, fn); err != nil {
  381. if spinner != nil {
  382. spinner.Stop()
  383. }
  384. if strings.Contains(err.Error(), "access denied") {
  385. return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
  386. }
  387. host := model.ParseName(args[0]).Host
  388. isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
  389. if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
  390. // the user has not added their ollama key to ollama.com
  391. // re-throw an error with a more user-friendly message
  392. return errFromUnknownKey(err)
  393. }
  394. return err
  395. }
  396. spinner.Stop()
  397. return nil
  398. }
  399. func ListHandler(cmd *cobra.Command, args []string) error {
  400. client, err := api.ClientFromEnvironment()
  401. if err != nil {
  402. return err
  403. }
  404. models, err := client.List(cmd.Context())
  405. if err != nil {
  406. return err
  407. }
  408. var data [][]string
  409. for _, m := range models.Models {
  410. if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
  411. data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
  412. }
  413. }
  414. table := tablewriter.NewWriter(os.Stdout)
  415. table.SetHeader([]string{"NAME", "ID", "SIZE", "MODIFIED"})
  416. table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
  417. table.SetAlignment(tablewriter.ALIGN_LEFT)
  418. table.SetHeaderLine(false)
  419. table.SetBorder(false)
  420. table.SetNoWhiteSpace(true)
  421. table.SetTablePadding("\t")
  422. table.AppendBulk(data)
  423. table.Render()
  424. return nil
  425. }
  426. func ListRunningHandler(cmd *cobra.Command, args []string) error {
  427. client, err := api.ClientFromEnvironment()
  428. if err != nil {
  429. return err
  430. }
  431. models, err := client.ListRunning(cmd.Context())
  432. if err != nil {
  433. return err
  434. }
  435. var data [][]string
  436. for _, m := range models.Models {
  437. if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
  438. var procStr string
  439. switch {
  440. case m.SizeVRAM == 0:
  441. procStr = "100% CPU"
  442. case m.SizeVRAM == m.Size:
  443. procStr = "100% GPU"
  444. case m.SizeVRAM > m.Size || m.Size == 0:
  445. procStr = "Unknown"
  446. default:
  447. sizeCPU := m.Size - m.SizeVRAM
  448. cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
  449. procStr = fmt.Sprintf("%d%%/%d%% CPU/GPU", int(cpuPercent), int(100-cpuPercent))
  450. }
  451. data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, format.HumanTime(m.ExpiresAt, "Never")})
  452. }
  453. }
  454. table := tablewriter.NewWriter(os.Stdout)
  455. table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "UNTIL"})
  456. table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
  457. table.SetAlignment(tablewriter.ALIGN_LEFT)
  458. table.SetHeaderLine(false)
  459. table.SetBorder(false)
  460. table.SetNoWhiteSpace(true)
  461. table.SetTablePadding("\t")
  462. table.AppendBulk(data)
  463. table.Render()
  464. return nil
  465. }
  466. func DeleteHandler(cmd *cobra.Command, args []string) error {
  467. client, err := api.ClientFromEnvironment()
  468. if err != nil {
  469. return err
  470. }
  471. for _, name := range args {
  472. req := api.DeleteRequest{Name: name}
  473. if err := client.Delete(cmd.Context(), &req); err != nil {
  474. return err
  475. }
  476. fmt.Printf("deleted '%s'\n", name)
  477. }
  478. return nil
  479. }
  480. func ShowHandler(cmd *cobra.Command, args []string) error {
  481. client, err := api.ClientFromEnvironment()
  482. if err != nil {
  483. return err
  484. }
  485. if len(args) != 1 {
  486. return errors.New("missing model name")
  487. }
  488. license, errLicense := cmd.Flags().GetBool("license")
  489. modelfile, errModelfile := cmd.Flags().GetBool("modelfile")
  490. parameters, errParams := cmd.Flags().GetBool("parameters")
  491. system, errSystem := cmd.Flags().GetBool("system")
  492. template, errTemplate := cmd.Flags().GetBool("template")
  493. for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate} {
  494. if boolErr != nil {
  495. return errors.New("error retrieving flags")
  496. }
  497. }
  498. flagsSet := 0
  499. showType := ""
  500. if license {
  501. flagsSet++
  502. showType = "license"
  503. }
  504. if modelfile {
  505. flagsSet++
  506. showType = "modelfile"
  507. }
  508. if parameters {
  509. flagsSet++
  510. showType = "parameters"
  511. }
  512. if system {
  513. flagsSet++
  514. showType = "system"
  515. }
  516. if template {
  517. flagsSet++
  518. showType = "template"
  519. }
  520. if flagsSet > 1 {
  521. return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
  522. } else if flagsSet == 0 {
  523. return errors.New("one of '--license', '--modelfile', '--parameters', '--system', or '--template' must be specified")
  524. }
  525. req := api.ShowRequest{Name: args[0]}
  526. resp, err := client.Show(cmd.Context(), &req)
  527. if err != nil {
  528. return err
  529. }
  530. switch showType {
  531. case "license":
  532. fmt.Println(resp.License)
  533. case "modelfile":
  534. fmt.Println(resp.Modelfile)
  535. case "parameters":
  536. fmt.Println(resp.Parameters)
  537. case "system":
  538. fmt.Println(resp.System)
  539. case "template":
  540. fmt.Println(resp.Template)
  541. }
  542. return nil
  543. }
  544. func CopyHandler(cmd *cobra.Command, args []string) error {
  545. client, err := api.ClientFromEnvironment()
  546. if err != nil {
  547. return err
  548. }
  549. req := api.CopyRequest{Source: args[0], Destination: args[1]}
  550. if err := client.Copy(cmd.Context(), &req); err != nil {
  551. return err
  552. }
  553. fmt.Printf("copied '%s' to '%s'\n", args[0], args[1])
  554. return nil
  555. }
  556. func PullHandler(cmd *cobra.Command, args []string) error {
  557. insecure, err := cmd.Flags().GetBool("insecure")
  558. if err != nil {
  559. return err
  560. }
  561. client, err := api.ClientFromEnvironment()
  562. if err != nil {
  563. return err
  564. }
  565. p := progress.NewProgress(os.Stderr)
  566. defer p.Stop()
  567. bars := make(map[string]*progress.Bar)
  568. var status string
  569. var spinner *progress.Spinner
  570. fn := func(resp api.ProgressResponse) error {
  571. if resp.Digest != "" {
  572. if spinner != nil {
  573. spinner.Stop()
  574. }
  575. bar, ok := bars[resp.Digest]
  576. if !ok {
  577. bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
  578. bars[resp.Digest] = bar
  579. p.Add(resp.Digest, bar)
  580. }
  581. bar.Set(resp.Completed)
  582. } else if status != resp.Status {
  583. if spinner != nil {
  584. spinner.Stop()
  585. }
  586. status = resp.Status
  587. spinner = progress.NewSpinner(status)
  588. p.Add(status, spinner)
  589. }
  590. return nil
  591. }
  592. request := api.PullRequest{Name: args[0], Insecure: insecure}
  593. if err := client.Pull(cmd.Context(), &request, fn); err != nil {
  594. return err
  595. }
  596. return nil
  597. }
  598. type generateContextKey string
  599. type runOptions struct {
  600. Model string
  601. ParentModel string
  602. Prompt string
  603. Messages []api.Message
  604. WordWrap bool
  605. Format string
  606. System string
  607. Template string
  608. Images []api.ImageData
  609. Options map[string]interface{}
  610. MultiModal bool
  611. KeepAlive *api.Duration
  612. }
  613. type displayResponseState struct {
  614. lineLength int
  615. wordBuffer string
  616. }
  617. func displayResponse(content string, wordWrap bool, state *displayResponseState) {
  618. termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
  619. if wordWrap && termWidth >= 10 {
  620. for _, ch := range content {
  621. if state.lineLength+1 > termWidth-5 {
  622. if runewidth.StringWidth(state.wordBuffer) > termWidth-10 {
  623. fmt.Printf("%s%c", state.wordBuffer, ch)
  624. state.wordBuffer = ""
  625. state.lineLength = 0
  626. continue
  627. }
  628. // backtrack the length of the last word and clear to the end of the line
  629. fmt.Printf("\x1b[%dD\x1b[K\n", runewidth.StringWidth(state.wordBuffer))
  630. fmt.Printf("%s%c", state.wordBuffer, ch)
  631. chWidth := runewidth.RuneWidth(ch)
  632. state.lineLength = runewidth.StringWidth(state.wordBuffer) + chWidth
  633. } else {
  634. fmt.Print(string(ch))
  635. state.lineLength += runewidth.RuneWidth(ch)
  636. if runewidth.RuneWidth(ch) >= 2 {
  637. state.wordBuffer = ""
  638. continue
  639. }
  640. switch ch {
  641. case ' ':
  642. state.wordBuffer = ""
  643. case '\n':
  644. state.lineLength = 0
  645. default:
  646. state.wordBuffer += string(ch)
  647. }
  648. }
  649. }
  650. } else {
  651. fmt.Printf("%s%s", state.wordBuffer, content)
  652. if len(state.wordBuffer) > 0 {
  653. state.wordBuffer = ""
  654. }
  655. }
  656. }
  657. func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
  658. client, err := api.ClientFromEnvironment()
  659. if err != nil {
  660. return nil, err
  661. }
  662. p := progress.NewProgress(os.Stderr)
  663. defer p.StopAndClear()
  664. spinner := progress.NewSpinner("")
  665. p.Add("", spinner)
  666. cancelCtx, cancel := context.WithCancel(cmd.Context())
  667. defer cancel()
  668. sigChan := make(chan os.Signal, 1)
  669. signal.Notify(sigChan, syscall.SIGINT)
  670. go func() {
  671. <-sigChan
  672. cancel()
  673. }()
  674. var state *displayResponseState = &displayResponseState{}
  675. var latest api.ChatResponse
  676. var fullResponse strings.Builder
  677. var role string
  678. fn := func(response api.ChatResponse) error {
  679. p.StopAndClear()
  680. latest = response
  681. role = response.Message.Role
  682. content := response.Message.Content
  683. fullResponse.WriteString(content)
  684. displayResponse(content, opts.WordWrap, state)
  685. return nil
  686. }
  687. req := &api.ChatRequest{
  688. Model: opts.Model,
  689. Messages: opts.Messages,
  690. Format: opts.Format,
  691. Options: opts.Options,
  692. }
  693. if opts.KeepAlive != nil {
  694. req.KeepAlive = opts.KeepAlive
  695. }
  696. if err := client.Chat(cancelCtx, req, fn); err != nil {
  697. if errors.Is(err, context.Canceled) {
  698. return nil, nil
  699. }
  700. return nil, err
  701. }
  702. if len(opts.Messages) > 0 {
  703. fmt.Println()
  704. fmt.Println()
  705. }
  706. verbose, err := cmd.Flags().GetBool("verbose")
  707. if err != nil {
  708. return nil, err
  709. }
  710. if verbose {
  711. latest.Summary()
  712. }
  713. return &api.Message{Role: role, Content: fullResponse.String()}, nil
  714. }
  715. func generate(cmd *cobra.Command, opts runOptions) error {
  716. client, err := api.ClientFromEnvironment()
  717. if err != nil {
  718. return err
  719. }
  720. p := progress.NewProgress(os.Stderr)
  721. defer p.StopAndClear()
  722. spinner := progress.NewSpinner("")
  723. p.Add("", spinner)
  724. var latest api.GenerateResponse
  725. generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
  726. if !ok {
  727. generateContext = []int{}
  728. }
  729. ctx, cancel := context.WithCancel(cmd.Context())
  730. defer cancel()
  731. sigChan := make(chan os.Signal, 1)
  732. signal.Notify(sigChan, syscall.SIGINT)
  733. go func() {
  734. <-sigChan
  735. cancel()
  736. }()
  737. var state *displayResponseState = &displayResponseState{}
  738. fn := func(response api.GenerateResponse) error {
  739. p.StopAndClear()
  740. latest = response
  741. content := response.Response
  742. displayResponse(content, opts.WordWrap, state)
  743. return nil
  744. }
  745. if opts.MultiModal {
  746. opts.Prompt, opts.Images, err = extractFileData(opts.Prompt)
  747. if err != nil {
  748. return err
  749. }
  750. }
  751. request := api.GenerateRequest{
  752. Model: opts.Model,
  753. Prompt: opts.Prompt,
  754. Context: generateContext,
  755. Images: opts.Images,
  756. Format: opts.Format,
  757. System: opts.System,
  758. Template: opts.Template,
  759. Options: opts.Options,
  760. KeepAlive: opts.KeepAlive,
  761. }
  762. if err := client.Generate(ctx, &request, fn); err != nil {
  763. if errors.Is(err, context.Canceled) {
  764. return nil
  765. }
  766. return err
  767. }
  768. if opts.Prompt != "" {
  769. fmt.Println()
  770. fmt.Println()
  771. }
  772. if !latest.Done {
  773. return nil
  774. }
  775. verbose, err := cmd.Flags().GetBool("verbose")
  776. if err != nil {
  777. return err
  778. }
  779. if verbose {
  780. latest.Summary()
  781. }
  782. ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
  783. cmd.SetContext(ctx)
  784. return nil
  785. }
  786. func RunServer(cmd *cobra.Command, _ []string) error {
  787. // retrieve the OLLAMA_HOST environment variable
  788. ollamaHost, err := api.GetOllamaHost()
  789. if err != nil {
  790. return err
  791. }
  792. if err := initializeKeypair(); err != nil {
  793. return err
  794. }
  795. ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
  796. if err != nil {
  797. return err
  798. }
  799. err = server.Serve(ln)
  800. if errors.Is(err, http.ErrServerClosed) {
  801. return nil
  802. }
  803. return err
  804. }
  805. func initializeKeypair() error {
  806. home, err := os.UserHomeDir()
  807. if err != nil {
  808. return err
  809. }
  810. privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
  811. pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
  812. _, err = os.Stat(privKeyPath)
  813. if os.IsNotExist(err) {
  814. fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
  815. cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
  816. if err != nil {
  817. return err
  818. }
  819. privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
  820. if err != nil {
  821. return err
  822. }
  823. if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
  824. return fmt.Errorf("could not create directory %w", err)
  825. }
  826. if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
  827. return err
  828. }
  829. sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
  830. if err != nil {
  831. return err
  832. }
  833. publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
  834. if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
  835. return err
  836. }
  837. fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
  838. }
  839. return nil
  840. }
  841. //nolint:unused
  842. func waitForServer(ctx context.Context, client *api.Client) error {
  843. // wait for the server to start
  844. timeout := time.After(5 * time.Second)
  845. tick := time.Tick(500 * time.Millisecond)
  846. for {
  847. select {
  848. case <-timeout:
  849. return errors.New("timed out waiting for server to start")
  850. case <-tick:
  851. if err := client.Heartbeat(ctx); err == nil {
  852. return nil // server has started
  853. }
  854. }
  855. }
  856. }
  857. func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
  858. client, err := api.ClientFromEnvironment()
  859. if err != nil {
  860. return err
  861. }
  862. if err := client.Heartbeat(cmd.Context()); err != nil {
  863. if !strings.Contains(err.Error(), " refused") {
  864. return err
  865. }
  866. if err := startApp(cmd.Context(), client); err != nil {
  867. return fmt.Errorf("could not connect to ollama app, is it running?")
  868. }
  869. }
  870. return nil
  871. }
  872. func versionHandler(cmd *cobra.Command, _ []string) {
  873. client, err := api.ClientFromEnvironment()
  874. if err != nil {
  875. return
  876. }
  877. serverVersion, err := client.Version(cmd.Context())
  878. if err != nil {
  879. fmt.Println("Warning: could not connect to a running Ollama instance")
  880. }
  881. if serverVersion != "" {
  882. fmt.Printf("ollama version is %s\n", serverVersion)
  883. }
  884. if serverVersion != version.Version {
  885. fmt.Printf("Warning: client version is %s\n", version.Version)
  886. }
  887. }
  888. type EnvironmentVar struct {
  889. Name string
  890. Description string
  891. }
  892. func appendEnvDocs(cmd *cobra.Command, envs []EnvironmentVar) {
  893. if len(envs) == 0 {
  894. return
  895. }
  896. envUsage := `
  897. Environment Variables:
  898. `
  899. for _, e := range envs {
  900. envUsage += fmt.Sprintf(" %-16s %s\n", e.Name, e.Description)
  901. }
  902. cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
  903. }
  904. func NewCLI() *cobra.Command {
  905. log.SetFlags(log.LstdFlags | log.Lshortfile)
  906. cobra.EnableCommandSorting = false
  907. if runtime.GOOS == "windows" {
  908. console.ConsoleFromFile(os.Stdin) //nolint:errcheck
  909. }
  910. rootCmd := &cobra.Command{
  911. Use: "ollama",
  912. Short: "Large language model runner",
  913. SilenceUsage: true,
  914. SilenceErrors: true,
  915. CompletionOptions: cobra.CompletionOptions{
  916. DisableDefaultCmd: true,
  917. },
  918. Run: func(cmd *cobra.Command, args []string) {
  919. if version, _ := cmd.Flags().GetBool("version"); version {
  920. versionHandler(cmd, args)
  921. return
  922. }
  923. cmd.Print(cmd.UsageString())
  924. },
  925. }
  926. rootCmd.Flags().BoolP("version", "v", false, "Show version information")
  927. createCmd := &cobra.Command{
  928. Use: "create MODEL",
  929. Short: "Create a model from a Modelfile",
  930. Args: cobra.ExactArgs(1),
  931. PreRunE: checkServerHeartbeat,
  932. RunE: CreateHandler,
  933. }
  934. createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile")
  935. createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)")
  936. showCmd := &cobra.Command{
  937. Use: "show MODEL",
  938. Short: "Show information for a model",
  939. Args: cobra.ExactArgs(1),
  940. PreRunE: checkServerHeartbeat,
  941. RunE: ShowHandler,
  942. }
  943. showCmd.Flags().Bool("license", false, "Show license of a model")
  944. showCmd.Flags().Bool("modelfile", false, "Show Modelfile of a model")
  945. showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
  946. showCmd.Flags().Bool("template", false, "Show template of a model")
  947. showCmd.Flags().Bool("system", false, "Show system message of a model")
  948. runCmd := &cobra.Command{
  949. Use: "run MODEL [PROMPT]",
  950. Short: "Run a model",
  951. Args: cobra.MinimumNArgs(1),
  952. PreRunE: checkServerHeartbeat,
  953. RunE: RunHandler,
  954. }
  955. runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
  956. runCmd.Flags().Bool("verbose", false, "Show timings for response")
  957. runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
  958. runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
  959. runCmd.Flags().String("format", "", "Response format (e.g. json)")
  960. serveCmd := &cobra.Command{
  961. Use: "serve",
  962. Aliases: []string{"start"},
  963. Short: "Start ollama",
  964. Args: cobra.ExactArgs(0),
  965. RunE: RunServer,
  966. }
  967. serveCmd.SetUsageTemplate(serveCmd.UsageTemplate() + `
  968. Environment Variables:
  969. OLLAMA_HOST The host:port to bind to (default "127.0.0.1:11434")
  970. OLLAMA_ORIGINS A comma separated list of allowed origins
  971. OLLAMA_MODELS The path to the models directory (default "~/.ollama/models")
  972. OLLAMA_KEEP_ALIVE The duration that models stay loaded in memory (default "5m")
  973. OLLAMA_DEBUG Set to 1 to enable additional debug logging
  974. `)
  975. pullCmd := &cobra.Command{
  976. Use: "pull MODEL",
  977. Short: "Pull a model from a registry",
  978. Args: cobra.ExactArgs(1),
  979. PreRunE: checkServerHeartbeat,
  980. RunE: PullHandler,
  981. }
  982. pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
  983. pushCmd := &cobra.Command{
  984. Use: "push MODEL",
  985. Short: "Push a model to a registry",
  986. Args: cobra.ExactArgs(1),
  987. PreRunE: checkServerHeartbeat,
  988. RunE: PushHandler,
  989. }
  990. pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
  991. listCmd := &cobra.Command{
  992. Use: "list",
  993. Aliases: []string{"ls"},
  994. Short: "List models",
  995. PreRunE: checkServerHeartbeat,
  996. RunE: ListHandler,
  997. }
  998. psCmd := &cobra.Command{
  999. Use: "ps",
  1000. Short: "List running models",
  1001. PreRunE: checkServerHeartbeat,
  1002. RunE: ListRunningHandler,
  1003. }
  1004. copyCmd := &cobra.Command{
  1005. Use: "cp SOURCE DESTINATION",
  1006. Short: "Copy a model",
  1007. Args: cobra.ExactArgs(2),
  1008. PreRunE: checkServerHeartbeat,
  1009. RunE: CopyHandler,
  1010. }
  1011. deleteCmd := &cobra.Command{
  1012. Use: "rm MODEL [MODEL...]",
  1013. Short: "Remove a model",
  1014. Args: cobra.MinimumNArgs(1),
  1015. PreRunE: checkServerHeartbeat,
  1016. RunE: DeleteHandler,
  1017. }
  1018. ollamaHostEnv := EnvironmentVar{"OLLAMA_HOST", "The host:port or base URL of the Ollama server (e.g. http://localhost:11434)"}
  1019. ollamaNoHistoryEnv := EnvironmentVar{"OLLAMA_NOHISTORY", "Disable readline history"}
  1020. envs := []EnvironmentVar{ollamaHostEnv}
  1021. for _, cmd := range []*cobra.Command{
  1022. createCmd,
  1023. showCmd,
  1024. runCmd,
  1025. pullCmd,
  1026. pushCmd,
  1027. listCmd,
  1028. psCmd,
  1029. copyCmd,
  1030. deleteCmd,
  1031. } {
  1032. switch cmd {
  1033. case runCmd:
  1034. appendEnvDocs(cmd, []EnvironmentVar{ollamaHostEnv, ollamaNoHistoryEnv})
  1035. default:
  1036. appendEnvDocs(cmd, envs)
  1037. }
  1038. }
  1039. rootCmd.AddCommand(
  1040. serveCmd,
  1041. createCmd,
  1042. showCmd,
  1043. runCmd,
  1044. pullCmd,
  1045. pushCmd,
  1046. listCmd,
  1047. psCmd,
  1048. copyCmd,
  1049. deleteCmd,
  1050. )
  1051. return rootCmd
  1052. }