|
@@ -24,6 +24,7 @@ import (
|
|
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/llama"
|
|
"github.com/ollama/ollama/llama"
|
|
|
|
+ "github.com/ollama/ollama/llm"
|
|
"github.com/ollama/ollama/runner/common"
|
|
"github.com/ollama/ollama/runner/common"
|
|
)
|
|
)
|
|
|
|
|
|
@@ -99,7 +100,7 @@ type NewSequenceParams struct {
|
|
embedding bool
|
|
embedding bool
|
|
}
|
|
}
|
|
|
|
|
|
-func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
|
|
|
|
|
+func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
|
s.ready.Wait()
|
|
s.ready.Wait()
|
|
|
|
|
|
startTime := time.Now()
|
|
startTime := time.Now()
|
|
@@ -163,7 +164,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|
// inputs processes the prompt and images into a list of inputs
|
|
// inputs processes the prompt and images into a list of inputs
|
|
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
|
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
|
// generating image embeddings for each image
|
|
// generating image embeddings for each image
|
|
-func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
|
|
|
|
|
+func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) {
|
|
var inputs []input
|
|
var inputs []input
|
|
var parts []string
|
|
var parts []string
|
|
var matches [][]string
|
|
var matches [][]string
|
|
@@ -229,7 +230,7 @@ type Server struct {
|
|
image *ImageContext
|
|
image *ImageContext
|
|
|
|
|
|
// status for external health reporting - loading, ready to serve, etc.
|
|
// status for external health reporting - loading, ready to serve, etc.
|
|
- status ServerStatus
|
|
|
|
|
|
+ status llm.ServerStatus
|
|
|
|
|
|
// current progress on loading the model
|
|
// current progress on loading the model
|
|
progress float32
|
|
progress float32
|
|
@@ -541,75 +542,18 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
-// TODO (jmorganca): use structs from the api package to avoid duplication
|
|
|
|
-// this way the api acts as a proxy instead of using a different api for the
|
|
|
|
-// runner
|
|
|
|
-type Options struct {
|
|
|
|
- api.Runner
|
|
|
|
-
|
|
|
|
- NumKeep int `json:"n_keep"`
|
|
|
|
- Seed int `json:"seed"`
|
|
|
|
- NumPredict int `json:"n_predict"`
|
|
|
|
- TopK int `json:"top_k"`
|
|
|
|
- TopP float32 `json:"top_p"`
|
|
|
|
- MinP float32 `json:"min_p"`
|
|
|
|
- TypicalP float32 `json:"typical_p"`
|
|
|
|
- RepeatLastN int `json:"repeat_last_n"`
|
|
|
|
- Temperature float32 `json:"temperature"`
|
|
|
|
- RepeatPenalty float32 `json:"repeat_penalty"`
|
|
|
|
- PresencePenalty float32 `json:"presence_penalty"`
|
|
|
|
- FrequencyPenalty float32 `json:"frequency_penalty"`
|
|
|
|
- Mirostat int `json:"mirostat"`
|
|
|
|
- MirostatTau float32 `json:"mirostat_tau"`
|
|
|
|
- MirostatEta float32 `json:"mirostat_eta"`
|
|
|
|
- Stop []string `json:"stop"`
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-type ImageData struct {
|
|
|
|
- Data []byte `json:"data"`
|
|
|
|
- ID int `json:"id"`
|
|
|
|
- AspectRatioID int `json:"aspect_ratio_id"`
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-type CompletionRequest struct {
|
|
|
|
- Prompt string `json:"prompt"`
|
|
|
|
- Images []ImageData `json:"image_data"`
|
|
|
|
- Grammar string `json:"grammar"`
|
|
|
|
- CachePrompt bool `json:"cache_prompt"`
|
|
|
|
-
|
|
|
|
- Options
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-type Timings struct {
|
|
|
|
- PredictedN int `json:"predicted_n"`
|
|
|
|
- PredictedMS float64 `json:"predicted_ms"`
|
|
|
|
- PromptN int `json:"prompt_n"`
|
|
|
|
- PromptMS float64 `json:"prompt_ms"`
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-type CompletionResponse struct {
|
|
|
|
- Content string `json:"content"`
|
|
|
|
- Stop bool `json:"stop"`
|
|
|
|
-
|
|
|
|
- Model string `json:"model,omitempty"`
|
|
|
|
- Prompt string `json:"prompt,omitempty"`
|
|
|
|
- StoppedLimit bool `json:"stopped_limit,omitempty"`
|
|
|
|
- PredictedN int `json:"predicted_n,omitempty"`
|
|
|
|
- PredictedMS float64 `json:"predicted_ms,omitempty"`
|
|
|
|
- PromptN int `json:"prompt_n,omitempty"`
|
|
|
|
- PromptMS float64 `json:"prompt_ms,omitempty"`
|
|
|
|
-
|
|
|
|
- Timings Timings `json:"timings"`
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
- var req CompletionRequest
|
|
|
|
- req.Options = Options(api.DefaultOptions())
|
|
|
|
|
|
+ var req llm.CompletionRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ if req.Options == nil {
|
|
|
|
+ opts := api.DefaultOptions()
|
|
|
|
+ req.Options = &opts
|
|
|
|
+ }
|
|
|
|
+
|
|
// Set the headers to indicate streaming
|
|
// Set the headers to indicate streaming
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Transfer-Encoding", "chunked")
|
|
w.Header().Set("Transfer-Encoding", "chunked")
|
|
@@ -620,26 +564,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- var samplingParams llama.SamplingParams
|
|
|
|
- samplingParams.TopK = req.TopK
|
|
|
|
- samplingParams.TopP = req.TopP
|
|
|
|
- samplingParams.MinP = req.MinP
|
|
|
|
- samplingParams.TypicalP = req.TypicalP
|
|
|
|
- samplingParams.Temp = req.Temperature
|
|
|
|
- samplingParams.RepeatLastN = req.RepeatLastN
|
|
|
|
- samplingParams.PenaltyRepeat = req.RepeatPenalty
|
|
|
|
- samplingParams.PenaltyFreq = req.FrequencyPenalty
|
|
|
|
- samplingParams.PenaltyPresent = req.PresencePenalty
|
|
|
|
- samplingParams.Mirostat = req.Mirostat
|
|
|
|
- samplingParams.MirostatTau = req.MirostatTau
|
|
|
|
- samplingParams.MirostatEta = req.MirostatEta
|
|
|
|
- samplingParams.Seed = uint32(req.Seed)
|
|
|
|
- samplingParams.Grammar = req.Grammar
|
|
|
|
|
|
+ // Extract options from the CompletionRequest
|
|
|
|
+ samplingParams := llama.SamplingParams{
|
|
|
|
+ TopK: req.Options.TopK,
|
|
|
|
+ TopP: req.Options.TopP,
|
|
|
|
+ MinP: req.Options.MinP,
|
|
|
|
+ TypicalP: req.Options.TypicalP,
|
|
|
|
+ Temp: req.Options.Temperature,
|
|
|
|
+ RepeatLastN: req.Options.RepeatLastN,
|
|
|
|
+ PenaltyRepeat: req.Options.RepeatPenalty,
|
|
|
|
+ PenaltyFreq: req.Options.FrequencyPenalty,
|
|
|
|
+ PenaltyPresent: req.Options.PresencePenalty,
|
|
|
|
+ Mirostat: req.Options.Mirostat,
|
|
|
|
+ MirostatTau: req.Options.MirostatTau,
|
|
|
|
+ MirostatEta: req.Options.MirostatEta,
|
|
|
|
+ Seed: uint32(req.Options.Seed),
|
|
|
|
+ Grammar: req.Grammar,
|
|
|
|
+ }
|
|
|
|
|
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
|
- numPredict: req.NumPredict,
|
|
|
|
- stop: req.Stop,
|
|
|
|
- numKeep: req.NumKeep,
|
|
|
|
|
|
+ numPredict: req.Options.NumPredict,
|
|
|
|
+ stop: req.Options.Stop,
|
|
|
|
+ numKeep: req.Options.NumKeep,
|
|
samplingParams: &samplingParams,
|
|
samplingParams: &samplingParams,
|
|
embedding: false,
|
|
embedding: false,
|
|
})
|
|
})
|
|
@@ -662,7 +608,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
found := false
|
|
found := false
|
|
for i, sq := range s.seqs {
|
|
for i, sq := range s.seqs {
|
|
if sq == nil {
|
|
if sq == nil {
|
|
- seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
|
|
|
|
|
+ seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
|
if err != nil {
|
|
if err != nil {
|
|
s.mu.Unlock()
|
|
s.mu.Unlock()
|
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
|
@@ -691,7 +637,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
return
|
|
return
|
|
case content, ok := <-seq.responses:
|
|
case content, ok := <-seq.responses:
|
|
if ok {
|
|
if ok {
|
|
- if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
|
|
|
|
|
+ if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
|
Content: content,
|
|
Content: content,
|
|
}); err != nil {
|
|
}); err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
@@ -702,15 +648,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
flusher.Flush()
|
|
flusher.Flush()
|
|
} else {
|
|
} else {
|
|
// Send the final response
|
|
// Send the final response
|
|
- if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
|
|
|
- Stop: true,
|
|
|
|
- StoppedLimit: seq.doneReason == "limit",
|
|
|
|
- Timings: Timings{
|
|
|
|
- PromptN: seq.numPromptInputs,
|
|
|
|
- PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
|
|
|
|
- PredictedN: seq.numDecoded,
|
|
|
|
- PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
|
|
|
|
- },
|
|
|
|
|
|
+ doneReason := "stop"
|
|
|
|
+ if seq.doneReason == "limit" {
|
|
|
|
+ doneReason = "length"
|
|
|
|
+ }
|
|
|
|
+ if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
|
|
|
+ Done: true,
|
|
|
|
+ DoneReason: doneReason,
|
|
|
|
+ PromptEvalCount: seq.numPromptInputs,
|
|
|
|
+ PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
|
|
|
+ EvalCount: seq.numDecoded,
|
|
|
|
+ EvalDuration: time.Since(seq.startGenerationTime),
|
|
}); err != nil {
|
|
}); err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
|
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
|
}
|
|
}
|
|
@@ -721,17 +669,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-type EmbeddingRequest struct {
|
|
|
|
- Content string `json:"content"`
|
|
|
|
- CachePrompt bool `json:"cache_prompt"`
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-type EmbeddingResponse struct {
|
|
|
|
- Embedding []float32 `json:"embedding"`
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|
- var req EmbeddingRequest
|
|
|
|
|
|
+ var req llm.EmbeddingRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
|
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
|
return
|
|
return
|
|
@@ -761,7 +700,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|
found := false
|
|
found := false
|
|
for i, sq := range s.seqs {
|
|
for i, sq := range s.seqs {
|
|
if sq == nil {
|
|
if sq == nil {
|
|
- seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
|
|
|
|
|
+ seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
|
|
if err != nil {
|
|
if err != nil {
|
|
s.mu.Unlock()
|
|
s.mu.Unlock()
|
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
|
@@ -782,41 +721,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
|
embedding := <-seq.embedding
|
|
embedding := <-seq.embedding
|
|
|
|
|
|
- if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
|
|
|
|
|
|
+ if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
|
Embedding: embedding,
|
|
Embedding: embedding,
|
|
}); err != nil {
|
|
}); err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-type HealthResponse struct {
|
|
|
|
- Status string `json:"status"`
|
|
|
|
- Progress float32 `json:"progress"`
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-type ServerStatus int
|
|
|
|
-
|
|
|
|
-const (
|
|
|
|
- ServerStatusReady ServerStatus = iota
|
|
|
|
- ServerStatusLoadingModel
|
|
|
|
- ServerStatusError
|
|
|
|
-)
|
|
|
|
-
|
|
|
|
-func (s ServerStatus) ToString() string {
|
|
|
|
- switch s {
|
|
|
|
- case ServerStatusReady:
|
|
|
|
- return "ok"
|
|
|
|
- case ServerStatusLoadingModel:
|
|
|
|
- return "loading model"
|
|
|
|
- default:
|
|
|
|
- return "server error"
|
|
|
|
- }
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
|
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Content-Type", "application/json")
|
|
- if err := json.NewEncoder(w).Encode(&HealthResponse{
|
|
|
|
- Status: s.status.ToString(),
|
|
|
|
|
|
+ if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
|
|
|
|
+ Status: s.status,
|
|
Progress: s.progress,
|
|
Progress: s.progress,
|
|
}); err != nil {
|
|
}); err != nil {
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
|
@@ -879,7 +794,7 @@ func (s *Server) loadModel(
|
|
panic(err)
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
|
|
- s.status = ServerStatusReady
|
|
|
|
|
|
+ s.status = llm.ServerStatusReady
|
|
s.ready.Done()
|
|
s.ready.Done()
|
|
}
|
|
}
|
|
|
|
|
|
@@ -937,7 +852,7 @@ func Execute(args []string) error {
|
|
parallel: *parallel,
|
|
parallel: *parallel,
|
|
seqs: make([]*Sequence, *parallel),
|
|
seqs: make([]*Sequence, *parallel),
|
|
seqsSem: semaphore.NewWeighted(int64(*parallel)),
|
|
seqsSem: semaphore.NewWeighted(int64(*parallel)),
|
|
- status: ServerStatusLoadingModel,
|
|
|
|
|
|
+ status: llm.ServerStatusLoadingModel,
|
|
}
|
|
}
|
|
|
|
|
|
var tensorSplitFloats []float32
|
|
var tensorSplitFloats []float32
|