client.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. package api
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "encoding/json"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "net/url"
  11. "os"
  12. )
  13. const DefaultHost = "localhost:11434"
  14. var (
  15. envHost = os.Getenv("OLLAMA_HOST")
  16. )
  17. type Client struct {
  18. Base url.URL
  19. HTTP http.Client
  20. Headers http.Header
  21. }
  22. func checkError(resp *http.Response, body []byte) error {
  23. if resp.StatusCode >= 200 && resp.StatusCode < 400 {
  24. return nil
  25. }
  26. apiError := StatusError{StatusCode: resp.StatusCode}
  27. err := json.Unmarshal(body, &apiError)
  28. if err != nil {
  29. // Use the full body as the message if we fail to decode a response.
  30. apiError.ErrorMessage = string(body)
  31. }
  32. return apiError
  33. }
  34. // Host returns the default host to use for the client. It is determined in the following order:
  35. // 1. The OLLAMA_HOST environment variable
  36. // 2. The default host (localhost:11434)
  37. func Host() string {
  38. if envHost != "" {
  39. return envHost
  40. }
  41. return DefaultHost
  42. }
  43. // FromEnv creates a new client using Host() as the host. An error is returns
  44. // if the host is invalid.
  45. func FromEnv() (*Client, error) {
  46. u, err := url.Parse(Host())
  47. if err != nil {
  48. return nil, err
  49. }
  50. return &Client{Base: *u}, nil
  51. }
  52. func NewClient(hosts ...string) *Client {
  53. host := DefaultHost
  54. if len(hosts) > 0 {
  55. host = hosts[0]
  56. }
  57. return &Client{
  58. Base: url.URL{Scheme: "http", Host: host},
  59. HTTP: http.Client{},
  60. }
  61. }
  62. func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
  63. var reqBody io.Reader
  64. var data []byte
  65. var err error
  66. if reqData != nil {
  67. data, err = json.Marshal(reqData)
  68. if err != nil {
  69. return err
  70. }
  71. reqBody = bytes.NewReader(data)
  72. }
  73. url := c.Base.JoinPath(path).String()
  74. req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
  75. if err != nil {
  76. return err
  77. }
  78. req.Header.Set("Content-Type", "application/json")
  79. req.Header.Set("Accept", "application/json")
  80. for k, v := range c.Headers {
  81. req.Header[k] = v
  82. }
  83. respObj, err := c.HTTP.Do(req)
  84. if err != nil {
  85. return err
  86. }
  87. defer respObj.Body.Close()
  88. respBody, err := io.ReadAll(respObj.Body)
  89. if err != nil {
  90. return err
  91. }
  92. if err := checkError(respObj, respBody); err != nil {
  93. return err
  94. }
  95. if len(respBody) > 0 && respData != nil {
  96. if err := json.Unmarshal(respBody, respData); err != nil {
  97. return err
  98. }
  99. }
  100. return nil
  101. }
  102. func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
  103. var buf *bytes.Buffer
  104. if data != nil {
  105. bts, err := json.Marshal(data)
  106. if err != nil {
  107. return err
  108. }
  109. buf = bytes.NewBuffer(bts)
  110. }
  111. request, err := http.NewRequestWithContext(ctx, method, c.Base.JoinPath(path).String(), buf)
  112. if err != nil {
  113. return err
  114. }
  115. request.Header.Set("Content-Type", "application/json")
  116. request.Header.Set("Accept", "application/json")
  117. response, err := http.DefaultClient.Do(request)
  118. if err != nil {
  119. return err
  120. }
  121. defer response.Body.Close()
  122. scanner := bufio.NewScanner(response.Body)
  123. for scanner.Scan() {
  124. var errorResponse struct {
  125. Error string `json:"error,omitempty"`
  126. }
  127. bts := scanner.Bytes()
  128. if err := json.Unmarshal(bts, &errorResponse); err != nil {
  129. return fmt.Errorf("unmarshal: %w", err)
  130. }
  131. if errorResponse.Error != "" {
  132. return fmt.Errorf(errorResponse.Error)
  133. }
  134. if response.StatusCode >= 400 {
  135. return StatusError{
  136. StatusCode: response.StatusCode,
  137. Status: response.Status,
  138. ErrorMessage: errorResponse.Error,
  139. }
  140. }
  141. if err := fn(bts); err != nil {
  142. return err
  143. }
  144. }
  145. return nil
  146. }
  147. type GenerateResponseFunc func(GenerateResponse) error
  148. func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
  149. return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error {
  150. var resp GenerateResponse
  151. if err := json.Unmarshal(bts, &resp); err != nil {
  152. return err
  153. }
  154. return fn(resp)
  155. })
  156. }
  157. type PullProgressFunc func(ProgressResponse) error
  158. func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
  159. return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
  160. var resp ProgressResponse
  161. if err := json.Unmarshal(bts, &resp); err != nil {
  162. return err
  163. }
  164. return fn(resp)
  165. })
  166. }
  167. type PushProgressFunc func(ProgressResponse) error
  168. func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
  169. return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
  170. var resp ProgressResponse
  171. if err := json.Unmarshal(bts, &resp); err != nil {
  172. return err
  173. }
  174. return fn(resp)
  175. })
  176. }
  177. type CreateProgressFunc func(ProgressResponse) error
  178. func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
  179. return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
  180. var resp ProgressResponse
  181. if err := json.Unmarshal(bts, &resp); err != nil {
  182. return err
  183. }
  184. return fn(resp)
  185. })
  186. }
  187. func (c *Client) List(ctx context.Context) (*ListResponse, error) {
  188. var lr ListResponse
  189. if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil {
  190. return nil, err
  191. }
  192. return &lr, nil
  193. }
  194. func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
  195. if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
  196. return err
  197. }
  198. return nil
  199. }
  200. func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
  201. if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
  202. return err
  203. }
  204. return nil
  205. }
  206. func (c *Client) Heartbeat(ctx context.Context) error {
  207. if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
  208. return err
  209. }
  210. return nil
  211. }