Bläddra i källkod

playing around with truncate stuff

Roy Han 10 månader sedan
förälder
incheckning
80c1a3f812
4 ändrade filer med 16 tillägg och 1 borttagningar
  1. 2 0
      api/types.go
  2. 8 0
      llm/ext_server/server.cpp
  3. 1 1
      llm/ext_server/utils.hpp
  4. 5 0
      server/routes.go

+ 2 - 0
api/types.go

@@ -216,6 +216,8 @@ type EmbedRequest struct {
 	// this request.
 	KeepAlive *Duration `json:"keep_alive,omitempty"`
 
+	Truncate *bool `json:"truncate,omitempty"`
+
 	// Options lists model-specific options.
 	Options map[string]interface{} `json:"options"`
 }

+ 8 - 0
llm/ext_server/server.cpp

@@ -1206,6 +1206,7 @@ struct llama_server_context
             res.result_json = json
             {
                 {"embedding", std::vector<float>(n_embd, 0.0f)},
+                {"truncated", slot.truncated}
             };
         }
         else
@@ -1223,6 +1224,7 @@ struct llama_server_context
                         res.result_json = json
                         {
                             {"embedding", std::vector<float>(n_embd, 0.0f)},
+                            {"truncated", slot.truncated}
                         };
                         continue;
                     }
@@ -1231,6 +1233,7 @@ struct llama_server_context
                 res.result_json = json
                 {
                     {"embedding", std::vector<float>(embd, embd + n_embd)},
+                    {"truncated", slot.truncated}
                 };
             }
         }
@@ -3060,6 +3063,7 @@ int main(int argc, char **argv) {
                 if (!json_value(data, "stream", false)) {
                     std::string completion_text;
                     task_result result = llama.queue_results.recv(task_id);
+                    LOG_INFO("completion", {{"result", result.result_json}});
                     if (!result.error && result.stop) {
                         res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
                     }
@@ -3075,6 +3079,7 @@ int main(int argc, char **argv) {
                         while (true)
                         {
                             task_result result = llama.queue_results.recv(task_id);
+                            LOG_INFO("completion", {{"result", result.result_json}});
                             if (!result.error) {
                                 const std::string str =
                                     "data: " +
@@ -3180,6 +3185,7 @@ int main(int argc, char **argv) {
                         if (result.result_json.count("results")) {
                             // result for multi-task
                             responses = result.result_json.at("results");
+                            LOG_INFO("results", {result.result_json});
                         } else {
                             // result for single task
                             responses = std::vector<json>(1, result.result_json);
@@ -3198,6 +3204,8 @@ int main(int argc, char **argv) {
                         }
                         // send the result
                         json result = json{{"embedding", embeddings}};
+                        // log result
+
                         return res.set_content(result.dump(), "application/json; charset=utf-8");
                     } else {
                         // return error

+ 1 - 1
llm/ext_server/utils.hpp

@@ -658,7 +658,7 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
 }
 
 // normalize a vector
-std::vector<float> normalize_vector(const std::vector<float>& vec, int size) {
+static std::vector<float> normalize_vector(const std::vector<float>& vec, int size) {
     double sum = 0.0;
     for (float value : vec) {
         sum += value * value;

+ 5 - 0
server/routes.go

@@ -356,6 +356,11 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		return
 	}
 
+	if req.Truncate == nil {
+		truncate := true
+		req.Truncate = &truncate
+	}
+
 	model, err := GetModel(req.Model)
 	if err != nil {
 		var pErr *fs.PathError