llm_utils_test.go 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. package server
  2. import (
  3. "context"
  4. "errors"
  5. "os"
  6. "path"
  7. "runtime"
  8. "testing"
  9. "time"
  10. "github.com/jmorganca/ollama/api"
  11. "github.com/jmorganca/ollama/llm"
  12. "github.com/stretchr/testify/require"
  13. )
  14. func SkipIFNoTestData(t *testing.T) {
  15. modelDir := getModelDir()
  16. if _, err := os.Stat(modelDir); errors.Is(err, os.ErrNotExist) {
  17. t.Skipf("%s does not exist - skipping integration tests", modelDir)
  18. }
  19. }
  20. func getModelDir() string {
  21. _, filename, _, _ := runtime.Caller(0)
  22. return path.Dir(path.Dir(filename) + "/../test_data/models/.")
  23. }
  24. func PrepareModelForPrompts(t *testing.T, modelName string, opts api.Options) (*Model, llm.LLM) {
  25. modelDir := getModelDir()
  26. os.Setenv("OLLAMA_MODELS", modelDir)
  27. model, err := GetModel(modelName)
  28. require.NoError(t, err, "GetModel ")
  29. err = opts.FromMap(model.Options)
  30. require.NoError(t, err, "opts from model ")
  31. runner, err := llm.New("unused", model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
  32. require.NoError(t, err, "llm.New failed")
  33. return model, runner
  34. }
  35. func OneShotPromptResponse(t *testing.T, ctx context.Context, req api.GenerateRequest, model *Model, runner llm.LLM) string {
  36. checkpointStart := time.Now()
  37. prompt, err := model.Prompt(PromptVars{
  38. System: req.System,
  39. Prompt: req.Prompt,
  40. First: len(req.Context) == 0,
  41. })
  42. require.NoError(t, err, "prompt generation failed")
  43. success := make(chan bool, 1)
  44. response := ""
  45. cb := func(r llm.PredictResult) {
  46. if !r.Done {
  47. response += r.Content
  48. } else {
  49. success <- true
  50. }
  51. }
  52. checkpointLoaded := time.Now()
  53. predictReq := llm.PredictOpts{
  54. Prompt: prompt,
  55. Format: req.Format,
  56. CheckpointStart: checkpointStart,
  57. CheckpointLoaded: checkpointLoaded,
  58. }
  59. err = runner.Predict(ctx, predictReq, cb)
  60. require.NoError(t, err, "predict call failed")
  61. select {
  62. case <-ctx.Done():
  63. t.Errorf("failed to complete before timeout: \n%s", response)
  64. return ""
  65. case <-success:
  66. return response
  67. }
  68. }