cmd.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. package cmd
  2. import (
  3. "context"
  4. "fmt"
  5. "log"
  6. "net"
  7. "os"
  8. "path"
  9. "sync"
  10. "github.com/gosuri/uiprogress"
  11. "github.com/jmorganca/ollama/api"
  12. "github.com/jmorganca/ollama/server"
  13. "github.com/spf13/cobra"
  14. )
  15. func cacheDir() string {
  16. home, err := os.UserHomeDir()
  17. if err != nil {
  18. panic(err)
  19. }
  20. return path.Join(home, ".ollama")
  21. }
  22. func bytesToGB(bytes int) float64 {
  23. return float64(bytes) / float64(1<<30)
  24. }
  25. func run(model string) error {
  26. client, err := NewAPIClient()
  27. if err != nil {
  28. return err
  29. }
  30. pr := api.PullRequest{
  31. Model: model,
  32. }
  33. var bar *uiprogress.Bar
  34. mutex := &sync.Mutex{}
  35. var progressData api.PullProgress
  36. pullCallback := func(progress api.PullProgress) {
  37. mutex.Lock()
  38. progressData = progress
  39. if bar == nil {
  40. uiprogress.Start() // start rendering
  41. bar = uiprogress.AddBar(int(progress.Total)) // Add a new bar
  42. // display the total file size and how much has downloaded so far
  43. bar.PrependFunc(func(b *uiprogress.Bar) string {
  44. return fmt.Sprintf("Downloading: %.2f GB / %.2f GB", bytesToGB(progressData.Completed), bytesToGB(progressData.Total))
  45. })
  46. // display completion percentage
  47. bar.AppendFunc(func(b *uiprogress.Bar) string {
  48. return fmt.Sprintf(" %d%%", int((float64(progressData.Completed)/float64(progressData.Total))*100))
  49. })
  50. }
  51. bar.Set(int(progress.Completed))
  52. mutex.Unlock()
  53. }
  54. if err := client.Pull(context.Background(), &pr, pullCallback); err != nil {
  55. return err
  56. }
  57. fmt.Println("Up to date.")
  58. return nil
  59. }
  60. func serve() error {
  61. ln, err := net.Listen("tcp", "127.0.0.1:11434")
  62. if err != nil {
  63. return err
  64. }
  65. return server.Serve(ln)
  66. }
  67. func NewAPIClient() (*api.Client, error) {
  68. return &api.Client{
  69. URL: "http://localhost:11434",
  70. }, nil
  71. }
  72. func NewCLI() *cobra.Command {
  73. log.SetFlags(log.LstdFlags | log.Lshortfile)
  74. rootCmd := &cobra.Command{
  75. Use: "ollama",
  76. Short: "Large language model runner",
  77. CompletionOptions: cobra.CompletionOptions{
  78. DisableDefaultCmd: true,
  79. },
  80. PersistentPreRun: func(cmd *cobra.Command, args []string) {
  81. // Disable usage printing on errors
  82. cmd.SilenceUsage = true
  83. // create the models directory and it's parent
  84. if err := os.MkdirAll(path.Join(cacheDir(), "models"), 0o700); err != nil {
  85. panic(err)
  86. }
  87. },
  88. }
  89. cobra.EnableCommandSorting = false
  90. runCmd := &cobra.Command{
  91. Use: "run MODEL",
  92. Short: "Run a model",
  93. Args: cobra.ExactArgs(1),
  94. RunE: func(cmd *cobra.Command, args []string) error {
  95. return run(args[0])
  96. },
  97. }
  98. serveCmd := &cobra.Command{
  99. Use: "serve",
  100. Aliases: []string{"start"},
  101. Short: "Start ollama",
  102. RunE: func(cmd *cobra.Command, args []string) error {
  103. return serve()
  104. },
  105. }
  106. rootCmd.AddCommand(
  107. serveCmd,
  108. runCmd,
  109. )
  110. return rootCmd
  111. }