client.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. package api
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "encoding/json"
  7. "fmt"
  8. "io"
  9. "net"
  10. "net/http"
  11. "net/url"
  12. "os"
  13. "runtime"
  14. "strings"
  15. "github.com/ollama/ollama/format"
  16. "github.com/ollama/ollama/version"
  17. )
  18. type Client struct {
  19. base *url.URL
  20. http *http.Client
  21. }
  22. func checkError(resp *http.Response, body []byte) error {
  23. if resp.StatusCode < http.StatusBadRequest {
  24. return nil
  25. }
  26. apiError := StatusError{StatusCode: resp.StatusCode}
  27. err := json.Unmarshal(body, &apiError)
  28. if err != nil {
  29. // Use the full body as the message if we fail to decode a response.
  30. apiError.ErrorMessage = string(body)
  31. }
  32. return apiError
  33. }
  34. func ClientFromEnvironment() (*Client, error) {
  35. defaultPort := "11434"
  36. scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
  37. switch {
  38. case !ok:
  39. scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
  40. case scheme == "http":
  41. defaultPort = "80"
  42. case scheme == "https":
  43. defaultPort = "443"
  44. }
  45. // trim trailing slashes
  46. hostport = strings.TrimRight(hostport, "/")
  47. host, port, err := net.SplitHostPort(hostport)
  48. if err != nil {
  49. host, port = "127.0.0.1", defaultPort
  50. if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
  51. host = ip.String()
  52. } else if hostport != "" {
  53. host = hostport
  54. }
  55. }
  56. return &Client{
  57. base: &url.URL{
  58. Scheme: scheme,
  59. Host: net.JoinHostPort(host, port),
  60. },
  61. http: http.DefaultClient,
  62. }, nil
  63. }
  64. func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
  65. var reqBody io.Reader
  66. var data []byte
  67. var err error
  68. switch reqData := reqData.(type) {
  69. case io.Reader:
  70. // reqData is already an io.Reader
  71. reqBody = reqData
  72. case nil:
  73. // noop
  74. default:
  75. data, err = json.Marshal(reqData)
  76. if err != nil {
  77. return err
  78. }
  79. reqBody = bytes.NewReader(data)
  80. }
  81. requestURL := c.base.JoinPath(path)
  82. request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
  83. if err != nil {
  84. return err
  85. }
  86. request.Header.Set("Content-Type", "application/json")
  87. request.Header.Set("Accept", "application/json")
  88. request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
  89. respObj, err := c.http.Do(request)
  90. if err != nil {
  91. return err
  92. }
  93. defer respObj.Body.Close()
  94. respBody, err := io.ReadAll(respObj.Body)
  95. if err != nil {
  96. return err
  97. }
  98. if err := checkError(respObj, respBody); err != nil {
  99. return err
  100. }
  101. if len(respBody) > 0 && respData != nil {
  102. if err := json.Unmarshal(respBody, respData); err != nil {
  103. return err
  104. }
  105. }
  106. return nil
  107. }
  108. const maxBufferSize = 512 * format.KiloByte
  109. func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
  110. var buf *bytes.Buffer
  111. if data != nil {
  112. bts, err := json.Marshal(data)
  113. if err != nil {
  114. return err
  115. }
  116. buf = bytes.NewBuffer(bts)
  117. }
  118. requestURL := c.base.JoinPath(path)
  119. request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
  120. if err != nil {
  121. return err
  122. }
  123. request.Header.Set("Content-Type", "application/json")
  124. request.Header.Set("Accept", "application/x-ndjson")
  125. request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
  126. response, err := c.http.Do(request)
  127. if err != nil {
  128. return err
  129. }
  130. defer response.Body.Close()
  131. scanner := bufio.NewScanner(response.Body)
  132. // increase the buffer size to avoid running out of space
  133. scanBuf := make([]byte, 0, maxBufferSize)
  134. scanner.Buffer(scanBuf, maxBufferSize)
  135. for scanner.Scan() {
  136. var errorResponse struct {
  137. Error string `json:"error,omitempty"`
  138. }
  139. bts := scanner.Bytes()
  140. if err := json.Unmarshal(bts, &errorResponse); err != nil {
  141. return fmt.Errorf("unmarshal: %w", err)
  142. }
  143. if errorResponse.Error != "" {
  144. return fmt.Errorf(errorResponse.Error)
  145. }
  146. if response.StatusCode >= http.StatusBadRequest {
  147. return StatusError{
  148. StatusCode: response.StatusCode,
  149. Status: response.Status,
  150. ErrorMessage: errorResponse.Error,
  151. }
  152. }
  153. if err := fn(bts); err != nil {
  154. return err
  155. }
  156. }
  157. return nil
  158. }
  159. type GenerateResponseFunc func(GenerateResponse) error
  160. func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
  161. return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error {
  162. var resp GenerateResponse
  163. if err := json.Unmarshal(bts, &resp); err != nil {
  164. return err
  165. }
  166. return fn(resp)
  167. })
  168. }
  169. type ChatResponseFunc func(ChatResponse) error
  170. func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
  171. return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
  172. var resp ChatResponse
  173. if err := json.Unmarshal(bts, &resp); err != nil {
  174. return err
  175. }
  176. return fn(resp)
  177. })
  178. }
  179. type PullProgressFunc func(ProgressResponse) error
  180. func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
  181. return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
  182. var resp ProgressResponse
  183. if err := json.Unmarshal(bts, &resp); err != nil {
  184. return err
  185. }
  186. return fn(resp)
  187. })
  188. }
  189. type PushProgressFunc func(ProgressResponse) error
  190. func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
  191. return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
  192. var resp ProgressResponse
  193. if err := json.Unmarshal(bts, &resp); err != nil {
  194. return err
  195. }
  196. return fn(resp)
  197. })
  198. }
  199. type CreateProgressFunc func(ProgressResponse) error
  200. func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
  201. return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
  202. var resp ProgressResponse
  203. if err := json.Unmarshal(bts, &resp); err != nil {
  204. return err
  205. }
  206. return fn(resp)
  207. })
  208. }
  209. func (c *Client) List(ctx context.Context) (*ListResponse, error) {
  210. var lr ListResponse
  211. if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil {
  212. return nil, err
  213. }
  214. return &lr, nil
  215. }
  216. func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
  217. if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
  218. return err
  219. }
  220. return nil
  221. }
  222. func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
  223. if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
  224. return err
  225. }
  226. return nil
  227. }
  228. func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) {
  229. var resp ShowResponse
  230. if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil {
  231. return nil, err
  232. }
  233. return &resp, nil
  234. }
  235. func (c *Client) Heartbeat(ctx context.Context) error {
  236. if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
  237. return err
  238. }
  239. return nil
  240. }
  241. func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
  242. var resp EmbeddingResponse
  243. if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
  244. return nil, err
  245. }
  246. return &resp, nil
  247. }
  248. func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
  249. return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
  250. }
  251. func (c *Client) Version(ctx context.Context) (string, error) {
  252. var version struct {
  253. Version string `json:"version"`
  254. }
  255. if err := c.do(ctx, http.MethodGet, "/api/version", nil, &version); err != nil {
  256. return "", err
  257. }
  258. return version.Version, nil
  259. }