Browse Source

llm: remove internal subprocess req and resp types (#9324)

This commit refactors the LLM subsystem by removing internal subprocess
request and response types. It consolidates duplicate type definitions
across the codebase, moving them to centralized locations. The change also
standardizes interfaces between components, simplifies the ServerStatusResp
struct, and moves the ParseDurationMs function to a common package. This
cleanup reduces code duplication between different runner implementations
(llamarunner and ollamarunner).
Bruce MacDonald 1 month ago
parent
commit
3892c3a703
4 changed files with 125 additions and 354 deletions
  1. 39 97
      llm/server.go
  2. 50 135
      runner/llamarunner/runner.go
  3. 1 0
      runner/ollamarunner/cache.go
  4. 35 122
      runner/ollamarunner/runner.go

+ 39 - 97
llm/server.go

@@ -402,7 +402,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
 			s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
 			s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
 		}
 		}
 
 
-		slog.Info("starting llama server", "cmd", s.cmd.String())
+		slog.Info("starting llama server", "cmd", s.cmd)
 		if envconfig.Debug() {
 		if envconfig.Debug() {
 			filteredEnv := []string{}
 			filteredEnv := []string{}
 			for _, ev := range s.cmd.Env {
 			for _, ev := range s.cmd.Env {
@@ -470,7 +470,7 @@ const ( // iota is reset to 0
 	ServerStatusError
 	ServerStatusError
 )
 )
 
 
-func (s ServerStatus) ToString() string {
+func (s ServerStatus) String() string {
 	switch s {
 	switch s {
 	case ServerStatusReady:
 	case ServerStatusReady:
 		return "llm server ready"
 		return "llm server ready"
@@ -485,12 +485,9 @@ func (s ServerStatus) ToString() string {
 	}
 	}
 }
 }
 
 
-type ServerStatusResp struct {
-	Status          string  `json:"status"`
-	SlotsIdle       int     `json:"slots_idle"`
-	SlotsProcessing int     `json:"slots_processing"`
-	Error           string  `json:"error"`
-	Progress        float32 `json:"progress"`
+type ServerStatusResponse struct {
+	Status   ServerStatus `json:"status"`
+	Progress float32      `json:"progress"`
 }
 }
 
 
 func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
 func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
@@ -502,7 +499,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
 		}
 		}
 		if s.cmd.ProcessState.ExitCode() == -1 {
 		if s.cmd.ProcessState.ExitCode() == -1 {
 			// Most likely a signal killed it, log some more details to try to help troubleshoot
 			// Most likely a signal killed it, log some more details to try to help troubleshoot
-			slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
+			slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState)
 		}
 		}
 		return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
 		return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
 	}
 	}
@@ -527,21 +524,19 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
 		return ServerStatusError, fmt.Errorf("read health request: %w", err)
 		return ServerStatusError, fmt.Errorf("read health request: %w", err)
 	}
 	}
 
 
-	var status ServerStatusResp
-	if err := json.Unmarshal(body, &status); err != nil {
+	var ssr ServerStatusResponse
+	if err := json.Unmarshal(body, &ssr); err != nil {
 		return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
 		return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
 	}
 	}
 
 
-	switch status.Status {
-	case "ok":
-		return ServerStatusReady, nil
-	case "no slot available":
-		return ServerStatusNoSlotsAvailable, nil
-	case "loading model":
-		s.loadProgress = status.Progress
-		return ServerStatusLoadingModel, nil
+	switch ssr.Status {
+	case ServerStatusLoadingModel:
+		s.loadProgress = ssr.Progress
+		return ssr.Status, nil
+	case ServerStatusReady, ServerStatusNoSlotsAvailable:
+		return ssr.Status, nil
 	default:
 	default:
-		return ServerStatusError, fmt.Errorf("server error: %+v", status)
+		return ssr.Status, fmt.Errorf("server error: %+v", ssr)
 	}
 	}
 }
 }
 
 
