1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- package api
- import (
- "bufio"
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- )
- type Client struct {
- URL string
- HTTP http.Client
- }
- func (c *Client) stream(ctx context.Context, method string, path string, reqData any, fn func(bts []byte) error) error {
- var reqBody io.Reader
- var data []byte
- var err error
- if reqData != nil {
- data, err = json.Marshal(reqData)
- if err != nil {
- return err
- }
- reqBody = bytes.NewReader(data)
- }
- url := fmt.Sprintf("%s%s", c.URL, path)
- req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
- if err != nil {
- return err
- }
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "application/json")
- res, err := c.HTTP.Do(req)
- if err != nil {
- return err
- }
- defer res.Body.Close()
- scanner := bufio.NewScanner(res.Body)
- for scanner.Scan() {
- if err := fn(scanner.Bytes()); err != nil {
- return err
- }
- }
- return nil
- }
- type GenerateResponseFunc func(GenerateResponse) error
- func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
- return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error {
- var resp GenerateResponse
- if err := json.Unmarshal(bts, &resp); err != nil {
- return err
- }
- return fn(resp)
- })
- }
- type PullProgressFunc func(PullProgress) error
- func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
- return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
- var resp PullProgress
- if err := json.Unmarshal(bts, &resp); err != nil {
- return err
- }
- return fn(resp)
- })
- }
|