updater.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. package updater
  2. import (
  3. "context"
  4. "crypto/rand"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "log/slog"
  10. "mime"
  11. "net/http"
  12. "net/url"
  13. "os"
  14. "path"
  15. "path/filepath"
  16. "runtime"
  17. "strings"
  18. "time"
  19. "github.com/ollama/ollama/auth"
  20. "github.com/ollama/ollama/version"
  21. )
  22. var (
  23. UpdateStageDir string
  24. )
  25. var (
  26. UpdateCheckURLBase = "https://ollama.com/api/update"
  27. UpdateDownloaded = false
  28. UpdateCheckInterval = 60 * 60 * time.Second
  29. )
  30. // TODO - maybe move up to the API package?
  31. type UpdateResponse struct {
  32. UpdateURL string `json:"url"`
  33. UpdateVersion string `json:"version"`
  34. }
  35. func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
  36. var updateResp UpdateResponse
  37. requestURL, err := url.Parse(UpdateCheckURLBase)
  38. if err != nil {
  39. return false, updateResp
  40. }
  41. query := requestURL.Query()
  42. query.Add("os", runtime.GOOS)
  43. query.Add("arch", runtime.GOARCH)
  44. query.Add("version", version.Version)
  45. query.Add("ts", fmt.Sprintf("%d", time.Now().Unix()))
  46. nonce, err := auth.NewNonce(rand.Reader, 16)
  47. if err != nil {
  48. return false, updateResp
  49. }
  50. query.Add("nonce", nonce)
  51. requestURL.RawQuery = query.Encode()
  52. data := []byte(fmt.Sprintf("%s,%s", http.MethodGet, requestURL.RequestURI()))
  53. signature, err := auth.Sign(ctx, data)
  54. if err != nil {
  55. return false, updateResp
  56. }
  57. req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
  58. if err != nil {
  59. slog.Warn(fmt.Sprintf("failed to check for update: %s", err))
  60. return false, updateResp
  61. }
  62. req.Header.Set("Authorization", signature)
  63. req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
  64. slog.Debug("checking for available update", "requestURL", requestURL)
  65. resp, err := http.DefaultClient.Do(req)
  66. if err != nil {
  67. slog.Warn(fmt.Sprintf("failed to check for update: %s", err))
  68. return false, updateResp
  69. }
  70. defer resp.Body.Close()
  71. if resp.StatusCode == 204 {
  72. slog.Debug("check update response 204 (current version is up to date)")
  73. return false, updateResp
  74. }
  75. body, err := io.ReadAll(resp.Body)
  76. if err != nil {
  77. slog.Warn(fmt.Sprintf("failed to read body response: %s", err))
  78. }
  79. if resp.StatusCode != 200 {
  80. slog.Info(fmt.Sprintf("check update error %d - %.96s", resp.StatusCode, string(body)))
  81. return false, updateResp
  82. }
  83. err = json.Unmarshal(body, &updateResp)
  84. if err != nil {
  85. slog.Warn(fmt.Sprintf("malformed response checking for update: %s", err))
  86. return false, updateResp
  87. }
  88. // Extract the version string from the URL in the github release artifact path
  89. updateResp.UpdateVersion = path.Base(path.Dir(updateResp.UpdateURL))
  90. slog.Info("New update available at " + updateResp.UpdateURL)
  91. return true, updateResp
  92. }
  93. func DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
  94. // Do a head first to check etag info
  95. req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil)
  96. if err != nil {
  97. return err
  98. }
  99. resp, err := http.DefaultClient.Do(req)
  100. if err != nil {
  101. return fmt.Errorf("error checking update: %w", err)
  102. }
  103. if resp.StatusCode != 200 {
  104. return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode)
  105. }
  106. resp.Body.Close()
  107. etag := strings.Trim(resp.Header.Get("etag"), "\"")
  108. if etag == "" {
  109. slog.Debug("no etag detected, falling back to filename based dedup")
  110. etag = "_"
  111. }
  112. filename := "OllamaSetup.exe"
  113. _, params, err := mime.ParseMediaType(resp.Header.Get("content-disposition"))
  114. if err == nil {
  115. filename = params["filename"]
  116. }
  117. stageFilename := filepath.Join(UpdateStageDir, etag, filename)
  118. // Check to see if we already have it downloaded
  119. _, err = os.Stat(stageFilename)
  120. if err == nil {
  121. slog.Info("update already downloaded")
  122. return nil
  123. }
  124. cleanupOldDownloads()
  125. req.Method = http.MethodGet
  126. resp, err = http.DefaultClient.Do(req)
  127. if err != nil {
  128. return fmt.Errorf("error checking update: %w", err)
  129. }
  130. defer resp.Body.Close()
  131. etag = strings.Trim(resp.Header.Get("etag"), "\"")
  132. if etag == "" {
  133. slog.Debug("no etag detected, falling back to filename based dedup") // TODO probably can get rid of this redundant log
  134. etag = "_"
  135. }
  136. stageFilename = filepath.Join(UpdateStageDir, etag, filename)
  137. _, err = os.Stat(filepath.Dir(stageFilename))
  138. if errors.Is(err, os.ErrNotExist) {
  139. if err := os.MkdirAll(filepath.Dir(stageFilename), 0o755); err != nil {
  140. return fmt.Errorf("create ollama dir %s: %v", filepath.Dir(stageFilename), err)
  141. }
  142. }
  143. payload, err := io.ReadAll(resp.Body)
  144. if err != nil {
  145. return fmt.Errorf("failed to read body response: %w", err)
  146. }
  147. fp, err := os.OpenFile(stageFilename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
  148. if err != nil {
  149. return fmt.Errorf("write payload %s: %w", stageFilename, err)
  150. }
  151. defer fp.Close()
  152. if n, err := fp.Write(payload); err != nil || n != len(payload) {
  153. return fmt.Errorf("write payload %s: %d vs %d -- %w", stageFilename, n, len(payload), err)
  154. }
  155. slog.Info("new update downloaded " + stageFilename)
  156. UpdateDownloaded = true
  157. return nil
  158. }
  159. func cleanupOldDownloads() {
  160. files, err := os.ReadDir(UpdateStageDir)
  161. if err != nil && errors.Is(err, os.ErrNotExist) {
  162. // Expected behavior on first run
  163. return
  164. } else if err != nil {
  165. slog.Warn(fmt.Sprintf("failed to list stage dir: %s", err))
  166. return
  167. }
  168. for _, file := range files {
  169. fullname := filepath.Join(UpdateStageDir, file.Name())
  170. slog.Debug("cleaning up old download: " + fullname)
  171. err = os.RemoveAll(fullname)
  172. if err != nil {
  173. slog.Warn(fmt.Sprintf("failed to cleanup stale update download %s", err))
  174. }
  175. }
  176. }
  177. func StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
  178. go func() {
  179. // Don't blast an update message immediately after startup
  180. // time.Sleep(30 * time.Second)
  181. time.Sleep(3 * time.Second)
  182. for {
  183. available, resp := IsNewReleaseAvailable(ctx)
  184. if available {
  185. err := DownloadNewRelease(ctx, resp)
  186. if err != nil {
  187. slog.Error(fmt.Sprintf("failed to download new release: %s", err))
  188. }
  189. err = cb(resp.UpdateVersion)
  190. if err != nil {
  191. slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err))
  192. }
  193. }
  194. select {
  195. case <-ctx.Done():
  196. slog.Debug("stopping background update checker")
  197. return
  198. default:
  199. time.Sleep(UpdateCheckInterval)
  200. }
  201. }
  202. }()
  203. }