|
@@ -104,6 +104,7 @@ type NewSequenceParams struct {
|
|
numKeep int
|
|
numKeep int
|
|
samplingParams *llama.SamplingParams
|
|
samplingParams *llama.SamplingParams
|
|
embedding bool
|
|
embedding bool
|
|
|
|
+ logprobs int
|
|
}
|
|
}
|
|
|
|
|
|
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
|
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
|
@@ -164,6 +165,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|
embeddingOnly: params.embedding,
|
|
embeddingOnly: params.embedding,
|
|
stop: params.stop,
|
|
stop: params.stop,
|
|
numKeep: params.numKeep,
|
|
numKeep: params.numKeep,
|
|
|
|
+ logprobs: params.logprobs,
|
|
}, nil
|
|
}, nil
|
|
}
|
|
}
|
|
|
|
|
|
@@ -285,37 +287,34 @@ func flushPending(seq *Sequence) bool {
|
|
if len(seq.pendingResponses) == 0 {
|
|
if len(seq.pendingResponses) == 0 {
|
|
return true
|
|
return true
|
|
}
|
|
}
|
|
- content := ""
|
|
|
|
|
|
+ resps := []CompletionResponse{}
|
|
for _, resp := range seq.pendingResponses {
|
|
for _, resp := range seq.pendingResponses {
|
|
- content += resp.Content
|
|
|
|
|
|
+ resps = append(resps, resp)
|
|
}
|
|
}
|
|
seq.pendingResponses = []CompletionResponse{}
|
|
seq.pendingResponses = []CompletionResponse{}
|
|
|
|
|
|
- // Check if there are any partial UTF-8 characters remaining.
|
|
|
|
- // We already check and queue as we are generating but some may
|
|
|
|
- // still make it here:
|
|
|
|
- // - Sequence is ending, e.g. generation limit has been hit
|
|
|
|
- // - Invalid characters in the middle of a string
|
|
|
|
- // This is a stricter check to ensure we never output invalid Unicode.
|
|
|
|
- for !utf8.ValidString(content) {
|
|
|
|
- content = content[:len(content)-1]
|
|
|
|
- }
|
|
|
|
|
|
+ // TODO: figure out this result logic
|
|
|
|
+ result := false
|
|
|
|
+ for _, resp := range resps {
|
|
|
|
+ // Check if there are any partial UTF-8 characters remaining.
|
|
|
|
+ // We already check and queue as we are generating but some may
|
|
|
|
+ // still make it here:
|
|
|
|
+ // - Sequence is ending, e.g. generation limit has been hit
|
|
|
|
+ // - Invalid characters in the middle of a string
|
|
|
|
+ // This is a stricter check to ensure we never output invalid Unicode.
|
|
|
|
+ for !utf8.ValidString(resp.Content) {
|
|
|
|
+ resp.Content = resp.Content[:len(resp.Content)-1]
|
|
|
|
+ }
|
|
|
|
|
|
- // Add logits if requested and available
|
|
|
|
- wantLogits := true
|
|
|
|
- if wantLogits && seq.logits != nil {
|
|
|
|
- // resp.Logits = seq.logits
|
|
|
|
- seq.logits = nil
|
|
|
|
|
|
+ select {
|
|
|
|
+ case seq.responses <- resp:
|
|
|
|
+ result = true
|
|
|
|
+ case <-seq.quit:
|
|
|
|
+ result = false
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
- select {
|
|
|
|
- case seq.responses <- CompletionResponse{
|
|
|
|
- Content: content,
|
|
|
|
- }:
|
|
|
|
- return true
|
|
|
|
- case <-seq.quit:
|
|
|
|
- return false
|
|
|
|
- }
|
|
|
|
|
|
+ return result
|
|
}
|
|
}
|
|
|
|
|
|
func (s *Server) removeSequence(seqIndex int, reason string) {
|
|
func (s *Server) removeSequence(seqIndex int, reason string) {
|
|
@@ -371,10 +370,11 @@ func (s *Server) run(ctx context.Context) {
|
|
|
|
|
|
// TokenProbs represents probability information for a token
|
|
// TokenProbs represents probability information for a token
|
|
type TokenProbs struct {
|
|
type TokenProbs struct {
|
|
- TokenID int
|
|
|
|
- Logit float32
|
|
|
|
- Prob float32
|
|
|
|
- LogProb float32
|
|
|
|
|
|
+ TokenID int `json:"id"`
|
|
|
|
+ Logit float32 `json:"logit"`
|
|
|
|
+ Prob float32 `json:"prob"`
|
|
|
|
+ LogProb float32 `json:"logprob"`
|
|
|
|
+ Token string `json:"token"`
|
|
}
|
|
}
|
|
|
|
|
|
// probs returns sorted token probabilities for a specific token index
|
|
// probs returns sorted token probabilities for a specific token index
|
|
@@ -553,9 +553,17 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
|
|
|
|
seq.numPredicted++
|
|
seq.numPredicted++
|
|
|
|
|
|
|
|
+ resp := CompletionResponse{Content: piece}
|
|
|
|
+
|
|
if seq.logprobs > 0 {
|
|
if seq.logprobs > 0 {
|
|
// TODO: return selected token in logprobs always
|
|
// TODO: return selected token in logprobs always
|
|
- // probs := s.probs(seq)
|
|
|
|
|
|
+ resp.LogProbs = s.probs(seq)
|
|
|
|
+ // TODO: fix this logprobs limit
|
|
|
|
+ resp.LogProbs = resp.LogProbs[:min(len(resp.LogProbs), seq.logprobs)]
|
|
|
|
+ for i := range resp.LogProbs {
|
|
|
|
+ // decode the token id to a piece
|
|
|
|
+ resp.LogProbs[i].Token = s.model.TokenToPiece(resp.LogProbs[i].TokenID)
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
// if it's an end of sequence token, break
|
|
// if it's an end of sequence token, break
|
|
@@ -571,7 +579,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
seq.inputs = []input{{token: token}}
|
|
seq.inputs = []input{{token: token}}
|
|
|
|
|
|
// TODO: add probs here
|
|
// TODO: add probs here
|
|
- seq.pendingResponses = append(seq.pendingResponses, CompletionResponse{Content: piece})
|
|
|
|
|
|
+ seq.pendingResponses = append(seq.pendingResponses, resp)
|
|
var sequence string
|
|
var sequence string
|
|
for _, r := range seq.pendingResponses {
|
|
for _, r := range seq.pendingResponses {
|
|
sequence += r.Content
|
|
sequence += r.Content
|
|
@@ -580,10 +588,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
if ok, stop := findStop(sequence, seq.stop); ok {
|
|
if ok, stop := findStop(sequence, seq.stop); ok {
|
|
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
|
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
|
|
|
|
|
|
|
+ // TODO: fix this stop sequence caching
|
|
var tokenTruncated bool
|
|
var tokenTruncated bool
|
|
- origLen := len(seq.pendingResponses)
|
|
|
|
- seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop)
|
|
|
|
- newLen := len(seq.pendingResponses)
|
|
|
|
|
|
+ origLen := len(sequence)
|
|
|
|
+ sequence, tokenTruncated = truncateStop(sequence, stop)
|
|
|
|
+ newLen := len(sequence)
|
|
|
|
|
|
// Update the cache based on the tokens that will be returned:
|
|
// Update the cache based on the tokens that will be returned:
|
|
// - We have 1 token more than is currently in the cache because
|
|
// - We have 1 token more than is currently in the cache because
|
|
@@ -654,6 +663,7 @@ type CompletionRequest struct {
|
|
Images []ImageData `json:"image_data"`
|
|
Images []ImageData `json:"image_data"`
|
|
Grammar string `json:"grammar"`
|
|
Grammar string `json:"grammar"`
|
|
CachePrompt bool `json:"cache_prompt"`
|
|
CachePrompt bool `json:"cache_prompt"`
|
|
|
|
+ Logprobs int `json:"logprobs,omitempty"`
|
|
|
|
|
|
Options
|
|
Options
|
|
}
|
|
}
|
|
@@ -669,8 +679,10 @@ type CompletionResponse struct {
|
|
Content string `json:"content"`
|
|
Content string `json:"content"`
|
|
Stop bool `json:"stop"`
|
|
Stop bool `json:"stop"`
|
|
|
|
|
|
- Model string `json:"model,omitempty"`
|
|
|
|
- Prompt string `json:"prompt,omitempty"`
|
|
|
|
|
|
+ Model string `json:"model,omitempty"`
|
|
|
|
+ Prompt string `json:"prompt,omitempty"`
|
|
|
|
+ LogProbs []TokenProbs `json:"logprobs,omitempty"`
|
|
|
|
+
|
|
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
|
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
|
PredictedN int `json:"predicted_n,omitempty"`
|
|
PredictedN int `json:"predicted_n,omitempty"`
|
|
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
|
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
|
@@ -688,10 +700,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- // Set the headers to indicate streaming
|
|
|
|
- w.Header().Set("Content-Type", "application/json")
|
|
|
|
- w.Header().Set("Transfer-Encoding", "chunked")
|
|
|
|
-
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
if !ok {
|
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
|
@@ -720,6 +728,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
numKeep: req.NumKeep,
|
|
numKeep: req.NumKeep,
|
|
samplingParams: &samplingParams,
|
|
samplingParams: &samplingParams,
|
|
embedding: false,
|
|
embedding: false,
|
|
|
|
+ logprobs: req.Logprobs,
|
|
})
|
|
})
|
|
if err != nil {
|
|
if err != nil {
|
|
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
|
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
|
@@ -769,6 +778,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
return
|
|
return
|
|
case resp, ok := <-seq.responses:
|
|
case resp, ok := <-seq.responses:
|
|
if ok {
|
|
if ok {
|
|
|
|
+ fmt.Println("response", resp)
|
|
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
|
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
close(seq.quit)
|
|
close(seq.quit)
|