cmd.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. package cmd
  2. import (
  3. "bufio"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "log"
  8. "net"
  9. "os"
  10. "path"
  11. "sync"
  12. "github.com/gosuri/uiprogress"
  13. "github.com/spf13/cobra"
  14. "golang.org/x/term"
  15. "github.com/jmorganca/ollama/api"
  16. "github.com/jmorganca/ollama/server"
  17. )
  18. func cacheDir() string {
  19. home, err := os.UserHomeDir()
  20. if err != nil {
  21. panic(err)
  22. }
  23. return path.Join(home, ".ollama")
  24. }
  25. func bytesToGB(bytes int) float64 {
  26. return float64(bytes) / float64(1<<30)
  27. }
  28. func RunRun(cmd *cobra.Command, args []string) error {
  29. client, err := NewAPIClient()
  30. if err != nil {
  31. return err
  32. }
  33. pr := api.PullRequest{
  34. Model: args[0],
  35. }
  36. var bar *uiprogress.Bar
  37. mutex := &sync.Mutex{}
  38. var progressData api.PullProgress
  39. pullCallback := func(progress api.PullProgress) {
  40. mutex.Lock()
  41. progressData = progress
  42. if bar == nil {
  43. uiprogress.Start()
  44. bar = uiprogress.AddBar(int(progress.Total))
  45. bar.PrependFunc(func(b *uiprogress.Bar) string {
  46. return fmt.Sprintf("Downloading: %.2f GB / %.2f GB", bytesToGB(progressData.Completed), bytesToGB(progressData.Total))
  47. })
  48. bar.AppendFunc(func(b *uiprogress.Bar) string {
  49. return fmt.Sprintf(" %d%%", int((float64(progressData.Completed)/float64(progressData.Total))*100))
  50. })
  51. }
  52. bar.Set(int(progress.Completed))
  53. mutex.Unlock()
  54. }
  55. if err := client.Pull(context.Background(), &pr, pullCallback); err != nil {
  56. return err
  57. }
  58. fmt.Println("Up to date.")
  59. return RunGenerate(cmd, args)
  60. }
  61. func RunGenerate(_ *cobra.Command, args []string) error {
  62. if len(args) > 1 {
  63. return generate(args[0], args[1:]...)
  64. }
  65. if term.IsTerminal(int(os.Stdin.Fd())) {
  66. return generateInteractive(args[0])
  67. }
  68. return generateBatch(args[0])
  69. }
  70. func generate(model string, prompts ...string) error {
  71. client, err := NewAPIClient()
  72. if err != nil {
  73. return err
  74. }
  75. for _, prompt := range prompts {
  76. client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(bts []byte) {
  77. var resp api.GenerateResponse
  78. if err := json.Unmarshal(bts, &resp); err != nil {
  79. return
  80. }
  81. fmt.Print(resp.Response)
  82. })
  83. }
  84. fmt.Println()
  85. fmt.Println()
  86. return nil
  87. }
  88. func generateInteractive(model string) error {
  89. fmt.Print(">>> ")
  90. scanner := bufio.NewScanner(os.Stdin)
  91. for scanner.Scan() {
  92. if err := generate(model, scanner.Text()); err != nil {
  93. return err
  94. }
  95. fmt.Print(">>> ")
  96. }
  97. return nil
  98. }
  99. func generateBatch(model string) error {
  100. scanner := bufio.NewScanner(os.Stdin)
  101. for scanner.Scan() {
  102. prompt := scanner.Text()
  103. fmt.Printf(">>> %s\n", prompt)
  104. if err := generate(model, prompt); err != nil {
  105. return err
  106. }
  107. }
  108. return nil
  109. }
  110. func RunServer(_ *cobra.Command, _ []string) error {
  111. ln, err := net.Listen("tcp", "127.0.0.1:11434")
  112. if err != nil {
  113. return err
  114. }
  115. return server.Serve(ln)
  116. }
  117. func NewAPIClient() (*api.Client, error) {
  118. return &api.Client{
  119. URL: "http://localhost:11434",
  120. }, nil
  121. }
  122. func NewCLI() *cobra.Command {
  123. log.SetFlags(log.LstdFlags | log.Lshortfile)
  124. rootCmd := &cobra.Command{
  125. Use: "ollama",
  126. Short: "Large language model runner",
  127. SilenceUsage: true,
  128. CompletionOptions: cobra.CompletionOptions{
  129. DisableDefaultCmd: true,
  130. },
  131. PersistentPreRunE: func(_ *cobra.Command, args []string) error {
  132. // create the models directory and it's parent
  133. return os.MkdirAll(path.Join(cacheDir(), "models"), 0o700)
  134. },
  135. }
  136. cobra.EnableCommandSorting = false
  137. runCmd := &cobra.Command{
  138. Use: "run MODEL [PROMPT]",
  139. Short: "Run a model",
  140. Args: cobra.MinimumNArgs(1),
  141. RunE: RunRun,
  142. }
  143. serveCmd := &cobra.Command{
  144. Use: "serve",
  145. Aliases: []string{"start"},
  146. Short: "Start ollama",
  147. RunE: RunServer,
  148. }
  149. rootCmd.AddCommand(
  150. serveCmd,
  151. runCmd,
  152. )
  153. return rootCmd
  154. }