server.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. package main
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log/slog"
  8. "os"
  9. "os/exec"
  10. "path/filepath"
  11. "time"
  12. "github.com/ollama/ollama/api"
  13. )
  14. type ServerOptions struct {
  15. Cors bool
  16. Expose bool
  17. ModelsPath string
  18. }
  19. func start(ctx context.Context, command string, options ServerOptions) (*exec.Cmd, error) {
  20. cmd := getCmd(ctx, command)
  21. // set environment variables
  22. if options.ModelsPath != "" {
  23. cmd.Env = append(cmd.Env, fmt.Sprintf("OLLAMA_MODELS=%s", options.ModelsPath))
  24. }
  25. if options.Cors {
  26. cmd.Env = append(cmd.Env, "OLLAMA_ORIGINS=*")
  27. }
  28. if options.Expose {
  29. cmd.Env = append(cmd.Env, "OLLAMA_HOST=0.0.0.0")
  30. }
  31. stdout, err := cmd.StdoutPipe()
  32. if err != nil {
  33. return nil, fmt.Errorf("failed to spawn server stdout pipe: %w", err)
  34. }
  35. stderr, err := cmd.StderrPipe()
  36. if err != nil {
  37. return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
  38. }
  39. // TODO - rotation
  40. logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
  41. if err != nil {
  42. return nil, fmt.Errorf("failed to create server log: %w", err)
  43. }
  44. go func() {
  45. defer logFile.Close()
  46. io.Copy(logFile, stdout) //nolint:errcheck
  47. }()
  48. go func() {
  49. defer logFile.Close()
  50. io.Copy(logFile, stderr) //nolint:errcheck
  51. }()
  52. // Re-wire context done behavior to attempt a graceful shutdown of the server
  53. cmd.Cancel = func() error {
  54. if cmd.Process != nil {
  55. err := terminate(cmd)
  56. if err != nil {
  57. slog.Warn("error trying to gracefully terminate server", "err", err)
  58. return cmd.Process.Kill()
  59. }
  60. tick := time.NewTicker(10 * time.Millisecond)
  61. defer tick.Stop()
  62. for {
  63. select {
  64. case <-tick.C:
  65. exited, err := isProcessExited(cmd.Process.Pid)
  66. if err != nil {
  67. return err
  68. }
  69. if exited {
  70. return nil
  71. }
  72. case <-time.After(5 * time.Second):
  73. slog.Warn("graceful server shutdown timeout, killing", "pid", cmd.Process.Pid)
  74. return cmd.Process.Kill()
  75. }
  76. }
  77. }
  78. return nil
  79. }
  80. // run the command and wait for it to finish
  81. if err := cmd.Start(); err != nil {
  82. return nil, fmt.Errorf("failed to start server %w", err)
  83. }
  84. if cmd.Process != nil {
  85. slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid))
  86. }
  87. slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile))
  88. return cmd, nil
  89. }
  90. func SpawnServer(ctx context.Context, command string, options ServerOptions) (chan int, error) {
  91. logDir := filepath.Dir(ServerLogFile)
  92. _, err := os.Stat(logDir)
  93. if errors.Is(err, os.ErrNotExist) {
  94. if err := os.MkdirAll(logDir, 0o755); err != nil {
  95. return nil, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
  96. }
  97. }
  98. done := make(chan int)
  99. go func() {
  100. // Keep the server running unless we're shuttind down the app
  101. crashCount := 0
  102. for {
  103. slog.Info(fmt.Sprintf("starting server..."))
  104. cmd, err := start(ctx, command, options)
  105. if err != nil {
  106. slog.Error(fmt.Sprintf("failed to start server %s", err))
  107. }
  108. cmd.Wait() //nolint:errcheck
  109. var code int
  110. if cmd.ProcessState != nil {
  111. code = cmd.ProcessState.ExitCode()
  112. }
  113. select {
  114. case <-ctx.Done():
  115. slog.Info(fmt.Sprintf("server shutdown with exit code %d", code))
  116. done <- code
  117. return
  118. default:
  119. crashCount++
  120. slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code))
  121. time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
  122. break
  123. }
  124. }
  125. }()
  126. return done, nil
  127. }
  128. func isServerRunning(ctx context.Context) bool {
  129. client, err := api.ClientFromEnvironment()
  130. if err != nil {
  131. slog.Info("unable to connect to server")
  132. return false
  133. }
  134. err = client.Heartbeat(ctx)
  135. if err != nil {
  136. slog.Debug(fmt.Sprintf("heartbeat from server: %s", err))
  137. slog.Info("unable to connect to server")
  138. return false
  139. }
  140. return true
  141. }