cmd.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. package cmd
  2. import (
  3. "bufio"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "net"
  9. "os"
  10. "path"
  11. "strings"
  12. "time"
  13. "github.com/schollz/progressbar/v3"
  14. "github.com/spf13/cobra"
  15. "golang.org/x/term"
  16. "github.com/jmorganca/ollama/api"
  17. "github.com/jmorganca/ollama/server"
  18. )
  19. func cacheDir() string {
  20. home, err := os.UserHomeDir()
  21. if err != nil {
  22. panic(err)
  23. }
  24. return path.Join(home, ".ollama")
  25. }
  26. func RunRun(cmd *cobra.Command, args []string) error {
  27. _, err := os.Stat(args[0])
  28. switch {
  29. case errors.Is(err, os.ErrNotExist):
  30. if err := pull(args[0]); err != nil {
  31. return err
  32. }
  33. case err != nil:
  34. return err
  35. }
  36. return RunGenerate(cmd, args)
  37. }
  38. func pull(model string) error {
  39. // TODO: check if the local model is up to date with remote
  40. _, err := os.Stat(cacheDir() + "/models/" + model + ".bin")
  41. switch {
  42. case errors.Is(err, os.ErrNotExist):
  43. client := api.NewClient()
  44. var bar *progressbar.ProgressBar
  45. return client.Pull(
  46. context.Background(),
  47. &api.PullRequest{Model: model},
  48. func(progress api.PullProgress) error {
  49. if bar == nil {
  50. bar = progressbar.DefaultBytes(progress.Total)
  51. }
  52. return bar.Set64(progress.Completed)
  53. },
  54. )
  55. case err != nil:
  56. return err
  57. }
  58. return nil
  59. }
  60. func RunGenerate(_ *cobra.Command, args []string) error {
  61. if len(args) > 1 {
  62. return generateOneshot(args[0], args[1:]...)
  63. }
  64. if term.IsTerminal(int(os.Stdin.Fd())) {
  65. return generateInteractive(args[0])
  66. }
  67. return generateBatch(args[0])
  68. }
  69. func generate(model, prompt string) error {
  70. if len(strings.TrimSpace(prompt)) > 0 {
  71. client := api.NewClient()
  72. spinner := progressbar.NewOptions(-1,
  73. progressbar.OptionSetWriter(os.Stderr),
  74. progressbar.OptionThrottle(60*time.Millisecond),
  75. progressbar.OptionSpinnerType(14),
  76. progressbar.OptionSetRenderBlankState(true),
  77. progressbar.OptionSetElapsedTime(false),
  78. progressbar.OptionClearOnFinish(),
  79. )
  80. go func() {
  81. for range time.Tick(60 * time.Millisecond) {
  82. if spinner.IsFinished() {
  83. break
  84. }
  85. spinner.Add(1)
  86. }
  87. }()
  88. client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(resp api.GenerateResponse) error {
  89. if !spinner.IsFinished() {
  90. spinner.Finish()
  91. }
  92. fmt.Print(resp.Response)
  93. return nil
  94. })
  95. fmt.Println()
  96. fmt.Println()
  97. }
  98. return nil
  99. }
  100. func generateOneshot(model string, prompts ...string) error {
  101. for _, prompt := range prompts {
  102. fmt.Printf(">>> %s\n", prompt)
  103. if err := generate(model, prompt); err != nil {
  104. return err
  105. }
  106. }
  107. return nil
  108. }
  109. func generateInteractive(model string) error {
  110. fmt.Print(">>> ")
  111. scanner := bufio.NewScanner(os.Stdin)
  112. for scanner.Scan() {
  113. if err := generate(model, scanner.Text()); err != nil {
  114. return err
  115. }
  116. fmt.Print(">>> ")
  117. }
  118. return nil
  119. }
  120. func generateBatch(model string) error {
  121. scanner := bufio.NewScanner(os.Stdin)
  122. for scanner.Scan() {
  123. prompt := scanner.Text()
  124. fmt.Printf(">>> %s\n", prompt)
  125. if err := generate(model, prompt); err != nil {
  126. return err
  127. }
  128. }
  129. return nil
  130. }
  131. func RunServer(_ *cobra.Command, _ []string) error {
  132. ln, err := net.Listen("tcp", "127.0.0.1:11434")
  133. if err != nil {
  134. return err
  135. }
  136. return server.Serve(ln)
  137. }
  138. func NewCLI() *cobra.Command {
  139. log.SetFlags(log.LstdFlags | log.Lshortfile)
  140. rootCmd := &cobra.Command{
  141. Use: "ollama",
  142. Short: "Large language model runner",
  143. SilenceUsage: true,
  144. CompletionOptions: cobra.CompletionOptions{
  145. DisableDefaultCmd: true,
  146. },
  147. PersistentPreRunE: func(_ *cobra.Command, args []string) error {
  148. // create the models directory and it's parent
  149. return os.MkdirAll(path.Join(cacheDir(), "models"), 0o700)
  150. },
  151. }
  152. cobra.EnableCommandSorting = false
  153. runCmd := &cobra.Command{
  154. Use: "run MODEL [PROMPT]",
  155. Short: "Run a model",
  156. Args: cobra.MinimumNArgs(1),
  157. RunE: RunRun,
  158. }
  159. serveCmd := &cobra.Command{
  160. Use: "serve",
  161. Aliases: []string{"start"},
  162. Short: "Start ollama",
  163. RunE: RunServer,
  164. }
  165. rootCmd.AddCommand(
  166. serveCmd,
  167. runCmd,
  168. )
  169. return rootCmd
  170. }