@@ -616,7 +611,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
 		status, _ := s.getServerStatus(ctx)
 		status, _ := s.getServerStatus(ctx)
 		if lastStatus != status && status != ServerStatusReady {
 		if lastStatus != status && status != ServerStatusReady {
 			// Only log on status changes
 			// Only log on status changes
-			slog.Info("waiting for server to become available", "status", status.ToString())
+			slog.Info("waiting for server to become available", "status", status)
 		}
 		}
 		switch status {
 		switch status {
 		case ServerStatusReady:
 		case ServerStatusReady:
@@ -630,7 +625,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
 				slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
 				slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
 				stallTimer = time.Now().Add(stallDuration)
 				stallTimer = time.Now().Add(stallDuration)
 			} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
 			} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
-				slog.Debug("model load completed, waiting for server to become available", "status", status.ToString())
+				slog.Debug("model load completed, waiting for server to become available", "status", status)
 				stallTimer = time.Now().Add(stallDuration)
 				stallTimer = time.Now().Add(stallDuration)
 				fullyLoaded = true
 				fullyLoaded = true
 			}
 			}
@@ -671,63 +666,26 @@ type ImageData struct {
 	AspectRatioID int    `json:"aspect_ratio_id"`
 	AspectRatioID int    `json:"aspect_ratio_id"`
 }
 }
 
 
-type completion struct {
-	Content      string `json:"content"`
-	Model        string `json:"model"`
-	Prompt       string `json:"prompt"`
-	Stop         bool   `json:"stop"`
-	StoppedLimit bool   `json:"stopped_limit"`
-
-	Timings struct {
-		PredictedN  int     `json:"predicted_n"`
-		PredictedMS float64 `json:"predicted_ms"`
-		PromptN     int     `json:"prompt_n"`
-		PromptMS    float64 `json:"prompt_ms"`
-	}
-}
-
 type CompletionRequest struct {
 type CompletionRequest struct {
 	Prompt  string
 	Prompt  string
 	Format  json.RawMessage
 	Format  json.RawMessage
 	Images  []ImageData
 	Images  []ImageData
 	Options *api.Options
 	Options *api.Options
+
+	Grammar string // set before sending the request to the subprocess
 }
 }
 
 
 type CompletionResponse struct {
 type CompletionResponse struct {
-	Content            string
-	DoneReason         string
-	Done               bool
-	PromptEvalCount    int
-	PromptEvalDuration time.Duration
-	EvalCount          int
-	EvalDuration       time.Duration
+	Content            string        `json:"content"`
+	DoneReason         string        `json:"done_reason"`
+	Done               bool          `json:"done"`
+	PromptEvalCount    int           `json:"prompt_eval_count"`
+	PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
+	EvalCount          int           `json:"eval_count"`
+	EvalDuration       time.Duration `json:"eval_duration"`
 }
 }
 
 
 func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
 func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
