浏览代码

session id

Michael Yang 1 年之前
父节点
当前提交
35af37a2cb
共有 4 个文件被更改,包括 67 次插入36 次删除
  1. 5 3
      api/types.go
  2. 16 8
      cmd/cmd.go
  3. 9 9
      llama/llama.go
  4. 37 16
      server/routes.go

+ 5 - 3
api/types.go

@@ -28,9 +28,10 @@ func (e StatusError) Error() string {
 }
 }
 
 
 type GenerateRequest struct {
 type GenerateRequest struct {
-	Model   string `json:"model"`
-	Prompt  string `json:"prompt"`
-	Context []int  `json:"context,omitempty"`
+	SessionID int64  `json:"session_id"`
+	Model     string `json:"model"`
+	Prompt    string `json:"prompt"`
+	Context   []int  `json:"context,omitempty"`
 
 
 	Options `json:"options"`
 	Options `json:"options"`
 }
 }
@@ -81,6 +82,7 @@ type ListResponseModel struct {
 }
 }
 
 
 type GenerateResponse struct {
 type GenerateResponse struct {
+	SessionID int64     `json:"session_id"`
 	Model     string    `json:"model"`
 	Model     string    `json:"model"`
 	CreatedAt time.Time `json:"created_at"`
 	CreatedAt time.Time `json:"created_at"`
 	Response  string    `json:"response,omitempty"`
 	Response  string    `json:"response,omitempty"`

+ 16 - 8
cmd/cmd.go

@@ -244,7 +244,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
 	return generateBatch(cmd, args[0])
 	return generateBatch(cmd, args[0])
 }
 }
 
 
-var generateContextKey struct{}
+type generateContextKey string
 
 
 func generate(cmd *cobra.Command, model, prompt string) error {
 func generate(cmd *cobra.Command, model, prompt string) error {
 	if len(strings.TrimSpace(prompt)) > 0 {
 	if len(strings.TrimSpace(prompt)) > 0 {
@@ -255,22 +255,25 @@ func generate(cmd *cobra.Command, model, prompt string) error {
 
 
 		var latest api.GenerateResponse
 		var latest api.GenerateResponse
 
 
-		generateContext, ok := cmd.Context().Value(generateContextKey).([]int)
+		generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
 		if !ok {
 		if !ok {
 			generateContext = []int{}
 			generateContext = []int{}
 		}
 		}
 
 
-		request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
-		fn := func(resp api.GenerateResponse) error {
+		generateSession, ok := cmd.Context().Value(generateContextKey("session")).(int64)
+		if !ok {
+			generateSession = 0
+		}
+
+		request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, SessionID: generateSession}
+		fn := func(response api.GenerateResponse) error {
 			if !spinner.IsFinished() {
 			if !spinner.IsFinished() {
 				spinner.Finish()
 				spinner.Finish()
 			}
 			}
 
 
-			latest = resp
+			latest = response
 
 
-			fmt.Print(resp.Response)
-
-			cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context))
+			fmt.Print(response.Response)
 			return nil
 			return nil
 		}
 		}
 
 
@@ -289,6 +292,11 @@ func generate(cmd *cobra.Command, model, prompt string) error {
 		if verbose {
 		if verbose {
 			latest.Summary()
 			latest.Summary()
 		}
 		}
+
+		ctx := cmd.Context()
+		ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
+		ctx = context.WithValue(ctx, generateContextKey("session"), latest.SessionID)
+		cmd.SetContext(ctx)
 	}
 	}
 
 
 	return nil
 	return nil

+ 9 - 9
llama/llama.go

@@ -91,7 +91,7 @@ import (
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/api"
 )
 )
 
 
