utils_test.go 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. //go:build integration
  2. package integration
  3. import (
  4. "bytes"
  5. "context"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "log/slog"
  10. "math/rand"
  11. "net"
  12. "net/http"
  13. "net/url"
  14. "os"
  15. "path/filepath"
  16. "runtime"
  17. "strconv"
  18. "strings"
  19. "sync"
  20. "testing"
  21. "time"
  22. "github.com/ollama/ollama/api"
  23. "github.com/ollama/ollama/app/lifecycle"
  24. "github.com/stretchr/testify/require"
  25. )
  26. func Init() {
  27. lifecycle.InitLogging()
  28. }
  29. func FindPort() string {
  30. port := 0
  31. if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
  32. var l *net.TCPListener
  33. if l, err = net.ListenTCP("tcp", a); err == nil {
  34. port = l.Addr().(*net.TCPAddr).Port
  35. l.Close()
  36. }
  37. }
  38. if port == 0 {
  39. port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
  40. }
  41. return strconv.Itoa(port)
  42. }
  43. func GetTestEndpoint() (*api.Client, string) {
  44. defaultPort := "11434"
  45. ollamaHost := os.Getenv("OLLAMA_HOST")
  46. scheme, hostport, ok := strings.Cut(ollamaHost, "://")
  47. if !ok {
  48. scheme, hostport = "http", ollamaHost
  49. }
  50. // trim trailing slashes
  51. hostport = strings.TrimRight(hostport, "/")
  52. host, port, err := net.SplitHostPort(hostport)
  53. if err != nil {
  54. host, port = "127.0.0.1", defaultPort
  55. if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
  56. host = ip.String()
  57. } else if hostport != "" {
  58. host = hostport
  59. }
  60. }
  61. if os.Getenv("OLLAMA_TEST_EXISTING") == "" && port == defaultPort {
  62. port = FindPort()
  63. }
  64. slog.Info("server connection", "host", host, "port", port)
  65. return api.NewClient(
  66. &url.URL{
  67. Scheme: scheme,
  68. Host: net.JoinHostPort(host, port),
  69. },
  70. http.DefaultClient), fmt.Sprintf("%s:%s", host, port)
  71. }
  72. var serverMutex sync.Mutex
  73. var serverReady bool
  74. func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
  75. // Make sure the server has been built
  76. CLIName, err := filepath.Abs("../ollama")
  77. if err != nil {
  78. return err
  79. }
  80. if runtime.GOOS == "windows" {
  81. CLIName += ".exe"
  82. }
  83. _, err = os.Stat(CLIName)
  84. if err != nil {
  85. return fmt.Errorf("CLI missing, did you forget to build first? %w", err)
  86. }
  87. serverMutex.Lock()
  88. defer serverMutex.Unlock()
  89. if serverReady {
  90. return nil
  91. }
  92. if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
  93. slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
  94. t.Setenv("OLLAMA_HOST", ollamaHost)
  95. envconfig.LoadConfig()
  96. }
  97. slog.Info("starting server", "url", ollamaHost)
  98. done, err := lifecycle.SpawnServer(ctx, "../ollama")
  99. if err != nil {
  100. return fmt.Errorf("failed to start server: %w", err)
  101. }
  102. go func() {
  103. <-ctx.Done()
  104. serverMutex.Lock()
  105. defer serverMutex.Unlock()
  106. exitCode := <-done
  107. if exitCode > 0 {
  108. slog.Warn("server failure", "exit", exitCode)
  109. }
  110. serverReady = false
  111. }()
  112. // TODO wait only long enough for the server to be responsive...
  113. time.Sleep(500 * time.Millisecond)
  114. serverReady = true
  115. return nil
  116. }
  117. func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
  118. slog.Info("checking status of model", "model", modelName)
  119. showReq := &api.ShowRequest{Name: modelName}
  120. showCtx, cancel := context.WithDeadlineCause(
  121. ctx,
  122. time.Now().Add(10*time.Second),
  123. fmt.Errorf("show for existing model %s took too long", modelName),
  124. )
  125. defer cancel()
  126. _, err := client.Show(showCtx, showReq)
  127. var statusError api.StatusError
  128. switch {
  129. case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
  130. break
  131. case err != nil:
  132. return err
  133. default:
  134. slog.Info("model already present", "model", modelName)
  135. return nil
  136. }
  137. slog.Info("model missing", "model", modelName)
  138. stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
  139. stallTimer := time.NewTimer(stallDuration)
  140. fn := func(resp api.ProgressResponse) error {
  141. // fmt.Print(".")
  142. if !stallTimer.Reset(stallDuration) {
  143. return fmt.Errorf("stall was detected, aborting status reporting")
  144. }
  145. return nil
  146. }
  147. stream := true
  148. pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
  149. var pullError error
  150. done := make(chan int)
  151. go func() {
  152. pullError = client.Pull(ctx, pullReq, fn)
  153. done <- 0
  154. }()
  155. select {
  156. case <-stallTimer.C:
  157. return fmt.Errorf("download stalled")
  158. case <-done:
  159. return pullError
  160. }
  161. }
  162. var serverProcMutex sync.Mutex
  163. // Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors
  164. // Starts the server if needed
  165. func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) {
  166. client, testEndpoint := GetTestEndpoint()
  167. if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
  168. serverProcMutex.Lock()
  169. fp, err := os.CreateTemp("", "ollama-server-*.log")
  170. if err != nil {
  171. t.Fatalf("failed to generate log file: %s", err)
  172. }
  173. lifecycle.ServerLogFile = fp.Name()
  174. fp.Close()
  175. require.NoError(t, startServer(t, ctx, testEndpoint))
  176. }
  177. return client, testEndpoint, func() {
  178. if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
  179. defer serverProcMutex.Unlock()
  180. if t.Failed() {
  181. fp, err := os.Open(lifecycle.ServerLogFile)
  182. if err != nil {
  183. slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
  184. return
  185. }
  186. data, err := io.ReadAll(fp)
  187. if err != nil {
  188. slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
  189. return
  190. }
  191. slog.Warn("SERVER LOG FOLLOWS")
  192. os.Stderr.Write(data)
  193. slog.Warn("END OF SERVER")
  194. }
  195. err := os.Remove(lifecycle.ServerLogFile)
  196. if err != nil && !os.IsNotExist(err) {
  197. slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
  198. }
  199. }
  200. }
  201. }
  202. func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
  203. client, _, cleanup := InitServerConnection(ctx, t)
  204. defer cleanup()
  205. require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
  206. DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
  207. }
  208. func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
  209. stallTimer := time.NewTimer(initialTimeout)
  210. var buf bytes.Buffer
  211. fn := func(response api.GenerateResponse) error {
  212. // fmt.Print(".")
  213. buf.Write([]byte(response.Response))
  214. if !stallTimer.Reset(streamTimeout) {
  215. return fmt.Errorf("stall was detected while streaming response, aborting")
  216. }
  217. return nil
  218. }
  219. stream := true
  220. genReq.Stream = &stream
  221. done := make(chan int)
  222. var genErr error
  223. go func() {
  224. genErr = client.Generate(ctx, &genReq, fn)
  225. done <- 0
  226. }()
  227. select {
  228. case <-stallTimer.C:
  229. if buf.Len() == 0 {
  230. t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
  231. } else {
  232. t.Errorf("generate stalled. Response so far:%s", buf.String())
  233. }
  234. case <-done:
  235. require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
  236. // Verify the response contains the expected data
  237. response := buf.String()
  238. atLeastOne := false
  239. for _, resp := range anyResp {
  240. if strings.Contains(strings.ToLower(response), resp) {
  241. atLeastOne = true
  242. break
  243. }
  244. }
  245. require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
  246. slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
  247. case <-ctx.Done():
  248. t.Error("outer test context done while waiting for generate")
  249. }
  250. }
  251. // Generate a set of requests
  252. // By default each request uses orca-mini as the model
  253. func GenerateRequests() ([]api.GenerateRequest, [][]string) {
  254. return []api.GenerateRequest{
  255. {
  256. Model: "orca-mini",
  257. Prompt: "why is the ocean blue?",
  258. Stream: &stream,
  259. KeepAlive: &api.Duration{Duration: 10 * time.Second},
  260. Options: map[string]interface{}{
  261. "seed": 42,
  262. "temperature": 0.0,
  263. },
  264. }, {
  265. Model: "orca-mini",
  266. Prompt: "why is the color of dirt brown?",
  267. Stream: &stream,
  268. KeepAlive: &api.Duration{Duration: 10 * time.Second},
  269. Options: map[string]interface{}{
  270. "seed": 42,
  271. "temperature": 0.0,
  272. },
  273. }, {
  274. Model: "orca-mini",
  275. Prompt: "what is the origin of the us thanksgiving holiday?",
  276. Stream: &stream,
  277. KeepAlive: &api.Duration{Duration: 10 * time.Second},
  278. Options: map[string]interface{}{
  279. "seed": 42,
  280. "temperature": 0.0,
  281. },
  282. }, {
  283. Model: "orca-mini",
  284. Prompt: "what is the origin of independence day?",
  285. Stream: &stream,
  286. KeepAlive: &api.Duration{Duration: 10 * time.Second},
  287. Options: map[string]interface{}{
  288. "seed": 42,
  289. "temperature": 0.0,
  290. },
  291. }, {
  292. Model: "orca-mini",
  293. Prompt: "what is the composition of air?",
  294. Stream: &stream,
  295. KeepAlive: &api.Duration{Duration: 10 * time.Second},
  296. Options: map[string]interface{}{
  297. "seed": 42,
  298. "temperature": 0.0,
  299. },
  300. },
  301. },
  302. [][]string{
  303. []string{"sunlight"},
  304. []string{"soil", "organic", "earth", "black", "tan"},
  305. []string{"england", "english", "massachusetts", "pilgrims", "british"},
  306. []string{"fourth", "july", "declaration", "independence"},
  307. []string{"nitrogen", "oxygen", "carbon", "dioxide"},
  308. }
  309. }