ParthSareen преди 4 месеца
родител
ревизия
1d529d8b7b
променени са 4 файла, в които са добавени 66 реда и са изтрити 0 реда
  1. 8 0
      api/client.go
  2. 10 0
      api/types.go
  3. 15 0
      server/prompt.go
  4. 33 0
      server/routes.go

+ 8 - 0
api/client.go

@@ -360,6 +360,14 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
 	return &resp, nil
 	return &resp, nil
 }
 }
 
 
+func (c *Client) Template(ctx context.Context, req *TemplateRequest) (*TemplateResponse, error) {
+	var resp TemplateResponse
+	if err := c.do(ctx, http.MethodPost, "/api/template", req, &resp); err != nil {
+		return nil, err
+	}
+	return &resp, nil
+}
+
 // CreateBlob creates a blob from a file on the server. digest is the
 // CreateBlob creates a blob from a file on the server. digest is the
 // expected SHA256 digest of the file, and r represents the file.
 // expected SHA256 digest of the file, and r represents the file.
 func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
 func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {

+ 10 - 0
api/types.go

@@ -310,6 +310,16 @@ type CreateRequest struct {
 	Quantization string `json:"quantization,omitempty"`
 	Quantization string `json:"quantization,omitempty"`
 }
 }
 
 
+type TemplateRequest struct {
+	Model    string    `json:"model"`
+	Messages []Message `json:"messages"`
+	Tools    []Tool    `json:"tools"`
+}
+
+type TemplateResponse struct {
+	TemplatedPrompt string `json:"templated_prompt"`
+}
+
 // DeleteRequest is the request passed to [Client.Delete].
 // DeleteRequest is the request passed to [Client.Delete].
 type DeleteRequest struct {
 type DeleteRequest struct {
 	Model string `json:"model"`
 	Model string `json:"model"`

+ 15 - 0
server/prompt.go

@@ -142,6 +142,21 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 	return b.String(), images, nil
 	return b.String(), images, nil
 }
 }
 
 
+func applyTemplate(m *Model, msgs []api.Message, tools []api.Tool) (string, error) {
+	isMllama := checkMllamaModelFamily(m)
+	for _, msg := range msgs {
+		if isMllama && len(msg.Images) > 1 {
+			return "", errTooManyImages
+		}
+	}
+
+	var b bytes.Buffer
+	if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools}); err != nil {
+		return "", err
+	}
+	return b.String(), nil
+}
+
 func checkMllamaModelFamily(m *Model) bool {
 func checkMllamaModelFamily(m *Model) bool {
 	for _, arch := range m.Config.ModelFamilies {
 	for _, arch := range m.Config.ModelFamilies {
 		if arch == "mllama" {
 		if arch == "mllama" {

+ 33 - 0
server/routes.go

@@ -1228,6 +1228,7 @@ func (s *Server) GenerateRoutes() http.Handler {
 	r.POST("/api/blobs/:digest", s.CreateBlobHandler)
 	r.POST("/api/blobs/:digest", s.CreateBlobHandler)
 	r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
 	r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
 	r.GET("/api/ps", s.PsHandler)
 	r.GET("/api/ps", s.PsHandler)
+	r.Any("/api/template", gin.WrapF(s.TemplateHandler))
 
 
 	// Compatibility endpoints
 	// Compatibility endpoints
 	r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
 	r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
@@ -1451,6 +1452,38 @@ func (s *Server) PsHandler(c *gin.Context) {
 	c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
 	c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
 }
 }
 
 
+func (s *Server) TemplateHandler(w http.ResponseWriter, r *http.Request) {
+	var req api.TemplateRequest
+	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+		http.Error(w, err.Error(), http.StatusBadRequest)
+		return
+	}
+
+	model, err := GetModel(req.Model)
+	if err != nil {
+		switch {
+		case os.IsNotExist(err):
+			http.Error(w, fmt.Sprintf("model '%s' not found", req.Model), http.StatusNotFound)
+		case err.Error() == "invalid model name":
+			http.Error(w, err.Error(), http.StatusBadRequest)
+		default:
+			http.Error(w, err.Error(), http.StatusInternalServerError)
+		}
+		return
+	}
+
+	prompt, err := applyTemplate(model, req.Messages, req.Tools)
+	if err != nil {
+		http.Error(w, err.Error(), http.StatusInternalServerError)
+		return
+	}
+
+	if err := json.NewEncoder(w).Encode(api.TemplateResponse{TemplatedPrompt: prompt}); err != nil {
+		http.Error(w, err.Error(), http.StatusInternalServerError)
+		return
+	}
+}
+
 func (s *Server) ChatHandler(c *gin.Context) {
 func (s *Server) ChatHandler(c *gin.Context) {
 	checkpointStart := time.Now()
 	checkpointStart := time.Now()