Explorar o código

return more info in generate response

Michael Yang hai 1 ano
pai
achega
05e08d2310
Modificáronse 4 ficheiros con 116 adicións e 31 borrados
  1. 41 2
      api/types.go
  2. 24 9
      cmd/cmd.go
  3. 40 18
      llama/llama.go
  4. 11 2
      server/routes.go

+ 41 - 2
api/types.go

@@ -1,6 +1,11 @@
 package api
 package api
 
 
-import "runtime"
+import (
+	"fmt"
+	"os"
+	"runtime"
+	"time"
+)
 
 
 type PullRequest struct {
 type PullRequest struct {
 	Model string `json:"model"`
 	Model string `json:"model"`
@@ -20,7 +25,41 @@ type GenerateRequest struct {
 }
 }
 
 
 type GenerateResponse struct {
 type GenerateResponse struct {
-	Response string `json:"response"`
+	Model     string    `json:"model"`
+	CreatedAt time.Time `json:"created_at"`
+	Response  string    `json:"response,omitempty"`
+
+	Done bool `json:"done"`
+
+	TotalDuration      time.Duration `json:"total_duration,omitempty"`
+	PromptEvalCount    int           `json:"prompt_eval_count,omitempty"`
+	PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
+	EvalCount          int           `json:"eval_count,omitempty"`
+	EvalDuration       time.Duration `json:"eval_duration,omitempty"`
+}
+
+func (r *GenerateResponse) Summary() {
+	if r.TotalDuration > 0 {
+		fmt.Fprintf(os.Stderr, "total duration:       %v\n", r.TotalDuration)
+	}
+
+	if r.PromptEvalCount > 0 {
+		fmt.Fprintf(os.Stderr, "prompt eval count:    %d token(s)\n", r.PromptEvalCount)
+	}
+
+	if r.PromptEvalDuration > 0 {
+		fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration)
+		fmt.Fprintf(os.Stderr, "prompt eval rate:     %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds())
+	}
+
+	if r.EvalCount > 0 {
+		fmt.Fprintf(os.Stderr, "eval count:           %d token(s)\n", r.EvalCount)
+	}
+
+	if r.EvalDuration > 0 {
+		fmt.Fprintf(os.Stderr, "eval duraiton:        %s\n", r.EvalDuration)
+		fmt.Fprintf(os.Stderr, "eval rate:            %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds())
+	}
 }
 }
 
 
 type Options struct {
 type Options struct {

+ 24 - 9
cmd/cmd.go

@@ -72,20 +72,20 @@ func pull(model string) error {
 	)
 	)
 }
 }
 
 
-func RunGenerate(_ *cobra.Command, args []string) error {
+func RunGenerate(cmd *cobra.Command, args []string) error {
 	if len(args) > 1 {
 	if len(args) > 1 {
 		// join all args into a single prompt
 		// join all args into a single prompt
-		return generate(args[0], strings.Join(args[1:], " "))
+		return generate(cmd, args[0], strings.Join(args[1:], " "))
 	}
 	}
 
 
 	if term.IsTerminal(int(os.Stdin.Fd())) {
 	if term.IsTerminal(int(os.Stdin.Fd())) {
-		return generateInteractive(args[0])
+		return generateInteractive(cmd, args[0])
 	}
 	}
 
 
-	return generateBatch(args[0])
+	return generateBatch(cmd, args[0])
 }
 }
 
 
-func generate(model, prompt string) error {
+func generate(cmd *cobra.Command, model, prompt string) error {
 	if len(strings.TrimSpace(prompt)) > 0 {
 	if len(strings.TrimSpace(prompt)) > 0 {
 		client := api.NewClient()
 		client := api.NewClient()
 
 
@@ -108,12 +108,16 @@ func generate(model, prompt string) error {
 			}
 			}
 		}()
 		}()
 
 
+		var latest api.GenerateResponse
+
 		request := api.GenerateRequest{Model: model, Prompt: prompt}
 		request := api.GenerateRequest{Model: model, Prompt: prompt}
 		fn := func(resp api.GenerateResponse) error {
 		fn := func(resp api.GenerateResponse) error {
 			if !spinner.IsFinished() {
 			if !spinner.IsFinished() {
 				spinner.Finish()
 				spinner.Finish()
 			}
 			}
 
 
+			latest = resp
+
 			fmt.Print(resp.Response)
 			fmt.Print(resp.Response)
 			return nil
 			return nil
 		}
 		}
@@ -124,16 +128,25 @@ func generate(model, prompt string) error {
 
 
 		fmt.Println()
 		fmt.Println()
 		fmt.Println()
 		fmt.Println()
+
+		verbose, err := cmd.Flags().GetBool("verbose")
+		if err != nil {
+			return err
+		}
+
+		if verbose {
+			latest.Summary()
+		}
 	}
 	}
 
 
 	return nil
 	return nil
 }
 }
 
 
