cmd.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. package cmd
  2. import (
  3. "bufio"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "net"
  9. "net/http"
  10. "os"
  11. "path/filepath"
  12. "strings"
  13. "time"
  14. "github.com/schollz/progressbar/v3"
  15. "github.com/spf13/cobra"
  16. "golang.org/x/term"
  17. "github.com/jmorganca/ollama/api"
  18. "github.com/jmorganca/ollama/server"
  19. )
  20. func cacheDir() string {
  21. home, err := os.UserHomeDir()
  22. if err != nil {
  23. panic(err)
  24. }
  25. return filepath.Join(home, ".ollama")
  26. }
  27. func create(cmd *cobra.Command, args []string) error {
  28. filename, _ := cmd.Flags().GetString("file")
  29. client := api.NewClient()
  30. request := api.CreateRequest{Name: args[0], Path: filename}
  31. fn := func(resp api.CreateProgress) error {
  32. fmt.Println(resp.Status)
  33. return nil
  34. }
  35. if err := client.Create(context.Background(), &request, fn); err != nil {
  36. return err
  37. }
  38. return nil
  39. }
  40. func RunRun(cmd *cobra.Command, args []string) error {
  41. _, err := os.Stat(args[0])
  42. switch {
  43. case errors.Is(err, os.ErrNotExist):
  44. if err := pull(args[0]); err != nil {
  45. var apiStatusError api.StatusError
  46. if !errors.As(err, &apiStatusError) {
  47. return err
  48. }
  49. if apiStatusError.StatusCode != http.StatusBadGateway {
  50. return err
  51. }
  52. }
  53. case err != nil:
  54. return err
  55. }
  56. return RunGenerate(cmd, args)
  57. }
  58. func push(cmd *cobra.Command, args []string) error {
  59. client := api.NewClient()
  60. request := api.PushRequest{Name: args[0]}
  61. fn := func(resp api.PushProgress) error {
  62. fmt.Println(resp.Status)
  63. return nil
  64. }
  65. if err := client.Push(context.Background(), &request, fn); err != nil {
  66. return err
  67. }
  68. return nil
  69. }
  70. func RunPull(cmd *cobra.Command, args []string) error {
  71. return pull(args[0])
  72. }
  73. func pull(model string) error {
  74. client := api.NewClient()
  75. var bar *progressbar.ProgressBar
  76. currentLayer := ""
  77. request := api.PullRequest{Name: model}
  78. fn := func(resp api.PullProgress) error {
  79. if resp.Digest != currentLayer && resp.Digest != "" {
  80. if currentLayer != "" {
  81. fmt.Println()
  82. }
  83. currentLayer = resp.Digest
  84. layerStr := resp.Digest[7:23] + "..."
  85. bar = progressbar.DefaultBytes(
  86. int64(resp.Total),
  87. "pulling "+layerStr,
  88. )
  89. } else if resp.Digest == currentLayer && resp.Digest != "" {
  90. bar.Set(resp.Completed)
  91. } else {
  92. currentLayer = ""
  93. fmt.Println(resp.Status)
  94. }
  95. return nil
  96. }
  97. if err := client.Pull(context.Background(), &request, fn); err != nil {
  98. return err
  99. }
  100. return nil
  101. }
  102. func RunGenerate(cmd *cobra.Command, args []string) error {
  103. if len(args) > 1 {
  104. // join all args into a single prompt
  105. return generate(cmd, args[0], strings.Join(args[1:], " "))
  106. }
  107. if term.IsTerminal(int(os.Stdin.Fd())) {
  108. return generateInteractive(cmd, args[0])
  109. }
  110. return generateBatch(cmd, args[0])
  111. }
  112. var generateContextKey struct{}
  113. func generate(cmd *cobra.Command, model, prompt string) error {
  114. if len(strings.TrimSpace(prompt)) > 0 {
  115. client := api.NewClient()
  116. spinner := progressbar.NewOptions(-1,
  117. progressbar.OptionSetWriter(os.Stderr),
  118. progressbar.OptionThrottle(60*time.Millisecond),
  119. progressbar.OptionSpinnerType(14),
  120. progressbar.OptionSetRenderBlankState(true),
  121. progressbar.OptionSetElapsedTime(false),
  122. progressbar.OptionClearOnFinish(),
  123. )
  124. go func() {
  125. for range time.Tick(60 * time.Millisecond) {
  126. if spinner.IsFinished() {
  127. break
  128. }
  129. spinner.Add(1)
  130. }
  131. }()
  132. var latest api.GenerateResponse
  133. generateContext, ok := cmd.Context().Value(generateContextKey).([]int)
  134. if !ok {
  135. generateContext = []int{}
  136. }
  137. request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
  138. fn := func(resp api.GenerateResponse) error {
  139. if !spinner.IsFinished() {
  140. spinner.Finish()
  141. }
  142. latest = resp
  143. fmt.Print(resp.Response)
  144. cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context))
  145. return nil
  146. }
  147. if err := client.Generate(context.Background(), &request, fn); err != nil {
  148. return err
  149. }
  150. fmt.Println()
  151. fmt.Println()
  152. verbose, err := cmd.Flags().GetBool("verbose")
  153. if err != nil {
  154. return err
  155. }
  156. if verbose {
  157. latest.Summary()
  158. }
  159. }
  160. return nil
  161. }
  162. func generateInteractive(cmd *cobra.Command, model string) error {
  163. fmt.Print(">>> ")
  164. scanner := bufio.NewScanner(os.Stdin)
  165. for scanner.Scan() {
  166. if err := generate(cmd, model, scanner.Text()); err != nil {
  167. return err
  168. }
  169. fmt.Print(">>> ")
  170. }
  171. return nil
  172. }
  173. func generateBatch(cmd *cobra.Command, model string) error {
  174. scanner := bufio.NewScanner(os.Stdin)
  175. for scanner.Scan() {
  176. prompt := scanner.Text()
  177. fmt.Printf(">>> %s\n", prompt)
  178. if err := generate(cmd, model, prompt); err != nil {
  179. return err
  180. }
  181. }
  182. return nil
  183. }
  184. func RunServer(_ *cobra.Command, _ []string) error {
  185. host := os.Getenv("OLLAMA_HOST")
  186. if host == "" {
  187. host = "127.0.0.1"
  188. }
  189. port := os.Getenv("OLLAMA_PORT")
  190. if port == "" {
  191. port = "11434"
  192. }
  193. ln, err := net.Listen("tcp", fmt.Sprintf("%s:%s", host, port))
  194. if err != nil {
  195. return err
  196. }
  197. return server.Serve(ln)
  198. }
  199. func NewCLI() *cobra.Command {
  200. log.SetFlags(log.LstdFlags | log.Lshortfile)
  201. rootCmd := &cobra.Command{
  202. Use: "ollama",
  203. Short: "Large language model runner",
  204. SilenceUsage: true,
  205. CompletionOptions: cobra.CompletionOptions{
  206. DisableDefaultCmd: true,
  207. },
  208. PersistentPreRunE: func(_ *cobra.Command, args []string) error {
  209. // create the models directory and it's parent
  210. return os.MkdirAll(filepath.Join(cacheDir(), "models"), 0o700)
  211. },
  212. }
  213. cobra.EnableCommandSorting = false
  214. createCmd := &cobra.Command{
  215. Use: "create MODEL",
  216. Short: "Create a model from a Modelfile",
  217. Args: cobra.MinimumNArgs(1),
  218. RunE: create,
  219. }
  220. createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")")
  221. runCmd := &cobra.Command{
  222. Use: "run MODEL [PROMPT]",
  223. Short: "Run a model",
  224. Args: cobra.MinimumNArgs(1),
  225. RunE: RunRun,
  226. }
  227. runCmd.Flags().Bool("verbose", false, "Show timings for response")
  228. serveCmd := &cobra.Command{
  229. Use: "serve",
  230. Aliases: []string{"start"},
  231. Short: "Start ollama",
  232. RunE: RunServer,
  233. }
  234. pullCmd := &cobra.Command{
  235. Use: "pull MODEL",
  236. Short: "Pull a model from a registry",
  237. Args: cobra.MinimumNArgs(1),
  238. RunE: RunPull,
  239. }
  240. pushCmd := &cobra.Command{
  241. Use: "push MODEL",
  242. Short: "Push a model to a registry",
  243. Args: cobra.MinimumNArgs(1),
  244. RunE: push,
  245. }
  246. rootCmd.AddCommand(
  247. serveCmd,
  248. createCmd,
  249. runCmd,
  250. pullCmd,
  251. pushCmd,
  252. )
  253. return rootCmd
  254. }