|
@@ -18,6 +18,7 @@ import (
|
|
|
"path/filepath"
|
|
|
"slices"
|
|
|
"strings"
|
|
|
+ "sync"
|
|
|
"syscall"
|
|
|
"time"
|
|
|
|
|
@@ -42,6 +43,9 @@ var mode string = gin.DebugMode
|
|
|
type Server struct {
|
|
|
addr net.Addr
|
|
|
sched *Scheduler
|
|
|
+
|
|
|
+ mu sync.Mutex
|
|
|
+ contextLengthLookup map[string]int
|
|
|
}
|
|
|
|
|
|
func init() {
|
|
@@ -343,11 +347,24 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- kvData, err := getKVData(m.ModelPath, false)
|
|
|
- if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
- }
|
|
|
+ contextLength, err := func() (int, error) {
|
|
|
+ s.mu.Lock()
|
|
|
+ defer s.mu.Unlock()
|
|
|
+ if s.contextLengthLookup == nil {
|
|
|
+ s.contextLengthLookup = make(map[string]int)
|
|
|
+ }
|
|
|
+ contextLength, ok := s.contextLengthLookup[m.ModelPath]
|
|
|
+ if !ok {
|
|
|
+ kvData, err := getKVData(m.ModelPath, false)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ contextLength = int(kvData.ContextLength())
|
|
|
+ s.contextLengthLookup[m.ModelPath] = int(kvData.ContextLength())
|
|
|
+ }
|
|
|
+ return contextLength, nil
|
|
|
+ }()
|
|
|
|
|
|
var count int
|
|
|
for i, s := range input {
|
|
@@ -357,7 +374,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
|
|
+ ctxLen := min(opts.NumCtx, int(contextLength))
|
|
|
if len(tokens) > ctxLen {
|
|
|
if !truncate {
|
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
|