-func generateInteractive(model string) error {
+func generateInteractive(cmd *cobra.Command, model string) error {
 	fmt.Print(">>> ")
 	fmt.Print(">>> ")
 	scanner := bufio.NewScanner(os.Stdin)
 	scanner := bufio.NewScanner(os.Stdin)
 	for scanner.Scan() {
 	for scanner.Scan() {
-		if err := generate(model, scanner.Text()); err != nil {
+		if err := generate(cmd, model, scanner.Text()); err != nil {
 			return err
 			return err
 		}
 		}
 
 
@@ -143,12 +156,12 @@ func generateInteractive(model string) error {
 	return nil
 	return nil
 }
 }
 
 
-func generateBatch(model string) error {
+func generateBatch(cmd *cobra.Command, model string) error {
 	scanner := bufio.NewScanner(os.Stdin)
 	scanner := bufio.NewScanner(os.Stdin)
 	for scanner.Scan() {
 	for scanner.Scan() {
 		prompt := scanner.Text()
 		prompt := scanner.Text()
 		fmt.Printf(">>> %s\n", prompt)
 		fmt.Printf(">>> %s\n", prompt)
-		if err := generate(model, prompt); err != nil {
+		if err := generate(cmd, model, prompt); err != nil {
 			return err
 			return err
 		}
 		}
 	}
 	}
@@ -200,6 +213,8 @@ func NewCLI() *cobra.Command {
 		RunE:  RunRun,
 		RunE:  RunRun,
 	}
 	}
 
 
+	runCmd.Flags().Bool("verbose", false, "Show timings for response")
+
 	serveCmd := &cobra.Command{
 	serveCmd := &cobra.Command{
 		Use:     "serve",
 		Use:     "serve",
 		Aliases: []string{"start"},
 		Aliases: []string{"start"},

+ 40 - 18
llama/llama.go

@@ -79,9 +79,11 @@ llama_token llama_sample(
 import "C"
 import "C"
 import (
 import (
 	"errors"
 	"errors"
+	"fmt"
 	"io"
 	"io"
 	"os"
 	"os"
 	"strings"
 	"strings"
+	"time"
 	"unsafe"
 	"unsafe"
 
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/api"
@@ -147,7 +149,7 @@ func (llm *llama) Close() {
 	C.llama_print_timings(llm.ctx)
 	C.llama_print_timings(llm.ctx)
 }
 }
 
 
-func (llm *llama) Predict(prompt string, fn func(string)) error {
+func (llm *llama) Predict(prompt string, fn func(api.GenerateResponse)) error {
 	if tokens := llm.tokenize(prompt); tokens != nil {
 	if tokens := llm.tokenize(prompt); tokens != nil {
 		return llm.generate(tokens, fn)
 		return llm.generate(tokens, fn)
 	}
 	}
@@ -176,7 +178,7 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
 	return sb.String()
 	return sb.String()
 }
 }
 
 
-func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error {
+func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
 	var opts C.struct_llama_sample_options
 	var opts C.struct_llama_sample_options
 	opts.repeat_penalty = C.float(llm.RepeatPenalty)
 	opts.repeat_penalty = C.float(llm.RepeatPenalty)
 	opts.frequency_penalty = C.float(llm.FrequencyPenalty)
 	opts.frequency_penalty = C.float(llm.FrequencyPenalty)
@@ -190,38 +192,58 @@ func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error {
 	opts.mirostat_tau = C.float(llm.MirostatTau)
 	opts.mirostat_tau = C.float(llm.MirostatTau)
 	opts.mirostat_eta = C.float(llm.MirostatEta)
 	opts.mirostat_eta = C.float(llm.MirostatEta)
 
 
-	pastTokens := deque[C.llama_token]{capacity: llm.RepeatLastN}
+	output := deque[C.llama_token]{capacity: llm.NumCtx}
 
 
 	for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
 	for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
-		if retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
+		if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
 			return errors.New("llama: eval")
 			return errors.New("llama: eval")
 		}
 		}
 
 
-		token, err := llm.sample(pastTokens, &opts)
-		switch {
-		case errors.Is(err, io.EOF):
-			return nil
-		case err != nil:
+		token, err := llm.sample(output, &opts)
+		if errors.Is(err, io.EOF) {
+			break
+		} else if err != nil {
 			return err
 			return err
 		}
 		}
 
 
-		fn(llm.detokenize(token))
+		// call the callback
+		fn(api.GenerateResponse{
+			Response: llm.detokenize(token),
+		})
+
+		output.PushLeft(token)
+
+		input = []C.llama_token{token}
+	}
 
 
-		tokens = []C.llama_token{token}
+	dur := func(ms float64) time.Duration {
+		d, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
+		if err != nil {
+			panic(err)
+		}
 
 
-		pastTokens.PushLeft(token)
+		return d
 	}
 	}
 
 
+	timings := C.llama_get_timings(llm.ctx)
+	fn(api.GenerateResponse{
+		Done:               true,
+		PromptEvalCount:    int(timings.n_p_eval),
+		PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
+		EvalCount:          int(timings.n_eval),
+		EvalDuration:       dur(float64(timings.t_eval_ms)),
+	})
+
 	return nil
 	return nil
 }
 }
 
 
-func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
+func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
 	numVocab := int(C.llama_n_vocab(llm.ctx))
 	numVocab := int(C.llama_n_vocab(llm.ctx))
 	logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
 	logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
 
 
-	candidates := make([]C.struct_llama_token_data, 0, numVocab)
-	for i := 0; i < numVocab; i++ {
-		candidates = append(candidates, C.llama_token_data{
+	candidates := deque[C.struct_llama_token_data]{capacity: numVocab}
+	for i := 0; i < candidates.Cap(); i++ {
+		candidates.PushLeft(C.struct_llama_token_data{
 			id:    C.int(i),
 			id:    C.int(i),
 			logit: logits[i],
 			logit: logits[i],
 			p:     0,
 			p:     0,
@@ -230,8 +252,8 @@ func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_s
 
 
 	token := C.llama_sample(
 	token := C.llama_sample(
 		llm.ctx,
 		llm.ctx,
-		unsafe.SliceData(candidates), C.ulong(len(candidates)),
-		unsafe.SliceData(pastTokens.Data()), C.ulong(pastTokens.Len()),
+		unsafe.SliceData(candidates.Data()), C.ulong(candidates.Len()),
+		unsafe.SliceData(output.Data()), C.ulong(output.Len()),
 		opts)
 		opts)
 	if token != C.llama_token_eos() {
 	if token != C.llama_token_eos() {
 		return token, nil
 		return token, nil

+ 11 - 2
server/routes.go

@@ -13,6 +13,7 @@ import (
 	"path"
 	"path"
 	"strings"
 	"strings"
 	"text/template"
 	"text/template"
+	"time"
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 	"github.com/lithammer/fuzzysearch/fuzzy"
 	"github.com/lithammer/fuzzysearch/fuzzy"
@@ -35,6 +36,8 @@ func cacheDir() string {
 }
 }
 
 
 func generate(c *gin.Context) {
 func generate(c *gin.Context) {
+	start := time.Now()
+
 	req := api.GenerateRequest{
 	req := api.GenerateRequest{
 		Options: api.DefaultOptions(),
 		Options: api.DefaultOptions(),
 	}
 	}
@@ -81,8 +84,14 @@ func generate(c *gin.Context) {
 	}
 	}
 	defer llm.Close()
 	defer llm.Close()
 
 
-	fn := func(s string) {
-		ch <- api.GenerateResponse{Response: s}
+	fn := func(r api.GenerateResponse) {
+		r.Model = req.Model
+		r.CreatedAt = time.Now().UTC()
+		if r.Done {
+			r.TotalDuration = time.Since(start)
+		}
+
+		ch <- r
 	}
 	}
 
 
 	if err := llm.Predict(req.Prompt, fn); err != nil {
 	if err := llm.Predict(req.Prompt, fn); err != nil {