utils_test.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. //go:build integration
  2. package integration
  3. import (
  4. "bytes"
  5. "context"
  6. "encoding/json"
  7. "fmt"
  8. "io"
  9. "log/slog"
  10. "math/rand"
  11. "net"
  12. "net/http"
  13. "os"
  14. "path/filepath"
  15. "runtime"
  16. "strconv"
  17. "strings"
  18. "sync"
  19. "testing"
  20. "time"
  21. "github.com/ollama/ollama/api"
  22. "github.com/ollama/ollama/app/lifecycle"
  23. "github.com/stretchr/testify/assert"
  24. )
  25. func FindPort() string {
  26. port := 0
  27. if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
  28. var l *net.TCPListener
  29. if l, err = net.ListenTCP("tcp", a); err == nil {
  30. port = l.Addr().(*net.TCPAddr).Port
  31. l.Close()
  32. }
  33. }
  34. if port == 0 {
  35. port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
  36. }
  37. return strconv.Itoa(port)
  38. }
  39. func GetTestEndpoint() (string, string) {
  40. defaultPort := "11434"
  41. ollamaHost := os.Getenv("OLLAMA_HOST")
  42. scheme, hostport, ok := strings.Cut(ollamaHost, "://")
  43. if !ok {
  44. scheme, hostport = "http", ollamaHost
  45. }
  46. // trim trailing slashes
  47. hostport = strings.TrimRight(hostport, "/")
  48. host, port, err := net.SplitHostPort(hostport)
  49. if err != nil {
  50. host, port = "127.0.0.1", defaultPort
  51. if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
  52. host = ip.String()
  53. } else if hostport != "" {
  54. host = hostport
  55. }
  56. }
  57. if os.Getenv("OLLAMA_TEST_EXISTING") == "" && port == defaultPort {
  58. port = FindPort()
  59. }
  60. url := fmt.Sprintf("%s:%s", host, port)
  61. slog.Info("server connection", "url", url)
  62. return scheme, url
  63. }
  64. // TODO make fanicier, grab logs, etc.
  65. var serverMutex sync.Mutex
  66. var serverReady bool
  67. func StartServer(ctx context.Context, ollamaHost string) error {
  68. // Make sure the server has been built
  69. CLIName, err := filepath.Abs("../ollama")
  70. if err != nil {
  71. return err
  72. }
  73. if runtime.GOOS == "windows" {
  74. CLIName += ".exe"
  75. }
  76. _, err = os.Stat(CLIName)
  77. if err != nil {
  78. return fmt.Errorf("CLI missing, did you forget to build first? %w", err)
  79. }
  80. serverMutex.Lock()
  81. defer serverMutex.Unlock()
  82. if serverReady {
  83. return nil
  84. }
  85. if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
  86. slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
  87. os.Setenv("OLLAMA_HOST", ollamaHost)
  88. }
  89. slog.Info("starting server", "url", ollamaHost)
  90. done, err := lifecycle.SpawnServer(ctx, "../ollama")
  91. if err != nil {
  92. return fmt.Errorf("failed to start server: %w", err)
  93. }
  94. go func() {
  95. <-ctx.Done()
  96. serverMutex.Lock()
  97. defer serverMutex.Unlock()
  98. exitCode := <-done
  99. if exitCode > 0 {
  100. slog.Warn("server failure", "exit", exitCode)
  101. }
  102. serverReady = false
  103. }()
  104. // TODO wait only long enough for the server to be responsive...
  105. time.Sleep(500 * time.Millisecond)
  106. serverReady = true
  107. return nil
  108. }
  109. func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
  110. slog.Debug("checking status of model", "model", modelName)
  111. showReq := &api.ShowRequest{Name: modelName}
  112. requestJSON, err := json.Marshal(showReq)
  113. if err != nil {
  114. return err
  115. }
  116. req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON))
  117. if err != nil {
  118. return err
  119. }
  120. // Make the request with the HTTP client
  121. response, err := client.Do(req.WithContext(ctx))
  122. if err != nil {
  123. return err
  124. }
  125. defer response.Body.Close()
  126. if response.StatusCode == 200 {
  127. slog.Info("model already present", "model", modelName)
  128. return nil
  129. }
  130. slog.Info("model missing", "status", response.StatusCode)
  131. pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
  132. requestJSON, err = json.Marshal(pullReq)
  133. if err != nil {
  134. return err
  135. }
  136. req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON))
  137. if err != nil {
  138. return err
  139. }
  140. slog.Info("pulling", "model", modelName)
  141. response, err = client.Do(req.WithContext(ctx))
  142. if err != nil {
  143. return err
  144. }
  145. defer response.Body.Close()
  146. if response.StatusCode != 200 {
  147. return fmt.Errorf("failed to pull model") // TODO more details perhaps
  148. }
  149. slog.Info("model pulled", "model", modelName)
  150. return nil
  151. }
  152. func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
  153. requestJSON, err := json.Marshal(genReq)
  154. if err != nil {
  155. t.Fatalf("Error serializing request: %v", err)
  156. }
  157. defer func() {
  158. if t.Failed() && os.Getenv("OLLAMA_TEST_EXISTING") == "" {
  159. // TODO
  160. fp, err := os.Open(lifecycle.ServerLogFile)
  161. if err != nil {
  162. slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
  163. return
  164. }
  165. data, err := io.ReadAll(fp)
  166. if err != nil {
  167. slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
  168. return
  169. }
  170. slog.Warn("SERVER LOG FOLLOWS")
  171. os.Stderr.Write(data)
  172. slog.Warn("END OF SERVER")
  173. }
  174. err = os.Remove(lifecycle.ServerLogFile)
  175. if err != nil && !os.IsNotExist(err) {
  176. slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
  177. }
  178. }()
  179. scheme, testEndpoint := GetTestEndpoint()
  180. if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
  181. assert.NoError(t, StartServer(ctx, testEndpoint))
  182. }
  183. err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
  184. if err != nil {
  185. t.Fatalf("Error pulling model: %v", err)
  186. }
  187. // Make the request and get the response
  188. req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
  189. if err != nil {
  190. t.Fatalf("Error creating request: %v", err)
  191. }
  192. // Set the content type for the request
  193. req.Header.Set("Content-Type", "application/json")
  194. // Make the request with the HTTP client
  195. response, err := client.Do(req.WithContext(ctx))
  196. if err != nil {
  197. t.Fatalf("Error making request: %v", err)
  198. }
  199. defer response.Body.Close()
  200. body, err := io.ReadAll(response.Body)
  201. assert.NoError(t, err)
  202. assert.Equal(t, response.StatusCode, 200, string(body))
  203. // Verify the response is valid JSON
  204. var payload api.GenerateResponse
  205. err = json.Unmarshal(body, &payload)
  206. if err != nil {
  207. assert.NoError(t, err, body)
  208. }
  209. // Verify the response contains the expected data
  210. atLeastOne := false
  211. for _, resp := range anyResp {
  212. if strings.Contains(strings.ToLower(payload.Response), resp) {
  213. atLeastOne = true
  214. break
  215. }
  216. }
  217. assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response)
  218. }