opp.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. package main
  2. import (
  3. "bytes"
  4. "cmp"
  5. "context"
  6. "encoding/json"
  7. "errors"
  8. "flag"
  9. "fmt"
  10. "io"
  11. "log"
  12. "mime"
  13. "net/http"
  14. "os"
  15. "runtime"
  16. "strings"
  17. "sync"
  18. "sync/atomic"
  19. "time"
  20. "github.com/ollama/ollama/server/internal/cache/blob"
  21. "github.com/ollama/ollama/server/internal/client/ollama"
  22. "github.com/ollama/ollama/server/internal/cmd/opp/internal/safetensors"
  23. "golang.org/x/sync/errgroup"
  24. )
  25. var stdout io.Writer = os.Stdout
  26. const usage = `Opp is a tool for pushing and pulling Ollama models.
  27. Usage:
  28. opp [flags] <push|pull|import>
  29. Commands:
  30. push Upload a model to the Ollama server.
  31. pull Download a model from the Ollama server.
  32. import Import a model from a local safetensor directory.
  33. Examples:
  34. # Pull a model from the Ollama server.
  35. opp pull library/llama3.2:latest
  36. # Push a model to the Ollama server.
  37. opp push username/my_model:8b
  38. # Import a model from a local safetensor directory.
  39. opp import /path/to/safetensor
  40. Envionment Variables:
  41. OLLAMA_MODELS
  42. The directory where models are pushed and pulled from
  43. (default ~/.ollama/models).
  44. `
  45. func main() {
  46. flag.Usage = func() {
  47. fmt.Fprint(os.Stderr, usage)
  48. }
  49. flag.Parse()
  50. c, err := ollama.DefaultCache()
  51. if err != nil {
  52. log.Fatal(err)
  53. }
  54. rc, err := ollama.RegistryFromEnv()
  55. if err != nil {
  56. log.Fatal(err)
  57. }
  58. ctx := context.Background()
  59. err = func() error {
  60. switch cmd := flag.Arg(0); cmd {
  61. case "pull":
  62. return cmdPull(ctx, rc, c)
  63. case "push":
  64. return cmdPush(ctx, rc, c)
  65. case "import":
  66. return cmdImport(ctx, c)
  67. default:
  68. if cmd == "" {
  69. flag.Usage()
  70. } else {
  71. fmt.Fprintf(os.Stderr, "unknown command %q\n", cmd)
  72. }
  73. os.Exit(2)
  74. return errors.New("unreachable")
  75. }
  76. }()
  77. if err != nil {
  78. fmt.Fprintf(os.Stderr, "opp: %v\n", err)
  79. os.Exit(1)
  80. }
  81. }
  82. func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
  83. model := flag.Arg(1)
  84. if model == "" {
  85. flag.Usage()
  86. os.Exit(1)
  87. }
  88. tr := http.DefaultTransport.(*http.Transport).Clone()
  89. // TODO(bmizerany): configure transport?
  90. rc.HTTPClient = &http.Client{Transport: tr}
  91. var mu sync.Mutex
  92. p := make(map[blob.Digest][2]int64) // digest -> [total, downloaded]
  93. var pb bytes.Buffer
  94. printProgress := func() {
  95. pb.Reset()
  96. mu.Lock()
  97. for d, s := range p {
  98. // Write progress to a buffer first to avoid blocking
  99. // on stdout while holding the lock.
  100. stamp := time.Now().Format("2006/01/02 15:04:05")
  101. fmt.Fprintf(&pb, "%s %s pulling %d/%d (%.1f%%)\n", stamp, d.Short(), s[1], s[0], 100*float64(s[1])/float64(s[0]))
  102. if s[0] == s[1] {
  103. delete(p, d)
  104. }
  105. }
  106. mu.Unlock()
  107. io.Copy(stdout, &pb)
  108. }
  109. ctx = ollama.WithTrace(ctx, &ollama.Trace{
  110. Update: func(l *ollama.Layer, n int64, err error) {
  111. if err != nil && !errors.Is(err, ollama.ErrCached) {
  112. fmt.Fprintf(stdout, "opp: pull %s ! %v\n", l.Digest.Short(), err)
  113. return
  114. }
  115. mu.Lock()
  116. p[l.Digest] = [2]int64{l.Size, n}
  117. mu.Unlock()
  118. },
  119. })
  120. errc := make(chan error)
  121. go func() {
  122. errc <- rc.Pull(ctx, c, model)
  123. }()
  124. t := time.NewTicker(time.Second)
  125. defer t.Stop()
  126. for {
  127. select {
  128. case <-t.C:
  129. printProgress()
  130. case err := <-errc:
  131. printProgress()
  132. return err
  133. }
  134. }
  135. }
  136. func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
  137. args := flag.Args()[1:]
  138. flag := flag.NewFlagSet("push", flag.ExitOnError)
  139. flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
  140. flag.Usage = func() {
  141. fmt.Fprintf(os.Stderr, "Usage: opp push <model>\n")
  142. flag.PrintDefaults()
  143. }
  144. flag.Parse(args)
  145. model := flag.Arg(0)
  146. if model == "" {
  147. return fmt.Errorf("missing model argument")
  148. }
  149. from := cmp.Or(*flagFrom, model)
  150. m, err := ollama.ResolveLocal(c, from)
  151. if err != nil {
  152. return err
  153. }
  154. ctx = ollama.WithTrace(ctx, &ollama.Trace{
  155. Update: func(l *ollama.Layer, n int64, err error) {
  156. switch {
  157. case errors.Is(err, ollama.ErrCached):
  158. fmt.Fprintf(stdout, "opp: uploading %s %d (existed)", l.Digest.Short(), n)
  159. case err != nil:
  160. fmt.Fprintf(stdout, "opp: uploading %s %d ! %v\n", l.Digest.Short(), n, err)
  161. case n == 0:
  162. l := m.Layer(l.Digest)
  163. mt, p, _ := mime.ParseMediaType(l.MediaType)
  164. mt, _ = strings.CutPrefix(mt, "application/vnd.ollama.image.")
  165. switch mt {
  166. case "tensor":
  167. fmt.Fprintf(stdout, "opp: uploading tensor %s %s\n", l.Digest.Short(), p["name"])
  168. default:
  169. fmt.Fprintf(stdout, "opp: uploading %s %s\n", l.Digest.Short(), l.MediaType)
  170. }
  171. }
  172. },
  173. })
  174. return rc.Push(ctx, c, model, &ollama.PushParams{
  175. From: from,
  176. })
  177. }
  178. type trackingReader struct {
  179. io.Reader
  180. n *atomic.Int64
  181. }
  182. func (r *trackingReader) Read(p []byte) (n int, err error) {
  183. n, err = r.Reader.Read(p)
  184. r.n.Add(int64(n))
  185. return n, err
  186. }
  187. func cmdImport(ctx context.Context, c *blob.DiskCache) error {
  188. args := flag.Args()[1:]
  189. flag := flag.NewFlagSet("import", flag.ExitOnError)
  190. flagAs := flag.String("as", "", "Import using the provided name.")
  191. flag.Usage = func() {
  192. fmt.Fprintf(os.Stderr, "Usage: opp import <SafetensorDir>\n")
  193. flag.PrintDefaults()
  194. }
  195. flag.Parse(args)
  196. dir := cmp.Or(flag.Arg(0), ".")
  197. fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
  198. m, err := safetensors.Read(os.DirFS(dir))
  199. if err != nil {
  200. return err
  201. }
  202. var total int64
  203. var tt []*safetensors.Tensor
  204. for t, err := range m.Tensors() {
  205. if err != nil {
  206. return err
  207. }
  208. tt = append(tt, t)
  209. total += t.Size()
  210. }
  211. var n atomic.Int64
  212. done := make(chan error)
  213. go func() {
  214. layers := make([]*ollama.Layer, len(tt))
  215. var g errgroup.Group
  216. g.SetLimit(runtime.GOMAXPROCS(0))
  217. var ctxErr error
  218. for i, t := range tt {
  219. if ctx.Err() != nil {
  220. // The context may cancel AFTER we exit the
  221. // loop, and so if we use ctx.Err() after the
  222. // loop we may report it as the error that
  223. // broke the loop, when it was not. This can
  224. // manifest as a false-negative, leading the
  225. // user to think their import failed when it
  226. // did not, so capture it if and only if we
  227. // exit the loop because of a ctx.Err() and
  228. // report it.
  229. ctxErr = ctx.Err()
  230. break
  231. }
  232. g.Go(func() (err error) {
  233. rc, err := t.Reader()
  234. if err != nil {
  235. return err
  236. }
  237. defer rc.Close()
  238. tr := &trackingReader{rc, &n}
  239. d, err := c.Import(tr, t.Size())
  240. if err != nil {
  241. return err
  242. }
  243. if err := rc.Close(); err != nil {
  244. return err
  245. }
  246. layers[i] = &ollama.Layer{
  247. Digest: d,
  248. Size: t.Size(),
  249. MediaType: mime.FormatMediaType("application/vnd.ollama.image.tensor", map[string]string{
  250. "name": t.Name(),
  251. "dtype": t.DataType(),
  252. "shape": t.Shape().String(),
  253. }),
  254. }
  255. return nil
  256. })
  257. }
  258. done <- func() error {
  259. if err := errors.Join(g.Wait(), ctxErr); err != nil {
  260. return err
  261. }
  262. m := &ollama.Manifest{Layers: layers}
  263. data, err := json.MarshalIndent(m, "", " ")
  264. if err != nil {
  265. return err
  266. }
  267. d := blob.DigestFromBytes(data)
  268. err = blob.PutBytes(c, d, data)
  269. if err != nil {
  270. return err
  271. }
  272. return c.Link(*flagAs, d)
  273. }()
  274. }()
  275. fmt.Fprintf(stdout, "Importing %d tensors from %s\n", len(tt), dir)
  276. csiHideCursor(stdout)
  277. defer csiShowCursor(stdout)
  278. csiSavePos(stdout)
  279. writeProgress := func() {
  280. csiRestorePos(stdout)
  281. nn := n.Load()
  282. fmt.Fprintf(stdout, "Imported %s/%s bytes (%d%%)%s\n",
  283. formatNatural(nn),
  284. formatNatural(total),
  285. nn*100/total,
  286. ansiClearToEndOfLine,
  287. )
  288. }
  289. ticker := time.NewTicker(time.Second)
  290. defer ticker.Stop()
  291. for {
  292. select {
  293. case <-ticker.C:
  294. writeProgress()
  295. case err := <-done:
  296. writeProgress()
  297. return err
  298. }
  299. }
  300. }
  301. func formatNatural(n int64) string {
  302. switch {
  303. case n < 1024:
  304. return fmt.Sprintf("%d B", n)
  305. case n < 1024*1024:
  306. return fmt.Sprintf("%.1f KB", float64(n)/1024)
  307. case n < 1024*1024*1024:
  308. return fmt.Sprintf("%.1f MB", float64(n)/(1024*1024))
  309. default:
  310. return fmt.Sprintf("%.1f GB", float64(n)/(1024*1024*1024))
  311. }
  312. }
  313. const ansiClearToEndOfLine = "\033[K"
  314. func csiSavePos(w io.Writer) { fmt.Fprint(w, "\033[s") }
  315. func csiRestorePos(w io.Writer) { fmt.Fprint(w, "\033[u") }
  316. func csiHideCursor(w io.Writer) { fmt.Fprint(w, "\033[?25l") }
  317. func csiShowCursor(w io.Writer) { fmt.Fprint(w, "\033[?25h") }