Ver código fonte

Implement timings response in Go server

This implements the fields necessary for `run --verbose`
to generate timing information.
Daniel Hiltgen 9 meses atrás
pai
commit
8527028bf4
1 arquivos alterados com 65 adições e 9 exclusões
  1. 65 9
      llama/runner/runner.go

+ 65 - 9
llama/runner/runner.go

@@ -14,6 +14,7 @@ import (
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
+	"time"
 
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/llama"
 	"github.com/ollama/ollama/llama"
@@ -50,6 +51,12 @@ type Sequence struct {
 	embeddingOnly bool
 	embeddingOnly bool
 
 
 	doneReason string
 	doneReason string
+
+	// Metrics
+	t_start_process_prompt time.Time
+	t_start_genereration   time.Time
+	n_decoded              int
+	n_prompt_tokens        int
 }
 }
 
 
 // prompt returns true if the prompt is still being processed
 // prompt returns true if the prompt is still being processed
@@ -80,12 +87,13 @@ func (s *Server) NewSequence(prompt string, numPredict int, stop []string, param
 	}
 	}
 
 
 	return &Sequence{
 	return &Sequence{
-		tokens:        tokens,
-		responses:     make(chan string, 1),
-		embedding:     make(chan []float32, 1),
-		samplingCtx:   sc,
-		embeddingOnly: embedding,
-		stop:          stop,
+		tokens:          tokens,
+		n_prompt_tokens: len(tokens),
+		responses:       make(chan string, 1),
+		embedding:       make(chan []float32, 1),
+		samplingCtx:     sc,
+		embeddingOnly:   embedding,
+		stop:            stop,
 	}
 	}
 }
 }
 
 
@@ -161,6 +169,10 @@ func (s *Server) run(ctx context.Context) {
 					continue
 					continue
 				}
 				}
 
 
+				if seq.t_start_process_prompt.IsZero() {
+					seq.t_start_process_prompt = time.Now()
+				}
+
 				for j, t := range seq.tokens {
 				for j, t := range seq.tokens {
 					// todo: make this n_batch
 					// todo: make this n_batch
 					if j > s.batchSize {
 					if j > s.batchSize {
@@ -207,6 +219,10 @@ func (s *Server) run(ctx context.Context) {
 				token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch)
 				token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch)
 
 
 				seq.samplingCtx.Accept(s.lc, token, true)
 				seq.samplingCtx.Accept(s.lc, token, true)
+				seq.n_decoded += 1
+				if seq.n_decoded == 1 {
+					seq.t_start_genereration = time.Now()
+				}
 				piece := s.model.TokenToPiece(token)
 				piece := s.model.TokenToPiece(token)
 
 
 				seq.numPredicted++
 				seq.numPredicted++
@@ -278,8 +294,26 @@ type CompletionRequest struct {
 	api.Options
 	api.Options
 }
 }
 
 
+type Timings struct {
+	PredictedN  int     `json:"predicted_n"`
+	PredictedMS float64 `json:"predicted_ms"`
+	PromptN     int     `json:"prompt_n"`
+	PromptMS    float64 `json:"prompt_ms"`
+}
+
 type CompletionResponse struct {
 type CompletionResponse struct {
-	Token string `json:"token"`
+	Content string `json:"content"`
+	Stop    bool   `json:"stop"`
+
+	Model        string  `json:"model,omitempty"`
+	Prompt       string  `json:"prompt,omitempty"`
+	StoppedLimit bool    `json:"stopped_limit,omitempty"`
+	PredictedN   int     `json:"predicted_n,omitempty"`
+	PredictedMS  float64 `json:"predicted_ms,omitempty"`
+	PromptN      int     `json:"prompt_n,omitempty"`
+	PromptMS     float64 `json:"prompt_ms,omitempty"`
+
+	Timings Timings `json:"timings"`
 }
 }
 
 
 func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
@@ -326,9 +360,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	s.mu.Unlock()
 	s.mu.Unlock()
 
 
 	// stream the response
 	// stream the response
-	for token := range seq.responses {
+	for content := range seq.responses {
 		if err := json.NewEncoder(w).Encode(&CompletionResponse{
 		if err := json.NewEncoder(w).Encode(&CompletionResponse{
-			Token: token,
+			Content: content,
 		}); err != nil {
 		}); err != nil {
 			log.Println("Failed to encode result:", err)
 			log.Println("Failed to encode result:", err)
 			return
 			return
@@ -342,6 +376,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 
 
 		flusher.Flush()
 		flusher.Flush()
 	}
 	}
+
+	// Send the stop
+	if err := json.NewEncoder(w).Encode(&CompletionResponse{
+		Stop: true,
+		Timings: Timings{
+			PromptN:     seq.n_prompt_tokens,
+			PromptMS:    float64(seq.t_start_genereration.Sub(seq.t_start_process_prompt).Milliseconds()),
+			PredictedN:  seq.n_decoded,
+			PredictedMS: float64(time.Since(seq.t_start_genereration).Milliseconds()),
+		},
+	}); err != nil {
+		log.Println("Failed to encode result:", err)
+		return
+	}
+
+	flusher, ok := w.(http.Flusher)
+	if !ok {
+		http.Error(w, "Streaming not supported", http.StatusInternalServerError)
+		return
+	}
+
+	flusher.Flush()
 }
 }
 
 
 type EmbeddingRequest struct {
 type EmbeddingRequest struct {