Pārlūkot izejas kodu

continue conversation

feed responses back into the llm
Michael Yang 1 gadu atpakaļ
vecāks
revīzija
1775647f76

+ 6 - 4
api/types.go

@@ -18,8 +18,9 @@ type PullProgress struct {
 }
 
 type GenerateRequest struct {
-	Model  string `json:"model"`
-	Prompt string `json:"prompt"`
+	Model   string `json:"model"`
+	Prompt  string `json:"prompt"`
+	Context []int  `json:"context,omitempty"`
 
 	Options `json:"options"`
 }
@@ -29,7 +30,8 @@ type GenerateResponse struct {
 	CreatedAt time.Time `json:"created_at"`
 	Response  string    `json:"response,omitempty"`
 
-	Done bool `json:"done"`
+	Done    bool  `json:"done"`
+	Context []int `json:"context,omitempty"`
 
 	TotalDuration      time.Duration `json:"total_duration,omitempty"`
 	PromptEvalCount    int           `json:"prompt_eval_count,omitempty"`
@@ -104,7 +106,7 @@ func DefaultOptions() Options {
 
 		UseNUMA: false,
 
-		NumCtx:   512,
+		NumCtx:   2048,
 		NumBatch: 512,
 		NumGPU:   1,
 		LowVRAM:  false,

+ 10 - 1
cmd/cmd.go

@@ -85,6 +85,8 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
 	return generateBatch(cmd, args[0])
 }
 
+var generateContextKey struct{}
+
 func generate(cmd *cobra.Command, model, prompt string) error {
 	if len(strings.TrimSpace(prompt)) > 0 {
 		client := api.NewClient()
@@ -110,7 +112,12 @@ func generate(cmd *cobra.Command, model, prompt string) error {
 
 		var latest api.GenerateResponse
 
-		request := api.GenerateRequest{Model: model, Prompt: prompt}
+		generateContext, ok := cmd.Context().Value(generateContextKey).([]int)
+		if !ok {
+			generateContext = []int{}
+		}
+
+		request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
 		fn := func(resp api.GenerateResponse) error {
 			if !spinner.IsFinished() {
 				spinner.Finish()
@@ -119,6 +126,8 @@ func generate(cmd *cobra.Command, model, prompt string) error {
 			latest = resp
 
 			fmt.Print(resp.Response)
+
+			cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context))
 			return nil
 		}
 

+ 15 - 3
llama/llama.go

@@ -149,9 +149,14 @@ func (llm *llama) Close() {
 	C.llama_print_timings(llm.ctx)
 }
 
-func (llm *llama) Predict(prompt string, fn func(api.GenerateResponse)) error {
-	if tokens := llm.tokenize(prompt); tokens != nil {
-		return llm.generate(tokens, fn)
+func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
+	if input := llm.tokenize(prompt); input != nil {
+		embd := make([]C.llama_token, len(ctx))
+		for i := range ctx {
+			embd[i] = C.llama_token(ctx[i])
+		}
+
+		return llm.generate(append(embd, input...), fn)
 	}
 
 	return errors.New("llama: tokenize")
@@ -194,6 +199,11 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
 
 	output := deque[C.llama_token]{capacity: llm.NumCtx}
 
+	context := deque[int]{capacity: llm.NumCtx / 2}
+	for _, in := range input {
+		context.PushLeft(int(in))
+	}
+
 	for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
 		if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
 			return errors.New("llama: eval")
@@ -212,6 +222,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
 		})
 
 		output.PushLeft(token)
+		context.PushLeft(int(token))
 
 		input = []C.llama_token{token}
 	}
@@ -228,6 +239,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
 	timings := C.llama_get_timings(llm.ctx)
 	fn(api.GenerateResponse{
 		Done:               true,
+		Context:            context.Data(),
 		PromptEvalCount:    int(timings.n_p_eval),
 		PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
 		EvalCount:          int(timings.n_eval),

+ 3 - 1
main.go

@@ -1,9 +1,11 @@
 package main
 
 import (
+	"context"
+
 	"github.com/jmorganca/ollama/cmd"
 )
 
 func main() {
-	cmd.NewCLI().Execute()
+	cmd.NewCLI().ExecuteContext(context.Background())
 }

+ 1 - 1
server/routes.go

@@ -94,7 +94,7 @@ func generate(c *gin.Context) {
 		ch <- r
 	}
 
-	if err := llm.Predict(req.Prompt, fn); err != nil {
+	if err := llm.Predict(req.Context, req.Prompt, fn); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}

+ 2 - 0
server/templates/alpaca.prompt

@@ -1,4 +1,6 @@
+{{- if not .Context }}
 Below is an instruction that describes a task. Write a response that appropriately completes the request.
+{{- end }}
 
 ### Instruction:
 {{ .Prompt }}

+ 2 - 0
server/templates/falcon.prompt

@@ -1,3 +1,5 @@
+{{- if not .Context }}
 A helpful assistant who helps the user with any questions asked.
+{{- end }}
 User: {{ .Prompt }}
 Assistant:

+ 2 - 0
server/templates/mpt.prompt

@@ -1,4 +1,6 @@
+{{- if not .Context }}
 Below is an instruction that describes a task. Write a response that appropriately completes the request. Be concise. Once the request is completed, include no other text.
+{{- end }}
 ### Instruction:
 {{ .Prompt }}
 ### Response:

+ 2 - 0
server/templates/orca.prompt

@@ -1,5 +1,7 @@
+{{- if not .Context }}
 ### System:
 You are an AI assistant that follows instruction extremely well. Help as much as you can.
+{{- end }}
 
 ### User:
 {{ .Prompt }}

+ 2 - 0
server/templates/vicuna.prompt

@@ -1,4 +1,6 @@
+{{ if not .Context }}
 A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
+{{- end }}
 
 USER: {{ .Prompt }}
 ASSISTANT:

+ 2 - 0
server/templates/wizardcoder.prompt

@@ -1,4 +1,6 @@
+{{- if not .Context }}
 Below is an instruction that describes a task. Write a response that appropriately completes the request
+{{- end }}
 
 ### Instruction: {{ .Prompt }}