cmd.go 6.9 KB


  1. package cmd
  2. import (
  3. "bufio"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "net"
  9. "net/http"
  10. "os"
  11. "strings"
  12. "time"
  13. "github.com/dustin/go-humanize"
  14. "github.com/olekukonko/tablewriter"
  15. "github.com/schollz/progressbar/v3"
  16. "github.com/spf13/cobra"
  17. "golang.org/x/term"
  18. "github.com/jmorganca/ollama/api"
  19. "github.com/jmorganca/ollama/format"
  20. "github.com/jmorganca/ollama/server"
  21. )
  22. func create(cmd *cobra.Command, args []string) error {
  23. filename, _ := cmd.Flags().GetString("file")
  24. client := api.NewClient()
  25. request := api.CreateRequest{Name: args[0], Path: filename}
  26. fn := func(resp api.CreateProgress) error {
  27. fmt.Println(resp.Status)
  28. return nil
  29. }
  30. if err := client.Create(context.Background(), &request, fn); err != nil {
  31. return err
  32. }
  33. return nil
  34. }
  35. func RunRun(cmd *cobra.Command, args []string) error {
  36. mp := server.ParseModelPath(args[0])
  37. fp, err := mp.GetManifestPath(false)
  38. if err != nil {
  39. return err
  40. }
  41. _, err = os.Stat(fp)
  42. switch {
  43. case errors.Is(err, os.ErrNotExist):
  44. if err := pull(args[0]); err != nil {
  45. var apiStatusError api.StatusError
  46. if !errors.As(err, &apiStatusError) {
  47. return err
  48. }
  49. if apiStatusError.StatusCode != http.StatusBadGateway {
  50. return err
  51. }
  52. }
  53. case err != nil:
  54. return err
  55. }
  56. return RunGenerate(cmd, args)
  57. }
  58. func push(cmd *cobra.Command, args []string) error {
  59. client := api.NewClient()
  60. request := api.PushRequest{Name: args[0]}
  61. fn := func(resp api.PushProgress) error {
  62. fmt.Println(resp.Status)
  63. return nil
  64. }
  65. if err := client.Push(context.Background(), &request, fn); err != nil {
  66. return err
  67. }
  68. return nil
  69. }
  70. func list(cmd *cobra.Command, args []string) error {
  71. client := api.NewClient()
  72. models, err := client.List(context.Background())
  73. if err != nil {
  74. return err
  75. }
  76. var data [][]string
  77. for _, m := range models.Models {
  78. data = append(data, []string{m.Name, humanize.Bytes(uint64(m.Size)), format.HumanTime(m.ModifiedAt, "Never")})
  79. }
  80. table := tablewriter.NewWriter(os.Stdout)
  81. table.SetHeader([]string{"NAME", "SIZE", "MODIFIED"})
  82. table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
  83. table.SetAlignment(tablewriter.ALIGN_LEFT)
  84. table.SetHeaderLine(false)
  85. table.SetBorder(false)
  86. table.SetNoWhiteSpace(true)
  87. table.SetTablePadding("\t")
  88. table.AppendBulk(data)
  89. table.Render()
  90. return nil
  91. }
  92. func RunPull(cmd *cobra.Command, args []string) error {
  93. return pull(args[0])
  94. }
  95. func pull(model string) error {
  96. client := api.NewClient()
  97. var bar *progressbar.ProgressBar
  98. currentLayer := ""
  99. request := api.PullRequest{Name: model}
  100. fn := func(resp api.PullProgress) error {
  101. if resp.Digest != currentLayer && resp.Digest != "" {
  102. if currentLayer != "" {
  103. fmt.Println()
  104. }
  105. currentLayer = resp.Digest
  106. layerStr := resp.Digest[7:23] + "..."
  107. bar = progressbar.DefaultBytes(
  108. int64(resp.Total),
  109. "pulling "+layerStr,
  110. )
  111. } else if resp.Digest == currentLayer && resp.Digest != "" {
  112. bar.Set(resp.Completed)
  113. } else {
  114. currentLayer = ""
  115. fmt.Println(resp.Status)
  116. }
  117. return nil
  118. }
  119. if err := client.Pull(context.Background(), &request, fn); err != nil {
  120. return err
  121. }
  122. return nil
  123. }
  124. func RunGenerate(cmd *cobra.Command, args []string) error {
  125. if len(args) > 1 {
  126. // join all args into a single prompt
  127. return generate(cmd, args[0], strings.Join(args[1:], " "))
  128. }
  129. if term.IsTerminal(int(os.Stdin.Fd())) {
  130. return generateInteractive(cmd, args[0])
  131. }
  132. return generateBatch(cmd, args[0])
  133. }
  134. var generateContextKey struct{}
  135. func generate(cmd *cobra.Command, model, prompt string) error {
  136. if len(strings.TrimSpace(prompt)) > 0 {
  137. client := api.NewClient()
  138. spinner := progressbar.NewOptions(-1,
  139. progressbar.OptionSetWriter(os.Stderr),
  140. progressbar.OptionThrottle(60*time.Millisecond),
  141. progressbar.OptionSpinnerType(14),
  142. progressbar.OptionSetRenderBlankState(true),
  143. progressbar.OptionSetElapsedTime(false),
  144. progressbar.OptionClearOnFinish(),
  145. )
  146. go func() {
  147. for range time.Tick(60 * time.Millisecond) {
  148. if spinner.IsFinished() {
  149. break
  150. }
  151. spinner.Add(1)
  152. }
  153. }()
  154. var latest api.GenerateResponse
  155. generateContext, ok := cmd.Context().Value(generateContextKey).([]int)
  156. if !ok {
  157. generateContext = []int{}
  158. }
  159. request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
  160. fn := func(resp api.GenerateResponse) error {
  161. if !spinner.IsFinished() {
  162. spinner.Finish()
  163. }
  164. latest = resp
  165. fmt.Print(resp.Response)
  166. cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context))
  167. return nil
  168. }
  169. if err := client.Generate(context.Background(), &request, fn); err != nil {
  170. return err
  171. }
  172. fmt.Println()
  173. fmt.Println()
  174. verbose, err := cmd.Flags().GetBool("verbose")
  175. if err != nil {
  176. return err
  177. }
  178. if verbose {
  179. latest.Summary()
  180. }
  181. }
  182. return nil
  183. }
  184. func generateInteractive(cmd *cobra.Command, model string) error {
  185. fmt.Print(">>> ")
  186. scanner := bufio.NewScanner(os.Stdin)
  187. for scanner.Scan() {
  188. if err := generate(cmd, model, scanner.Text()); err != nil {
  189. return err
  190. }
  191. fmt.Print(">>> ")
  192. }
  193. return nil
  194. }
  195. func generateBatch(cmd *cobra.Command, model string) error {
  196. scanner := bufio.NewScanner(os.Stdin)
  197. for scanner.Scan() {
  198. prompt := scanner.Text()
  199. fmt.Printf(">>> %s\n", prompt)
  200. if err := generate(cmd, model, prompt); err != nil {
  201. return err
  202. }
  203. }
  204. return nil
  205. }
  206. func RunServer(_ *cobra.Command, _ []string) error {
  207. host := os.Getenv("OLLAMA_HOST")
  208. if host == "" {
  209. host = "127.0.0.1"
  210. }
  211. port := os.Getenv("OLLAMA_PORT")
  212. if port == "" {
  213. port = "11434"
  214. }
  215. ln, err := net.Listen("tcp", fmt.Sprintf("%s:%s", host, port))
  216. if err != nil {
  217. return err
  218. }
  219. return server.Serve(ln)
  220. }
  221. func NewCLI() *cobra.Command {
  222. log.SetFlags(log.LstdFlags | log.Lshortfile)
  223. rootCmd := &cobra.Command{
  224. Use: "ollama",
  225. Short: "Large language model runner",
  226. SilenceUsage: true,
  227. CompletionOptions: cobra.CompletionOptions{
  228. DisableDefaultCmd: true,
  229. },
  230. }
  231. cobra.EnableCommandSorting = false
  232. createCmd := &cobra.Command{
  233. Use: "create MODEL",
  234. Short: "Create a model from a Modelfile",
  235. Args: cobra.MinimumNArgs(1),
  236. RunE: create,
  237. }
  238. createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")")
  239. runCmd := &cobra.Command{
  240. Use: "run MODEL [PROMPT]",
  241. Short: "Run a model",
  242. Args: cobra.MinimumNArgs(1),
  243. RunE: RunRun,
  244. }
  245. runCmd.Flags().Bool("verbose", false, "Show timings for response")
  246. serveCmd := &cobra.Command{
  247. Use: "serve",
  248. Aliases: []string{"start"},
  249. Short: "Start ollama",
  250. RunE: RunServer,
  251. }
  252. pullCmd := &cobra.Command{
  253. Use: "pull MODEL",
  254. Short: "Pull a model from a registry",
  255. Args: cobra.MinimumNArgs(1),
  256. RunE: RunPull,
  257. }
  258. pushCmd := &cobra.Command{
  259. Use: "push MODEL",
  260. Short: "Push a model to a registry",
  261. Args: cobra.MinimumNArgs(1),
  262. RunE: push,
  263. }
  264. listCmd := &cobra.Command{
  265. Use: "list",
  266. Short: "List models",
  267. RunE: list,
  268. }
  269. rootCmd.AddCommand(
  270. serveCmd,
  271. createCmd,
  272. runCmd,
  273. pullCmd,
  274. pushCmd,
  275. listCmd,
  276. )
  277. return rootCmd
  278. }