Roy Han пре 10 месеци
родитељ
комит
5213c12354
3 измењених фајлова са 7 додато и 10 уклоњено
  1. 2 6
      api/types.go
  2. 4 4
      llm/ext_server/server.cpp
  3. 1 0
      server/routes.go

+ 2 - 6
api/types.go

@@ -226,10 +226,7 @@ type EmbeddingRequest struct {
 	Model string `json:"model"`
 	Model string `json:"model"`
 
 
 	// Prompt is the textual prompt to embed.
 	// Prompt is the textual prompt to embed.
-	Prompt string `json:"prompt,omitempty"`
-
-	// PromptBatch is a list of prompts to embed.
-	PromptBatch []string `json:"prompt_batch,omitempty"`
+	Prompt string `json:"prompt"`
 
 
 	// KeepAlive controls how long the model will stay loaded in memory following
 	// KeepAlive controls how long the model will stay loaded in memory following
 	// this request.
 	// this request.
@@ -246,8 +243,7 @@ type EmbedResponse struct {
 
 
 // EmbeddingResponse is the response from [Client.Embeddings].
 // EmbeddingResponse is the response from [Client.Embeddings].
 type EmbeddingResponse struct {
 type EmbeddingResponse struct {
-	Embedding      []float64   `json:"embedding,omitempty"`
-	EmbeddingBatch [][]float64 `json:"embedding_batch,omitempty"`
+	Embedding []float64 `json:"embedding"`
 }
 }
 
 
 // CreateRequest is the request passed to [Client.Create].
 // CreateRequest is the request passed to [Client.Create].

+ 4 - 4
llm/ext_server/server.cpp

@@ -3156,14 +3156,14 @@ int main(int argc, char **argv) {
             {
             {
                 res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
                 res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
                 const json body = json::parse(req.body);
                 const json body = json::parse(req.body);
-                json input;
+                json prompt;
                 if (body.count("content") != 0)
                 if (body.count("content") != 0)
                 {
                 {
-                    input = body["content"];
+                    prompt = body["content"];
                 }
                 }
                 else
                 else
                 {
                 {
-                    input = "";
+                    prompt = "";
                 }
                 }
 
 
                 // create and queue the task
                 // create and queue the task
@@ -3171,7 +3171,7 @@ int main(int argc, char **argv) {
                 {
                 {
                     const int id_task = llama.queue_tasks.get_new_id();
                     const int id_task = llama.queue_tasks.get_new_id();
                     llama.queue_results.add_waiting_task_id(id_task);
                     llama.queue_results.add_waiting_task_id(id_task);
-                    llama.request_completion(id_task, {{"prompt", input}}, true, -1);
+                    llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
 
 
                     // get the result
                     // get the result
                     task_result result = llama.queue_results.recv(id_task);
                     task_result result = llama.queue_results.recv(id_task);

+ 1 - 0
server/routes.go

@@ -473,6 +473,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
 		c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
 		return
 		return
 	}
 	}
+
 	embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
 	embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
 	if err != nil {
 	if err != nil {
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))