123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- package server
- import (
- "encoding/json"
- "fmt"
- "net/http"
- "testing"
- "time"
- "github.com/gin-gonic/gin"
- "github.com/ollama/ollama/api"
- "github.com/ollama/ollama/discover"
- "github.com/ollama/ollama/llama"
- "github.com/ollama/ollama/llm"
- )
- type mockModelLoader struct {
- LoadModelFn func(string, llama.ModelParams) (*loadedModel, error)
- }
- func (ml *mockModelLoader) LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
- if ml.LoadModelFn != nil {
- return ml.LoadModelFn(name, params)
- }
- return nil, nil
- }
- type mockModel struct {
- llama.Model
- TokenizeFn func(text string, addBos bool, addEos bool) ([]int, error)
- TokenToPieceFn func(token int) string
- }
- func (mockModel) Tokenize(text string, addBos bool, addEos bool) ([]int, error) {
- return []int{1, 2, 3}, nil
- }
- func (mockModel) TokenToPiece(token int) string {
- return fmt.Sprint(token)
- }
- func TestTokenizeHandler(t *testing.T) {
- gin.SetMode(gin.TestMode)
- mockModel := mockModel{}
- s := Server{
- sched: &Scheduler{
- pendingReqCh: make(chan *LlmRequest, 1),
- finishedReqCh: make(chan *LlmRequest, 1),
- expiredCh: make(chan *runnerRef, 1),
- unloadedCh: make(chan any, 1),
- loaded: make(map[string]*runnerRef),
- newServerFn: newMockServer(&mockRunner{}),
- getGpuFn: discover.GetGPUInfo,
- getCpuFn: discover.GetCPUInfo,
- reschedDelay: 250 * time.Millisecond,
- loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
- time.Sleep(time.Millisecond)
- req.successCh <- &runnerRef{
- llama: &mockRunner{},
- }
- },
- },
- ml: mockLoader,
- }
- t.Run("method not allowed", func(t *testing.T) {
- w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil)
- if w.Code != http.StatusMethodNotAllowed {
- t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
- }
- })
- t.Run("missing body", func(t *testing.T) {
- w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil)
- if w.Code != http.StatusBadRequest {
- t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
- }
- })
- t.Run("missing text", func(t *testing.T) {
- w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
- Model: "test",
- })
- if w.Code != http.StatusBadRequest {
- t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
- }
- })
- t.Run("missing model", func(t *testing.T) {
- w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
- Text: "test text",
- })
- if w.Code != http.StatusBadRequest {
- t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
- }
- })
- t.Run("model not found", func(t *testing.T) {
- w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
- Model: "nonexistent",
- Text: "test text",
- })
- if w.Code != http.StatusInternalServerError {
- t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
- }
- })
- t.Run("successful tokenization", func(t *testing.T) {
- w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
- Model: "test",
- Text: "test text",
- })
- if w.Code != http.StatusOK {
- t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
- }
- var resp api.TokenizeResponse
- if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
- t.Fatal(err)
- }
- expectedTokens := []int{0, 1}
- if len(resp.Tokens) != len(expectedTokens) {
- t.Errorf("expected %d tokens, got %d", len(expectedTokens), len(resp.Tokens))
- }
- for i, token := range resp.Tokens {
- if token != expectedTokens[i] {
- t.Errorf("expected token %d at position %d, got %d", expectedTokens[i], i, token)
- }
- }
- })
- }
|