models.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. package server
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "os"
  8. "path"
  9. "strconv"
  10. "github.com/jmorganca/ollama/api"
  11. )
  12. const directoryURL = "https://ollama.ai/api/models"
  13. type Model struct {
  14. Name string `json:"name"`
  15. DisplayName string `json:"display_name"`
  16. Parameters string `json:"parameters"`
  17. URL string `json:"url"`
  18. ShortDescription string `json:"short_description"`
  19. Description string `json:"description"`
  20. PublishedBy string `json:"published_by"`
  21. OriginalAuthor string `json:"original_author"`
  22. OriginalURL string `json:"original_url"`
  23. License string `json:"license"`
  24. }
  25. func (m *Model) FullName() string {
  26. home, err := os.UserHomeDir()
  27. if err != nil {
  28. panic(err)
  29. }
  30. return path.Join(home, ".ollama", "models", m.Name+".bin")
  31. }
  32. func pull(model string, progressCh chan<- api.PullProgress) error {
  33. remote, err := getRemote(model)
  34. if err != nil {
  35. return fmt.Errorf("failed to pull model: %w", err)
  36. }
  37. return saveModel(remote, progressCh)
  38. }
  39. func getRemote(model string) (*Model, error) {
  40. // resolve the model download from our directory
  41. resp, err := http.Get(directoryURL)
  42. if err != nil {
  43. return nil, fmt.Errorf("failed to get directory: %w", err)
  44. }
  45. defer resp.Body.Close()
  46. body, err := io.ReadAll(resp.Body)
  47. if err != nil {
  48. return nil, fmt.Errorf("failed to read directory: %w", err)
  49. }
  50. var models []Model
  51. err = json.Unmarshal(body, &models)
  52. if err != nil {
  53. return nil, fmt.Errorf("failed to parse directory: %w", err)
  54. }
  55. for _, m := range models {
  56. if m.Name == model {
  57. return &m, nil
  58. }
  59. }
  60. return nil, fmt.Errorf("model not found in directory: %s", model)
  61. }
  62. func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
  63. // this models cache directory is created by the server on startup
  64. client := &http.Client{}
  65. req, err := http.NewRequest("GET", model.URL, nil)
  66. if err != nil {
  67. return fmt.Errorf("failed to download model: %w", err)
  68. }
  69. // check for resume
  70. alreadyDownloaded := int64(0)
  71. fileInfo, err := os.Stat(model.FullName())
  72. if err != nil {
  73. if !os.IsNotExist(err) {
  74. return fmt.Errorf("failed to check resume model file: %w", err)
  75. }
  76. // file doesn't exist, create it now
  77. } else {
  78. alreadyDownloaded = fileInfo.Size()
  79. req.Header.Add("Range", fmt.Sprintf("bytes=%d-", alreadyDownloaded))
  80. }
  81. resp, err := client.Do(req)
  82. if err != nil {
  83. return fmt.Errorf("failed to download model: %w", err)
  84. }
  85. defer resp.Body.Close()
  86. if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
  87. // already downloaded
  88. progressCh <- api.PullProgress{
  89. Total: alreadyDownloaded,
  90. Completed: alreadyDownloaded,
  91. Percent: 100,
  92. }
  93. return nil
  94. }
  95. if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
  96. return fmt.Errorf("failed to download model: %s", resp.Status)
  97. }
  98. out, err := os.OpenFile(model.FullName(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
  99. if err != nil {
  100. panic(err)
  101. }
  102. defer out.Close()
  103. totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
  104. buf := make([]byte, 1024)
  105. totalBytes := alreadyDownloaded
  106. totalSize += alreadyDownloaded
  107. for {
  108. n, err := resp.Body.Read(buf)
  109. if err != nil && err != io.EOF {
  110. return err
  111. }
  112. if n == 0 {
  113. break
  114. }
  115. if _, err := out.Write(buf[:n]); err != nil {
  116. return err
  117. }
  118. totalBytes += int64(n)
  119. // send progress updates
  120. progressCh <- api.PullProgress{
  121. Total: totalSize,
  122. Completed: totalBytes,
  123. Percent: float64(totalBytes) / float64(totalSize) * 100,
  124. }
  125. }
  126. progressCh <- api.PullProgress{
  127. Total: totalSize,
  128. Completed: totalSize,
  129. Percent: 100,
  130. }
  131. return nil
  132. }