فهرست منبع

increase streaming buffer size (#692)

Bruce MacDonald 1 سال پیش
والد
کامیت
9e2de1bd2c
2فایلهای تغییر یافته به همراه11 افزوده شده و 3 حذف شده
  1. 6 3
      api/client.go
  2. 5 0
      llm/llama.go

+ 6 - 3
api/client.go

@@ -18,9 +18,7 @@ import (
 
 
 const DefaultHost = "127.0.0.1:11434"
 const DefaultHost = "127.0.0.1:11434"
 
 
-var (
-	envHost = os.Getenv("OLLAMA_HOST")
-)
+var envHost = os.Getenv("OLLAMA_HOST")
 
 
 type Client struct {
 type Client struct {
 	Base    url.URL
 	Base    url.URL
@@ -123,6 +121,8 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
 	return nil
 	return nil
 }
 }
 
 
+const maxBufferSize = 512 * 1024 // 512KB
+
 func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
 func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
 	var buf *bytes.Buffer
 	var buf *bytes.Buffer
 	if data != nil {
 	if data != nil {
@@ -151,6 +151,9 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
 	defer response.Body.Close()
 	defer response.Body.Close()
 
 
 	scanner := bufio.NewScanner(response.Body)
 	scanner := bufio.NewScanner(response.Body)
+	// increase the buffer size to avoid running out of space
+	scanBuf := make([]byte, 0, maxBufferSize)
+	scanner.Buffer(scanBuf, maxBufferSize)
 	for scanner.Scan() {
 	for scanner.Scan() {
 		var errorResponse struct {
 		var errorResponse struct {
 			Error string `json:"error,omitempty"`
 			Error string `json:"error,omitempty"`

+ 5 - 0
llm/llama.go

@@ -438,6 +438,8 @@ type PredictRequest struct {
 	Stop             []string `json:"stop,omitempty"`
 	Stop             []string `json:"stop,omitempty"`
 }
 }
 
 
+const maxBufferSize = 512 * 1024 // 512KB
+
 func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
 func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
 	prevConvo, err := llm.Decode(ctx, prevContext)
 	prevConvo, err := llm.Decode(ctx, prevContext)
 	if err != nil {
 	if err != nil {
@@ -498,6 +500,9 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
 	}
 	}
 
 
 	scanner := bufio.NewScanner(resp.Body)
 	scanner := bufio.NewScanner(resp.Body)
+	// increase the buffer size to avoid running out of space
+	buf := make([]byte, 0, maxBufferSize)
+	scanner.Buffer(buf, maxBufferSize)
 	for scanner.Scan() {
 	for scanner.Scan() {
 		select {
 		select {
 		case <-ctx.Done():
 		case <-ctx.Done():