-	request := map[string]any{
-		"prompt":            req.Prompt,
-		"stream":            true,
-		"n_predict":         req.Options.NumPredict,
-		"n_keep":            req.Options.NumKeep,
-		"main_gpu":          req.Options.MainGPU,
-		"temperature":       req.Options.Temperature,
-		"top_k":             req.Options.TopK,
-		"top_p":             req.Options.TopP,
-		"min_p":             req.Options.MinP,
-		"typical_p":         req.Options.TypicalP,
-		"repeat_last_n":     req.Options.RepeatLastN,
-		"repeat_penalty":    req.Options.RepeatPenalty,
-		"presence_penalty":  req.Options.PresencePenalty,
-		"frequency_penalty": req.Options.FrequencyPenalty,
-		"mirostat":          req.Options.Mirostat,
-		"mirostat_tau":      req.Options.MirostatTau,
-		"mirostat_eta":      req.Options.MirostatEta,
-		"seed":              req.Options.Seed,
-		"stop":              req.Options.Stop,
-		"image_data":        req.Images,
-		"cache_prompt":      true,
-	}
-
 	if len(req.Format) > 0 {
 	if len(req.Format) > 0 {
 		switch string(req.Format) {
 		switch string(req.Format) {
 		case `null`, `""`:
 		case `null`, `""`:
@@ -735,7 +693,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 			// these as "not set".
 			// these as "not set".
 			break
 			break
 		case `"json"`:
 		case `"json"`:
-			request["grammar"] = grammarJSON
+			req.Grammar = grammarJSON
 		default:
 		default:
 			if req.Format[0] != '{' {
 			if req.Format[0] != '{' {
 				return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
 				return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
@@ -746,10 +704,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 			if g == nil {
 			if g == nil {
 				return fmt.Errorf("invalid JSON schema in format")
 				return fmt.Errorf("invalid JSON schema in format")
 			}
 			}
-			request["grammar"] = string(g)
+			req.Grammar = string(g)
 		}
 		}
 	}
 	}
 
 
+	if req.Options == nil {
+		opts := api.DefaultOptions()
+		req.Options = &opts
+	}
+
 	if err := s.sem.Acquire(ctx, 1); err != nil {
 	if err := s.sem.Acquire(ctx, 1); err != nil {
 		if errors.Is(err, context.Canceled) {
 		if errors.Is(err, context.Canceled) {
 			slog.Info("aborting completion request due to client closing the connection")
 			slog.Info("aborting completion request due to client closing the connection")
@@ -770,7 +733,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	} else if status != ServerStatusReady {
 	} else if status != ServerStatusReady {
-		return fmt.Errorf("unexpected server status: %s", status.ToString())
+		return fmt.Errorf("unexpected server status: %s", status)
 	}
 	}
 
 
 	// Handling JSON marshaling with special characters unescaped.
 	// Handling JSON marshaling with special characters unescaped.
@@ -778,7 +741,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 	enc := json.NewEncoder(buffer)
 	enc := json.NewEncoder(buffer)
 	enc.SetEscapeHTML(false)
 	enc.SetEscapeHTML(false)
 
 
-	if err := enc.Encode(request); err != nil {
+	if err := enc.Encode(req); err != nil {
 		return fmt.Errorf("failed to marshal data: %v", err)
 		return fmt.Errorf("failed to marshal data: %v", err)
 	}
 	}
 
 
@@ -829,7 +792,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 				evt = line
 				evt = line
 			}
 			}
 
 
-			var c completion
+			var c CompletionResponse
 			if err := json.Unmarshal(evt, &c); err != nil {
 			if err := json.Unmarshal(evt, &c); err != nil {
 				return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
 				return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
 			}
 			}
@@ -853,20 +816,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 				})
 				})
 			}
 			}
 
 
-			if c.Stop {
-				doneReason := "stop"
-				if c.StoppedLimit {
-					doneReason = "length"
-				}
-
-				fn(CompletionResponse{
-					Done:               true,
-					DoneReason:         doneReason,
-					PromptEvalCount:    c.Timings.PromptN,
-					PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
-					EvalCount:          c.Timings.PredictedN,
-					EvalDuration:       parseDurationMs(c.Timings.PredictedMS),
-				})
+			if c.Done {
+				fn(c)
 				return nil
 				return nil
 			}
 			}
 		}
 		}
@@ -914,7 +865,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	} else if status != ServerStatusReady {
 	} else if status != ServerStatusReady {
-		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
+		return nil, fmt.Errorf("unexpected server status: %s", status)
 	}
 	}
 
 
 	data, err := json.Marshal(EmbeddingRequest{Content: input})
 	data, err := json.Marshal(EmbeddingRequest{Content: input})
@@ -1059,12 +1010,3 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
 	}
 	}
 	return 0
 	return 0
 }
 }
-
-func parseDurationMs(ms float64) time.Duration {
-	dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
-	if err != nil {
-		panic(err)
-	}
-
-	return dur
-}

+ 50 - 135
runner/llamarunner/runner.go

@@ -24,6 +24,7 @@ import (
 
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/llama"
 	"github.com/ollama/ollama/llama"
+	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/runner/common"
 	"github.com/ollama/ollama/runner/common"
 )
 )
 
 
@@ -99,7 +100,7 @@ type NewSequenceParams struct {
 	embedding      bool
 	embedding      bool
 }
 }
 
 
-func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
+func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
 	s.ready.Wait()
 	s.ready.Wait()
 
 
 	startTime := time.Now()
 	startTime := time.Now()
