Blake Mizerany 8 miesięcy temu
rodzic
commit
e7254617e3
4 zmienionych plików z 50 dodań i 11 usunięć
  1. 7 0
      main.go
  2. 19 4
      server/images.go
  3. 23 6
      server/routes.go
  4. 1 1
      server/sched.go

+ 7 - 0
main.go

@@ -6,8 +6,15 @@ import (
 	"github.com/spf13/cobra"
 	"github.com/spf13/cobra"
 
 
 	"github.com/ollama/ollama/cmd"
 	"github.com/ollama/ollama/cmd"
+
+	"net/http"
+	_ "net/http/pprof"
 )
 )
 
 
 func main() {
 func main() {
+	go func() {
+		http.ListenAndServe("localhost:6060", nil)
+	}()
+
 	cobra.CheckErr(cmd.NewCLI().ExecuteContext(context.Background()))
 	cobra.CheckErr(cmd.NewCLI().ExecuteContext(context.Background()))
 }
 }

+ 19 - 4
server/images.go

@@ -21,6 +21,7 @@ import (
 	"slices"
 	"slices"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
+	"sync"
 
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/auth"
 	"github.com/ollama/ollama/auth"
@@ -209,13 +210,25 @@ type RootFS struct {
 	DiffIDs []string `json:"diff_ids"`
 	DiffIDs []string `json:"diff_ids"`
 }
 }
 
 
+var manifestCache struct {
+	sync.Mutex
+	cache map[string]*Manifest
+}
+
 func GetManifest(mp ModelPath) (*Manifest, string, error) {
 func GetManifest(mp ModelPath) (*Manifest, string, error) {
-	fp, err := mp.GetManifestPath()
-	if err != nil {
-		return nil, "", err
+	manifestCache.Lock()
+	defer manifestCache.Unlock()
+
+	if manifestCache.cache == nil {
+		manifestCache.cache = make(map[string]*Manifest)
 	}
 	}
 
 
-	if _, err = os.Stat(fp); err != nil {
+	if manifest, ok := manifestCache.cache[mp.GetFullTagname()]; ok {
+		return manifest, "", nil
+	}
+
+	fp, err := mp.GetManifestPath()
+	if err != nil {
 		return nil, "", err
 		return nil, "", err
 	}
 	}
 
 
@@ -233,6 +246,8 @@ func GetManifest(mp ModelPath) (*Manifest, string, error) {
 		return nil, "", err
 		return nil, "", err
 	}
 	}
 
 
+	manifestCache.cache[mp.GetFullTagname()] = manifest
+
 	return manifest, shaStr, nil
 	return manifest, shaStr, nil
 }
 }
 
 

+ 23 - 6
server/routes.go

@@ -18,6 +18,7 @@ import (
 	"path/filepath"
 	"path/filepath"
 	"slices"
 	"slices"
 	"strings"
 	"strings"
+	"sync"
 	"syscall"
 	"syscall"
 	"time"
 	"time"
 
 
@@ -42,6 +43,9 @@ var mode string = gin.DebugMode
 type Server struct {
 type Server struct {
 	addr  net.Addr
 	addr  net.Addr
 	sched *Scheduler
 	sched *Scheduler
+
+	mu                  sync.Mutex
+	contextLengthLookup map[string]int
 }
 }
 
 
 func init() {
 func init() {
@@ -343,11 +347,24 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	kvData, err := getKVData(m.ModelPath, false)
-	if err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
+	contextLength, err := func() (int, error) {
+		s.mu.Lock()
+		defer s.mu.Unlock()
+		if s.contextLengthLookup == nil {
+			s.contextLengthLookup = make(map[string]int)
+		}
+		contextLength, ok := s.contextLengthLookup[m.ModelPath]
+		if !ok {
+			kvData, err := getKVData(m.ModelPath, false)
+			if err != nil {
+				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+				return 0, err
+			}
+			contextLength = int(kvData.ContextLength())
+			s.contextLengthLookup[m.ModelPath] = int(kvData.ContextLength())
+		}
+		return contextLength, nil
+	}()
 
 
 	var count int
 	var count int
 	for i, s := range input {
 	for i, s := range input {
@@ -357,7 +374,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 			return
 			return
 		}
 		}
 
 
-		ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
+		ctxLen := min(opts.NumCtx, int(contextLength))
 		if len(tokens) > ctxLen {
 		if len(tokens) > ctxLen {
 			if !truncate {
 			if !truncate {
 				c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
 				c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})

+ 1 - 1
server/sched.go

@@ -66,7 +66,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
 		pendingReqCh:  make(chan *LlmRequest, maxQueue),
 		pendingReqCh:  make(chan *LlmRequest, maxQueue),
 		finishedReqCh: make(chan *LlmRequest, maxQueue),
 		finishedReqCh: make(chan *LlmRequest, maxQueue),
 		expiredCh:     make(chan *runnerRef, maxQueue),
 		expiredCh:     make(chan *runnerRef, maxQueue),
-		unloadedCh:    make(chan interface{}, maxQueue),
+		unloadedCh:    make(chan any, maxQueue),
 		loaded:        make(map[string]*runnerRef),
 		loaded:        make(map[string]*runnerRef),
 		newServerFn:   llm.NewLlamaServer,
 		newServerFn:   llm.NewLlamaServer,
 		getGpuFn:      gpu.GetGPUInfo,
 		getGpuFn:      gpu.GetGPUInfo,