Forráskód Böngészése

llm: remove internal subprocess req and resp types (#9324)

This commit refactors the LLM subsystem by removing internal subprocess
request and response types. It consolidates duplicate type definitions
across the codebase, moving them to centralized locations. The change also
standardizes interfaces between components, simplifies the ServerStatusResp
struct, and moves the ParseDurationMs function to a common package. This
cleanup reduces code duplication between different runner implementations
(llamarunner and ollamarunner).
Bruce MacDonald 1 hónapja
szülő
commit
3892c3a703
4 módosított fájl, 125 hozzáadás és 354 törlés
  1. 39 97
      llm/server.go
  2. 50 135
      runner/llamarunner/runner.go
  3. 1 0
      runner/ollamarunner/cache.go
  4. 35 122
      runner/ollamarunner/runner.go

+ 39 - 97
llm/server.go

@@ -402,7 +402,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
 			s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
 		}
 
-		slog.Info("starting llama server", "cmd", s.cmd.String())
+		slog.Info("starting llama server", "cmd", s.cmd)
 		if envconfig.Debug() {
 			filteredEnv := []string{}
 			for _, ev := range s.cmd.Env {
@@ -470,7 +470,7 @@ const ( // iota is reset to 0
 	ServerStatusError
 )
 
-func (s ServerStatus) ToString() string {
+func (s ServerStatus) String() string {
 	switch s {
 	case ServerStatusReady:
 		return "llm server ready"
@@ -485,12 +485,9 @@ func (s ServerStatus) ToString() string {
 	}
 }
 
-type ServerStatusResp struct {
-	Status          string  `json:"status"`
-	SlotsIdle       int     `json:"slots_idle"`
-	SlotsProcessing int     `json:"slots_processing"`
-	Error           string  `json:"error"`
-	Progress        float32 `json:"progress"`
+type ServerStatusResponse struct {
+	Status   ServerStatus `json:"status"`
+	Progress float32      `json:"progress"`
 }
 
 func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
@@ -502,7 +499,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
 		}
 		if s.cmd.ProcessState.ExitCode() == -1 {
 			// Most likely a signal killed it, log some more details to try to help troubleshoot
-			slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
+			slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState)
 		}
 		return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
 	}
@@ -527,21 +524,19 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
 		return ServerStatusError, fmt.Errorf("read health request: %w", err)
 	}
 
-	var status ServerStatusResp
-	if err := json.Unmarshal(body, &status); err != nil {
+	var ssr ServerStatusResponse
+	if err := json.Unmarshal(body, &ssr); err != nil {
 		return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
 	}
 
-	switch status.Status {
-	case "ok":
-		return ServerStatusReady, nil
-	case "no slot available":
-		return ServerStatusNoSlotsAvailable, nil
-	case "loading model":
-		s.loadProgress = status.Progress
-		return ServerStatusLoadingModel, nil
+	switch ssr.Status {
+	case ServerStatusLoadingModel:
+		s.loadProgress = ssr.Progress
+		return ssr.Status, nil
+	case ServerStatusReady, ServerStatusNoSlotsAvailable:
+		return ssr.Status, nil
 	default:
-		return ServerStatusError, fmt.Errorf("server error: %+v", status)
+		return ssr.Status, fmt.Errorf("server error: %+v", ssr)
 	}
 }
 
@@ -616,7 +611,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
 		status, _ := s.getServerStatus(ctx)
 		if lastStatus != status && status != ServerStatusReady {
 			// Only log on status changes
-			slog.Info("waiting for server to become available", "status", status.ToString())
+			slog.Info("waiting for server to become available", "status", status)
 		}
 		switch status {
 		case ServerStatusReady:
@@ -630,7 +625,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
 				slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
 				stallTimer = time.Now().Add(stallDuration)
 			} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
-				slog.Debug("model load completed, waiting for server to become available", "status", status.ToString())
+				slog.Debug("model load completed, waiting for server to become available", "status", status)
 				stallTimer = time.Now().Add(stallDuration)
 				fullyLoaded = true
 			}
@@ -671,63 +666,26 @@ type ImageData struct {
 	AspectRatioID int    `json:"aspect_ratio_id"`
 }
 
-type completion struct {
-	Content      string `json:"content"`
-	Model        string `json:"model"`
-	Prompt       string `json:"prompt"`
-	Stop         bool   `json:"stop"`
-	StoppedLimit bool   `json:"stopped_limit"`
-
-	Timings struct {
-		PredictedN  int     `json:"predicted_n"`
-		PredictedMS float64 `json:"predicted_ms"`
-		PromptN     int     `json:"prompt_n"`
-		PromptMS    float64 `json:"prompt_ms"`
-	}
-}
-
 type CompletionRequest struct {
 	Prompt  string
 	Format  json.RawMessage
 	Images  []ImageData
 	Options *api.Options
+
+	Grammar string // set before sending the request to the subprocess
 }
 
 type CompletionResponse struct {
-	Content            string
-	DoneReason         string
-	Done               bool
-	PromptEvalCount    int
-	PromptEvalDuration time.Duration
-	EvalCount          int
-	EvalDuration       time.Duration
+	Content            string        `json:"content"`
+	DoneReason         string        `json:"done_reason"`
+	Done               bool          `json:"done"`
+	PromptEvalCount    int           `json:"prompt_eval_count"`
+	PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
+	EvalCount          int           `json:"eval_count"`
+	EvalDuration       time.Duration `json:"eval_duration"`
 }
 
 func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
-	request := map[string]any{
-		"prompt":            req.Prompt,
-		"stream":            true,
-		"n_predict":         req.Options.NumPredict,
-		"n_keep":            req.Options.NumKeep,
-		"main_gpu":          req.Options.MainGPU,
-		"temperature":       req.Options.Temperature,
-		"top_k":             req.Options.TopK,
-		"top_p":             req.Options.TopP,
-		"min_p":             req.Options.MinP,
-		"typical_p":         req.Options.TypicalP,
-		"repeat_last_n":     req.Options.RepeatLastN,
-		"repeat_penalty":    req.Options.RepeatPenalty,
-		"presence_penalty":  req.Options.PresencePenalty,
-		"frequency_penalty": req.Options.FrequencyPenalty,
-		"mirostat":          req.Options.Mirostat,
-		"mirostat_tau":      req.Options.MirostatTau,
-		"mirostat_eta":      req.Options.MirostatEta,
-		"seed":              req.Options.Seed,
-		"stop":              req.Options.Stop,
-		"image_data":        req.Images,
-		"cache_prompt":      true,
-	}
-
 	if len(req.Format) > 0 {
 		switch string(req.Format) {
 		case `null`, `""`:
@@ -735,7 +693,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 			// these as "not set".
 			break
 		case `"json"`:
-			request["grammar"] = grammarJSON
+			req.Grammar = grammarJSON
 		default:
 			if req.Format[0] != '{' {
 				return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
@@ -746,10 +704,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 			if g == nil {
 				return fmt.Errorf("invalid JSON schema in format")
 			}
-			request["grammar"] = string(g)
+			req.Grammar = string(g)
 		}
 	}
 
+	if req.Options == nil {
+		opts := api.DefaultOptions()
+		req.Options = &opts
+	}
+
 	if err := s.sem.Acquire(ctx, 1); err != nil {
 		if errors.Is(err, context.Canceled) {
 			slog.Info("aborting completion request due to client closing the connection")
@@ -770,7 +733,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 	if err != nil {
 		return err
 	} else if status != ServerStatusReady {
-		return fmt.Errorf("unexpected server status: %s", status.ToString())
+		return fmt.Errorf("unexpected server status: %s", status)
 	}
 
 	// Handling JSON marshaling with special characters unescaped.
@@ -778,7 +741,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 	enc := json.NewEncoder(buffer)
 	enc.SetEscapeHTML(false)
 
-	if err := enc.Encode(request); err != nil {
+	if err := enc.Encode(req); err != nil {
 		return fmt.Errorf("failed to marshal data: %v", err)
 	}
 
@@ -829,7 +792,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 				evt = line
 			}
 
-			var c completion
+			var c CompletionResponse
 			if err := json.Unmarshal(evt, &c); err != nil {
 				return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
 			}
@@ -853,20 +816,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 				})
 			}
 
-			if c.Stop {
-				doneReason := "stop"
-				if c.StoppedLimit {
-					doneReason = "length"
-				}
-
-				fn(CompletionResponse{
-					Done:               true,
-					DoneReason:         doneReason,
-					PromptEvalCount:    c.Timings.PromptN,
-					PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
-					EvalCount:          c.Timings.PredictedN,
-					EvalDuration:       parseDurationMs(c.Timings.PredictedMS),
-				})
+			if c.Done {
+				fn(c)
 				return nil
 			}
 		}
@@ -914,7 +865,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
 	if err != nil {
 		return nil, err
 	} else if status != ServerStatusReady {
-		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
+		return nil, fmt.Errorf("unexpected server status: %s", status)
 	}
 
 	data, err := json.Marshal(EmbeddingRequest{Content: input})
@@ -1059,12 +1010,3 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
 	}
 	return 0
 }
