cmd.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. package cmd
  2. import (
  3. "bufio"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "net"
  9. "net/http"
  10. "os"
  11. "path"
  12. "strings"
  13. "time"
  14. "github.com/schollz/progressbar/v3"
  15. "github.com/spf13/cobra"
  16. "golang.org/x/term"
  17. "github.com/jmorganca/ollama/api"
  18. "github.com/jmorganca/ollama/server"
  19. )
  20. func cacheDir() string {
  21. home, err := os.UserHomeDir()
  22. if err != nil {
  23. panic(err)
  24. }
  25. return path.Join(home, ".ollama")
  26. }
  27. func RunRun(cmd *cobra.Command, args []string) error {
  28. _, err := os.Stat(args[0])
  29. switch {
  30. case errors.Is(err, os.ErrNotExist):
  31. if err := pull(args[0]); err != nil {
  32. var apiStatusError api.StatusError
  33. if !errors.As(err, &apiStatusError) {
  34. return err
  35. }
  36. if apiStatusError.StatusCode != http.StatusBadGateway {
  37. return err
  38. }
  39. }
  40. case err != nil:
  41. return err
  42. }
  43. return RunGenerate(cmd, args)
  44. }
  45. func pull(model string) error {
  46. client := api.NewClient()
  47. var bar *progressbar.ProgressBar
  48. return client.Pull(
  49. context.Background(),
  50. &api.PullRequest{Model: model},
  51. func(progress api.PullProgress) error {
  52. if bar == nil {
  53. if progress.Percent >= 100 {
  54. // already downloaded
  55. return nil
  56. }
  57. bar = progressbar.DefaultBytes(progress.Total)
  58. }
  59. return bar.Set64(progress.Completed)
  60. },
  61. )
  62. }
  63. func RunGenerate(cmd *cobra.Command, args []string) error {
  64. if len(args) > 1 {
  65. // join all args into a single prompt
  66. return generate(cmd, args[0], strings.Join(args[1:], " "))
  67. }
  68. if term.IsTerminal(int(os.Stdin.Fd())) {
  69. return generateInteractive(cmd, args[0])
  70. }
  71. return generateBatch(cmd, args[0])
  72. }
  73. var generateContextKey struct{}
  74. func generate(cmd *cobra.Command, model, prompt string) error {
  75. if len(strings.TrimSpace(prompt)) > 0 {
  76. client := api.NewClient()
  77. spinner := progressbar.NewOptions(-1,
  78. progressbar.OptionSetWriter(os.Stderr),
  79. progressbar.OptionThrottle(60*time.Millisecond),
  80. progressbar.OptionSpinnerType(14),
  81. progressbar.OptionSetRenderBlankState(true),
  82. progressbar.OptionSetElapsedTime(false),
  83. progressbar.OptionClearOnFinish(),
  84. )
  85. go func() {
  86. for range time.Tick(60 * time.Millisecond) {
  87. if spinner.IsFinished() {
  88. break
  89. }
  90. spinner.Add(1)
  91. }
  92. }()
  93. var latest api.GenerateResponse
  94. generateContext, ok := cmd.Context().Value(generateContextKey).([]int)
  95. if !ok {
  96. generateContext = []int{}
  97. }
  98. request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
  99. fn := func(resp api.GenerateResponse) error {
  100. if !spinner.IsFinished() {
  101. spinner.Finish()
  102. }
  103. latest = resp
  104. fmt.Print(resp.Response)
  105. cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context))
  106. return nil
  107. }
  108. if err := client.Generate(context.Background(), &request, fn); err != nil {
  109. return err
  110. }
  111. fmt.Println()
  112. fmt.Println()
  113. verbose, err := cmd.Flags().GetBool("verbose")
  114. if err != nil {
  115. return err
  116. }
  117. if verbose {
  118. latest.Summary()
  119. }
  120. }
  121. return nil
  122. }
  123. func generateInteractive(cmd *cobra.Command, model string) error {
  124. fmt.Print(">>> ")
  125. scanner := bufio.NewScanner(os.Stdin)
  126. for scanner.Scan() {
  127. if err := generate(cmd, model, scanner.Text()); err != nil {
  128. return err
  129. }
  130. fmt.Print(">>> ")
  131. }
  132. return nil
  133. }
  134. func generateBatch(cmd *cobra.Command, model string) error {
  135. scanner := bufio.NewScanner(os.Stdin)
  136. for scanner.Scan() {
  137. prompt := scanner.Text()
  138. fmt.Printf(">>> %s\n", prompt)
  139. if err := generate(cmd, model, prompt); err != nil {
  140. return err
  141. }
  142. }
  143. return nil
  144. }
  145. func RunServer(_ *cobra.Command, _ []string) error {
  146. host := os.Getenv("OLLAMA_HOST")
  147. if host == "" {
  148. host = "127.0.0.1"
  149. }
  150. port := os.Getenv("OLLAMA_PORT")
  151. if port == "" {
  152. port = "11434"
  153. }
  154. ln, err := net.Listen("tcp", fmt.Sprintf("%s:%s", host, port))
  155. if err != nil {
  156. return err
  157. }
  158. return server.Serve(ln)
  159. }
  160. func NewCLI() *cobra.Command {
  161. log.SetFlags(log.LstdFlags | log.Lshortfile)
  162. rootCmd := &cobra.Command{
  163. Use: "ollama",
  164. Short: "Large language model runner",
  165. SilenceUsage: true,
  166. CompletionOptions: cobra.CompletionOptions{
  167. DisableDefaultCmd: true,
  168. },
  169. PersistentPreRunE: func(_ *cobra.Command, args []string) error {
  170. // create the models directory and it's parent
  171. return os.MkdirAll(path.Join(cacheDir(), "models"), 0o700)
  172. },
  173. }
  174. cobra.EnableCommandSorting = false
  175. runCmd := &cobra.Command{
  176. Use: "run MODEL [PROMPT]",
  177. Short: "Run a model",
  178. Args: cobra.MinimumNArgs(1),
  179. RunE: RunRun,
  180. }
  181. runCmd.Flags().Bool("verbose", false, "Show timings for response")
  182. serveCmd := &cobra.Command{
  183. Use: "serve",
  184. Aliases: []string{"start"},
  185. Short: "Start ollama",
  186. RunE: RunServer,
  187. }
  188. rootCmd.AddCommand(
  189. serveCmd,
  190. runCmd,
  191. )
  192. return rootCmd
  193. }