소스 검색

add session expiration

Michael Yang 1 년 전
부모
커밋
f62a882760
3개의 변경된 파일100개의 추가작업 그리고 20개의 파일을 삭제
  1. 40 5
      api/types.go
  2. 14 0
      llama/llama.go
  3. 46 15
      server/routes.go

+ 40 - 5
api/types.go

@@ -1,7 +1,9 @@
 package api
 
 import (
+	"encoding/json"
 	"fmt"
+	"math"
 	"os"
 	"runtime"
 	"time"
@@ -28,10 +30,12 @@ func (e StatusError) Error() string {
 }
 
 type GenerateRequest struct {
-	SessionID int64  `json:"session_id"`
-	Model     string `json:"model"`
-	Prompt    string `json:"prompt"`
-	Context   []int  `json:"context,omitempty"`
+	SessionID       int64    `json:"session_id"`
+	SessionDuration Duration `json:"session_duration,omitempty"`
+
+	Model   string `json:"model"`
+	Prompt  string `json:"prompt"`
+	Context []int  `json:"context,omitempty"`
 
 	Options `json:"options"`
 }
@@ -82,7 +86,9 @@ type ListResponseModel struct {
 }
 
 type GenerateResponse struct {
-	SessionID int64     `json:"session_id"`
+	SessionID        int64     `json:"session_id"`
+	SessionExpiresAt time.Time `json:"session_expires_at"`
+
 	Model     string    `json:"model"`
 	CreatedAt time.Time `json:"created_at"`
 	Response  string    `json:"response,omitempty"`
@@ -195,3 +201,32 @@ func DefaultOptions() Options {
 		NumThread: runtime.NumCPU(),
 	}
 }
+
+type Duration struct {
+	time.Duration
+}
+
+func (d *Duration) UnmarshalJSON(b []byte) (err error) {
+	var v any
+	if err := json.Unmarshal(b, &v); err != nil {
+		return err
+	}
+
+	d.Duration = 5 * time.Minute
+
+	switch t := v.(type) {
+	case float64:
+		if t < 0 {
+			t = math.MaxFloat64
+		}
+
+		d.Duration = time.Duration(t)
+	case string:
+		d.Duration, err = time.ParseDuration(t)
+		if err != nil {
+			return err
+		}
+	}
+
+	return nil
+}

+ 14 - 0
llama/llama.go

@@ -92,6 +92,7 @@ import (
 	"log"
 	"os"
 	"strings"
+	"sync"
 	"unicode/utf8"
 	"unsafe"
 
@@ -107,6 +108,9 @@ type LLM struct {
 	embd   []C.llama_token
 	cursor int
 
+	mu sync.Mutex
+	gc bool
+
 	api.Options
 }
 
@@ -156,6 +160,11 @@ func New(model string, opts api.Options) (*LLM, error) {
 }
 
 func (llm *LLM) Close() {
+	llm.gc = true
+
+	llm.mu.Lock()
+	defer llm.mu.Unlock()
+
 	defer C.llama_free_model(llm.model)
 	defer C.llama_free(llm.ctx)
 
@@ -163,6 +172,9 @@ func (llm *LLM) Close() {
 }
 
 func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
+	llm.mu.Lock()
+	defer llm.mu.Unlock()
+
 	C.llama_reset_timings(llm.ctx)
 
 	tokens := make([]C.llama_token, len(ctx))
@@ -185,6 +197,8 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
 			break
 		} else if err != nil {
 			return err
+		} else if llm.gc {
+			return io.EOF
 		}
 
 		b.WriteString(llm.detokenize(token))

+ 46 - 15
server/routes.go

@@ -22,16 +22,19 @@ import (
 	"github.com/jmorganca/ollama/llama"
 )
 
-var mu sync.Mutex
-
 var activeSession struct {
-	ID int64
-	*llama.LLM
+	mu sync.Mutex
+
+	id  int64
+	llm *llama.LLM
+
+	expireAt    time.Time
+	expireTimer *time.Timer
 }
 
 func GenerateHandler(c *gin.Context) {
-	mu.Lock()
-	defer mu.Unlock()
+	activeSession.mu.Lock()
+	defer activeSession.mu.Unlock()
 
 	checkpointStart := time.Now()
 
@@ -47,10 +50,10 @@ func GenerateHandler(c *gin.Context) {
 		return
 	}
 
-	if req.SessionID == 0 || req.SessionID != activeSession.ID {
-		if activeSession.LLM != nil {
-			activeSession.Close()
-			activeSession.LLM = nil
+	if req.SessionID == 0 || req.SessionID != activeSession.id {
+		if activeSession.llm != nil {
+			activeSession.llm.Close()
+			activeSession.llm = nil
 		}
 
 		opts := api.DefaultOptions()
@@ -70,9 +73,33 @@ func GenerateHandler(c *gin.Context) {
 			return
 		}
 
-		activeSession.ID = time.Now().UnixNano()
-		activeSession.LLM = llm
+		activeSession.id = time.Now().UnixNano()
+		activeSession.llm = llm
+	}
+
+	sessionDuration := req.SessionDuration
+	sessionID := activeSession.id
+
+	activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
+	if activeSession.expireTimer == nil {
+		activeSession.expireTimer = time.AfterFunc(sessionDuration.Duration, func() {
+			activeSession.mu.Lock()
+			defer activeSession.mu.Unlock()
+
+			if sessionID != activeSession.id {
+				return
+			}
+
+			if time.Now().Before(activeSession.expireAt) {
+				return
+			}
+
+			activeSession.llm.Close()
+			activeSession.llm = nil
+			activeSession.id = 0
+		})
 	}
+	activeSession.expireTimer.Reset(sessionDuration.Duration)
 
 	checkpointLoaded := time.Now()
 
@@ -86,9 +113,13 @@ func GenerateHandler(c *gin.Context) {
 	go func() {
 		defer close(ch)
 		fn := func(r api.GenerateResponse) {
+			activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
+			activeSession.expireTimer.Reset(sessionDuration.Duration)
+
 			r.Model = req.Model
 			r.CreatedAt = time.Now().UTC()
-			r.SessionID = activeSession.ID
+			r.SessionID = activeSession.id
+			r.SessionExpiresAt = activeSession.expireAt.UTC()
 			if r.Done {
 				r.TotalDuration = time.Since(checkpointStart)
 				r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
@@ -97,7 +128,7 @@ func GenerateHandler(c *gin.Context) {
 			ch <- r
 		}
 
-		if err := activeSession.LLM.Predict(req.Context, prompt, fn); err != nil {
+		if err := activeSession.llm.Predict(req.Context, prompt, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
 	}()
@@ -247,7 +278,7 @@ func ListModelsHandler(c *gin.Context) {
 		return
 	}
 
-	c.JSON(http.StatusOK, api.ListResponse{models})
+	c.JSON(http.StatusOK, api.ListResponse{Models: models})
 }
 
 func CopyModelHandler(c *gin.Context) {