Просмотр исходного кода

api: document client stream behavior with a test (#8996)

Added unit tests to verify error handling behavior in the Client.stream and Client.do methods.
Tests cover various error scenarios including:
- Error responses with status codes >= 400
- Error messages with successful status codes
- Empty error messages
- Successful responses
Bruce MacDonald 2 месяцев назад
Родитель
Сommit
14b5a9a150
2 измененных файлов с 211 добавлено и 1 удалено
  1. 1 1
      api/client.go
  2. 210 0
      api/client_test.go

+ 1 - 1
api/client.go

@@ -132,7 +132,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
 const maxBufferSize = 512 * format.KiloByte
 
 func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
-	var buf *bytes.Buffer
+	var buf io.Reader
 	if data != nil {
 		bts, err := json.Marshal(data)
 		if err != nil {

+ 210 - 0
api/client_test.go

@@ -1,6 +1,13 @@
 package api
 
 import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"net/http"
+	"net/http/httptest"
+	"net/url"
+	"strings"
 	"testing"
 )
 
@@ -43,3 +50,206 @@ func TestClientFromEnvironment(t *testing.T) {
 		})
 	}
 }
+
+// testError represents an internal error type with status code and message
+// this is used since the error response from the server is not a standard error struct
+type testError struct {
+	message    string
+	statusCode int
+}
+
+func (e testError) Error() string {
+	return e.message
+}
+
+func TestClientStream(t *testing.T) {
+	testCases := []struct {
+		name      string
+		responses []any
+		wantErr   string
+	}{
+		{
+			name: "immediate error response",
+			responses: []any{
+				testError{
+					message:    "test error message",
+					statusCode: http.StatusBadRequest,
+				},
+			},
+			wantErr: "test error message",
+		},
+		{
+			name: "error after successful chunks, ok response",
+			responses: []any{
+				ChatResponse{Message: Message{Content: "partial response 1"}},
+				ChatResponse{Message: Message{Content: "partial response 2"}},
+				testError{
+					message:    "mid-stream error",
+					statusCode: http.StatusOK,
+				},
+			},
+			wantErr: "mid-stream error",
+		},
+		{
+			name: "successful stream completion",
+			responses: []any{
+				ChatResponse{Message: Message{Content: "chunk 1"}},
+				ChatResponse{Message: Message{Content: "chunk 2"}},
+				ChatResponse{
+					Message:    Message{Content: "final chunk"},
+					Done:       true,
+					DoneReason: "stop",
+				},
+			},
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				flusher, ok := w.(http.Flusher)
+				if !ok {
+					t.Fatal("expected http.Flusher")
+				}
+
+				w.Header().Set("Content-Type", "application/x-ndjson")
+
+				for _, resp := range tc.responses {
+					if errResp, ok := resp.(testError); ok {
+						w.WriteHeader(errResp.statusCode)
+						err := json.NewEncoder(w).Encode(map[string]string{
+							"error": errResp.message,
+						})
+						if err != nil {
+							t.Fatal("failed to encode error response:", err)
+						}
+						return
+					}
+
+					if err := json.NewEncoder(w).Encode(resp); err != nil {
+						t.Fatalf("failed to encode response: %v", err)
+					}
+					flusher.Flush()
+				}
+			}))
+			defer ts.Close()
+
+			client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
+
+			var receivedChunks []ChatResponse
+			err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
+				var resp ChatResponse
+				if err := json.Unmarshal(chunk, &resp); err != nil {
+					return fmt.Errorf("failed to unmarshal chunk: %w", err)
+				}
+				receivedChunks = append(receivedChunks, resp)
+				return nil
+			})
+
+			if tc.wantErr != "" {
+				if err == nil {
+					t.Fatal("expected error but got nil")
+				}
+				if !strings.Contains(err.Error(), tc.wantErr) {
+					t.Errorf("expected error containing %q, got %v", tc.wantErr, err)
+				}
+				return
+			}
+			if err != nil {
+				t.Errorf("unexpected error: %v", err)
+			}
+		})
+	}
+}
+
+func TestClientDo(t *testing.T) {
+	testCases := []struct {
+		name     string
+		response any
+		wantErr  string
+	}{
+		{
+			name: "immediate error response",
+			response: testError{
+				message:    "test error message",
+				statusCode: http.StatusBadRequest,
+			},
+			wantErr: "test error message",
+		},
+		{
+			name: "server error response",
+			response: testError{
+				message:    "internal error",
+				statusCode: http.StatusInternalServerError,
+			},
+			wantErr: "internal error",
+		},
+		{
+			name: "successful response",
+			response: struct {
+				ID      string `json:"id"`
+				Success bool   `json:"success"`
+			}{
+				ID:      "msg_123",
+				Success: true,
+			},
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				if errResp, ok := tc.response.(testError); ok {
+					w.WriteHeader(errResp.statusCode)
+					err := json.NewEncoder(w).Encode(map[string]string{
+						"error": errResp.message,
+					})
+					if err != nil {
+						t.Fatal("failed to encode error response:", err)
+					}
+					return
+				}
+
+				w.Header().Set("Content-Type", "application/json")
+				if err := json.NewEncoder(w).Encode(tc.response); err != nil {
+					t.Fatalf("failed to encode response: %v", err)
+				}
+			}))
+			defer ts.Close()
+
+			client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
+
+			var resp struct {
+				ID      string `json:"id"`
+				Success bool   `json:"success"`
+			}
+			err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
+
+			if tc.wantErr != "" {
+				if err == nil {
+					t.Fatalf("got nil, want error %q", tc.wantErr)
+				}
+				if err.Error() != tc.wantErr {
+					t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
+				}
+				return
+			}
+
+			if err != nil {
+				t.Fatalf("got error %q, want nil", err)
+			}
+
+			if expectedResp, ok := tc.response.(struct {
+				ID      string `json:"id"`
+				Success bool   `json:"success"`
+			}); ok {
+				if resp.ID != expectedResp.ID {
+					t.Errorf("response ID mismatch: got %q, want %q", resp.ID, expectedResp.ID)
+				}
+				if resp.Success != expectedResp.Success {
+					t.Errorf("response Success mismatch: got %v, want %v", resp.Success, expectedResp.Success)
+				}
+			}
+		})
+	}
+}