model_test.go 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. package bert_test
  2. import (
  3. "encoding/json"
  4. "os"
  5. "path/filepath"
  6. "strings"
  7. "testing"
  8. "github.com/ollama/ollama/ml"
  9. "github.com/ollama/ollama/model"
  10. )
  11. func blob(t *testing.T, tag string) string {
  12. t.Helper()
  13. home, err := os.UserHomeDir()
  14. if err != nil {
  15. t.Fatal(err)
  16. }
  17. p := filepath.Join(home, ".ollama", "models")
  18. manifestBytes, err := os.ReadFile(filepath.Join(p, "manifests", "registry.ollama.ai", "library", "all-minilm", tag))
  19. if err != nil {
  20. t.Fatal(err)
  21. }
  22. var manifest struct {
  23. Layers []struct {
  24. MediaType string `json:"mediaType"`
  25. Digest string `json:"digest"`
  26. }
  27. }
  28. if err := json.Unmarshal(manifestBytes, &manifest); err != nil {
  29. t.Fatal(err)
  30. }
  31. var digest string
  32. for _, layer := range manifest.Layers {
  33. if layer.MediaType == "application/vnd.ollama.image.model" {
  34. digest = layer.Digest
  35. break
  36. }
  37. }
  38. if digest == "" {
  39. t.Fatal("no model layer found")
  40. }
  41. return filepath.Join(p, "blobs", strings.ReplaceAll(digest, ":", "-"))
  42. }
  43. func TestEmbedding(t *testing.T) {
  44. m, err := model.New(blob(t, "latest"))
  45. if err != nil {
  46. t.Fatal(err)
  47. }
  48. text, err := os.ReadFile(filepath.Join("..", "testdata", "war-and-peace.txt"))
  49. if err != nil {
  50. t.Fatal(err)
  51. }
  52. inputIDs, err := m.(model.TextProcessor).Encode(string(text))
  53. if err != nil {
  54. t.Fatal(err)
  55. }
  56. logit, err := model.Forward(m, model.WithInputIDs(inputIDs))
  57. if err != nil {
  58. t.Fatal(err)
  59. }
  60. t.Log(ml.Dump(logit))
  61. }