client_test.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. package api
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "net/http/httptest"
  8. "net/url"
  9. "testing"
  10. )
  11. func TestClientFromEnvironment(t *testing.T) {
  12. type testCase struct {
  13. value string
  14. expect string
  15. err error
  16. }
  17. testCases := map[string]*testCase{
  18. "empty": {value: "", expect: "http://127.0.0.1:11434"},
  19. "only address": {value: "1.2.3.4", expect: "http://1.2.3.4:11434"},
  20. "only port": {value: ":1234", expect: "http://:1234"},
  21. "address and port": {value: "1.2.3.4:1234", expect: "http://1.2.3.4:1234"},
  22. "scheme http and address": {value: "http://1.2.3.4", expect: "http://1.2.3.4:80"},
  23. "scheme https and address": {value: "https://1.2.3.4", expect: "https://1.2.3.4:443"},
  24. "scheme, address, and port": {value: "https://1.2.3.4:1234", expect: "https://1.2.3.4:1234"},
  25. "hostname": {value: "example.com", expect: "http://example.com:11434"},
  26. "hostname and port": {value: "example.com:1234", expect: "http://example.com:1234"},
  27. "scheme http and hostname": {value: "http://example.com", expect: "http://example.com:80"},
  28. "scheme https and hostname": {value: "https://example.com", expect: "https://example.com:443"},
  29. "scheme, hostname, and port": {value: "https://example.com:1234", expect: "https://example.com:1234"},
  30. "trailing slash": {value: "example.com/", expect: "http://example.com:11434"},
  31. "trailing slash port": {value: "example.com:1234/", expect: "http://example.com:1234"},
  32. }
  33. for k, v := range testCases {
  34. t.Run(k, func(t *testing.T) {
  35. t.Setenv("OLLAMA_HOST", v.value)
  36. client, err := ClientFromEnvironment()
  37. if err != v.err {
  38. t.Fatalf("expected %s, got %s", v.err, err)
  39. }
  40. if client.base.String() != v.expect {
  41. t.Fatalf("expected %s, got %s", v.expect, client.base.String())
  42. }
  43. })
  44. }
  45. }
  46. // testError represents an internal error type for testing different error formats
  47. type testError struct {
  48. message string // basic error message
  49. structured *ErrorResponse // structured error response, nil for basic format
  50. statusCode int
  51. }
  52. func (e testError) Error() string {
  53. return e.message
  54. }
  55. func TestClientStream(t *testing.T) {
  56. testCases := []struct {
  57. name string
  58. responses []any
  59. wantErr string
  60. }{
  61. {
  62. name: "basic error format",
  63. responses: []any{
  64. testError{
  65. message: "test error message",
  66. statusCode: http.StatusBadRequest,
  67. },
  68. },
  69. wantErr: "test error message",
  70. },
  71. {
  72. name: "structured error format",
  73. responses: []any{
  74. testError{
  75. message: "test structured error",
  76. structured: &ErrorResponse{
  77. Err: "test structured error",
  78. Hint: "test hint",
  79. },
  80. statusCode: http.StatusBadRequest,
  81. },
  82. },
  83. wantErr: "test structured error\ntest hint",
  84. },
  85. {
  86. name: "error after chunks - basic format",
  87. responses: []any{
  88. ChatResponse{Message: Message{Content: "partial 1"}},
  89. ChatResponse{Message: Message{Content: "partial 2"}},
  90. testError{
  91. message: "mid-stream basic error",
  92. statusCode: http.StatusOK,
  93. },
  94. },
  95. wantErr: "mid-stream basic error",
  96. },
  97. {
  98. name: "error after chunks - structured format",
  99. responses: []any{
  100. ChatResponse{Message: Message{Content: "partial 1"}},
  101. ChatResponse{Message: Message{Content: "partial 2"}},
  102. testError{
  103. message: "mid-stream structured error",
  104. structured: &ErrorResponse{
  105. Err: "mid-stream structured error",
  106. Hint: "additional context",
  107. },
  108. statusCode: http.StatusOK,
  109. },
  110. },
  111. wantErr: "mid-stream structured error\nadditional context",
  112. },
  113. {
  114. name: "successful stream completion",
  115. responses: []any{
  116. ChatResponse{Message: Message{Content: "chunk 1"}},
  117. ChatResponse{Message: Message{Content: "chunk 2"}},
  118. ChatResponse{
  119. Message: Message{Content: "final chunk"},
  120. Done: true,
  121. DoneReason: "stop",
  122. },
  123. },
  124. },
  125. }
  126. for _, tc := range testCases {
  127. t.Run(tc.name, func(t *testing.T) {
  128. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  129. flusher, ok := w.(http.Flusher)
  130. if !ok {
  131. t.Fatal("expected http.Flusher")
  132. }
  133. w.Header().Set("Content-Type", "application/x-ndjson")
  134. for _, resp := range tc.responses {
  135. if errResp, ok := resp.(testError); ok {
  136. w.WriteHeader(errResp.statusCode)
  137. var err error
  138. if errResp.structured != nil {
  139. err = json.NewEncoder(w).Encode(errResp.structured)
  140. } else {
  141. err = json.NewEncoder(w).Encode(map[string]string{
  142. "error": errResp.message,
  143. })
  144. }
  145. if err != nil {
  146. t.Fatal("failed to encode error response:", err)
  147. }
  148. return
  149. }
  150. if err := json.NewEncoder(w).Encode(resp); err != nil {
  151. t.Fatalf("failed to encode response: %v", err)
  152. }
  153. flusher.Flush()
  154. }
  155. }))
  156. defer ts.Close()
  157. client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
  158. var receivedChunks []ChatResponse
  159. err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
  160. var resp ChatResponse
  161. if err := json.Unmarshal(chunk, &resp); err != nil {
  162. return fmt.Errorf("failed to unmarshal chunk: %w", err)
  163. }
  164. receivedChunks = append(receivedChunks, resp)
  165. return nil
  166. })
  167. if tc.wantErr != "" {
  168. if err == nil {
  169. t.Fatalf("got nil, want error %q", tc.wantErr)
  170. }
  171. if err.Error() != tc.wantErr {
  172. t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
  173. }
  174. } else {
  175. if err != nil {
  176. t.Errorf("got error %q, want nil", err)
  177. }
  178. }
  179. })
  180. }
  181. }
  182. func TestClientDo(t *testing.T) {
  183. testCases := []struct {
  184. name string
  185. response any
  186. wantErr string
  187. }{
  188. {
  189. name: "basic error format",
  190. response: testError{
  191. message: "test error message",
  192. statusCode: http.StatusBadRequest,
  193. },
  194. wantErr: "test error message",
  195. },
  196. {
  197. name: "structured error format",
  198. response: testError{
  199. message: "test structured error",
  200. structured: &ErrorResponse{
  201. Err: "test structured error",
  202. Hint: "test hint",
  203. },
  204. statusCode: http.StatusBadRequest,
  205. },
  206. wantErr: "test structured error",
  207. },
  208. {
  209. name: "server error - basic format",
  210. response: testError{
  211. message: "internal error",
  212. statusCode: http.StatusInternalServerError,
  213. },
  214. wantErr: "internal error",
  215. },
  216. {
  217. name: "server error - structured format",
  218. response: testError{
  219. message: "internal server error",
  220. structured: &ErrorResponse{
  221. Err: "internal server error",
  222. Hint: "please try again later",
  223. },
  224. statusCode: http.StatusInternalServerError,
  225. },
  226. wantErr: "internal server error",
  227. },
  228. {
  229. name: "successful response",
  230. response: struct {
  231. ID string `json:"id"`
  232. Success bool `json:"success"`
  233. }{
  234. ID: "msg_123",
  235. Success: true,
  236. },
  237. },
  238. }
  239. for _, tc := range testCases {
  240. t.Run(tc.name, func(t *testing.T) {
  241. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  242. if errResp, ok := tc.response.(testError); ok {
  243. w.WriteHeader(errResp.statusCode)
  244. var err error
  245. if errResp.structured != nil {
  246. err = json.NewEncoder(w).Encode(errResp.structured)
  247. } else {
  248. err = json.NewEncoder(w).Encode(map[string]string{
  249. "error": errResp.message,
  250. })
  251. }
  252. if err != nil {
  253. t.Fatal("failed to encode error response:", err)
  254. }
  255. return
  256. }
  257. w.Header().Set("Content-Type", "application/json")
  258. if err := json.NewEncoder(w).Encode(tc.response); err != nil {
  259. t.Fatalf("failed to encode response: %v", err)
  260. }
  261. }))
  262. defer ts.Close()
  263. client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
  264. var resp struct {
  265. ID string `json:"id"`
  266. Success bool `json:"success"`
  267. }
  268. err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
  269. if tc.wantErr != "" {
  270. if err == nil {
  271. t.Fatalf("got nil, want error %q", tc.wantErr)
  272. }
  273. if err.Error() != tc.wantErr {
  274. t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
  275. }
  276. return
  277. }
  278. if err != nil {
  279. t.Errorf("got error %q, want nil", err)
  280. }
  281. if expectedResp, ok := tc.response.(struct {
  282. ID string `json:"id"`
  283. Success bool `json:"success"`
  284. }); ok {
  285. if resp.ID != expectedResp.ID {
  286. t.Errorf("response ID mismatch: got %q, want %q", resp.ID, expectedResp.ID)
  287. }
  288. if resp.Success != expectedResp.Success {
  289. t.Errorf("response Success mismatch: got %v, want %v", resp.Success, expectedResp.Success)
  290. }
  291. }
  292. })
  293. }
  294. }