|
@@ -394,6 +394,21 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+ truncate := func(s string) (string, error) {
|
|
|
+ tokens, err := runner.llama.Tokenize(c.Request.Context(), s)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return "", err
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(tokens) > opts.NumCtx {
|
|
|
+ tokens = tokens[len(tokens)-opts.NumCtx:]
|
|
|
+ return runner.llama.Detokenize(c.Request.Context(), tokens)
|
|
|
+ }
|
|
|
+
|
|
|
+ return s, nil
|
|
|
+ }
|
|
|
+
|
|
|
embeddings := [][]float64{}
|
|
|
|
|
|
switch reqEmbed := req.Input.(type) {
|
|
@@ -402,6 +417,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
|
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
|
|
|
return
|
|
|
}
|
|
|
+ if *req.Truncate {
|
|
|
+ reqEmbed, err = truncate(reqEmbed)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
|
|
|
case []any:
|
|
|
if reqEmbed == nil {
|
|
@@ -412,6 +434,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
|
reqEmbedArray := make([]string, len(reqEmbed))
|
|
|
for i, v := range reqEmbed {
|
|
|
if s, ok := v.(string); ok {
|
|
|
+ if *req.Truncate {
|
|
|
+ s, err = truncate(s)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
reqEmbedArray[i] = s
|
|
|
} else {
|
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|