Browse Source

continue conversation

feed responses back into the llm
Michael Yang 1 year ago
parent
commit
1775647f76

+ 6 - 4
api/types.go

@@ -18,8 +18,9 @@ type PullProgress struct {
 }
 }
 
 
 type GenerateRequest struct {
 type GenerateRequest struct {
-	Model  string `json:"model"`
-	Prompt string `json:"prompt"`
+	Model   string `json:"model"`
+	Prompt  string `json:"prompt"`
+	Context []int  `json:"context,omitempty"`
 
 
 	Options `json:"options"`
 	Options `json:"options"`
 }
 }
@@ -29,7 +30,8 @@ type GenerateResponse struct {
 	CreatedAt time.Time `json:"created_at"`
 	CreatedAt time.Time `json:"created_at"`
 	Response  string    `json:"response,omitempty"`
 	Response  string    `json:"response,omitempty"`
 
 
-	Done bool `json:"done"`
+	Done    bool  `json:"done"`
+	Context []int `json:"context,omitempty"`
 
 
 	TotalDuration      time.Duration `json:"total_duration,omitempty"`
 	TotalDuration      time.Duration `json:"total_duration,omitempty"`
 	PromptEvalCount    int           `json:"prompt_eval_count,omitempty"`
 	PromptEvalCount    int           `json:"prompt_eval_count,omitempty"`
@@ -104,7 +106,7 @@ func DefaultOptions() Options {
 
 
 		UseNUMA: false,
 		UseNUMA: false,
 
 
-		NumCtx:   512,
+		NumCtx:   2048,
 		NumBatch: 512,
 		NumBatch: 512,
 		NumGPU:   1,
 		NumGPU:   1,
 		LowVRAM:  false,
 		LowVRAM:  false,

+ 10 - 1
cmd/cmd.go

@@ -85,6 +85,8 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
 	return generateBatch(cmd, args[0])
 	return generateBatch(cmd, args[0])
 }
 }
 
 
+var generateContextKey struct{}
+
 func generate(cmd *cobra.Command, model, prompt string) error {
 func generate(cmd *cobra.Command, model, prompt string) error {
 	if len(strings.TrimSpace(prompt)) > 0 {
 	if len(strings.TrimSpace(prompt)) > 0 {
 		client := api.NewClient()
 		client := api.NewClient()
@@ -110,7 +112,12 @@ func generate(cmd *cobra.Command, model, prompt string) error {
 
 
 		var latest api.GenerateResponse
 		var latest api.GenerateResponse
 
 
-		request := api.GenerateRequest{Model: model, Prompt: prompt}
+		generateContext, ok := cmd.Context().Value(generateContextKey).([]int)
+		if !ok {
+			generateContext = []int{}
+		}
+
+		request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
 		fn := func(resp api.GenerateResponse) error {
 		fn := func(resp api.GenerateResponse) error {
 			if !spinner.IsFinished() {
 			if !spinner.IsFinished() {
 				spinner.Finish()
 				spinner.Finish()
@@ -119,6 +126,8 @@ func generate(cmd *cobra.Command, model, prompt string) error {
 			latest = resp
 			latest = resp
 
 
 			fmt.Print(resp.Response)
 			fmt.Print(resp.Response)
+
+			cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context))
 			return nil
 			return nil
 		}
 		}
 
 

+ 15 - 3
llama/llama.go

@@ -149,9 +149,14 @@ func (llm *llama) Close() {
 	C.llama_print_timings(llm.ctx)
 	C.llama_print_timings(llm.ctx)
 }
 }
 
 
-func (llm *llama) Predict(prompt string, fn func(api.GenerateResponse)) error {
-	if tokens := llm.tokenize(prompt); tokens != nil {
-		return llm.generate(tokens, fn)
+func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
+	if input := llm.tokenize(prompt); input != nil {
+		embd := make([]C.llama_token, len(ctx))
+		for i := range ctx {
+			embd[i] = C.llama_token(ctx[i])
+		}
+
+		return llm.generate(append(embd, input...), fn)
 	}
 	}
 
 
 	return errors.New("llama: tokenize")
 	return errors.New("llama: tokenize")
@@ -194,6 +199,11 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
 
 
 	output := deque[C.llama_token]{capacity: llm.NumCtx}
 	output := deque[C.llama_token]{capacity: llm.NumCtx}
 
 
+	context := deque[int]{capacity: llm.NumCtx / 2}
+	for _, in := range input {
+		context.PushLeft(int(in))
+	}
+
 	for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
 	for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
 		if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
 		if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
 			return errors.New("llama: eval")
 			return errors.New("llama: eval")
@@ -212,6 +222,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
 		})
 		})
 
 
 		output.PushLeft(token)
 		output.PushLeft(token)
+		context.PushLeft(int(token))
 
 
 		input = []C.llama_token{token}
 		input = []C.llama_token{token}
 	}
 	}
