cmd.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. package cmd
  2. import (
  3. "bufio"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "net"
  9. "os"
  10. "path"
  11. "time"
  12. "github.com/schollz/progressbar/v3"
  13. "github.com/spf13/cobra"
  14. "golang.org/x/term"
  15. "github.com/jmorganca/ollama/api"
  16. "github.com/jmorganca/ollama/server"
  17. )
  18. func cacheDir() string {
  19. home, err := os.UserHomeDir()
  20. if err != nil {
  21. panic(err)
  22. }
  23. return path.Join(home, ".ollama")
  24. }
  25. func RunRun(cmd *cobra.Command, args []string) error {
  26. _, err := os.Stat(args[0])
  27. switch {
  28. case errors.Is(err, os.ErrNotExist):
  29. if err := pull(args[0]); err != nil {
  30. return err
  31. }
  32. case err != nil:
  33. return err
  34. }
  35. return RunGenerate(cmd, args)
  36. }
  37. func pull(model string) error {
  38. // TODO: check if the local model is up to date with remote
  39. _, err := os.Stat(cacheDir() + "/models/" + model + ".bin")
  40. switch {
  41. case errors.Is(err, os.ErrNotExist):
  42. client := api.NewClient()
  43. var bar *progressbar.ProgressBar
  44. return client.Pull(
  45. context.Background(),
  46. &api.PullRequest{Model: model},
  47. func(progress api.PullProgress) error {
  48. if bar == nil {
  49. bar = progressbar.DefaultBytes(progress.Total)
  50. }
  51. return bar.Set64(progress.Completed)
  52. },
  53. )
  54. case err != nil:
  55. return err
  56. }
  57. return nil
  58. }
  59. func RunGenerate(_ *cobra.Command, args []string) error {
  60. if len(args) > 1 {
  61. return generateOneshot(args[0], args[1:]...)
  62. }
  63. if term.IsTerminal(int(os.Stdin.Fd())) {
  64. return generateInteractive(args[0])
  65. }
  66. return generateBatch(args[0])
  67. }
  68. func generate(model, prompt string) error {
  69. client := api.NewClient()
  70. spinner := progressbar.NewOptions(-1,
  71. progressbar.OptionSetWriter(os.Stderr),
  72. progressbar.OptionThrottle(60*time.Millisecond),
  73. progressbar.OptionSpinnerType(14),
  74. progressbar.OptionSetRenderBlankState(true),
  75. progressbar.OptionSetElapsedTime(false),
  76. progressbar.OptionClearOnFinish(),
  77. )
  78. go func() {
  79. for range time.Tick(60 * time.Millisecond) {
  80. if spinner.IsFinished() {
  81. break
  82. }
  83. spinner.Add(1)
  84. }
  85. }()
  86. client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(resp api.GenerateResponse) error {
  87. if !spinner.IsFinished() {
  88. spinner.Finish()
  89. }
  90. fmt.Print(resp.Response)
  91. return nil
  92. })
  93. fmt.Println()
  94. fmt.Println()
  95. return nil
  96. }
  97. func generateOneshot(model string, prompts ...string) error {
  98. for _, prompt := range prompts {
  99. fmt.Printf(">>> %s\n", prompt)
  100. if err := generate(model, prompt); err != nil {
  101. return err
  102. }
  103. }
  104. return nil
  105. }
  106. func generateInteractive(model string) error {
  107. fmt.Print(">>> ")
  108. scanner := bufio.NewScanner(os.Stdin)
  109. for scanner.Scan() {
  110. if err := generate(model, scanner.Text()); err != nil {
  111. return err
  112. }
  113. fmt.Print(">>> ")
  114. }
  115. return nil
  116. }
  117. func generateBatch(model string) error {
  118. scanner := bufio.NewScanner(os.Stdin)
  119. for scanner.Scan() {
  120. prompt := scanner.Text()
  121. fmt.Printf(">>> %s\n", prompt)
  122. if err := generate(model, prompt); err != nil {
  123. return err
  124. }
  125. }
  126. return nil
  127. }
  128. func RunServer(_ *cobra.Command, _ []string) error {
  129. ln, err := net.Listen("tcp", "127.0.0.1:11434")
  130. if err != nil {
  131. return err
  132. }
  133. return server.Serve(ln)
  134. }
  135. func NewCLI() *cobra.Command {
  136. log.SetFlags(log.LstdFlags | log.Lshortfile)
  137. rootCmd := &cobra.Command{
  138. Use: "ollama",
  139. Short: "Large language model runner",
  140. SilenceUsage: true,
  141. CompletionOptions: cobra.CompletionOptions{
  142. DisableDefaultCmd: true,
  143. },
  144. PersistentPreRunE: func(_ *cobra.Command, args []string) error {
  145. // create the models directory and it's parent
  146. return os.MkdirAll(path.Join(cacheDir(), "models"), 0o700)
  147. },
  148. }
  149. cobra.EnableCommandSorting = false
  150. runCmd := &cobra.Command{
  151. Use: "run MODEL [PROMPT]",
  152. Short: "Run a model",
  153. Args: cobra.MinimumNArgs(1),
  154. RunE: RunRun,
  155. }
  156. serveCmd := &cobra.Command{
  157. Use: "serve",
  158. Aliases: []string{"start"},
  159. Short: "Start ollama",
  160. RunE: RunServer,
  161. }
  162. rootCmd.AddCommand(
  163. serveCmd,
  164. runCmd,
  165. )
  166. return rootCmd
  167. }