-
-func parseDurationMs(ms float64) time.Duration {
-	dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
-	if err != nil {
-		panic(err)
-	}
-
-	return dur
-}

+ 50 - 135
runner/llamarunner/runner.go

@@ -24,6 +24,7 @@ import (
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/llama"
+	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/runner/common"
 )
 
@@ -99,7 +100,7 @@ type NewSequenceParams struct {
 	embedding      bool
 }
 
-func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
+func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
 	s.ready.Wait()
 
 	startTime := time.Now()
@@ -163,7 +164,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
 // inputs processes the prompt and images into a list of inputs
 // by splitting the prompt on [img-<n>] tags, tokenizing text and
 // generating image embeddings for each image
-func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
+func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) {
 	var inputs []input
 	var parts []string
 	var matches [][]string
@@ -229,7 +230,7 @@ type Server struct {
 	image *ImageContext
 
 	// status for external health reporting - loading, ready to serve, etc.
-	status ServerStatus
+	status llm.ServerStatus
 
 	// current progress on loading the model
 	progress float32
@@ -541,75 +542,18 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 	return nil
 }
 
-// TODO (jmorganca): use structs from the api package to avoid duplication
-// this way the api acts as a proxy instead of using a different api for the
-// runner
-type Options struct {
-	api.Runner
-
-	NumKeep          int      `json:"n_keep"`
-	Seed             int      `json:"seed"`
-	NumPredict       int      `json:"n_predict"`
-	TopK             int      `json:"top_k"`
-	TopP             float32  `json:"top_p"`
-	MinP             float32  `json:"min_p"`
-	TypicalP         float32  `json:"typical_p"`
-	RepeatLastN      int      `json:"repeat_last_n"`
-	Temperature      float32  `json:"temperature"`
-	RepeatPenalty    float32  `json:"repeat_penalty"`
-	PresencePenalty  float32  `json:"presence_penalty"`
-	FrequencyPenalty float32  `json:"frequency_penalty"`
-	Mirostat         int      `json:"mirostat"`
-	MirostatTau      float32  `json:"mirostat_tau"`
-	MirostatEta      float32  `json:"mirostat_eta"`
-	Stop             []string `json:"stop"`
-}
-
-type ImageData struct {
-	Data          []byte `json:"data"`
-	ID            int    `json:"id"`
-	AspectRatioID int    `json:"aspect_ratio_id"`
-}
-
-type CompletionRequest struct {
-	Prompt      string      `json:"prompt"`
-	Images      []ImageData `json:"image_data"`
-	Grammar     string      `json:"grammar"`
-	CachePrompt bool        `json:"cache_prompt"`
-
-	Options
-}
-
-type Timings struct {
-	PredictedN  int     `json:"predicted_n"`
-	PredictedMS float64 `json:"predicted_ms"`
-	PromptN     int     `json:"prompt_n"`
-	PromptMS    float64 `json:"prompt_ms"`
-}
-
-type CompletionResponse struct {
-	Content string `json:"content"`
-	Stop    bool   `json:"stop"`
-
-	Model        string  `json:"model,omitempty"`
-	Prompt       string  `json:"prompt,omitempty"`
-	StoppedLimit bool    `json:"stopped_limit,omitempty"`
-	PredictedN   int     `json:"predicted_n,omitempty"`
-	PredictedMS  float64 `json:"predicted_ms,omitempty"`
-	PromptN      int     `json:"prompt_n,omitempty"`
-	PromptMS     float64 `json:"prompt_ms,omitempty"`
-
-	Timings Timings `json:"timings"`
-}
-
 func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
-	var req CompletionRequest
-	req.Options = Options(api.DefaultOptions())
+	var req llm.CompletionRequest
 	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
 		http.Error(w, "Bad request", http.StatusBadRequest)
 		return
 	}
 