@@ -228,6 +239,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
 	timings := C.llama_get_timings(llm.ctx)
 	timings := C.llama_get_timings(llm.ctx)
 	fn(api.GenerateResponse{
 	fn(api.GenerateResponse{
 		Done:               true,
 		Done:               true,
+		Context:            context.Data(),
 		PromptEvalCount:    int(timings.n_p_eval),
 		PromptEvalCount:    int(timings.n_p_eval),
 		PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
 		PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
 		EvalCount:          int(timings.n_eval),
 		EvalCount:          int(timings.n_eval),

+ 3 - 1
main.go

@@ -1,9 +1,11 @@
 package main
 package main
 
 
 import (
 import (
+	"context"
+
 	"github.com/jmorganca/ollama/cmd"
 	"github.com/jmorganca/ollama/cmd"
 )
 )
 
 
 func main() {
 func main() {
-	cmd.NewCLI().Execute()
+	cmd.NewCLI().ExecuteContext(context.Background())
 }
 }

+ 1 - 1
server/routes.go

@@ -94,7 +94,7 @@ func generate(c *gin.Context) {
 		ch <- r
 		ch <- r
 	}
 	}
 
 
-	if err := llm.Predict(req.Prompt, fn); err != nil {
+	if err := llm.Predict(req.Context, req.Prompt, fn); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}

+ 2 - 0
server/templates/alpaca.prompt

@@ -1,4 +1,6 @@
+{{- if not .Context }}
 Below is an instruction that describes a task. Write a response that appropriately completes the request.
 Below is an instruction that describes a task. Write a response that appropriately completes the request.
+{{- end }}
 
 
 ### Instruction:
 ### Instruction:
 {{ .Prompt }}
 {{ .Prompt }}

+ 2 - 0
server/templates/falcon.prompt

@@ -1,3 +1,5 @@
+{{- if not .Context }}
 A helpful assistant who helps the user with any questions asked.
 A helpful assistant who helps the user with any questions asked.
+{{- end }}
 User: {{ .Prompt }}
 User: {{ .Prompt }}
 Assistant:
 Assistant:

+ 2 - 0
server/templates/mpt.prompt

@@ -1,4 +1,6 @@
+{{- if not .Context }}
 Below is an instruction that describes a task. Write a response that appropriately completes the request. Be concise. Once the request is completed, include no other text.
 Below is an instruction that describes a task. Write a response that appropriately completes the request. Be concise. Once the request is completed, include no other text.
+{{- end }}
 ### Instruction:
 ### Instruction:
 {{ .Prompt }}
 {{ .Prompt }}
 ### Response:
 ### Response:

+ 2 - 0
server/templates/orca.prompt

@@ -1,5 +1,7 @@
+{{- if not .Context }}
 ### System:
 ### System:
 You are an AI assistant that follows instruction extremely well. Help as much as you can.
 You are an AI assistant that follows instruction extremely well. Help as much as you can.
+{{- end }}
 
 
 ### User:
 ### User:
 {{ .Prompt }}
 {{ .Prompt }}

+ 2 - 0
server/templates/vicuna.prompt

@@ -1,4 +1,6 @@
+{{ if not .Context }}
 A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
 A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
+{{- end }}
 
 
 USER: {{ .Prompt }}
 USER: {{ .Prompt }}
 ASSISTANT:
 ASSISTANT:

+ 2 - 0
server/templates/wizardcoder.prompt

@@ -1,4 +1,6 @@
+{{- if not .Context }}
 Below is an instruction that describes a task. Write a response that appropriately completes the request
 Below is an instruction that describes a task. Write a response that appropriately completes the request
+{{- end }}
 
 
 ### Instruction: {{ .Prompt }}
 ### Instruction: {{ .Prompt }}