server.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. package main
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log/slog"
  8. "os"
  9. "path/filepath"
  10. "time"
  11. "github.com/ollama/ollama/api"
  12. )
  13. func SpawnServer(ctx context.Context, command string) (chan int, error) {
  14. done := make(chan int)
  15. logDir := filepath.Dir(ServerLogFile)
  16. _, err := os.Stat(logDir)
  17. if errors.Is(err, os.ErrNotExist) {
  18. if err := os.MkdirAll(logDir, 0o755); err != nil {
  19. return done, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
  20. }
  21. }
  22. cmd := getCmd(ctx, command)
  23. stdout, err := cmd.StdoutPipe()
  24. if err != nil {
  25. return done, fmt.Errorf("failed to spawn server stdout pipe: %w", err)
  26. }
  27. stderr, err := cmd.StderrPipe()
  28. if err != nil {
  29. return done, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
  30. }
  31. stdin, err := cmd.StdinPipe()
  32. if err != nil {
  33. return done, fmt.Errorf("failed to spawn server stdin pipe: %w", err)
  34. }
  35. // TODO - rotation
  36. logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
  37. if err != nil {
  38. return done, fmt.Errorf("failed to create server log: %w", err)
  39. }
  40. go func() {
  41. defer logFile.Close()
  42. io.Copy(logFile, stdout) //nolint:errcheck
  43. }()
  44. go func() {
  45. defer logFile.Close()
  46. io.Copy(logFile, stderr) //nolint:errcheck
  47. }()
  48. // Re-wire context done behavior to attempt a graceful shutdown of the server
  49. cmd.Cancel = func() error {
  50. if cmd.Process != nil {
  51. err := terminate(cmd)
  52. if err != nil {
  53. slog.Warn("error trying to gracefully terminate server", "err", err)
  54. return cmd.Process.Kill()
  55. }
  56. tick := time.NewTicker(10 * time.Millisecond)
  57. defer tick.Stop()
  58. for {
  59. select {
  60. case <-tick.C:
  61. exited, err := isProcessExited(cmd.Process.Pid)
  62. if err != nil {
  63. return err
  64. }
  65. if exited {
  66. return nil
  67. }
  68. case <-time.After(5 * time.Second):
  69. slog.Warn("graceful server shutdown timeout, killing", "pid", cmd.Process.Pid)
  70. return cmd.Process.Kill()
  71. }
  72. }
  73. }
  74. return nil
  75. }
  76. // run the command and wait for it to finish
  77. if err := cmd.Start(); err != nil {
  78. return done, fmt.Errorf("failed to start server %w", err)
  79. }
  80. if cmd.Process != nil {
  81. slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid))
  82. }
  83. slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile))
  84. go func() {
  85. // Keep the server running unless we're shuttind down the app
  86. crashCount := 0
  87. for {
  88. cmd.Wait() //nolint:errcheck
  89. stdin.Close()
  90. var code int
  91. if cmd.ProcessState != nil {
  92. code = cmd.ProcessState.ExitCode()
  93. }
  94. select {
  95. case <-ctx.Done():
  96. slog.Info(fmt.Sprintf("server shutdown with exit code %d", code))
  97. done <- code
  98. return
  99. default:
  100. crashCount++
  101. slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code))
  102. time.Sleep(500 * time.Millisecond)
  103. if err := cmd.Start(); err != nil {
  104. slog.Error(fmt.Sprintf("failed to restart server %s", err))
  105. // Keep trying, but back off if we keep failing
  106. time.Sleep(time.Duration(crashCount) * time.Second)
  107. }
  108. }
  109. }
  110. }()
  111. return done, nil
  112. }
  113. func isServerRunning(ctx context.Context) bool {
  114. client, err := api.ClientFromEnvironment()
  115. if err != nil {
  116. slog.Info("unable to connect to server")
  117. return false
  118. }
  119. err = client.Heartbeat(ctx)
  120. if err != nil {
  121. slog.Debug(fmt.Sprintf("heartbeat from server: %s", err))
  122. slog.Info("unable to connect to server")
  123. return false
  124. }
  125. return true
  126. }