|
@@ -22,16 +22,19 @@ import (
|
|
"github.com/jmorganca/ollama/llama"
|
|
"github.com/jmorganca/ollama/llama"
|
|
)
|
|
)
|
|
|
|
|
|
-var mu sync.Mutex
|
|
|
|
-
|
|
|
|
var activeSession struct {
|
|
var activeSession struct {
|
|
- ID int64
|
|
|
|
- *llama.LLM
|
|
|
|
|
|
+ mu sync.Mutex
|
|
|
|
+
|
|
|
|
+ id int64
|
|
|
|
+ llm *llama.LLM
|
|
|
|
+
|
|
|
|
+ expireAt time.Time
|
|
|
|
+ expireTimer *time.Timer
|
|
}
|
|
}
|
|
|
|
|
|
func GenerateHandler(c *gin.Context) {
|
|
func GenerateHandler(c *gin.Context) {
|
|
- mu.Lock()
|
|
|
|
- defer mu.Unlock()
|
|
|
|
|
|
+ activeSession.mu.Lock()
|
|
|
|
+ defer activeSession.mu.Unlock()
|
|
|
|
|
|
checkpointStart := time.Now()
|
|
checkpointStart := time.Now()
|
|
|
|
|
|
@@ -47,10 +50,10 @@ func GenerateHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- if req.SessionID == 0 || req.SessionID != activeSession.ID {
|
|
|
|
- if activeSession.LLM != nil {
|
|
|
|
- activeSession.Close()
|
|
|
|
- activeSession.LLM = nil
|
|
|
|
|
|
+ if req.SessionID == 0 || req.SessionID != activeSession.id {
|
|
|
|
+ if activeSession.llm != nil {
|
|
|
|
+ activeSession.llm.Close()
|
|
|
|
+ activeSession.llm = nil
|
|
}
|
|
}
|
|
|
|
|
|
opts := api.DefaultOptions()
|
|
opts := api.DefaultOptions()
|
|
@@ -70,9 +73,33 @@ func GenerateHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- activeSession.ID = time.Now().UnixNano()
|
|
|
|
- activeSession.LLM = llm
|
|
|
|
|
|
+ activeSession.id = time.Now().UnixNano()
|
|
|
|
+ activeSession.llm = llm
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ sessionDuration := req.SessionDuration
|
|
|
|
+ sessionID := activeSession.id
|
|
|
|
+
|
|
|
|
+ activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
|
|
|
|
+ if activeSession.expireTimer == nil {
|
|
|
|
+ activeSession.expireTimer = time.AfterFunc(sessionDuration.Duration, func() {
|
|
|
|
+ activeSession.mu.Lock()
|
|
|
|
+ defer activeSession.mu.Unlock()
|
|
|
|
+
|
|
|
|
+ if sessionID != activeSession.id {
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if time.Now().Before(activeSession.expireAt) {
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ activeSession.llm.Close()
|
|
|
|
+ activeSession.llm = nil
|
|
|
|
+ activeSession.id = 0
|
|
|
|
+ })
|
|
}
|
|
}
|
|
|
|
+ activeSession.expireTimer.Reset(sessionDuration.Duration)
|
|
|
|
|
|
checkpointLoaded := time.Now()
|
|
checkpointLoaded := time.Now()
|
|
|
|
|
|
@@ -86,9 +113,13 @@ func GenerateHandler(c *gin.Context) {
|
|
go func() {
|
|
go func() {
|
|
defer close(ch)
|
|
defer close(ch)
|
|
fn := func(r api.GenerateResponse) {
|
|
fn := func(r api.GenerateResponse) {
|
|
|
|
+ activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
|
|
|
|
+ activeSession.expireTimer.Reset(sessionDuration.Duration)
|
|
|
|
+
|
|
r.Model = req.Model
|
|
r.Model = req.Model
|
|
r.CreatedAt = time.Now().UTC()
|
|
r.CreatedAt = time.Now().UTC()
|
|
- r.SessionID = activeSession.ID
|
|
|
|
|
|
+ r.SessionID = activeSession.id
|
|
|
|
+ r.SessionExpiresAt = activeSession.expireAt.UTC()
|
|
if r.Done {
|
|
if r.Done {
|
|
r.TotalDuration = time.Since(checkpointStart)
|
|
r.TotalDuration = time.Since(checkpointStart)
|
|
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
@@ -97,7 +128,7 @@ func GenerateHandler(c *gin.Context) {
|
|
ch <- r
|
|
ch <- r
|
|
}
|
|
}
|
|
|
|
|
|
- if err := activeSession.LLM.Predict(req.Context, prompt, fn); err != nil {
|
|
|
|
|
|
+ if err := activeSession.llm.Predict(req.Context, prompt, fn); err != nil {
|
|
ch <- gin.H{"error": err.Error()}
|
|
ch <- gin.H{"error": err.Error()}
|
|
}
|
|
}
|
|
}()
|
|
}()
|
|
@@ -247,7 +278,7 @@ func ListModelsHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- c.JSON(http.StatusOK, api.ListResponse{models})
|
|
|
|
|
|
+ c.JSON(http.StatusOK, api.ListResponse{Models: models})
|
|
}
|
|
}
|
|
|
|
|
|
func CopyModelHandler(c *gin.Context) {
|
|
func CopyModelHandler(c *gin.Context) {
|