opp.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. package main
  2. import (
  3. "bytes"
  4. "cmp"
  5. "context"
  6. "encoding/json"
  7. "errors"
  8. "flag"
  9. "fmt"
  10. "io"
  11. "log"
  12. "mime"
  13. "net/http"
  14. "os"
  15. "runtime"
  16. "strings"
  17. "sync"
  18. "sync/atomic"
  19. "time"
  20. "github.com/ollama/ollama/server/internal/cache/blob"
  21. "github.com/ollama/ollama/server/internal/client/ollama"
  22. "github.com/ollama/ollama/server/internal/cmd/opp/internal/safetensors"
  23. "golang.org/x/sync/errgroup"
  24. )
  25. var stdout io.Writer = os.Stdout
  26. const usage = `Opp is a tool for pushing and pulling Ollama models.
  27. Usage:
  28. opp [flags] <push|pull|import>
  29. Commands:
  30. push Upload a model to the Ollama server.
  31. pull Download a model from the Ollama server.
  32. import Import a model from a local safetensor directory.
  33. Examples:
  34. # Pull a model from the Ollama server.
  35. opp pull library/llama3.2:latest
  36. # Push a model to the Ollama server.
  37. opp push username/my_model:8b
  38. # Import a model from a local safetensor directory.
  39. opp import /path/to/safetensor
  40. Envionment Variables:
  41. OLLAMA_MODELS
  42. The directory where models are pushed and pulled from
  43. (default ~/.ollama/models).
  44. `
  45. func main() {
  46. flag.Usage = func() {
  47. fmt.Fprint(os.Stderr, usage)
  48. }
  49. flag.Parse()
  50. c, err := ollama.DefaultCache()
  51. if err != nil {
  52. log.Fatal(err)
  53. }
  54. rc, err := ollama.DefaultRegistry()
  55. if err != nil {
  56. log.Fatal(err)
  57. }
  58. ctx := context.Background()
  59. err = func() error {
  60. switch cmd := flag.Arg(0); cmd {
  61. case "pull":
  62. return cmdPull(ctx, rc, c)
  63. case "push":
  64. return cmdPush(ctx, rc, c)
  65. case "import":
  66. return cmdImport(ctx, c)
  67. default:
  68. if cmd == "" {
  69. flag.Usage()
  70. } else {
  71. fmt.Fprintf(os.Stderr, "unknown command %q\n", cmd)
  72. }
  73. os.Exit(2)
  74. return errors.New("unreachable")
  75. }
  76. }()
  77. if err != nil {
  78. fmt.Fprintf(os.Stderr, "opp: %v\n", err)
  79. os.Exit(1)
  80. }
  81. }
  82. func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
  83. model := flag.Arg(1)
  84. if model == "" {
  85. flag.Usage()
  86. os.Exit(1)
  87. }
  88. tr := http.DefaultTransport.(*http.Transport).Clone()
  89. // TODO(bmizerany): configure transport?
  90. rc.HTTPClient = &http.Client{Transport: tr}
  91. var mu sync.Mutex
  92. p := make(map[blob.Digest][2]int64) // digest -> [total, downloaded]
  93. var pb bytes.Buffer
  94. printProgress := func() {
  95. pb.Reset()
  96. mu.Lock()
  97. for d, s := range p {
  98. // Write progress to a buffer first to avoid blocking
  99. // on stdout while holding the lock.
  100. stamp := time.Now().Format("2006/01/02 15:04:05")
  101. fmt.Fprintf(&pb, "%s %s pulling %d/%d (%.1f%%)\n", stamp, d.Short(), s[1], s[0], 100*float64(s[1])/float64(s[0]))
  102. if s[0] == s[1] {
  103. delete(p, d)
  104. }
  105. }
  106. mu.Unlock()
  107. io.Copy(stdout, &pb)
  108. }
  109. ctx = ollama.WithTrace(ctx, &ollama.Trace{
  110. Update: func(l *ollama.Layer, n int64, err error) {
  111. if err != nil && !errors.Is(err, ollama.ErrCached) {
  112. fmt.Fprintf(stdout, "opp: pull %s ! %v\n", l.Digest.Short(), err)
  113. return
  114. }
  115. mu.Lock()
  116. p[l.Digest] = [2]int64{l.Size, n}
  117. mu.Unlock()
  118. },
  119. })
  120. errc := make(chan error)
  121. go func() {
  122. errc <- rc.Pull(ctx, c, model)
  123. }()
  124. t := time.NewTicker(time.Second)
  125. defer t.Stop()
  126. for {
  127. select {
  128. case <-t.C:
  129. printProgress()
  130. case err := <-errc:
  131. printProgress()
  132. return err
  133. }
  134. }
  135. }
  136. func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
  137. args := flag.Args()[1:]
  138. flag := flag.NewFlagSet("push", flag.ExitOnError)
  139. flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
  140. flag.Usage = func() {
  141. fmt.Fprintf(os.Stderr, "Usage: opp push <model>\n")
  142. flag.PrintDefaults()
  143. }
  144. flag.Parse(args)
  145. model := flag.Arg(0)
  146. if model == "" {
  147. return fmt.Errorf("missing model argument")
  148. }
  149. from := cmp.Or(*flagFrom, model)
  150. m, err := rc.ResolveLocal(c, from)
  151. if err != nil {
  152. return err
  153. }
  154. ctx = ollama.WithTrace(ctx, &ollama.Trace{
  155. Update: func(l *ollama.Layer, n int64, err error) {
  156. switch {
  157. case errors.Is(err, ollama.ErrCached):
  158. fmt.Fprintf(stdout, "opp: uploading %s %d (existed)", l.Digest.Short(), n)
  159. case err != nil:
  160. fmt.Fprintf(stdout, "opp: uploading %s %d ! %v\n", l.Digest.Short(), n, err)
  161. case n == 0:
  162. l := m.Layer(l.Digest)
  163. mt, p, _ := mime.ParseMediaType(l.MediaType)
  164. mt, _ = strings.CutPrefix(mt, "application/vnd.ollama.image.")
  165. switch mt {
  166. case "tensor":
  167. fmt.Fprintf(stdout, "opp: uploading tensor %s %s\n", l.Digest.Short(), p["name"])
  168. default:
  169. fmt.Fprintf(stdout, "opp: uploading %s %s\n", l.Digest.Short(), l.MediaType)
  170. }
  171. }
  172. },
  173. })
  174. return rc.Push(ctx, c, model, &ollama.PushParams{
  175. From: from,
  176. })
  177. }
  178. type trackingReader struct {
  179. io.Reader
  180. n *atomic.Int64
  181. }
  182. func (r *trackingReader) Read(p []byte) (n int, err error) {
  183. n, err = r.Reader.Read(p)
  184. r.n.Add(int64(n))
  185. return n, err
  186. }
  187. func cmdImport(ctx context.Context, c *blob.DiskCache) error {
  188. args := flag.Args()[1:]
  189. flag := flag.NewFlagSet("import", flag.ExitOnError)
  190. flagAs := flag.String("as", "", "Import using the provided name.")
  191. flag.Usage = func() {
  192. fmt.Fprintf(os.Stderr, "Usage: opp import <SafetensorDir>\n")
  193. flag.PrintDefaults()
  194. }
  195. flag.Parse(args)
  196. if *flagAs == "" {
  197. return fmt.Errorf("missing -as flag")
  198. }
  199. as := ollama.CompleteName(*flagAs)
  200. dir := cmp.Or(flag.Arg(0), ".")
  201. fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
  202. m, err := safetensors.Read(os.DirFS(dir))
  203. if err != nil {
  204. return err
  205. }
  206. var total int64
  207. var tt []*safetensors.Tensor
  208. for t, err := range m.Tensors() {
  209. if err != nil {
  210. return err
  211. }
  212. tt = append(tt, t)
  213. total += t.Size()
  214. }
  215. var n atomic.Int64
  216. done := make(chan error)
  217. go func() {
  218. layers := make([]*ollama.Layer, len(tt))
  219. var g errgroup.Group
  220. g.SetLimit(runtime.GOMAXPROCS(0))
  221. var ctxErr error
  222. for i, t := range tt {
  223. if ctx.Err() != nil {
  224. // The context may cancel AFTER we exit the
  225. // loop, and so if we use ctx.Err() after the
  226. // loop we may report it as the error that
  227. // broke the loop, when it was not. This can
  228. // manifest as a false-negative, leading the
  229. // user to think their import failed when it
  230. // did not, so capture it if and only if we
  231. // exit the loop because of a ctx.Err() and
  232. // report it.
  233. ctxErr = ctx.Err()
  234. break
  235. }
  236. g.Go(func() (err error) {
  237. rc, err := t.Reader()
  238. if err != nil {
  239. return err
  240. }
  241. defer rc.Close()
  242. tr := &trackingReader{rc, &n}
  243. d, err := c.Import(tr, t.Size())
  244. if err != nil {
  245. return err
  246. }
  247. if err := rc.Close(); err != nil {
  248. return err
  249. }
  250. layers[i] = &ollama.Layer{
  251. Digest: d,
  252. Size: t.Size(),
  253. MediaType: mime.FormatMediaType("application/vnd.ollama.image.tensor", map[string]string{
  254. "name": t.Name(),
  255. "dtype": t.DataType(),
  256. "shape": t.Shape().String(),
  257. }),
  258. }
  259. return nil
  260. })
  261. }
  262. done <- func() error {
  263. if err := errors.Join(g.Wait(), ctxErr); err != nil {
  264. return err
  265. }
  266. m := &ollama.Manifest{Layers: layers}
  267. data, err := json.MarshalIndent(m, "", " ")
  268. if err != nil {
  269. return err
  270. }
  271. d := blob.DigestFromBytes(data)
  272. err = blob.PutBytes(c, d, data)
  273. if err != nil {
  274. return err
  275. }
  276. return c.Link(as, d)
  277. }()
  278. }()
  279. fmt.Fprintf(stdout, "Importing %d tensors from %s\n", len(tt), dir)
  280. csiHideCursor(stdout)
  281. defer csiShowCursor(stdout)
  282. csiSavePos(stdout)
  283. writeProgress := func() {
  284. csiRestorePos(stdout)
  285. nn := n.Load()
  286. fmt.Fprintf(stdout, "Imported %s/%s bytes (%d%%)%s\n",
  287. formatNatural(nn),
  288. formatNatural(total),
  289. nn*100/total,
  290. ansiClearToEndOfLine,
  291. )
  292. }
  293. ticker := time.NewTicker(time.Second)
  294. defer ticker.Stop()
  295. for {
  296. select {
  297. case <-ticker.C:
  298. writeProgress()
  299. case err := <-done:
  300. writeProgress()
  301. fmt.Println()
  302. fmt.Println("Successfully imported", as)
  303. return err
  304. }
  305. }
  306. }
  307. func formatNatural(n int64) string {
  308. switch {
  309. case n < 1024:
  310. return fmt.Sprintf("%d B", n)
  311. case n < 1024*1024:
  312. return fmt.Sprintf("%.1f KB", float64(n)/1024)
  313. case n < 1024*1024*1024:
  314. return fmt.Sprintf("%.1f MB", float64(n)/(1024*1024))
  315. default:
  316. return fmt.Sprintf("%.1f GB", float64(n)/(1024*1024*1024))
  317. }
  318. }
  319. const ansiClearToEndOfLine = "\033[K"
  320. func csiSavePos(w io.Writer) { fmt.Fprint(w, "\033[s") }
  321. func csiRestorePos(w io.Writer) { fmt.Fprint(w, "\033[u") }
  322. func csiHideCursor(w io.Writer) { fmt.Fprint(w, "\033[?25l") }
  323. func csiShowCursor(w io.Writer) { fmt.Fprint(w, "\033[?25h") }