llm_utils_test.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. package server
  2. import (
  3. "context"
  4. "errors"
  5. "os"
  6. "path"
  7. "runtime"
  8. "testing"
  9. "github.com/jmorganca/ollama/api"
  10. "github.com/jmorganca/ollama/llm"
  11. "github.com/stretchr/testify/require"
  12. )
  13. func SkipIFNoTestData(t *testing.T) {
  14. modelDir := getModelDir()
  15. if _, err := os.Stat(modelDir); errors.Is(err, os.ErrNotExist) {
  16. t.Skipf("%s does not exist - skipping integration tests", modelDir)
  17. }
  18. }
  19. func getModelDir() string {
  20. _, filename, _, _ := runtime.Caller(0)
  21. return path.Dir(path.Dir(filename) + "/../test_data/models/.")
  22. }
  23. func PrepareModelForPrompts(t *testing.T, modelName string, opts api.Options) (*Model, llm.LLM) {
  24. modelDir := getModelDir()
  25. os.Setenv("OLLAMA_MODELS", modelDir)
  26. model, err := GetModel(modelName)
  27. require.NoError(t, err, "GetModel ")
  28. err = opts.FromMap(model.Options)
  29. require.NoError(t, err, "opts from model ")
  30. runner, err := llm.New("unused", model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
  31. require.NoError(t, err, "llm.New failed")
  32. return model, runner
  33. }
  34. func OneShotPromptResponse(t *testing.T, ctx context.Context, req api.GenerateRequest, model *Model, runner llm.LLM) string {
  35. prompt, err := model.Prompt(PromptVars{
  36. System: req.System,
  37. Prompt: req.Prompt,
  38. First: len(req.Context) == 0,
  39. })
  40. require.NoError(t, err, "prompt generation failed")
  41. success := make(chan bool, 1)
  42. response := ""
  43. cb := func(r llm.PredictResult) {
  44. if !r.Done {
  45. response += r.Content
  46. } else {
  47. success <- true
  48. }
  49. }
  50. predictReq := llm.PredictOpts{
  51. Prompt: prompt,
  52. Format: req.Format,
  53. Images: req.Images,
  54. }
  55. err = runner.Predict(ctx, predictReq, cb)
  56. require.NoError(t, err, "predict call failed")
  57. select {
  58. case <-ctx.Done():
  59. t.Errorf("failed to complete before timeout: \n%s", response)
  60. return ""
  61. case <-success:
  62. return response
  63. }
  64. }