@@ -163,7 +164,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
 // inputs processes the prompt and images into a list of inputs
 // inputs processes the prompt and images into a list of inputs
 // by splitting the prompt on [img-<n>] tags, tokenizing text and
 // by splitting the prompt on [img-<n>] tags, tokenizing text and
 // generating image embeddings for each image
 // generating image embeddings for each image
-func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
+func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) {
 	var inputs []input
 	var inputs []input
 	var parts []string
 	var parts []string
 	var matches [][]string
 	var matches [][]string
@@ -229,7 +230,7 @@ type Server struct {
 	image *ImageContext
 	image *ImageContext
 
 
 	// status for external health reporting - loading, ready to serve, etc.
 	// status for external health reporting - loading, ready to serve, etc.
-	status ServerStatus
+	status llm.ServerStatus
 
 
 	// current progress on loading the model
 	// current progress on loading the model
 	progress float32
 	progress float32
@@ -541,75 +542,18 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 	return nil
 	return nil
 }
 }
 
 
-// TODO (jmorganca): use structs from the api package to avoid duplication
-// this way the api acts as a proxy instead of using a different api for the
-// runner
-type Options struct {
-	api.Runner
-
-	NumKeep          int      `json:"n_keep"`
-	Seed             int      `json:"seed"`
-	NumPredict       int      `json:"n_predict"`
-	TopK             int      `json:"top_k"`
-	TopP             float32  `json:"top_p"`
-	MinP             float32  `json:"min_p"`
-	TypicalP         float32  `json:"typical_p"`
-	RepeatLastN      int      `json:"repeat_last_n"`
-	Temperature      float32  `json:"temperature"`
-	RepeatPenalty    float32  `json:"repeat_penalty"`
-	PresencePenalty  float32  `json:"presence_penalty"`
-	FrequencyPenalty float32  `json:"frequency_penalty"`
-	Mirostat         int      `json:"mirostat"`
-	MirostatTau      float32  `json:"mirostat_tau"`
-	MirostatEta      float32  `json:"mirostat_eta"`
-	Stop             []string `json:"stop"`
-}
-
-type ImageData struct {
-	Data          []byte `json:"data"`
-	ID            int    `json:"id"`
-	AspectRatioID int    `json:"aspect_ratio_id"`
-}
-
-type CompletionRequest struct {
-	Prompt      string      `json:"prompt"`
-	Images      []ImageData `json:"image_data"`
-	Grammar     string      `json:"grammar"`
-	CachePrompt bool        `json:"cache_prompt"`
-
-	Options
-}
-
-type Timings struct {
-	PredictedN  int     `json:"predicted_n"`
-	PredictedMS float64 `json:"predicted_ms"`
-	PromptN     int     `json:"prompt_n"`
-	PromptMS    float64 `json:"prompt_ms"`
-}
-
-type CompletionResponse struct {
-	Content string `json:"content"`
-	Stop    bool   `json:"stop"`
-
-	Model        string  `json:"model,omitempty"`
-	Prompt       string  `json:"prompt,omitempty"`
-	StoppedLimit bool    `json:"stopped_limit,omitempty"`
-	PredictedN   int     `json:"predicted_n,omitempty"`
-	PredictedMS  float64 `json:"predicted_ms,omitempty"`
-	PromptN      int     `json:"prompt_n,omitempty"`
-	PromptMS     float64 `json:"prompt_ms,omitempty"`
-
-	Timings Timings `json:"timings"`
-}
-
 func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
-	var req CompletionRequest
-	req.Options = Options(api.DefaultOptions())
+	var req llm.CompletionRequest
 	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
 	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
 		http.Error(w, "Bad request", http.StatusBadRequest)
 		http.Error(w, "Bad request", http.StatusBadRequest)
 		return
 		return
 	}
 	}
 
 
+	if req.Options == nil {
+		opts := api.DefaultOptions()
+		req.Options = &opts
+	}
+
 	// Set the headers to indicate streaming
 	// Set the headers to indicate streaming
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Transfer-Encoding", "chunked")
 	w.Header().Set("Transfer-Encoding", "chunked")
@@ -620,26 +564,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 		return
 		return
 	}
 	}
 
 
