routes_tokenize_test.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. package server
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "net/http"
  6. "testing"
  7. "time"
  8. "github.com/gin-gonic/gin"
  9. "github.com/ollama/ollama/api"
  10. "github.com/ollama/ollama/discover"
  11. "github.com/ollama/ollama/llama"
  12. "github.com/ollama/ollama/llm"
  13. )
  14. type mockModelLoader struct {
  15. LoadModelFn func(string, llama.ModelParams) (*loadedModel, error)
  16. }
  17. func (ml *mockModelLoader) LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
  18. if ml.LoadModelFn != nil {
  19. return ml.LoadModelFn(name, params)
  20. }
  21. return nil, nil
  22. }
  23. type mockModel struct {
  24. llama.Model
  25. TokenizeFn func(text string, addBos bool, addEos bool) ([]int, error)
  26. TokenToPieceFn func(token int) string
  27. }
  28. func (mockModel) Tokenize(text string, addBos bool, addEos bool) ([]int, error) {
  29. return []int{1, 2, 3}, nil
  30. }
  31. func (mockModel) TokenToPiece(token int) string {
  32. return fmt.Sprint(token)
  33. }
  34. func TestTokenizeHandler(t *testing.T) {
  35. gin.SetMode(gin.TestMode)
  36. mockModel := mockModel{}
  37. s := Server{
  38. sched: &Scheduler{
  39. pendingReqCh: make(chan *LlmRequest, 1),
  40. finishedReqCh: make(chan *LlmRequest, 1),
  41. expiredCh: make(chan *runnerRef, 1),
  42. unloadedCh: make(chan any, 1),
  43. loaded: make(map[string]*runnerRef),
  44. newServerFn: newMockServer(&mockRunner{}),
  45. getGpuFn: discover.GetGPUInfo,
  46. getCpuFn: discover.GetCPUInfo,
  47. reschedDelay: 250 * time.Millisecond,
  48. loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
  49. time.Sleep(time.Millisecond)
  50. req.successCh <- &runnerRef{
  51. llama: &mockRunner{},
  52. }
  53. },
  54. },
  55. ml: mockLoader,
  56. }
  57. t.Run("method not allowed", func(t *testing.T) {
  58. w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil)
  59. if w.Code != http.StatusMethodNotAllowed {
  60. t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
  61. }
  62. })
  63. t.Run("missing body", func(t *testing.T) {
  64. w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil)
  65. if w.Code != http.StatusBadRequest {
  66. t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
  67. }
  68. })
  69. t.Run("missing text", func(t *testing.T) {
  70. w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
  71. Model: "test",
  72. })
  73. if w.Code != http.StatusBadRequest {
  74. t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
  75. }
  76. })
  77. t.Run("missing model", func(t *testing.T) {
  78. w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
  79. Text: "test text",
  80. })
  81. if w.Code != http.StatusBadRequest {
  82. t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
  83. }
  84. })
  85. t.Run("model not found", func(t *testing.T) {
  86. w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
  87. Model: "nonexistent",
  88. Text: "test text",
  89. })
  90. if w.Code != http.StatusInternalServerError {
  91. t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
  92. }
  93. })
  94. t.Run("successful tokenization", func(t *testing.T) {
  95. w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
  96. Model: "test",
  97. Text: "test text",
  98. })
  99. if w.Code != http.StatusOK {
  100. t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
  101. }
  102. var resp api.TokenizeResponse
  103. if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
  104. t.Fatal(err)
  105. }
  106. expectedTokens := []int{0, 1}
  107. if len(resp.Tokens) != len(expectedTokens) {
  108. t.Errorf("expected %d tokens, got %d", len(expectedTokens), len(resp.Tokens))
  109. }
  110. for i, token := range resp.Tokens {
  111. if token != expectedTokens[i] {
  112. t.Errorf("expected token %d at position %d, got %d", expectedTokens[i], i, token)
  113. }
  114. }
  115. })
  116. }