Roy Han 9 月之前
父节点
当前提交
d2b25c1bfb
共有 2 个文件被更改,包括 26 次插入1 次删除
  1. 4 0
      server/model.go
  2. 22 1
      server/routes.go

+ 4 - 0
server/model.go

@@ -15,6 +15,7 @@ import (
 	"slices"
 	"strings"
 	"text/template/parse"
+	"time"
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/convert"
@@ -312,6 +313,7 @@ func detectContentType(r io.Reader) (string, error) {
 // mxyng: this only really works if the input contains tool calls in some JSON format
 func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 	// create a subtree from the node that ranges over .ToolCalls
+	start := time.Now()
 	tmpl := m.Template.Subtree(func(n parse.Node) bool {
 		if t, ok := n.(*parse.RangeNode); ok {
 			return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
@@ -415,5 +417,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 		}
 	}
 
+	end := time.Now()
+	slog.Debug("parseToolCalls", "duration", end.Sub(start).String())
 	return toolCalls, len(toolCalls) > 0
 }

+ 22 - 1
server/routes.go

@@ -1369,7 +1369,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		}
 	}()
 
-	if req.Stream != nil && !*req.Stream {
+	if (req.Stream != nil && !*req.Stream) || ((req.Stream == nil || *req.Stream) && len(req.Tools) > 0) {
 		var resp api.ChatResponse
 		var sb strings.Builder
 		for rr := range ch {
@@ -1400,6 +1400,27 @@ func (s *Server) ChatHandler(c *gin.Context) {
 			}
 		}
 
+		if (req.Stream == nil || *req.Stream) && len(resp.Message.ToolCalls) > 0 {
+			toolCh := make(chan any)
+			go func() {
+				toolCalls := resp.Message.ToolCalls
+				for _, toolCall := range toolCalls[:len(toolCalls)-1] {
+					chunk := api.ChatResponse{
+						Model:      resp.Model,
+						CreatedAt:  resp.CreatedAt,
+						Message:    api.Message{Role: "assistant", ToolCalls: []api.ToolCall{toolCall}},
+						DoneReason: "tool_calls",
+					}
+					toolCh <- chunk
+				}
+				resp.Message.ToolCalls = []api.ToolCall{toolCalls[len(toolCalls)-1]}
+				toolCh <- resp
+				close(toolCh)
+			}()
+			streamResponse(c, toolCh)
+			return
+		}
+
 		c.JSON(http.StatusOK, resp)
 		return
 	}