-	var samplingParams llama.SamplingParams
-	samplingParams.TopK = req.TopK
-	samplingParams.TopP = req.TopP
-	samplingParams.MinP = req.MinP
-	samplingParams.TypicalP = req.TypicalP
-	samplingParams.Temp = req.Temperature
-	samplingParams.RepeatLastN = req.RepeatLastN
-	samplingParams.PenaltyRepeat = req.RepeatPenalty
-	samplingParams.PenaltyFreq = req.FrequencyPenalty
-	samplingParams.PenaltyPresent = req.PresencePenalty
-	samplingParams.Mirostat = req.Mirostat
-	samplingParams.MirostatTau = req.MirostatTau
-	samplingParams.MirostatEta = req.MirostatEta
-	samplingParams.Seed = uint32(req.Seed)
-	samplingParams.Grammar = req.Grammar
+	// Extract options from the CompletionRequest
+	samplingParams := llama.SamplingParams{
+		TopK:           req.Options.TopK,
+		TopP:           req.Options.TopP,
+		MinP:           req.Options.MinP,
+		TypicalP:       req.Options.TypicalP,
+		Temp:           req.Options.Temperature,
+		RepeatLastN:    req.Options.RepeatLastN,
+		PenaltyRepeat:  req.Options.RepeatPenalty,
+		PenaltyFreq:    req.Options.FrequencyPenalty,
+		PenaltyPresent: req.Options.PresencePenalty,
+		Mirostat:       req.Options.Mirostat,
+		MirostatTau:    req.Options.MirostatTau,
+		MirostatEta:    req.Options.MirostatEta,
+		Seed:           uint32(req.Options.Seed),
+		Grammar:        req.Grammar,
+	}
 
 
 	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
 	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
-		numPredict:     req.NumPredict,
-		stop:           req.Stop,
-		numKeep:        req.NumKeep,
+		numPredict:     req.Options.NumPredict,
+		stop:           req.Options.Stop,
+		numKeep:        req.Options.NumKeep,
 		samplingParams: &samplingParams,
 		samplingParams: &samplingParams,
 		embedding:      false,
 		embedding:      false,
 	})
 	})
