Browse Source

uses input prompt

Josh Yan 8 months ago
parent
commit
a33e56cddb
1 changed files with 20 additions and 12 deletions
  1. 20 12
      llm/ext_server/server.cpp

+ 20 - 12
llm/ext_server/server.cpp

@@ -1271,6 +1271,7 @@ struct llama_server_context
         }
     }
 
+    // 1 image only
     bool prepare_pali(server_slot &slot, int n_batch)
     {
         int n_past = 0;
@@ -1288,8 +1289,6 @@ struct llama_server_context
         // generate user_prompt -> this should contain image tokens prepended and a new line appended:
         // batch.n_tokens += (int)slot.images.size() * llama_n_embd(model);
         std::vector<llama_token> tokens;
-        std::string prompt = "How much ketchup is in this image?";
-        std::vector<llama_token> text = ::llama_tokenize(ctx, prompt, false, true);
 
         for (int i = 0; i < (int)slot.images.size() * 256; i++)
         {
@@ -1298,18 +1297,31 @@ struct llama_server_context
 
         tokens.push_back(2);
 
-        for (int i = 0; i < text.size(); i++)
+        // move prefix prompt behind image tokens
+        for (int i = 0; i < batch.n_tokens; i++)
         {
-            tokens.push_back(text[i]);
+            tokens.push_back(batch.token[i]);
         }
 
-        tokens.push_back(108);
-
+        llama_batch_clear(batch);
         for (int i = 0; i < (int)tokens.size(); ++i)
         {
             llama_batch_add(batch, tokens[i], system_tokens.size() + slot.n_past, {slot.id}, true);
             slot.n_past += 1;
         }
+
+        // append prefix of next image
+        const auto json_prompt = slot.params.input_suffix;
+
+        std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
+        append_tokens.push_back(108);
+
+        for (int i = 0; i < (int)append_tokens.size(); ++i)
+        {
+            llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, {slot.id}, true);
+            slot.n_past += 1;
+        }
+        // llama_set_causal_attn(ctx, false);
         return true;
     }
 
@@ -1400,6 +1412,7 @@ struct llama_server_context
         case 25:
             return prepare_pali(slot, n_batch);
         default:
+            LOG_TEE("%s : failed to retrieve model architecture\n", __func__);
             return false;
         }
     }
@@ -1893,12 +1906,6 @@ struct llama_server_context
                         llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
                         slot_npast++;
                     }
-                    LOG_DEBUG("hi gpt params processing images", {
-                                                                     {"gpt_params.model", params.model.c_str()},
-                                                                     {"model alias", params.model_alias.c_str()},
-                                                                 });
-                    printf("gpt_params model is %s\n", params.model.c_str());
-                    printf("gpt_params model is %s\n", params.model.c_str());
 
                     if (has_images && !ingest_images(slot, n_batch))
                     {
@@ -1977,6 +1984,7 @@ struct llama_server_context
             };
 
             const int ret = llama_decode(ctx, batch_view);
+            llama_set_causal_attn(ctx, true);
 
             if (ret != 0)
             {