Преглед на файлове

add keep_alive to generate/chat/embedding api endpoints (#2146)

Patrick Devine преди 1 година
родител
ревизия
b5cf31b460
променени са 2 файла, в които са добавени 48 реда и са изтрити 20 реда
  1. 25 17
      api/types.go
  2. 23 3
      server/routes.go

+ 25 - 17
api/types.go

@@ -34,24 +34,26 @@ func (e StatusError) Error() string {
 type ImageData []byte
 
 type GenerateRequest struct {
-	Model    string      `json:"model"`
-	Prompt   string      `json:"prompt"`
-	System   string      `json:"system"`
-	Template string      `json:"template"`
-	Context  []int       `json:"context,omitempty"`
-	Stream   *bool       `json:"stream,omitempty"`
-	Raw      bool        `json:"raw,omitempty"`
-	Format   string      `json:"format"`
-	Images   []ImageData `json:"images,omitempty"`
+	Model     string      `json:"model"`
+	Prompt    string      `json:"prompt"`
+	System    string      `json:"system"`
+	Template  string      `json:"template"`
+	Context   []int       `json:"context,omitempty"`
+	Stream    *bool       `json:"stream,omitempty"`
+	Raw       bool        `json:"raw,omitempty"`
+	Format    string      `json:"format"`
+	KeepAlive *Duration   `json:"keep_alive,omitempty"`
+	Images    []ImageData `json:"images,omitempty"`
 
 	Options map[string]interface{} `json:"options"`
 }
 
 type ChatRequest struct {
-	Model    string    `json:"model"`
-	Messages []Message `json:"messages"`
-	Stream   *bool     `json:"stream,omitempty"`
-	Format   string    `json:"format"`
+	Model     string    `json:"model"`
+	Messages  []Message `json:"messages"`
+	Stream    *bool     `json:"stream,omitempty"`
+	Format    string    `json:"format"`
+	KeepAlive *Duration `json:"keep_alive,omitempty"`
 
 	Options map[string]interface{} `json:"options"`
 }
@@ -126,8 +128,9 @@ type Runner struct {
 }
 
 type EmbeddingRequest struct {
-	Model  string `json:"model"`
-	Prompt string `json:"prompt"`
+	Model     string    `json:"model"`
+	Prompt    string    `json:"prompt"`
+	KeepAlive *Duration `json:"keep_alive,omitempty"`
 
 	Options map[string]interface{} `json:"options"`
 }
@@ -413,14 +416,19 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
 	case float64:
 		if t < 0 {
 			t = math.MaxFloat64
+			d.Duration = time.Duration(t)
+		} else {
+			d.Duration = time.Duration(t * float64(time.Second))
 		}
-
-		d.Duration = time.Duration(t)
 	case string:
 		d.Duration, err = time.ParseDuration(t)
 		if err != nil {
 			return err
 		}
+		if d.Duration < 0 {
+			mf := math.MaxFloat64
+			d.Duration = time.Duration(mf)
+		}
 	}
 
 	return nil

+ 23 - 3
server/routes.go

@@ -186,7 +186,13 @@ func GenerateHandler(c *gin.Context) {
 		return
 	}
 
-	sessionDuration := defaultSessionDuration
+	var sessionDuration time.Duration
+	if req.KeepAlive == nil {
+		sessionDuration = defaultSessionDuration
+	} else {
+		sessionDuration = req.KeepAlive.Duration
+	}
+
 	if err := load(c, model, opts, sessionDuration); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
@@ -378,7 +384,14 @@ func EmbeddingHandler(c *gin.Context) {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
-	sessionDuration := defaultSessionDuration
+
+	var sessionDuration time.Duration
+	if req.KeepAlive == nil {
+		sessionDuration = defaultSessionDuration
+	} else {
+		sessionDuration = req.KeepAlive.Duration
+	}
+
 	if err := load(c, model, opts, sessionDuration); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
@@ -1074,7 +1087,14 @@ func ChatHandler(c *gin.Context) {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
-	sessionDuration := defaultSessionDuration
+
+	var sessionDuration time.Duration
+	if req.KeepAlive == nil {
+		sessionDuration = defaultSessionDuration
+	} else {
+		sessionDuration = req.KeepAlive.Duration
+	}
+
 	if err := load(c, model, opts, sessionDuration); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return