client.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. // Package api implements the client-side API for code wishing to interact
  2. // with the ollama service. The methods of the [Client] type correspond to
  3. // the ollama REST API as described in [the API documentation].
  4. // The ollama command-line client itself uses this package to interact with
  5. // the backend service.
  6. //
  7. // # Examples
  8. //
  9. // Several examples of using this package are available [in the GitHub
  10. // repository].
  11. //
  12. // [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
  13. // [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples
  14. package api
  15. import (
  16. "bufio"
  17. "bytes"
  18. "context"
  19. "encoding/json"
  20. "fmt"
  21. "io"
  22. "net/http"
  23. "net/url"
  24. "runtime"
  25. "github.com/ollama/ollama/envconfig"
  26. "github.com/ollama/ollama/format"
  27. "github.com/ollama/ollama/version"
  28. )
  29. // StatusError is an error with an HTTP status code and message,
  30. // it is parsed on the client-side and not returned from the API
  31. type StatusError struct {
  32. StatusCode int // e.g. 200
  33. Status string // e.g. "200 OK"
  34. ErrorResponse
  35. }
  36. func (e StatusError) Error() string {
  37. switch {
  38. case e.Status != "" && e.Err != "":
  39. return fmt.Sprintf("%s: %s", e.Status, e.Err)
  40. case e.Status != "":
  41. return e.Status
  42. case e.Err != "":
  43. return e.Err
  44. default:
  45. // this should not happen
  46. return "something went wrong, please see the ollama server logs for details"
  47. }
  48. }
  49. // Client encapsulates client state for interacting with the ollama
  50. // service. Use [ClientFromEnvironment] to create new Clients.
  51. type Client struct {
  52. base *url.URL
  53. http *http.Client
  54. }
  55. func checkError(resp *http.Response, body []byte) error {
  56. if resp.StatusCode < http.StatusBadRequest {
  57. return nil
  58. }
  59. apiError := StatusError{StatusCode: resp.StatusCode}
  60. err := json.Unmarshal(body, &apiError)
  61. if err != nil {
  62. // Use the full body as the message if we fail to decode a response.
  63. apiError.Err = string(body)
  64. }
  65. return apiError
  66. }
  67. // ClientFromEnvironment creates a new [Client] using configuration from the
  68. // environment variable OLLAMA_HOST, which points to the network host and
  69. // port on which the ollama service is listening. The format of this variable
  70. // is:
  71. //
  72. // <scheme>://<host>:<port>
  73. //
  74. // If the variable is not specified, a default ollama host and port will be
  75. // used.
  76. func ClientFromEnvironment() (*Client, error) {
  77. return &Client{
  78. base: envconfig.Host(),
  79. http: http.DefaultClient,
  80. }, nil
  81. }
  82. func NewClient(base *url.URL, http *http.Client) *Client {
  83. return &Client{
  84. base: base,
  85. http: http,
  86. }
  87. }
  88. func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
  89. var reqBody io.Reader
  90. var data []byte
  91. var err error
  92. switch reqData := reqData.(type) {
  93. case io.Reader:
  94. // reqData is already an io.Reader
  95. reqBody = reqData
  96. case nil:
  97. // noop
  98. default:
  99. data, err = json.Marshal(reqData)
  100. if err != nil {
  101. return err
  102. }
  103. reqBody = bytes.NewReader(data)
  104. }
  105. requestURL := c.base.JoinPath(path)
  106. request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
  107. if err != nil {
  108. return err
  109. }
  110. request.Header.Set("Content-Type", "application/json")
  111. request.Header.Set("Accept", "application/json")
  112. request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
  113. respObj, err := c.http.Do(request)
  114. if err != nil {
  115. return err
  116. }
  117. defer respObj.Body.Close()
  118. respBody, err := io.ReadAll(respObj.Body)
  119. if err != nil {
  120. return err
  121. }
  122. if err := checkError(respObj, respBody); err != nil {
  123. return err
  124. }
  125. if len(respBody) > 0 && respData != nil {
  126. if err := json.Unmarshal(respBody, respData); err != nil {
  127. return err
  128. }
  129. }
  130. return nil
  131. }
  132. const maxBufferSize = 512 * format.KiloByte
  133. func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
  134. var buf io.Reader
  135. if data != nil {
  136. bts, err := json.Marshal(data)
  137. if err != nil {
  138. return err
  139. }
  140. buf = bytes.NewBuffer(bts)
  141. }
  142. requestURL := c.base.JoinPath(path)
  143. request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
  144. if err != nil {
  145. return err
  146. }
  147. request.Header.Set("Content-Type", "application/json")
  148. request.Header.Set("Accept", "application/x-ndjson")
  149. request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
  150. response, err := c.http.Do(request)
  151. if err != nil {
  152. return err
  153. }
  154. defer response.Body.Close()
  155. scanner := bufio.NewScanner(response.Body)
  156. // increase the buffer size to avoid running out of space
  157. scanBuf := make([]byte, 0, maxBufferSize)
  158. scanner.Buffer(scanBuf, maxBufferSize)
  159. for scanner.Scan() {
  160. bts := scanner.Bytes()
  161. var errorResponse ErrorResponse
  162. if err := json.Unmarshal(bts, &errorResponse); err != nil {
  163. return fmt.Errorf("unmarshal: %w", err)
  164. }
  165. if errorResponse.Err != "" {
  166. return errorResponse
  167. }
  168. if response.StatusCode >= http.StatusBadRequest {
  169. return StatusError{
  170. StatusCode: response.StatusCode,
  171. Status: response.Status,
  172. ErrorResponse: errorResponse,
  173. }
  174. }
  175. if err := fn(bts); err != nil {
  176. return err
  177. }
  178. }
  179. return nil
  180. }
  181. // GenerateResponseFunc is a function that [Client.Generate] invokes every time
  182. // a response is received from the service. If this function returns an error,
  183. // [Client.Generate] will stop generating and return this error.
  184. type GenerateResponseFunc func(GenerateResponse) error
  185. // Generate generates a response for a given prompt. The req parameter should
  186. // be populated with prompt details. fn is called for each response (there may
  187. // be multiple responses, e.g. in case streaming is enabled).
  188. func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
  189. return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error {
  190. var resp GenerateResponse
  191. if err := json.Unmarshal(bts, &resp); err != nil {
  192. return err
  193. }
  194. return fn(resp)
  195. })
  196. }
  197. // ChatResponseFunc is a function that [Client.Chat] invokes every time
  198. // a response is received from the service. If this function returns an error,
  199. // [Client.Chat] will stop generating and return this error.
  200. type ChatResponseFunc func(ChatResponse) error
  201. // Chat generates the next message in a chat. [ChatRequest] may contain a
  202. // sequence of messages which can be used to maintain chat history with a model.
  203. // fn is called for each response (there may be multiple responses, e.g. if case
  204. // streaming is enabled).
  205. func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
  206. return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
  207. var resp ChatResponse
  208. if err := json.Unmarshal(bts, &resp); err != nil {
  209. return err
  210. }
  211. return fn(resp)
  212. })
  213. }
  214. // PullProgressFunc is a function that [Client.Pull] invokes every time there
  215. // is progress with a "pull" request sent to the service. If this function
  216. // returns an error, [Client.Pull] will stop the process and return this error.
  217. type PullProgressFunc func(ProgressResponse) error
  218. // Pull downloads a model from the ollama library. fn is called each time
  219. // progress is made on the request and can be used to display a progress bar,
  220. // etc.
  221. func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
  222. return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
  223. var resp ProgressResponse
  224. if err := json.Unmarshal(bts, &resp); err != nil {
  225. return err
  226. }
  227. return fn(resp)
  228. })
  229. }
  230. // PushProgressFunc is a function that [Client.Push] invokes when progress is
  231. // made.
  232. // It's similar to other progress function types like [PullProgressFunc].
  233. type PushProgressFunc func(ProgressResponse) error
  234. // Push uploads a model to the model library; requires registering for ollama.ai
  235. // and adding a public key first. fn is called each time progress is made on
  236. // the request and can be used to display a progress bar, etc.
  237. func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
  238. return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
  239. var resp ProgressResponse
  240. if err := json.Unmarshal(bts, &resp); err != nil {
  241. return err
  242. }
  243. return fn(resp)
  244. })
  245. }
  246. // CreateProgressFunc is a function that [Client.Create] invokes when progress
  247. // is made.
  248. // It's similar to other progress function types like [PullProgressFunc].
  249. type CreateProgressFunc func(ProgressResponse) error
  250. // Create creates a model from a [Modelfile]. fn is a progress function that
  251. // behaves similarly to other methods (see [Client.Pull]).
  252. //
  253. // [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
  254. func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
  255. return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
  256. var resp ProgressResponse
  257. if err := json.Unmarshal(bts, &resp); err != nil {
  258. return err
  259. }
  260. return fn(resp)
  261. })
  262. }
  263. // List lists models that are available locally.
  264. func (c *Client) List(ctx context.Context) (*ListResponse, error) {
  265. var lr ListResponse
  266. if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil {
  267. return nil, err
  268. }
  269. return &lr, nil
  270. }
  271. // ListRunning lists running models.
  272. func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
  273. var lr ProcessResponse
  274. if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil {
  275. return nil, err
  276. }
  277. return &lr, nil
  278. }
  279. // Copy copies a model - creating a model with another name from an existing
  280. // model.
  281. func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
  282. if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
  283. return err
  284. }
  285. return nil
  286. }
  287. // Delete deletes a model and its data.
  288. func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
  289. if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
  290. return err
  291. }
  292. return nil
  293. }
  294. // Show obtains model information, including details, modelfile, license etc.
  295. func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) {
  296. var resp ShowResponse
  297. if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil {
  298. return nil, err
  299. }
  300. return &resp, nil
  301. }
  302. // Heartbeat checks if the server has started and is responsive; if yes, it
  303. // returns nil, otherwise an error.
  304. func (c *Client) Heartbeat(ctx context.Context) error {
  305. if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
  306. return err
  307. }
  308. return nil
  309. }
  310. // Embed generates embeddings from a model.
  311. func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
  312. var resp EmbedResponse
  313. if err := c.do(ctx, http.MethodPost, "/api/embed", req, &resp); err != nil {
  314. return nil, err
  315. }
  316. return &resp, nil
  317. }
  318. // Embeddings generates an embedding from a model.
  319. func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
  320. var resp EmbeddingResponse
  321. if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
  322. return nil, err
  323. }
  324. return &resp, nil
  325. }
  326. // CreateBlob creates a blob from a file on the server. digest is the
  327. // expected SHA256 digest of the file, and r represents the file.
  328. func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
  329. return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
  330. }
  331. // Version returns the Ollama server version as a string.
  332. func (c *Client) Version(ctx context.Context) (string, error) {
  333. var version struct {
  334. Version string `json:"version"`
  335. }
  336. if err := c.do(ctx, http.MethodGet, "/api/version", nil, &version); err != nil {
  337. return "", err
  338. }
  339. return version.Version, nil
  340. }