@@ -662,7 +608,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	found := false
 	found := false
 	for i, sq := range s.seqs {
 	for i, sq := range s.seqs {
 		if sq == nil {
 		if sq == nil {
-			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
+			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
 			if err != nil {
 			if err != nil {
 				s.mu.Unlock()
 				s.mu.Unlock()
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -691,7 +637,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 			return
 			return
 		case content, ok := <-seq.responses:
 		case content, ok := <-seq.responses:
 			if ok {
 			if ok {
-				if err := json.NewEncoder(w).Encode(&CompletionResponse{
+				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
 					Content: content,
 					Content: content,
 				}); err != nil {
 				}); err != nil {
 					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
 					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -702,15 +648,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 				flusher.Flush()
 				flusher.Flush()
 			} else {
 			} else {
 				// Send the final response
 				// Send the final response
-				if err := json.NewEncoder(w).Encode(&CompletionResponse{
-					Stop:         true,
-					StoppedLimit: seq.doneReason == "limit",
-					Timings: Timings{
-						PromptN:     seq.numPromptInputs,
-						PromptMS:    float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
-						PredictedN:  seq.numDecoded,
-						PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
-					},
+				doneReason := "stop"
+				if seq.doneReason == "limit" {
+					doneReason = "length"
+				}
+				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
+					Done:               true,
+					DoneReason:         doneReason,
+					PromptEvalCount:    seq.numPromptInputs,
+					PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
+					EvalCount:          seq.numDecoded,
+					EvalDuration:       time.Since(seq.startGenerationTime),
 				}); err != nil {
 				}); err != nil {
 					http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
 					http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
 				}
 				}
@@ -721,17 +669,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	}
 	}
 }
 }
 
 
-type EmbeddingRequest struct {
-	Content     string `json:"content"`
-	CachePrompt bool   `json:"cache_prompt"`
-}
-
-type EmbeddingResponse struct {
-	Embedding []float32 `json:"embedding"`
-}
-
 func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
-	var req EmbeddingRequest
+	var req llm.EmbeddingRequest
 	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
 	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
 		http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
 		http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
 		return
 		return
@@ -761,7 +700,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 	found := false
 	found := false
 	for i, sq := range s.seqs {
 	for i, sq := range s.seqs {
 		if sq == nil {
 		if sq == nil {
-			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
+			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
 			if err != nil {
 			if err != nil {
 				s.mu.Unlock()
 				s.mu.Unlock()
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -782,41 +721,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 
 
 	embedding := <-seq.embedding
 	embedding := <-seq.embedding
 
 
-	if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
+	if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
 		Embedding: embedding,
 		Embedding: embedding,
 	}); err != nil {
 	}); err != nil {
 		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
 		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
 	}
 	}
 }
 }
 
 
-type HealthResponse struct {
-	Status   string  `json:"status"`
-	Progress float32 `json:"progress"`
-}
-
-type ServerStatus int
-
-const (
-	ServerStatusReady ServerStatus = iota
-	ServerStatusLoadingModel
-	ServerStatusError
-)
-
-func (s ServerStatus) ToString() string {
-	switch s {
-	case ServerStatusReady:
-		return "ok"
-	case ServerStatusLoadingModel:
-		return "loading model"
-	default:
-		return "server error"
-	}
-}
-
 func (s *Server) health(w http.ResponseWriter, r *http.Request) {
 func (s *Server) health(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
-	if err := json.NewEncoder(w).Encode(&HealthResponse{
-		Status:   s.status.ToString(),
+	if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
+		Status:   s.status,
 		Progress: s.progress,
 		Progress: s.progress,
 	}); err != nil {
 	}); err != nil {
 		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
 		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -879,7 +794,7 @@ func (s *Server) loadModel(
 		panic(err)
 		panic(err)
 	}
 	}
 
 
-	s.status = ServerStatusReady
+	s.status = llm.ServerStatusReady
 	s.ready.Done()
 	s.ready.Done()
 }
 }
 
 
@@ -937,7 +852,7 @@ func Execute(args []string) error {
 		parallel:  *parallel,
 		parallel:  *parallel,
 		seqs:      make([]*Sequence, *parallel),
 		seqs:      make([]*Sequence, *parallel),
 		seqsSem:   semaphore.NewWeighted(int64(*parallel)),
 		seqsSem:   semaphore.NewWeighted(int64(*parallel)),
-		status:    ServerStatusLoadingModel,
+		status:    llm.ServerStatusLoadingModel,
 	}
 	}
 
 
 	var tensorSplitFloats []float32
 	var tensorSplitFloats []float32

+ 1 - 0
runner/ollamarunner/cache.go

@@ -107,6 +107,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
 		return nil, nil, err
 		return nil, nil, err
 	}
 	}
 
 
+	// TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved?
 	if !cachePrompt {
 	if !cachePrompt {
 		numPast = 0
 		numPast = 0
 	}
 	}

+ 35 - 122
runner/ollamarunner/runner.go

@@ -24,6 +24,7 @@ import (
 	"golang.org/x/sync/semaphore"
 	"golang.org/x/sync/semaphore"
 
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/model"
 	"github.com/ollama/ollama/model"
 	"github.com/ollama/ollama/model/input"
 	"github.com/ollama/ollama/model/input"
@@ -94,7 +95,7 @@ type NewSequenceParams struct {
 	embedding  bool
 	embedding  bool
 }
 }
 
 
-func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
+func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
 	s.ready.Wait()
 	s.ready.Wait()
 
 
 	startTime := time.Now()
 	startTime := time.Now()
@@ -145,7 +146,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
 // inputs processes the prompt and images into a list of inputs
 // inputs processes the prompt and images into a list of inputs
 // by splitting the prompt on [img-<n>] tags, tokenizing text and
 // by splitting the prompt on [img-<n>] tags, tokenizing text and
 // decoding images
 // decoding images