+	if req.Options == nil {
+		opts := api.DefaultOptions()
+		req.Options = &opts
+	}
+
 	// Set the headers to indicate streaming
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Transfer-Encoding", "chunked")
@@ -620,26 +564,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	var samplingParams llama.SamplingParams
-	samplingParams.TopK = req.TopK
-	samplingParams.TopP = req.TopP
-	samplingParams.MinP = req.MinP
-	samplingParams.TypicalP = req.TypicalP
-	samplingParams.Temp = req.Temperature
-	samplingParams.RepeatLastN = req.RepeatLastN
-	samplingParams.PenaltyRepeat = req.RepeatPenalty
-	samplingParams.PenaltyFreq = req.FrequencyPenalty
-	samplingParams.PenaltyPresent = req.PresencePenalty
-	samplingParams.Mirostat = req.Mirostat
-	samplingParams.MirostatTau = req.MirostatTau
-	samplingParams.MirostatEta = req.MirostatEta
-	samplingParams.Seed = uint32(req.Seed)
-	samplingParams.Grammar = req.Grammar
+	// Extract options from the CompletionRequest
+	samplingParams := llama.SamplingParams{
+		TopK:           req.Options.TopK,
+		TopP:           req.Options.TopP,
+		MinP:           req.Options.MinP,
+		TypicalP:       req.Options.TypicalP,
+		Temp:           req.Options.Temperature,
+		RepeatLastN:    req.Options.RepeatLastN,
+		PenaltyRepeat:  req.Options.RepeatPenalty,
+		PenaltyFreq:    req.Options.FrequencyPenalty,
+		PenaltyPresent: req.Options.PresencePenalty,
+		Mirostat:       req.Options.Mirostat,
+		MirostatTau:    req.Options.MirostatTau,
+		MirostatEta:    req.Options.MirostatEta,
+		Seed:           uint32(req.Options.Seed),
+		Grammar:        req.Grammar,
+	}
 
 	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
