opp.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. package main
  2. import (
  3. "bytes"
  4. "cmp"
  5. "context"
  6. "encoding/json"
  7. "errors"
  8. "flag"
  9. "fmt"
  10. "io"
  11. "log"
  12. "mime"
  13. "net/http"
  14. "os"
  15. "runtime"
  16. "strings"
  17. "sync"
  18. "sync/atomic"
  19. "time"
  20. "github.com/ollama/ollama/server/internal/cache/blob"
  21. "github.com/ollama/ollama/server/internal/client/ollama"
  22. "github.com/ollama/ollama/server/internal/cmd/opp/internal/safetensors"
  23. "golang.org/x/sync/errgroup"
  24. )
  25. var stdout io.Writer = os.Stdout
  26. const usage = `Opp is a tool for pushing and pulling Ollama models.
  27. Usage:
  28. opp [flags] <push|pull|import>
  29. Commands:
  30. push Upload a model to the Ollama server.
  31. pull Download a model from the Ollama server.
  32. import Import a model from a local safetensor directory.
  33. Examples:
  34. # Pull a model from the Ollama server.
  35. opp pull library/llama3.2:latest
  36. # Push a model to the Ollama server.
  37. opp push username/my_model:8b
  38. # Import a model from a local safetensor directory.
  39. opp import /path/to/safetensor
  40. Envionment Variables:
  41. OLLAMA_MODELS
  42. The directory where models are pushed and pulled from
  43. (default ~/.ollama/models).
  44. `
  45. func main() {
  46. flag.Usage = func() {
  47. fmt.Fprint(os.Stderr, usage)
  48. }
  49. flag.Parse()
  50. ctx := context.Background()
  51. err := func() error {
  52. switch cmd := flag.Arg(0); cmd {
  53. case "pull":
  54. rc, err := ollama.DefaultRegistry()
  55. if err != nil {
  56. log.Fatal(err)
  57. }
  58. return cmdPull(ctx, rc)
  59. case "push":
  60. rc, err := ollama.DefaultRegistry()
  61. if err != nil {
  62. log.Fatal(err)
  63. }
  64. return cmdPush(ctx, rc)
  65. case "import":
  66. c, err := ollama.DefaultCache()
  67. if err != nil {
  68. log.Fatal(err)
  69. }
  70. return cmdImport(ctx, c)
  71. default:
  72. if cmd == "" {
  73. flag.Usage()
  74. } else {
  75. fmt.Fprintf(os.Stderr, "unknown command %q\n", cmd)
  76. }
  77. os.Exit(2)
  78. return errors.New("unreachable")
  79. }
  80. }()
  81. if err != nil {
  82. fmt.Fprintf(os.Stderr, "opp: %v\n", err)
  83. os.Exit(1)
  84. }
  85. }
  86. func cmdPull(ctx context.Context, rc *ollama.Registry) error {
  87. model := flag.Arg(1)
  88. if model == "" {
  89. flag.Usage()
  90. os.Exit(1)
  91. }
  92. tr := http.DefaultTransport.(*http.Transport).Clone()
  93. // TODO(bmizerany): configure transport?
  94. rc.HTTPClient = &http.Client{Transport: tr}
  95. var mu sync.Mutex
  96. p := make(map[blob.Digest][2]int64) // digest -> [total, downloaded]
  97. var pb bytes.Buffer
  98. printProgress := func() {
  99. pb.Reset()
  100. mu.Lock()
  101. for d, s := range p {
  102. // Write progress to a buffer first to avoid blocking
  103. // on stdout while holding the lock.
  104. stamp := time.Now().Format("2006/01/02 15:04:05")
  105. fmt.Fprintf(&pb, "%s %s pulling %d/%d (%.1f%%)\n", stamp, d.Short(), s[1], s[0], 100*float64(s[1])/float64(s[0]))
  106. if s[0] == s[1] {
  107. delete(p, d)
  108. }
  109. }
  110. mu.Unlock()
  111. io.Copy(stdout, &pb)
  112. }
  113. ctx = ollama.WithTrace(ctx, &ollama.Trace{
  114. Update: func(l *ollama.Layer, n int64, err error) {
  115. if err != nil && !errors.Is(err, ollama.ErrCached) {
  116. fmt.Fprintf(stdout, "opp: pull %s ! %v\n", l.Digest.Short(), err)
  117. return
  118. }
  119. mu.Lock()
  120. p[l.Digest] = [2]int64{l.Size, n}
  121. mu.Unlock()
  122. },
  123. })
  124. errc := make(chan error)
  125. go func() {
  126. errc <- rc.Pull(ctx, model)
  127. }()
  128. t := time.NewTicker(time.Second)
  129. defer t.Stop()
  130. for {
  131. select {
  132. case <-t.C:
  133. printProgress()
  134. case err := <-errc:
  135. printProgress()
  136. return err
  137. }
  138. }
  139. }
  140. func cmdPush(ctx context.Context, rc *ollama.Registry) error {
  141. args := flag.Args()[1:]
  142. flag := flag.NewFlagSet("push", flag.ExitOnError)
  143. flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
  144. flag.Usage = func() {
  145. fmt.Fprintf(os.Stderr, "Usage: opp push <model>\n")
  146. flag.PrintDefaults()
  147. }
  148. flag.Parse(args)
  149. model := flag.Arg(0)
  150. if model == "" {
  151. return fmt.Errorf("missing model argument")
  152. }
  153. from := cmp.Or(*flagFrom, model)
  154. m, err := rc.ResolveLocal(from)
  155. if err != nil {
  156. return err
  157. }
  158. ctx = ollama.WithTrace(ctx, &ollama.Trace{
  159. Update: func(l *ollama.Layer, n int64, err error) {
  160. switch {
  161. case errors.Is(err, ollama.ErrCached):
  162. fmt.Fprintf(stdout, "opp: uploading %s %d (existed)", l.Digest.Short(), n)
  163. case err != nil:
  164. fmt.Fprintf(stdout, "opp: uploading %s %d ! %v\n", l.Digest.Short(), n, err)
  165. case n == 0:
  166. l := m.Layer(l.Digest)
  167. mt, p, _ := mime.ParseMediaType(l.MediaType)
  168. mt, _ = strings.CutPrefix(mt, "application/vnd.ollama.image.")
  169. switch mt {
  170. case "tensor":
  171. fmt.Fprintf(stdout, "opp: uploading tensor %s %s\n", l.Digest.Short(), p["name"])
  172. default:
  173. fmt.Fprintf(stdout, "opp: uploading %s %s\n", l.Digest.Short(), l.MediaType)
  174. }
  175. }
  176. },
  177. })
  178. return rc.Push(ctx, model, &ollama.PushParams{
  179. From: from,
  180. })
  181. }
  182. type trackingReader struct {
  183. io.Reader
  184. n *atomic.Int64
  185. }
  186. func (r *trackingReader) Read(p []byte) (n int, err error) {
  187. n, err = r.Reader.Read(p)
  188. r.n.Add(int64(n))
  189. return n, err
  190. }
  191. func cmdImport(ctx context.Context, c *blob.DiskCache) error {
  192. args := flag.Args()[1:]
  193. flag := flag.NewFlagSet("import", flag.ExitOnError)
  194. flagAs := flag.String("as", "", "Import using the provided name.")
  195. flag.Usage = func() {
  196. fmt.Fprintf(os.Stderr, "Usage: opp import <SafetensorDir>\n")
  197. flag.PrintDefaults()
  198. }
  199. flag.Parse(args)
  200. if *flagAs == "" {
  201. return fmt.Errorf("missing -as flag")
  202. }
  203. as := ollama.CompleteName(*flagAs)
  204. dir := cmp.Or(flag.Arg(0), ".")
  205. fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
  206. m, err := safetensors.Read(os.DirFS(dir))
  207. if err != nil {
  208. return err
  209. }
  210. var total int64
  211. var tt []*safetensors.Tensor
  212. for t, err := range m.Tensors() {
  213. if err != nil {
  214. return err
  215. }
  216. tt = append(tt, t)
  217. total += t.Size()
  218. }
  219. var n atomic.Int64
  220. done := make(chan error)
  221. go func() {
  222. layers := make([]*ollama.Layer, len(tt))
  223. var g errgroup.Group
  224. g.SetLimit(runtime.GOMAXPROCS(0))
  225. var ctxErr error
  226. for i, t := range tt {
  227. if ctx.Err() != nil {
  228. // The context may cancel AFTER we exit the
  229. // loop, and so if we use ctx.Err() after the
  230. // loop we may report it as the error that
  231. // broke the loop, when it was not. This can
  232. // manifest as a false-negative, leading the
  233. // user to think their import failed when it
  234. // did not, so capture it if and only if we
  235. // exit the loop because of a ctx.Err() and
  236. // report it.
  237. ctxErr = ctx.Err()
  238. break
  239. }
  240. g.Go(func() (err error) {
  241. rc, err := t.Reader()
  242. if err != nil {
  243. return err
  244. }
  245. defer rc.Close()
  246. tr := &trackingReader{rc, &n}
  247. d, err := c.Import(tr, t.Size())
  248. if err != nil {
  249. return err
  250. }
  251. if err := rc.Close(); err != nil {
  252. return err
  253. }
  254. layers[i] = &ollama.Layer{
  255. Digest: d,
  256. Size: t.Size(),
  257. MediaType: mime.FormatMediaType("application/vnd.ollama.image.tensor", map[string]string{
  258. "name": t.Name(),
  259. "dtype": t.DataType(),
  260. "shape": t.Shape().String(),
  261. }),
  262. }
  263. return nil
  264. })
  265. }
  266. done <- func() error {
  267. if err := errors.Join(g.Wait(), ctxErr); err != nil {
  268. return err
  269. }
  270. m := &ollama.Manifest{Layers: layers}
  271. data, err := json.MarshalIndent(m, "", " ")
  272. if err != nil {
  273. return err
  274. }
  275. d := blob.DigestFromBytes(data)
  276. err = blob.PutBytes(c, d, data)
  277. if err != nil {
  278. return err
  279. }
  280. return c.Link(as, d)
  281. }()
  282. }()
  283. fmt.Fprintf(stdout, "Importing %d tensors from %s\n", len(tt), dir)
  284. csiHideCursor(stdout)
  285. defer csiShowCursor(stdout)
  286. csiSavePos(stdout)
  287. writeProgress := func() {
  288. csiRestorePos(stdout)
  289. nn := n.Load()
  290. fmt.Fprintf(stdout, "Imported %s/%s bytes (%d%%)%s\n",
  291. formatNatural(nn),
  292. formatNatural(total),
  293. nn*100/total,
  294. ansiClearToEndOfLine,
  295. )
  296. }
  297. ticker := time.NewTicker(time.Second)
  298. defer ticker.Stop()
  299. for {
  300. select {
  301. case <-ticker.C:
  302. writeProgress()
  303. case err := <-done:
  304. writeProgress()
  305. fmt.Println()
  306. fmt.Println("Successfully imported", as)
  307. return err
  308. }
  309. }
  310. }
  311. func formatNatural(n int64) string {
  312. switch {
  313. case n < 1024:
  314. return fmt.Sprintf("%d B", n)
  315. case n < 1024*1024:
  316. return fmt.Sprintf("%.1f KB", float64(n)/1024)
  317. case n < 1024*1024*1024:
  318. return fmt.Sprintf("%.1f MB", float64(n)/(1024*1024))
  319. default:
  320. return fmt.Sprintf("%.1f GB", float64(n)/(1024*1024*1024))
  321. }
  322. }
  323. const ansiClearToEndOfLine = "\033[K"
  324. func csiSavePos(w io.Writer) { fmt.Fprint(w, "\033[s") }
  325. func csiRestorePos(w io.Writer) { fmt.Fprint(w, "\033[u") }
  326. func csiHideCursor(w io.Writer) { fmt.Fprint(w, "\033[?25l") }
  327. func csiShowCursor(w io.Writer) { fmt.Fprint(w, "\033[?25h") }