-func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) {
+func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ([]input.Input, error) {
 	var inputs []input.Input
 	var inputs []input.Input
 	var parts []string
 	var parts []string
 	var matches [][]string
 	var matches [][]string
@@ -222,7 +223,7 @@ type Server struct {
 	model model.Model
 	model model.Model
 
 
 	// status for external health reporting - loading, ready to serve, etc.
 	// status for external health reporting - loading, ready to serve, etc.
-	status ServerStatus
+	status llm.ServerStatus
 
 
 	// current progress on loading the model
 	// current progress on loading the model
 	progress float32
 	progress float32
@@ -501,75 +502,18 @@ func (s *Server) processBatch() error {
 	return nil
 	return nil
 }
 }
 
 
-// TODO (jmorganca): use structs from the api package to avoid duplication
-// this way the api acts as a proxy instead of using a different api for the
-// runner
-type Options struct {
-	api.Runner
-
-	NumKeep          int      `json:"n_keep"`
-	Seed             int      `json:"seed"`
-	NumPredict       int      `json:"n_predict"`
-	TopK             int      `json:"top_k"`
-	TopP             float32  `json:"top_p"`
-	MinP             float32  `json:"min_p"`
-	TypicalP         float32  `json:"typical_p"`
-	RepeatLastN      int      `json:"repeat_last_n"`
-	Temperature      float32  `json:"temperature"`
-	RepeatPenalty    float32  `json:"repeat_penalty"`
-	PresencePenalty  float32  `json:"presence_penalty"`
-	FrequencyPenalty float32  `json:"frequency_penalty"`
-	Mirostat         int      `json:"mirostat"`
-	MirostatTau      float32  `json:"mirostat_tau"`
-	MirostatEta      float32  `json:"mirostat_eta"`
-	Stop             []string `json:"stop"`
-}
-
-type ImageData struct {
-	Data          []byte `json:"data"`
-	ID            int    `json:"id"`
-	AspectRatioID int    `json:"aspect_ratio_id"`
-}
-
-type CompletionRequest struct {
-	Prompt      string      `json:"prompt"`
-	Images      []ImageData `json:"image_data"`
-	Grammar     string      `json:"grammar"`
-	CachePrompt bool        `json:"cache_prompt"`
-
-	Options
-}
-
-type Timings struct {
-	PredictedN  int     `json:"predicted_n"`
-	PredictedMS float64 `json:"predicted_ms"`
-	PromptN     int     `json:"prompt_n"`
-	PromptMS    float64 `json:"prompt_ms"`
-}
-
-type CompletionResponse struct {
-	Content string `json:"content"`
-	Stop    bool   `json:"stop"`
-
-	Model        string  `json:"model,omitempty"`
-	Prompt       string  `json:"prompt,omitempty"`
-	StoppedLimit bool    `json:"stopped_limit,omitempty"`
-	PredictedN   int     `json:"predicted_n,omitempty"`
-	PredictedMS  float64 `json:"predicted_ms,omitempty"`
-	PromptN      int     `json:"prompt_n,omitempty"`
-	PromptMS     float64 `json:"prompt_ms,omitempty"`
-
-	Timings Timings `json:"timings"`
-}
-
 func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
-	var req CompletionRequest
-	req.Options = Options(api.DefaultOptions())
+	var req llm.CompletionRequest
 	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
 	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
 		http.Error(w, "Bad request", http.StatusBadRequest)
 		http.Error(w, "Bad request", http.StatusBadRequest)
 		return
 		return
 	}
 	}
 
 
+	if req.Options == nil {
+		opts := api.DefaultOptions()
+		req.Options = &opts
+	}
+
 	// Set the headers to indicate streaming
 	// Set the headers to indicate streaming
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Transfer-Encoding", "chunked")
 	w.Header().Set("Transfer-Encoding", "chunked")
@@ -591,18 +535,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	}
 	}
 
 
 	sampler := sample.NewSampler(
 	sampler := sample.NewSampler(
-		req.Temperature,
-		req.TopK,
-		req.TopP,
-		req.MinP,
-		req.Seed,
+		req.Options.Temperature,
+		req.Options.TopK,
+		req.Options.TopP,
+		req.Options.MinP,
+		req.Options.Seed,
 		grammar,
 		grammar,
 	)
 	)
 
 
 	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
 	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
