llm_utils_test.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. //go:build integration
  2. package server
  3. import (
  4. "context"
  5. "errors"
  6. "os"
  7. "path"
  8. "runtime"
  9. "testing"
  10. "github.com/jmorganca/ollama/api"
  11. "github.com/jmorganca/ollama/llm"
  12. "github.com/stretchr/testify/require"
  13. )
  14. func SkipIFNoTestData(t *testing.T) {
  15. modelDir := getModelDir()
  16. if _, err := os.Stat(modelDir); errors.Is(err, os.ErrNotExist) {
  17. t.Skipf("%s does not exist - skipping integration tests", modelDir)
  18. }
  19. }
  20. func getModelDir() string {
  21. _, filename, _, _ := runtime.Caller(0)
  22. return path.Dir(path.Dir(filename) + "/../test_data/models/.")
  23. }
  24. func PrepareModelForPrompts(t *testing.T, modelName string, opts api.Options) (*Model, llm.LLM) {
  25. modelDir := getModelDir()
  26. os.Setenv("OLLAMA_MODELS", modelDir)
  27. model, err := GetModel(modelName)
  28. require.NoError(t, err, "GetModel ")
  29. err = opts.FromMap(model.Options)
  30. require.NoError(t, err, "opts from model ")
  31. runner, err := llm.New("unused", model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
  32. require.NoError(t, err, "llm.New failed")
  33. return model, runner
  34. }
  35. func OneShotPromptResponse(t *testing.T, ctx context.Context, req api.GenerateRequest, model *Model, runner llm.LLM) string {
  36. prompt, err := model.PreResponsePrompt(PromptVars{
  37. System: req.System,
  38. Prompt: req.Prompt,
  39. First: len(req.Context) == 0,
  40. })
  41. require.NoError(t, err, "prompt generation failed")
  42. success := make(chan bool, 1)
  43. response := ""
  44. cb := func(r llm.PredictResult) {
  45. if !r.Done {
  46. response += r.Content
  47. } else {
  48. success <- true
  49. }
  50. }
  51. predictReq := llm.PredictOpts{
  52. Prompt: prompt,
  53. Format: req.Format,
  54. Images: req.Images,
  55. }
  56. err = runner.Predict(ctx, predictReq, cb)
  57. require.NoError(t, err, "predict call failed")
  58. select {
  59. case <-ctx.Done():
  60. t.Errorf("failed to complete before timeout: \n%s", response)
  61. return ""
  62. case <-success:
  63. return response
  64. }
  65. }