-		numPredict:     req.NumPredict,
-		stop:           req.Stop,
-		numKeep:        req.NumKeep,
+		numPredict:     req.Options.NumPredict,
+		stop:           req.Options.Stop,
+		numKeep:        req.Options.NumKeep,
 		samplingParams: &samplingParams,
 		embedding:      false,
 	})
@@ -662,7 +608,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	found := false
 	for i, sq := range s.seqs {
 		if sq == nil {
-			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
+			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
 			if err != nil {
 				s.mu.Unlock()
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -691,7 +637,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 			return
 		case content, ok := <-seq.responses:
 			if ok {
-				if err := json.NewEncoder(w).Encode(&CompletionResponse{
+				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
 					Content: content,
 				}); err != nil {
 					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -702,15 +648,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 				flusher.Flush()
 			} else {
 				// Send the final response
-				if err := json.NewEncoder(w).Encode(&CompletionResponse{
-					Stop:         true,
-					StoppedLimit: seq.doneReason == "limit",
-					Timings: Timings{
-						PromptN:     seq.numPromptInputs,
-						PromptMS:    float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
-						PredictedN:  seq.numDecoded,
-						PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
-					},
+				doneReason := "stop"
+				if seq.doneReason == "limit" {
+					doneReason = "length"
+				}
+				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
+					Done:               true,
+					DoneReason:         doneReason,
+					PromptEvalCount:    seq.numPromptInputs,
+					PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
+					EvalCount:          seq.numDecoded,
+					EvalDuration:       time.Since(seq.startGenerationTime),
 				}); err != nil {
 					http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
 				}
@@ -721,17 +669,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	}
 }
 
-type EmbeddingRequest struct {
-	Content     string `json:"content"`
-	CachePrompt bool   `json:"cache_prompt"`
-}
-
-type EmbeddingResponse struct {
-	Embedding []float32 `json:"embedding"`
-}
-
 func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
-	var req EmbeddingRequest
+	var req llm.EmbeddingRequest
 	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
 		http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
 		return
@@ -761,7 +700,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 	found := false
 	for i, sq := range s.seqs {
 		if sq == nil {
-			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
+			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
 			if err != nil {
 				s.mu.Unlock()
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -782,41 +721,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 
 	embedding := <-seq.embedding
 
-	if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
+	if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
 		Embedding: embedding,
 	}); err != nil {
 		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
 	}
 }
 
-type HealthResponse struct {
-	Status   string  `json:"status"`
-	Progress float32 `json:"progress"`
-}
-
-type ServerStatus int
-
-const (
-	ServerStatusReady ServerStatus = iota
-	ServerStatusLoadingModel
-	ServerStatusError
-)
-
-func (s ServerStatus) ToString() string {
-	switch s {
-	case ServerStatusReady:
-		return "ok"
-	case ServerStatusLoadingModel:
-		return "loading model"
-	default:
-		return "server error"
-	}
-}
-
 func (s *Server) health(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
-	if err := json.NewEncoder(w).Encode(&HealthResponse{
-		Status:   s.status.ToString(),
+	if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
+		Status:   s.status,
 		Progress: s.progress,
 	}); err != nil {
 		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -879,7 +794,7 @@ func (s *Server) loadModel(
 		panic(err)
 	}
 
-	s.status = ServerStatusReady
+	s.status = llm.ServerStatusReady
 	s.ready.Done()
 }
 
@@ -937,7 +852,7 @@ func Execute(args []string) error {
 		parallel:  *parallel,
 		seqs:      make([]*Sequence, *parallel),
 		seqsSem:   semaphore.NewWeighted(int64(*parallel)),
-		status:    ServerStatusLoadingModel,
+		status:    llm.ServerStatusLoadingModel,
 	}
 
 	var tensorSplitFloats []float32

+ 1 - 0
runner/ollamarunner/cache.go

@@ -107,6 +107,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
 		return nil, nil, err
 	}
 
+	// TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved?
 	if !cachePrompt {
 		numPast = 0
 	}

+ 35 - 122
runner/ollamarunner/runner.go

@@ -24,6 +24,7 @@ import (
 	"golang.org/x/sync/semaphore"
 
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/model"
 	"github.com/ollama/ollama/model/input"
@@ -94,7 +95,7 @@ type NewSequenceParams struct {
 	embedding  bool
 }
 
-func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
+func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
 	s.ready.Wait()
 
 	startTime := time.Now()
@@ -145,7 +146,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
 // inputs processes the prompt and images into a list of inputs
 // by splitting the prompt on [img-<n>] tags, tokenizing text and
 // decoding images
-func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) {
+func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ([]input.Input, error) {
 	var inputs []input.Input
 	var parts []string
 	var matches [][]string
@@ -222,7 +223,7 @@ type Server struct {
 	model model.Model
 
 	// status for external health reporting - loading, ready to serve, etc.
-	status ServerStatus
+	status llm.ServerStatus
 
 	// current progress on loading the model
 	progress float32
@@ -501,75 +502,18 @@ func (s *Server) processBatch() error {
 	return nil
 }
 
-// TODO (jmorganca): use structs from the api package to avoid duplication
-// this way the api acts as a proxy instead of using a different api for the
-// runner
-type Options struct {
-	api.Runner
-
-	NumKeep          int      `json:"n_keep"`
-	Seed             int      `json:"seed"`
-	NumPredict       int      `json:"n_predict"`
-	TopK             int      `json:"top_k"`
-	TopP             float32  `json:"top_p"`
-	MinP             float32  `json:"min_p"`
-	TypicalP         float32  `json:"typical_p"`
-	RepeatLastN      int      `json:"repeat_last_n"`
-	Temperature      float32  `json:"temperature"`
-	RepeatPenalty    float32  `json:"repeat_penalty"`
-	PresencePenalty  float32  `json:"presence_penalty"`
-	FrequencyPenalty float32  `json:"frequency_penalty"`
-	Mirostat         int      `json:"mirostat"`
-	MirostatTau      float32  `json:"mirostat_tau"`
-	MirostatEta      float32  `json:"mirostat_eta"`
-	Stop             []string `json:"stop"`
-}
-
-type ImageData struct {
-	Data          []byte `json:"data"`
-	ID            int    `json:"id"`
-	AspectRatioID int    `json:"aspect_ratio_id"`
-}
-
-type CompletionRequest struct {
-	Prompt      string      `json:"prompt"`
-	Images      []ImageData `json:"image_data"`
-	Grammar     string      `json:"grammar"`
-	CachePrompt bool        `json:"cache_prompt"`
-
-	Options
-}
-
-type Timings struct {
-	PredictedN  int     `json:"predicted_n"`
-	PredictedMS float64 `json:"predicted_ms"`
-	PromptN     int     `json:"prompt_n"`
-	PromptMS    float64 `json:"prompt_ms"`
-}
-
-type CompletionResponse struct {
-	Content string `json:"content"`
-	Stop    bool   `json:"stop"`
-
-	Model        string  `json:"model,omitempty"`
-	Prompt       string  `json:"prompt,omitempty"`
-	StoppedLimit bool    `json:"stopped_limit,omitempty"`
-	PredictedN   int     `json:"predicted_n,omitempty"`
-	PredictedMS  float64 `json:"predicted_ms,omitempty"`
-	PromptN      int     `json:"prompt_n,omitempty"`
-	PromptMS     float64 `json:"prompt_ms,omitempty"`
-
-	Timings Timings `json:"timings"`
-}
-
 func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
-	var req CompletionRequest
-	req.Options = Options(api.DefaultOptions())
+	var req llm.CompletionRequest
 	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
 		http.Error(w, "Bad request", http.StatusBadRequest)
 		return
 	}
 
+	if req.Options == nil {
+		opts := api.DefaultOptions()
+		req.Options = &opts
+	}
+
 	// Set the headers to indicate streaming
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Transfer-Encoding", "chunked")
@@ -591,18 +535,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	}
 
 	sampler := sample.NewSampler(
-		req.Temperature,
-		req.TopK,
-		req.TopP,
-		req.MinP,
-		req.Seed,
+		req.Options.Temperature,
+		req.Options.TopK,
+		req.Options.TopP,
+		req.Options.MinP,
+		req.Options.Seed,
 		grammar,
 	)
 
 	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
-		numPredict: req.NumPredict,
-		stop:       req.Stop,
-		numKeep:    int32(req.NumKeep),
+		numPredict: req.Options.NumPredict,
+		stop:       req.Options.Stop,
+		numKeep:    int32(req.Options.NumKeep),
 		sampler:    sampler,
 		embedding:  false,
 	})
@@ -625,7 +569,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	found := false
 	for i, sq := range s.seqs {
 		if sq == nil {
-			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
+			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
 			if err != nil {
 				s.mu.Unlock()
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -652,7 +596,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 			return
 		case content, ok := <-seq.responses:
 			if ok {
-				if err := json.NewEncoder(w).Encode(&CompletionResponse{
+				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
 					Content: content,
 				}); err != nil {
 					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -663,15 +607,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 				flusher.Flush()
 			} else {
 				// Send the final response
-				if err := json.NewEncoder(w).Encode(&CompletionResponse{
-					Stop:         true,
-					StoppedLimit: seq.doneReason == "limit",
-					Timings: Timings{
-						PromptN:     seq.numPromptInputs,
-						PromptMS:    float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
-						PredictedN:  seq.numPredicted,
-						PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
-					},
+				doneReason := "stop"
+				if seq.doneReason == "limit" {
+					doneReason = "length"
+				}
+				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
+					Done:               true,
+					DoneReason:         doneReason,
+					PromptEvalCount:    seq.numPromptInputs,
+					PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
+					EvalCount:          seq.numPredicted,
+					EvalDuration:       time.Since(seq.startGenerationTime),
 				}); err != nil {
 					http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
 				}
@@ -682,43 +628,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	}
 }
 
-type EmbeddingRequest struct {
-	Content     string `json:"content"`
-	CachePrompt bool   `json:"cache_prompt"`
-}
-
-type EmbeddingResponse struct {
-	Embedding []float32 `json:"embedding"`
-}
-
-type HealthResponse struct {
-	Status   string  `json:"status"`
-	Progress float32 `json:"progress"`
-}
-
-type ServerStatus int
-
-const (
-	ServerStatusReady ServerStatus = iota
-	ServerStatusLoadingModel
-	ServerStatusError
-)
-
-func (s ServerStatus) ToString() string {
-	switch s {
-	case ServerStatusReady:
-		return "ok"
-	case ServerStatusLoadingModel:
-		return "loading model"
-	default:
-		return "server error"
-	}
-}
-
 func (s *Server) health(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
-	if err := json.NewEncoder(w).Encode(&HealthResponse{
-		Status:   s.status.ToString(),
+	if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
+		Status:   s.status,
 		Progress: s.progress,
 	}); err != nil {
 		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -772,7 +685,7 @@ func (s *Server) loadModel(
 	s.seqs = make([]*Sequence, s.parallel)
 	s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
 
-	s.status = ServerStatusReady
+	s.status = llm.ServerStatusReady
 	s.ready.Done()
 }
 
@@ -824,7 +737,7 @@ func Execute(args []string) error {
 
 	server := &Server{
 		batchSize: *batchSize,
-		status:    ServerStatusLoadingModel,
+		status:    llm.ServerStatusLoadingModel,
 	}
 
 	// TODO(jessegross): Parameters that need to be implemented: