client.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. package api
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "encoding/json"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "strings"
  11. )
  12. type Client struct {
  13. URL string
  14. HTTP http.Client
  15. }
  16. func checkError(resp *http.Response, body []byte) error {
  17. if resp.StatusCode >= 200 && resp.StatusCode < 400 {
  18. return nil
  19. }
  20. apiError := Error{Code: int32(resp.StatusCode)}
  21. if err := json.Unmarshal(body, &apiError); err != nil {
  22. // Use the full body as the message if we fail to decode a response.
  23. apiError.Message = string(body)
  24. }
  25. return apiError
  26. }
  27. func (c *Client) stream(ctx context.Context, method string, path string, reqData any, callback func(data []byte)) error {
  28. var reqBody io.Reader
  29. var data []byte
  30. var err error
  31. if reqData != nil {
  32. data, err = json.Marshal(reqData)
  33. if err != nil {
  34. return err
  35. }
  36. reqBody = bytes.NewReader(data)
  37. }
  38. url := fmt.Sprintf("%s%s", c.URL, path)
  39. req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
  40. if err != nil {
  41. return err
  42. }
  43. req.Header.Set("Content-Type", "application/json")
  44. req.Header.Set("Accept", "application/json")
  45. res, err := c.HTTP.Do(req)
  46. if err != nil {
  47. return err
  48. }
  49. defer res.Body.Close()
  50. reader := bufio.NewReader(res.Body)
  51. for {
  52. line, err := reader.ReadBytes('\n')
  53. if err != nil {
  54. break
  55. }
  56. callback(bytes.TrimSuffix(line, []byte("\n")))
  57. }
  58. return nil
  59. }
  60. func (c *Client) do(ctx context.Context, method string, path string, reqData any, respData any) error {
  61. var reqBody io.Reader
  62. var data []byte
  63. var err error
  64. if reqData != nil {
  65. data, err = json.Marshal(reqData)
  66. if err != nil {
  67. return err
  68. }
  69. reqBody = bytes.NewReader(data)
  70. }
  71. url := fmt.Sprintf("%s%s", c.URL, path)
  72. req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
  73. if err != nil {
  74. return err
  75. }
  76. req.Header.Set("Content-Type", "application/json")
  77. req.Header.Set("Accept", "application/json")
  78. respObj, err := c.HTTP.Do(req)
  79. if err != nil {
  80. return err
  81. }
  82. defer respObj.Body.Close()
  83. respBody, err := io.ReadAll(respObj.Body)
  84. if err != nil {
  85. return err
  86. }
  87. if err := checkError(respObj, respBody); err != nil {
  88. return err
  89. }
  90. if len(respBody) > 0 && respData != nil {
  91. if err := json.Unmarshal(respBody, respData); err != nil {
  92. return err
  93. }
  94. }
  95. return nil
  96. }
  97. func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback func(token string)) (*GenerateResponse, error) {
  98. var res GenerateResponse
  99. if err := c.stream(ctx, http.MethodPost, "/api/generate", req, func(token []byte) {
  100. callback(string(token))
  101. }); err != nil {
  102. return nil, err
  103. }
  104. return &res, nil
  105. }
  106. func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) (*PullResponse, error) {
  107. var res PullResponse
  108. if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) {
  109. /*
  110. Events have the following format for progress:
  111. event:progress
  112. data:{"total":123,"completed":123,"percent":0.1}
  113. Need to parse out the data part and unmarshal it.
  114. */
  115. eventParts := strings.Split(string(progressBytes), "data:")
  116. if len(eventParts) < 2 {
  117. // no data part, ignore
  118. return
  119. }
  120. eventData := eventParts[1]
  121. var progress PullProgress
  122. if err := json.Unmarshal([]byte(eventData), &progress); err != nil {
  123. fmt.Println(err)
  124. return
  125. }
  126. callback(progress)
  127. }); err != nil {
  128. return nil, err
  129. }
  130. return &res, nil
  131. }