Browse Source

Default Keep Alive environment variable (#3094)

---------

Co-authored-by: Chris-AS1 <8493773+Chris-AS1@users.noreply.github.com>
Patrick Devine 1 year ago
parent
commit
47cfe58af5
2 changed files with 81 additions and 3 deletions
  1. 50 0
      api/types_test.go
  2. 31 3
      server/routes.go

+ 50 - 0
api/types_test.go

@@ -0,0 +1,50 @@
+package api
+
+import (
+	"encoding/json"
+	"math"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestKeepAliveParsingFromJSON(t *testing.T) {
+	tests := []struct {
+		name string
+		req  string
+		exp  *Duration
+	}{
+		{
+			name: "Positive Integer",
+			req:  `{ "keep_alive": 42 }`,
+			exp:  &Duration{42 * time.Second},
+		},
+		{
+			name: "Positive Integer String",
+			req:  `{ "keep_alive": "42m" }`,
+			exp:  &Duration{42 * time.Minute},
+		},
+		{
+			name: "Negative Integer",
+			req:  `{ "keep_alive": -1 }`,
+			exp:  &Duration{math.MaxInt64},
+		},
+		{
+			name: "Negative Integer String",
+			req:  `{ "keep_alive": "-1m" }`,
+			exp:  &Duration{math.MaxInt64},
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			var dec ChatRequest
+			err := json.Unmarshal([]byte(test.req), &dec)
+			require.NoError(t, err)
+
+			assert.Equal(t, test.exp, dec.KeepAlive)
+		})
+	}
+}

+ 31 - 3
server/routes.go

@@ -8,6 +8,7 @@ import (
 	"io"
 	"io/fs"
 	"log/slog"
+	"math"
 	"net"
 	"net/http"
 	"net/netip"
@@ -16,6 +17,7 @@ import (
 	"path/filepath"
 	"reflect"
 	"runtime"
+	"strconv"
 	"strings"
 	"sync"
 	"syscall"
@@ -207,7 +209,7 @@ func GenerateHandler(c *gin.Context) {
 
 	var sessionDuration time.Duration
 	if req.KeepAlive == nil {
-		sessionDuration = defaultSessionDuration
+		sessionDuration = getDefaultSessionDuration()
 	} else {
 		sessionDuration = req.KeepAlive.Duration
 	}
@@ -384,6 +386,32 @@ func GenerateHandler(c *gin.Context) {
 	streamResponse(c, ch)
 }
 
+func getDefaultSessionDuration() time.Duration {
+	if t, exists := os.LookupEnv("OLLAMA_KEEP_ALIVE"); exists {
+		v, err := strconv.Atoi(t)
+		if err != nil {
+			d, err := time.ParseDuration(t)
+			if err != nil {
+				return defaultSessionDuration
+			}
+
+			if d < 0 {
+				return time.Duration(math.MaxInt64)
+			}
+
+			return d
+		}
+
+		d := time.Duration(v) * time.Second
+		if d < 0 {
+			return time.Duration(math.MaxInt64)
+		}
+		return d
+	}
+
+	return defaultSessionDuration
+}
+
 func EmbeddingsHandler(c *gin.Context) {
 	loaded.mu.Lock()
 	defer loaded.mu.Unlock()
@@ -427,7 +455,7 @@ func EmbeddingsHandler(c *gin.Context) {
 
 	var sessionDuration time.Duration
 	if req.KeepAlive == nil {
-		sessionDuration = defaultSessionDuration
+		sessionDuration = getDefaultSessionDuration()
 	} else {
 		sessionDuration = req.KeepAlive.Duration
 	}
@@ -1228,7 +1256,7 @@ func ChatHandler(c *gin.Context) {
 
 	var sessionDuration time.Duration
 	if req.KeepAlive == nil {
-		sessionDuration = defaultSessionDuration
+		sessionDuration = getDefaultSessionDuration()
 	} else {
 		sessionDuration = req.KeepAlive.Duration
 	}