cmd.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. package cmd
  2. import (
  3. "bufio"
  4. "context"
  5. "fmt"
  6. "log"
  7. "net"
  8. "os"
  9. "path"
  10. "github.com/schollz/progressbar/v3"
  11. "github.com/spf13/cobra"
  12. "golang.org/x/term"
  13. "github.com/jmorganca/ollama/api"
  14. "github.com/jmorganca/ollama/server"
  15. )
  16. func cacheDir() string {
  17. home, err := os.UserHomeDir()
  18. if err != nil {
  19. panic(err)
  20. }
  21. return path.Join(home, ".ollama")
  22. }
  23. func RunRun(cmd *cobra.Command, args []string) error {
  24. if err := pull(args[0]); err != nil {
  25. return err
  26. }
  27. fmt.Println("Up to date.")
  28. return RunGenerate(cmd, args)
  29. }
  30. func pull(model string) error {
  31. client, err := NewAPIClient()
  32. if err != nil {
  33. return err
  34. }
  35. var bar *progressbar.ProgressBar
  36. return client.Pull(
  37. context.Background(),
  38. &api.PullRequest{Model: model},
  39. func(progress api.PullProgress) error {
  40. if bar == nil {
  41. bar = progressbar.DefaultBytes(progress.Total)
  42. }
  43. return bar.Set64(progress.Completed)
  44. },
  45. )
  46. }
  47. func RunGenerate(_ *cobra.Command, args []string) error {
  48. if len(args) > 1 {
  49. return generate(args[0], args[1:]...)
  50. }
  51. if term.IsTerminal(int(os.Stdin.Fd())) {
  52. return generateInteractive(args[0])
  53. }
  54. return generateBatch(args[0])
  55. }
  56. func generate(model string, prompts ...string) error {
  57. client, err := NewAPIClient()
  58. if err != nil {
  59. return err
  60. }
  61. for _, prompt := range prompts {
  62. client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(resp api.GenerateResponse) error {
  63. fmt.Print(resp.Response)
  64. return nil
  65. })
  66. }
  67. fmt.Println()
  68. fmt.Println()
  69. return nil
  70. }
  71. func generateInteractive(model string) error {
  72. fmt.Print(">>> ")
  73. scanner := bufio.NewScanner(os.Stdin)
  74. for scanner.Scan() {
  75. if err := generate(model, scanner.Text()); err != nil {
  76. return err
  77. }
  78. fmt.Print(">>> ")
  79. }
  80. return nil
  81. }
  82. func generateBatch(model string) error {
  83. scanner := bufio.NewScanner(os.Stdin)
  84. for scanner.Scan() {
  85. prompt := scanner.Text()
  86. fmt.Printf(">>> %s\n", prompt)
  87. if err := generate(model, prompt); err != nil {
  88. return err
  89. }
  90. }
  91. return nil
  92. }
  93. func RunServer(_ *cobra.Command, _ []string) error {
  94. ln, err := net.Listen("tcp", "127.0.0.1:11434")
  95. if err != nil {
  96. return err
  97. }
  98. return server.Serve(ln)
  99. }
  100. func NewAPIClient() (*api.Client, error) {
  101. return &api.Client{
  102. URL: "http://localhost:11434",
  103. }, nil
  104. }
  105. func NewCLI() *cobra.Command {
  106. log.SetFlags(log.LstdFlags | log.Lshortfile)
  107. rootCmd := &cobra.Command{
  108. Use: "ollama",
  109. Short: "Large language model runner",
  110. SilenceUsage: true,
  111. CompletionOptions: cobra.CompletionOptions{
  112. DisableDefaultCmd: true,
  113. },
  114. PersistentPreRunE: func(_ *cobra.Command, args []string) error {
  115. // create the models directory and it's parent
  116. return os.MkdirAll(path.Join(cacheDir(), "models"), 0o700)
  117. },
  118. }
  119. cobra.EnableCommandSorting = false
  120. runCmd := &cobra.Command{
  121. Use: "run MODEL [PROMPT]",
  122. Short: "Run a model",
  123. Args: cobra.MinimumNArgs(1),
  124. RunE: RunRun,
  125. }
  126. serveCmd := &cobra.Command{
  127. Use: "serve",
  128. Aliases: []string{"start"},
  129. Short: "Start ollama",
  130. RunE: RunServer,
  131. }
  132. rootCmd.AddCommand(
  133. serveCmd,
  134. runCmd,
  135. )
  136. return rootCmd
  137. }