123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- package api
- import (
- "bufio"
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- )
- type Client struct {
- Name string
- Version string
- URL string
- HTTP http.Client
- Headers http.Header
- PrivateKey []byte
- }
- func checkError(resp *http.Response, body []byte) error {
- if resp.StatusCode >= 200 && resp.StatusCode < 400 {
- return nil
- }
- apiError := Error{Code: int32(resp.StatusCode)}
- err := json.Unmarshal(body, &apiError)
- if err != nil {
- // Use the full body as the message if we fail to decode a response.
- apiError.Message = string(body)
- }
- return apiError
- }
- func (c *Client) stream(ctx context.Context, method string, path string, reqData any, callback func (data []byte)) error {
- var reqBody io.Reader
- var data []byte
- var err error
- if reqData != nil {
- data, err = json.Marshal(reqData)
- if err != nil {
- return err
- }
- reqBody = bytes.NewReader(data)
- }
- url := fmt.Sprintf("%s%s", c.URL, path)
- req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
- if err != nil {
- return err
- }
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "application/json")
- for k, v := range c.Headers {
- req.Header[k] = v
- }
- res, err := c.HTTP.Do(req)
- if err != nil {
- return err
- }
- defer res.Body.Close()
- reader := bufio.NewReader(res.Body)
- for {
- line, err := reader.ReadBytes('\n')
- if err != nil {
- break
- }
- callback(bytes.TrimSuffix(line, []byte("\n")))
- }
- return nil
- }
- func (c *Client) do(ctx context.Context, method string, path string, reqData any, respData any) error {
- var reqBody io.Reader
- var data []byte
- var err error
- if reqData != nil {
- data, err = json.Marshal(reqData)
- if err != nil {
- return err
- }
- reqBody = bytes.NewReader(data)
- }
- url := fmt.Sprintf("%s%s", c.URL, path)
- req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
- if err != nil {
- return err
- }
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "application/json")
- for k, v := range c.Headers {
- req.Header[k] = v
- }
- respObj, err := c.HTTP.Do(req)
- if err != nil {
- return err
- }
- defer respObj.Body.Close()
- respBody, err := io.ReadAll(respObj.Body)
- if err != nil {
- return err
- }
- if err := checkError(respObj, respBody); err != nil {
- return err
- }
- if len(respBody) > 0 && respData != nil {
- if err := json.Unmarshal(respBody, respData); err != nil {
- return err
- }
- }
- return nil
- }
- func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback func(token string)) (*GenerateResponse, error) {
- var res GenerateResponse
- if err := c.stream(ctx, http.MethodPost, "/api/generate", req, func(token []byte) {
- callback(string(token))
- }); err != nil {
- return nil, err
- }
- return &res, nil
- }
|