-type llama struct {
+type LLM struct {
 	params *C.struct_llama_context_params
 	params *C.struct_llama_context_params
 	model  *C.struct_llama_model
 	model  *C.struct_llama_model
 	ctx    *C.struct_llama_context
 	ctx    *C.struct_llama_context
@@ -99,12 +99,12 @@ type llama struct {
 	api.Options
 	api.Options
 }
 }
 
 
-func New(model string, opts api.Options) (*llama, error) {
+func New(model string, opts api.Options) (*LLM, error) {
 	if _, err := os.Stat(model); err != nil {
 	if _, err := os.Stat(model); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	llm := llama{Options: opts}
+	llm := LLM{Options: opts}
 
 
 	C.llama_backend_init(C.bool(llm.UseNUMA))
 	C.llama_backend_init(C.bool(llm.UseNUMA))
 
 
@@ -144,14 +144,14 @@ func New(model string, opts api.Options) (*llama, error) {
 	return &llm, nil
 	return &llm, nil
 }
 }
 
 
-func (llm *llama) Close() {
+func (llm *LLM) Close() {
 	defer C.llama_free_model(llm.model)
 	defer C.llama_free_model(llm.model)
 	defer C.llama_free(llm.ctx)
 	defer C.llama_free(llm.ctx)
 
 
 	C.llama_print_timings(llm.ctx)
 	C.llama_print_timings(llm.ctx)
 }
 }
 
 
-func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
+func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
 	if input := llm.tokenize(prompt); input != nil {
 	if input := llm.tokenize(prompt); input != nil {
 		embd := make([]C.llama_token, len(ctx))
 		embd := make([]C.llama_token, len(ctx))
 		for i := range ctx {
 		for i := range ctx {
@@ -164,7 +164,7 @@ func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse
 	return errors.New("llama: tokenize")
 	return errors.New("llama: tokenize")
 }
 }
 
 
-func (llm *llama) tokenize(prompt string) []C.llama_token {
+func (llm *LLM) tokenize(prompt string) []C.llama_token {
 	cPrompt := C.CString(prompt)
 	cPrompt := C.CString(prompt)
 	defer C.free(unsafe.Pointer(cPrompt))
 	defer C.free(unsafe.Pointer(cPrompt))
 
 
@@ -176,7 +176,7 @@ func (llm *llama) tokenize(prompt string) []C.llama_token {
 	return nil
 	return nil
 }
 }
 
 
-func (llm *llama) detokenize(tokens ...C.llama_token) string {
+func (llm *LLM) detokenize(tokens ...C.llama_token) string {
 	var sb strings.Builder
 	var sb strings.Builder
 	for _, token := range tokens {
 	for _, token := range tokens {
 		sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
 		sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
@@ -185,7 +185,7 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
 	return sb.String()
 	return sb.String()
 }
 }
 
 
-func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
+func (llm *LLM) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
 	var opts C.struct_llama_sample_options
 	var opts C.struct_llama_sample_options
 	opts.repeat_penalty = C.float(llm.RepeatPenalty)
 	opts.repeat_penalty = C.float(llm.RepeatPenalty)
 	opts.frequency_penalty = C.float(llm.FrequencyPenalty)
 	opts.frequency_penalty = C.float(llm.FrequencyPenalty)
@@ -256,7 +256,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
 	return nil
 	return nil
 }
 }
 
 
-func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
+func (llm *LLM) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
 	numVocab := int(C.llama_n_vocab(llm.ctx))
 	numVocab := int(C.llama_n_vocab(llm.ctx))
 	logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
 	logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
 
 

+ 37 - 16
server/routes.go

@@ -11,6 +11,7 @@ import (
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"strings"
 	"strings"
+	"sync"
 	"time"
 	"time"
 
 
 	"dario.cat/mergo"
 	"dario.cat/mergo"
@@ -21,7 +22,17 @@ import (
 	"github.com/jmorganca/ollama/llama"
 	"github.com/jmorganca/ollama/llama"
 )
 )
 
 
+var mu sync.Mutex
+
+var activeSession struct {
+	ID int64
+	*llama.LLM
+}
+
 func GenerateHandler(c *gin.Context) {
 func GenerateHandler(c *gin.Context) {
+	mu.Lock()
+	defer mu.Unlock()
+
 	start := time.Now()
 	start := time.Now()
 
 
 	var req api.GenerateRequest
 	var req api.GenerateRequest
@@ -36,29 +47,38 @@ func GenerateHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	opts := api.DefaultOptions()
-	if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
+	if req.SessionID == 0 || req.SessionID != activeSession.ID {
+		if activeSession.LLM != nil {
+			activeSession.Close()
+			activeSession.LLM = nil
+		}
 
 
-	if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
+		opts := api.DefaultOptions()
+		if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
 
 
-	prompt, err := model.Prompt(req)
-	if err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
+		if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+
+		llm, err := llama.New(model.ModelPath, opts)
+		if err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+
+		activeSession.ID = time.Now().UnixNano()
+		activeSession.LLM = llm
 	}
 	}
 
 
-	llm, err := llama.New(model.ModelPath, opts)
+	prompt, err := model.Prompt(req)
 	if err != nil {
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
-	defer llm.Close()
 
 
 	ch := make(chan any)
 	ch := make(chan any)
 	go func() {
 	go func() {
@@ -66,6 +86,7 @@ func GenerateHandler(c *gin.Context) {
 		fn := func(r api.GenerateResponse) {
 		fn := func(r api.GenerateResponse) {
 			r.Model = req.Model
 			r.Model = req.Model
 			r.CreatedAt = time.Now().UTC()
 			r.CreatedAt = time.Now().UTC()
+			r.SessionID = activeSession.ID
 			if r.Done {
 			if r.Done {
 				r.TotalDuration = time.Since(start)
 				r.TotalDuration = time.Since(start)
 			}
 			}
@@ -73,7 +94,7 @@ func GenerateHandler(c *gin.Context) {
 			ch <- r
 			ch <- r
 		}
 		}
 
 
-		if err := llm.Predict(req.Context, prompt, fn); err != nil {
+		if err := activeSession.LLM.Predict(req.Context, prompt, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 			ch <- gin.H{"error": err.Error()}
 		}
 		}
 	}()
 	}()