client.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. package api
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "encoding/json"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "strings"
  11. "sync"
  12. )
  13. type Client struct {
  14. URL string
  15. HTTP http.Client
  16. }
  17. func checkError(resp *http.Response, body []byte) error {
  18. if resp.StatusCode >= 200 && resp.StatusCode < 400 {
  19. return nil
  20. }
  21. apiError := Error{Code: int32(resp.StatusCode)}
  22. if err := json.Unmarshal(body, &apiError); err != nil {
  23. // Use the full body as the message if we fail to decode a response.
  24. apiError.Message = string(body)
  25. }
  26. return apiError
  27. }
  28. func (c *Client) stream(ctx context.Context, method string, path string, reqData any, callback func(data []byte)) error {
  29. var reqBody io.Reader
  30. var data []byte
  31. var err error
  32. if reqData != nil {
  33. data, err = json.Marshal(reqData)
  34. if err != nil {
  35. return err
  36. }
  37. reqBody = bytes.NewReader(data)
  38. }
  39. url := fmt.Sprintf("%s%s", c.URL, path)
  40. req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
  41. if err != nil {
  42. return err
  43. }
  44. req.Header.Set("Content-Type", "application/json")
  45. req.Header.Set("Accept", "application/json")
  46. res, err := c.HTTP.Do(req)
  47. if err != nil {
  48. return err
  49. }
  50. defer res.Body.Close()
  51. reader := bufio.NewReader(res.Body)
  52. for {
  53. line, err := reader.ReadBytes('\n')
  54. if err != nil {
  55. if err == io.EOF {
  56. break
  57. } else {
  58. return err // Handle other errors
  59. }
  60. }
  61. if err := checkError(res, line); err != nil {
  62. return err
  63. }
  64. callback(bytes.TrimSuffix(line, []byte("\n")))
  65. }
  66. return nil
  67. }
  68. func (c *Client) do(ctx context.Context, method string, path string, reqData any, respData any) error {
  69. var reqBody io.Reader
  70. var data []byte
  71. var err error
  72. if reqData != nil {
  73. data, err = json.Marshal(reqData)
  74. if err != nil {
  75. return err
  76. }
  77. reqBody = bytes.NewReader(data)
  78. }
  79. url := fmt.Sprintf("%s%s", c.URL, path)
  80. req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
  81. if err != nil {
  82. return err
  83. }
  84. req.Header.Set("Content-Type", "application/json")
  85. req.Header.Set("Accept", "application/json")
  86. respObj, err := c.HTTP.Do(req)
  87. if err != nil {
  88. return err
  89. }
  90. defer respObj.Body.Close()
  91. respBody, err := io.ReadAll(respObj.Body)
  92. if err != nil {
  93. return err
  94. }
  95. if err := checkError(respObj, respBody); err != nil {
  96. return err
  97. }
  98. if len(respBody) > 0 && respData != nil {
  99. if err := json.Unmarshal(respBody, respData); err != nil {
  100. return err
  101. }
  102. }
  103. return nil
  104. }
  105. func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback func(token string)) (*GenerateResponse, error) {
  106. var res GenerateResponse
  107. if err := c.stream(ctx, http.MethodPost, "/api/generate", req, func(token []byte) {
  108. callback(string(token))
  109. }); err != nil {
  110. return nil, err
  111. }
  112. return &res, nil
  113. }
  114. func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) error {
  115. var wg sync.WaitGroup
  116. wg.Add(1)
  117. if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) {
  118. /*
  119. Events have the following format for progress:
  120. event:progress
  121. data:{"total":123,"completed":123,"percent":0.1}
  122. Need to parse out the data part and unmarshal it.
  123. */
  124. eventParts := strings.Split(string(progressBytes), "data:")
  125. if len(eventParts) < 2 {
  126. // no data part, ignore
  127. return
  128. }
  129. eventData := eventParts[1]
  130. var progress PullProgress
  131. if err := json.Unmarshal([]byte(eventData), &progress); err != nil {
  132. fmt.Println(err)
  133. return
  134. }
  135. if progress.Completed >= progress.Total {
  136. wg.Done()
  137. }
  138. callback(progress)
  139. }); err != nil {
  140. return err
  141. }
  142. wg.Wait()
  143. return nil
  144. }