routes_tokenization_test.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "net/http"
  8. "net/http/httptest"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/google/go-cmp/cmp"
  14. "github.com/ollama/ollama/api"
  15. "github.com/ollama/ollama/discover"
  16. "github.com/ollama/ollama/llm"
  17. )
  18. func TestTokenize(t *testing.T) {
  19. gin.SetMode(gin.TestMode)
  20. mock := mockRunner{
  21. CompletionResponse: llm.CompletionResponse{
  22. Done: true,
  23. DoneReason: "stop",
  24. PromptEvalCount: 1,
  25. PromptEvalDuration: 1,
  26. EvalCount: 1,
  27. EvalDuration: 1,
  28. },
  29. }
  30. s := Server{
  31. sched: &Scheduler{
  32. pendingReqCh: make(chan *LlmRequest, 1),
  33. finishedReqCh: make(chan *LlmRequest, 1),
  34. expiredCh: make(chan *runnerRef, 1),
  35. unloadedCh: make(chan any, 1),
  36. loaded: make(map[string]*runnerRef),
  37. newServerFn: newMockServer(&mock),
  38. getGpuFn: discover.GetGPUInfo,
  39. getCpuFn: discover.GetCPUInfo,
  40. reschedDelay: 250 * time.Millisecond,
  41. loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
  42. // add small delay to simulate loading
  43. time.Sleep(time.Millisecond)
  44. req.successCh <- &runnerRef{
  45. llama: &mock,
  46. }
  47. },
  48. },
  49. }
  50. go s.sched.Run(context.TODO())
  51. t.Run("missing body", func(t *testing.T) {
  52. w := httptest.NewRecorder()
  53. r := httptest.NewRequest(http.MethodPost, "/api/tokenize", nil)
  54. s.TokenizeHandler(w, r)
  55. if w.Code != http.StatusBadRequest {
  56. t.Errorf("expected status 400, got %d", w.Code)
  57. }
  58. if diff := cmp.Diff(w.Body.String(), "missing request body\n"); diff != "" {
  59. t.Errorf("mismatch (-got +want):\n%s", diff)
  60. }
  61. })
  62. t.Run("missing model", func(t *testing.T) {
  63. w := httptest.NewRecorder()
  64. r := httptest.NewRequest(http.MethodPost, "/api/tokenize", strings.NewReader("{}"))
  65. s.TokenizeHandler(w, r)
  66. if w.Code != http.StatusBadRequest {
  67. t.Errorf("expected status 400, got %d", w.Code)
  68. }
  69. if diff := cmp.Diff(w.Body.String(), "missing `text` for tokenization\n"); diff != "" {
  70. t.Errorf("mismatch (-got +want):\n%s", diff)
  71. }
  72. })
  73. t.Run("tokenize text", func(t *testing.T) {
  74. // First create the model
  75. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  76. Model: "test",
  77. Modelfile: fmt.Sprintf(`FROM %s`, createBinFile(t, llm.KV{
  78. "general.architecture": "llama",
  79. "llama.block_count": uint32(1),
  80. "llama.context_length": uint32(8192),
  81. "llama.embedding_length": uint32(4096),
  82. "llama.attention.head_count": uint32(32),
  83. "llama.attention.head_count_kv": uint32(8),
  84. "tokenizer.ggml.tokens": []string{""},
  85. "tokenizer.ggml.scores": []float32{0},
  86. "tokenizer.ggml.token_type": []int32{0},
  87. }, []llm.Tensor{
  88. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  89. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  90. })),
  91. })
  92. if w.Code != http.StatusOK {
  93. t.Fatalf("failed to create model: %d", w.Code)
  94. }
  95. // Now test tokenization
  96. body, err := json.Marshal(api.TokenizeRequest{
  97. Model: "test",
  98. Text: "Hello world how are you",
  99. })
  100. if err != nil {
  101. t.Fatalf("failed to marshal request: %v", err)
  102. }
  103. w = httptest.NewRecorder()
  104. r := httptest.NewRequest(http.MethodPost, "/api/tokenize", bytes.NewReader(body))
  105. r.Header.Set("Content-Type", "application/json")
  106. s.TokenizeHandler(w, r)
  107. if w.Code != http.StatusOK {
  108. t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
  109. }
  110. var resp api.TokenizeResponse
  111. if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
  112. t.Errorf("failed to decode response: %v", err)
  113. }
  114. // Our mock tokenizer creates sequential tokens based on word count
  115. expected := []int{0, 1, 2, 3, 4}
  116. if diff := cmp.Diff(resp.Tokens, expected); diff != "" {
  117. t.Errorf("mismatch (-got +want):\n%s", diff)
  118. }
  119. })
  120. t.Run("tokenize empty text", func(t *testing.T) {
  121. body, err := json.Marshal(api.TokenizeRequest{
  122. Model: "test",
  123. Text: "",
  124. })
  125. if err != nil {
  126. t.Fatalf("failed to marshal request: %v", err)
  127. }
  128. w := httptest.NewRecorder()
  129. r := httptest.NewRequest(http.MethodPost, "/api/tokenize", bytes.NewReader(body))
  130. r.Header.Set("Content-Type", "application/json")
  131. s.TokenizeHandler(w, r)
  132. if w.Code != http.StatusBadRequest {
  133. t.Errorf("expected status 400, got %d", w.Code)
  134. }
  135. if diff := cmp.Diff(w.Body.String(), "missing `text` for tokenization\n"); diff != "" {
  136. t.Errorf("mismatch (-got +want):\n%s", diff)
  137. }
  138. })
  139. }
  140. func TestDetokenize(t *testing.T) {
  141. gin.SetMode(gin.TestMode)
  142. mock := mockRunner{
  143. CompletionResponse: llm.CompletionResponse{
  144. Done: true,
  145. DoneReason: "stop",
  146. PromptEvalCount: 1,
  147. PromptEvalDuration: 1,
  148. EvalCount: 1,
  149. EvalDuration: 1,
  150. },
  151. }
  152. s := Server{
  153. sched: &Scheduler{
  154. pendingReqCh: make(chan *LlmRequest, 1),
  155. finishedReqCh: make(chan *LlmRequest, 1),
  156. expiredCh: make(chan *runnerRef, 1),
  157. unloadedCh: make(chan any, 1),
  158. loaded: make(map[string]*runnerRef),
  159. newServerFn: newMockServer(&mock),
  160. getGpuFn: discover.GetGPUInfo,
  161. getCpuFn: discover.GetCPUInfo,
  162. reschedDelay: 250 * time.Millisecond,
  163. loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
  164. // add small delay to simulate loading
  165. time.Sleep(time.Millisecond)
  166. req.successCh <- &runnerRef{
  167. llama: &mock,
  168. }
  169. },
  170. },
  171. }
  172. go s.sched.Run(context.TODO())
  173. t.Run("detokenize tokens", func(t *testing.T) {
  174. // Create the model first
  175. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  176. Model: "test",
  177. Modelfile: fmt.Sprintf(`FROM %s`, createBinFile(t, llm.KV{
  178. "general.architecture": "llama",
  179. "llama.block_count": uint32(1),
  180. "llama.context_length": uint32(8192),
  181. "llama.embedding_length": uint32(4096),
  182. "llama.attention.head_count": uint32(32),
  183. "llama.attention.head_count_kv": uint32(8),
  184. "tokenizer.ggml.tokens": []string{""},
  185. "tokenizer.ggml.scores": []float32{0},
  186. "tokenizer.ggml.token_type": []int32{0},
  187. }, []llm.Tensor{
  188. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  189. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  190. })),
  191. Stream: &stream,
  192. })
  193. if w.Code != http.StatusOK {
  194. t.Fatalf("failed to create model: %d - %s", w.Code, w.Body.String())
  195. }
  196. body, err := json.Marshal(api.DetokenizeRequest{
  197. Model: "test",
  198. Tokens: []int{0, 1, 2, 3, 4},
  199. })
  200. if err != nil {
  201. t.Fatalf("failed to marshal request: %v", err)
  202. }
  203. w = httptest.NewRecorder()
  204. r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
  205. r.Header.Set("Content-Type", "application/json")
  206. s.DetokenizeHandler(w, r)
  207. if w.Code != http.StatusOK {
  208. t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
  209. }
  210. var resp api.DetokenizeResponse
  211. if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
  212. t.Errorf("failed to decode response: %v", err)
  213. }
  214. })
  215. t.Run("detokenize empty tokens", func(t *testing.T) {
  216. body, err := json.Marshal(api.DetokenizeRequest{
  217. Model: "test",
  218. })
  219. if err != nil {
  220. t.Fatalf("failed to marshal request: %v", err)
  221. }
  222. w := httptest.NewRecorder()
  223. r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
  224. r.Header.Set("Content-Type", "application/json")
  225. s.DetokenizeHandler(w, r)
  226. if w.Code != http.StatusBadRequest {
  227. t.Errorf("expected status 400, got %d", w.Code)
  228. }
  229. if diff := cmp.Diff(w.Body.String(), "missing tokens for detokenization\n"); diff != "" {
  230. t.Errorf("mismatch (-got +want):\n%s", diff)
  231. }
  232. })
  233. t.Run("detokenize missing model", func(t *testing.T) {
  234. body, err := json.Marshal(api.DetokenizeRequest{
  235. Tokens: []int{0, 1, 2},
  236. })
  237. if err != nil {
  238. t.Fatalf("failed to marshal request: %v", err)
  239. }
  240. w := httptest.NewRecorder()
  241. r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
  242. r.Header.Set("Content-Type", "application/json")
  243. s.DetokenizeHandler(w, r)
  244. if w.Code != http.StatusNotFound {
  245. t.Errorf("expected status 404, got %d", w.Code)
  246. }
  247. if diff := cmp.Diff(w.Body.String(), "model '' not found\n"); diff != "" {
  248. t.Errorf("mismatch (-got +want):\n%s", diff)
  249. }
  250. })
  251. }