client.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. package api
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "net/http"
  11. "sync"
  12. )
  13. type Client struct {
  14. URL string
  15. HTTP http.Client
  16. }
  17. func checkError(resp *http.Response, body []byte) error {
  18. if resp.StatusCode >= 200 && resp.StatusCode < 400 {
  19. return nil
  20. }
  21. apiError := Error{Code: int32(resp.StatusCode)}
  22. if err := json.Unmarshal(body, &apiError); err != nil {
  23. // Use the full body as the message if we fail to decode a response.
  24. apiError.Message = string(body)
  25. }
  26. return apiError
  27. }
  28. func (c *Client) stream(ctx context.Context, method string, path string, reqData any, callback func(data []byte)) error {
  29. var reqBody io.Reader
  30. var data []byte
  31. var err error
  32. if reqData != nil {
  33. data, err = json.Marshal(reqData)
  34. if err != nil {
  35. return err
  36. }
  37. reqBody = bytes.NewReader(data)
  38. }
  39. url := fmt.Sprintf("%s%s", c.URL, path)
  40. req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
  41. if err != nil {
  42. return err
  43. }
  44. req.Header.Set("Content-Type", "application/json")
  45. req.Header.Set("Accept", "application/json")
  46. res, err := c.HTTP.Do(req)
  47. if err != nil {
  48. return err
  49. }
  50. defer res.Body.Close()
  51. reader := bufio.NewReader(res.Body)
  52. for {
  53. line, err := reader.ReadBytes('\n')
  54. switch {
  55. case errors.Is(err, io.EOF):
  56. return nil
  57. case err != nil:
  58. return err
  59. }
  60. if err := checkError(res, line); err != nil {
  61. return err
  62. }
  63. callback(bytes.TrimSuffix(line, []byte("\n")))
  64. }
  65. }
  66. func (c *Client) do(ctx context.Context, method string, path string, reqData any, respData any) error {
  67. var reqBody io.Reader
  68. var data []byte
  69. var err error
  70. if reqData != nil {
  71. data, err = json.Marshal(reqData)
  72. if err != nil {
  73. return err
  74. }
  75. reqBody = bytes.NewReader(data)
  76. }
  77. url := fmt.Sprintf("%s%s", c.URL, path)
  78. req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
  79. if err != nil {
  80. return err
  81. }
  82. req.Header.Set("Content-Type", "application/json")
  83. req.Header.Set("Accept", "application/json")
  84. respObj, err := c.HTTP.Do(req)
  85. if err != nil {
  86. return err
  87. }
  88. defer respObj.Body.Close()
  89. respBody, err := io.ReadAll(respObj.Body)
  90. if err != nil {
  91. return err
  92. }
  93. if err := checkError(respObj, respBody); err != nil {
  94. return err
  95. }
  96. if len(respBody) > 0 && respData != nil {
  97. if err := json.Unmarshal(respBody, respData); err != nil {
  98. return err
  99. }
  100. }
  101. return nil
  102. }
  103. func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback func(bts []byte)) (*GenerateResponse, error) {
  104. var res GenerateResponse
  105. if err := c.stream(ctx, http.MethodPost, "/api/generate", req, callback); err != nil {
  106. return nil, err
  107. }
  108. return &res, nil
  109. }
  110. func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) error {
  111. var wg sync.WaitGroup
  112. wg.Add(1)
  113. if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) {
  114. var progress PullProgress
  115. if err := json.Unmarshal(progressBytes, &progress); err != nil {
  116. fmt.Println(err)
  117. return
  118. }
  119. if progress.Completed >= progress.Total {
  120. wg.Done()
  121. }
  122. callback(progress)
  123. }); err != nil {
  124. return err
  125. }
  126. wg.Wait()
  127. return nil
  128. }