Browse Source

Add tests

ParthSareen 3 months ago
parent
commit
5c2f35d846
1 changed files with 73 additions and 15 deletions
  1. 73 15
      openai/openai_test.go

+ 73 - 15
openai/openai_test.go

@@ -7,7 +7,6 @@ import (
 	"io"
 	"net/http"
 	"net/http/httptest"
-	"reflect"
 	"strings"
 	"testing"
 	"time"
@@ -82,7 +81,7 @@ func TestChatMiddleware(t *testing.T) {
 					{"role": "user", "content": "Hello"}
 				],
 				"stream":            true,
-				"max_tokens":        999,
+				"max_completion_tokens":        999,
 				"seed":              123,
 				"stop":              ["\n", "stop"],
 				"temperature":       3.0,
@@ -315,6 +314,61 @@ func TestChatMiddleware(t *testing.T) {
 				Stream: &True,
 			},
 		},
+		{
+			name: "chat handler with num_ctx",
+			body: `{
+				"model": "test-model",
+				"messages": [{"role": "user", "content": "Hello"}],
+				"num_ctx": 4096 
+			}`,
+			req: api.ChatRequest{
+				Model:    "test-model",
+				Messages: []api.Message{{Role: "user", Content: "Hello"}},
+				Options: map[string]any{
+					"num_ctx":     4096.0, // float because JSON doesn't distinguish between float and int
+					"temperature": 1.0,
+					"top_p":       1.0,
+				},
+				Stream: &False,
+			},
+		},
+		{
+			name: "chat handler with max_completion_tokens < num_ctx",
+			body: `{
+				"model": "test-model",
+				"messages": [{"role": "user", "content": "Hello"}],
+				"max_completion_tokens": 2
+			}`,
+			req: api.ChatRequest{
+				Model:    "test-model",
+				Messages: []api.Message{{Role: "user", Content: "Hello"}},
+				Options: map[string]any{
+					"num_predict": 2.0, // float because JSON doesn't distinguish between float and int
+					"temperature": 1.0,
+					"top_p":       1.0,
+				},
+				Stream: &False,
+			},
+		},
+		{
+			name: "chat handler with max_completion_tokens > num_ctx",
+			body: `{
+				"model": "test-model",
+				"messages": [{"role": "user", "content": "Hello"}],
+				"max_completion_tokens": 4096
+			}`,
+			req: api.ChatRequest{
+				Model:    "test-model",
+				Messages: []api.Message{{Role: "user", Content: "Hello"}},
+				Options: map[string]any{
+					"num_predict": 4096.0, // float because JSON doesn't distinguish between float and int
+					"num_ctx":     4096.0,
+					"temperature": 1.0,
+					"top_p":       1.0,
+				},
+				Stream: &False,
+			},
+		},
 		{
 			name: "chat handler error forwarding",
 			body: `{
@@ -359,7 +413,7 @@ func TestChatMiddleware(t *testing.T) {
 				return
 			}
 			if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
-				t.Fatalf("requests did not match: %+v", diff)
+				t.Fatalf("requests did not match (-want +got):\n%s", diff)
 			}
 			if diff := cmp.Diff(tc.err, errResp); diff != "" {
 				t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
@@ -493,12 +547,14 @@ func TestCompletionsMiddleware(t *testing.T) {
 				}
 			}
 
-			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
-				t.Fatal("requests did not match")
+			if capturedRequest != nil {
+				if diff := cmp.Diff(tc.req, *capturedRequest); diff != "" {
+					t.Fatalf("requests did not match (-want +got):\n%s", diff)
+				}
 			}
 
-			if !reflect.DeepEqual(tc.err, errResp) {
-				t.Fatal("errors did not match")
+			if diff := cmp.Diff(tc.err, errResp); diff != "" {
+				t.Fatalf("errors did not match (-want +got):\n%s", diff)
 			}
 
 			capturedRequest = nil
@@ -577,12 +633,14 @@ func TestEmbeddingsMiddleware(t *testing.T) {
 				}
 			}
 
-			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
-				t.Fatal("requests did not match")
+			if capturedRequest != nil {
+				if diff := cmp.Diff(tc.req, *capturedRequest); diff != "" {
+					t.Fatalf("requests did not match (-want +got):\n%s", diff)
+				}
 			}
 
-			if !reflect.DeepEqual(tc.err, errResp) {
-				t.Fatal("errors did not match")
+			if diff := cmp.Diff(tc.err, errResp); diff != "" {
+				t.Fatalf("errors did not match (-want +got):\n%s", diff)
 			}
 
 			capturedRequest = nil
@@ -656,8 +714,8 @@ func TestListMiddleware(t *testing.T) {
 			t.Fatalf("failed to unmarshal actual response: %v", err)
 		}
 
-		if !reflect.DeepEqual(expected, actual) {
-			t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
+		if diff := cmp.Diff(expected, actual); diff != "" {
+			t.Errorf("responses did not match (-want +got):\n%s", diff)
 		}
 	}
 }
@@ -722,8 +780,8 @@ func TestRetrieveMiddleware(t *testing.T) {
 			t.Fatalf("failed to unmarshal actual response: %v", err)
 		}
 
-		if !reflect.DeepEqual(expected, actual) {
-			t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
+		if diff := cmp.Diff(expected, actual); diff != "" {
+			t.Errorf("responses did not match (-want +got):\n%s", diff)
 		}
 	}
 }