-		numPredict: req.NumPredict,
-		stop:       req.Stop,
-		numKeep:    int32(req.NumKeep),
+		numPredict: req.Options.NumPredict,
+		stop:       req.Options.Stop,
+		numKeep:    int32(req.Options.NumKeep),
 		sampler:    sampler,
 		sampler:    sampler,
 		embedding:  false,
 		embedding:  false,
 	})
 	})
@@ -625,7 +569,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	found := false
 	found := false
 	for i, sq := range s.seqs {
 	for i, sq := range s.seqs {
 		if sq == nil {
 		if sq == nil {
-			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
+			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
 			if err != nil {
 			if err != nil {
 				s.mu.Unlock()
 				s.mu.Unlock()
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -652,7 +596,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 			return
 			return
 		case content, ok := <-seq.responses:
 		case content, ok := <-seq.responses:
 			if ok {
 			if ok {
-				if err := json.NewEncoder(w).Encode(&CompletionResponse{
+				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
 					Content: content,
 					Content: content,
 				}); err != nil {
 				}); err != nil {
 					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
 					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -663,15 +607,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 				flusher.Flush()
 				flusher.Flush()
 			} else {
 			} else {
 				// Send the final response
 				// Send the final response
-				if err := json.NewEncoder(w).Encode(&CompletionResponse{
-					Stop:         true,
-					StoppedLimit: seq.doneReason == "limit",
-					Timings: Timings{
-						PromptN:     seq.numPromptInputs,
-						PromptMS:    float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
-						PredictedN:  seq.numPredicted,
-						PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
-					},
+				doneReason := "stop"
+				if seq.doneReason == "limit" {
+					doneReason = "length"
+				}
+				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
+					Done:               true,
+					DoneReason:         doneReason,
+					PromptEvalCount:    seq.numPromptInputs,
+					PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
+					EvalCount:          seq.numPredicted,
+					EvalDuration:       time.Since(seq.startGenerationTime),
 				}); err != nil {
 				}); err != nil {
 					http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
 					http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
 				}
 				}
@@ -682,43 +628,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	}
 	}
 }
 }
 
 
-type EmbeddingRequest struct {
-	Content     string `json:"content"`
-	CachePrompt bool   `json:"cache_prompt"`
-}
-
-type EmbeddingResponse struct {
-	Embedding []float32 `json:"embedding"`
-}
-
-type HealthResponse struct {
-	Status   string  `json:"status"`
-	Progress float32 `json:"progress"`
-}
-
-type ServerStatus int
-
-const (
-	ServerStatusReady ServerStatus = iota
-	ServerStatusLoadingModel
-	ServerStatusError
-)
-
-func (s ServerStatus) ToString() string {
-	switch s {
-	case ServerStatusReady:
-		return "ok"
-	case ServerStatusLoadingModel:
-		return "loading model"
-	default:
-		return "server error"
-	}
-}
-
 func (s *Server) health(w http.ResponseWriter, r *http.Request) {
 func (s *Server) health(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
-	if err := json.NewEncoder(w).Encode(&HealthResponse{
-		Status:   s.status.ToString(),
+	if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
+		Status:   s.status,
 		Progress: s.progress,
 		Progress: s.progress,
 	}); err != nil {
 	}); err != nil {
 		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
 		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -772,7 +685,7 @@ func (s *Server) loadModel(
 	s.seqs = make([]*Sequence, s.parallel)
 	s.seqs = make([]*Sequence, s.parallel)
 	s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
 	s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
 
 
-	s.status = ServerStatusReady
+	s.status = llm.ServerStatusReady
 	s.ready.Done()
 	s.ready.Done()
 }
 }
 
 
@@ -824,7 +737,7 @@ func Execute(args []string) error {
 
 
 	server := &Server{
 	server := &Server{
 		batchSize: *batchSize,
 		batchSize: *batchSize,
-		status:    ServerStatusLoadingModel,
+		status:    llm.ServerStatusLoadingModel,
 	}
 	}
 
 
 	// TODO(jessegross): Parameters that need to be implemented:
 	// TODO(jessegross): Parameters that need to be implemented: