client.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. // Package api implements the client-side API for code wishing to interact
  2. // with the ollama service. The methods of the [Client] type correspond to
  3. // the ollama REST API as described in https://github.com/ollama/ollama/blob/main/docs/api.md
  4. //
  5. // The ollama command-line client itself uses this package to interact with
  6. // the backend service.
  7. package api
  8. import (
  9. "bufio"
  10. "bytes"
  11. "context"
  12. "encoding/json"
  13. "fmt"
  14. "io"
  15. "net"
  16. "net/http"
  17. "net/url"
  18. "os"
  19. "runtime"
  20. "strings"
  21. "github.com/ollama/ollama/format"
  22. "github.com/ollama/ollama/version"
  23. )
  24. // Client encapsulates client state for interacting with the ollama
  25. // service. Use [ClientFromEnvironment] to create new Clients.
  26. type Client struct {
  27. base *url.URL
  28. http *http.Client
  29. }
  30. func checkError(resp *http.Response, body []byte) error {
  31. if resp.StatusCode < http.StatusBadRequest {
  32. return nil
  33. }
  34. apiError := StatusError{StatusCode: resp.StatusCode}
  35. err := json.Unmarshal(body, &apiError)
  36. if err != nil {
  37. // Use the full body as the message if we fail to decode a response.
  38. apiError.ErrorMessage = string(body)
  39. }
  40. return apiError
  41. }
  42. // ClientFromEnvironment creates a new [Client] using configuration from the
  43. // environment variable OLLAMA_HOST, which points to the network host and
  44. // port on which the ollama service is listenting. The format of this variable
  45. // is:
  46. //
  47. // <scheme>://<host>:<port>
  48. //
  49. // If the variable is not specified, a default ollama host and port will be
  50. // used.
  51. func ClientFromEnvironment() (*Client, error) {
  52. defaultPort := "11434"
  53. scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
  54. switch {
  55. case !ok:
  56. scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
  57. case scheme == "http":
  58. defaultPort = "80"
  59. case scheme == "https":
  60. defaultPort = "443"
  61. }
  62. // trim trailing slashes
  63. hostport = strings.TrimRight(hostport, "/")
  64. host, port, err := net.SplitHostPort(hostport)
  65. if err != nil {
  66. host, port = "127.0.0.1", defaultPort
  67. if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
  68. host = ip.String()
  69. } else if hostport != "" {
  70. host = hostport
  71. }
  72. }
  73. return &Client{
  74. base: &url.URL{
  75. Scheme: scheme,
  76. Host: net.JoinHostPort(host, port),
  77. },
  78. http: http.DefaultClient,
  79. }, nil
  80. }
  81. func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
  82. var reqBody io.Reader
  83. var data []byte
  84. var err error
  85. switch reqData := reqData.(type) {
  86. case io.Reader:
  87. // reqData is already an io.Reader
  88. reqBody = reqData
  89. case nil:
  90. // noop
  91. default:
  92. data, err = json.Marshal(reqData)
  93. if err != nil {
  94. return err
  95. }
  96. reqBody = bytes.NewReader(data)
  97. }
  98. requestURL := c.base.JoinPath(path)
  99. request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
  100. if err != nil {
  101. return err
  102. }
  103. request.Header.Set("Content-Type", "application/json")
  104. request.Header.Set("Accept", "application/json")
  105. request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
  106. respObj, err := c.http.Do(request)
  107. if err != nil {
  108. return err
  109. }
  110. defer respObj.Body.Close()
  111. respBody, err := io.ReadAll(respObj.Body)
  112. if err != nil {
  113. return err
  114. }
  115. if err := checkError(respObj, respBody); err != nil {
  116. return err
  117. }
  118. if len(respBody) > 0 && respData != nil {
  119. if err := json.Unmarshal(respBody, respData); err != nil {
  120. return err
  121. }
  122. }
  123. return nil
  124. }
  125. const maxBufferSize = 512 * format.KiloByte
  126. func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
  127. var buf *bytes.Buffer
  128. if data != nil {
  129. bts, err := json.Marshal(data)
  130. if err != nil {
  131. return err
  132. }
  133. buf = bytes.NewBuffer(bts)
  134. }
  135. requestURL := c.base.JoinPath(path)
  136. request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
  137. if err != nil {
  138. return err
  139. }
  140. request.Header.Set("Content-Type", "application/json")
  141. request.Header.Set("Accept", "application/x-ndjson")
  142. request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
  143. response, err := c.http.Do(request)
  144. if err != nil {
  145. return err
  146. }
  147. defer response.Body.Close()
  148. scanner := bufio.NewScanner(response.Body)
  149. // increase the buffer size to avoid running out of space
  150. scanBuf := make([]byte, 0, maxBufferSize)
  151. scanner.Buffer(scanBuf, maxBufferSize)
  152. for scanner.Scan() {
  153. var errorResponse struct {
  154. Error string `json:"error,omitempty"`
  155. }
  156. bts := scanner.Bytes()
  157. if err := json.Unmarshal(bts, &errorResponse); err != nil {
  158. return fmt.Errorf("unmarshal: %w", err)
  159. }
  160. if errorResponse.Error != "" {
  161. return fmt.Errorf(errorResponse.Error)
  162. }
  163. if response.StatusCode >= http.StatusBadRequest {
  164. return StatusError{
  165. StatusCode: response.StatusCode,
  166. Status: response.Status,
  167. ErrorMessage: errorResponse.Error,
  168. }
  169. }
  170. if err := fn(bts); err != nil {
  171. return err
  172. }
  173. }
  174. return nil
  175. }
  176. // GenerateResponseFunc is a function that [Client.Generate] invokes every time
  177. // a response is received from the service. If this function returns an error,
  178. // [Client.Generate] will stop generating and return this error.
  179. type GenerateResponseFunc func(GenerateResponse) error
  180. // Generate generates a response for a given prompt. The req parameter should
  181. // be populated with prompt details. fn is called for each response (there may
  182. // be multiple responses, e.g. in case streaming is enabled).
  183. func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
  184. return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error {
  185. var resp GenerateResponse
  186. if err := json.Unmarshal(bts, &resp); err != nil {
  187. return err
  188. }
  189. return fn(resp)
  190. })
  191. }
  192. // ChatResponseFunc is a function that [Client.Chat] invokes every time
  193. // a response is received from the service. If this function returns an error,
  194. // [Client.Chat] will stop generating and return this error.
  195. type ChatResponseFunc func(ChatResponse) error
  196. // Chat generates the next message in a chat. [ChatRequest] may contain a
  197. // sequence of messages which can be used to maintain chat history with a model.
  198. // fn is called for each response (there may be multiple responses, e.g. if case
  199. // streaming is enabled).
  200. func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
  201. return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
  202. var resp ChatResponse
  203. if err := json.Unmarshal(bts, &resp); err != nil {
  204. return err
  205. }
  206. return fn(resp)
  207. })
  208. }
  209. // PullProgressFunc is a function that [Client.Pull] invokes every time there
  210. // is progress with a "pull" request sent to the service. If this function
  211. // returns an error, [Client.Pull] will stop the process and return this error.
  212. type PullProgressFunc func(ProgressResponse) error
  213. // Pull downloads a model from the ollama library. fn is called each time
  214. // progress is made on the request and can be used to display a progress bar,
  215. // etc.
  216. func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
  217. return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
  218. var resp ProgressResponse
  219. if err := json.Unmarshal(bts, &resp); err != nil {
  220. return err
  221. }
  222. return fn(resp)
  223. })
  224. }
  225. type PushProgressFunc func(ProgressResponse) error
  226. func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
  227. return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
  228. var resp ProgressResponse
  229. if err := json.Unmarshal(bts, &resp); err != nil {
  230. return err
  231. }
  232. return fn(resp)
  233. })
  234. }
  235. type CreateProgressFunc func(ProgressResponse) error
  236. func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
  237. return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
  238. var resp ProgressResponse
  239. if err := json.Unmarshal(bts, &resp); err != nil {
  240. return err
  241. }
  242. return fn(resp)
  243. })
  244. }
  245. func (c *Client) List(ctx context.Context) (*ListResponse, error) {
  246. var lr ListResponse
  247. if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil {
  248. return nil, err
  249. }
  250. return &lr, nil
  251. }
  252. func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
  253. if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
  254. return err
  255. }
  256. return nil
  257. }
  258. func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
  259. if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
  260. return err
  261. }
  262. return nil
  263. }
  264. func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) {
  265. var resp ShowResponse
  266. if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil {
  267. return nil, err
  268. }
  269. return &resp, nil
  270. }
  271. func (c *Client) Heartbeat(ctx context.Context) error {
  272. if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
  273. return err
  274. }
  275. return nil
  276. }
  277. func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
  278. var resp EmbeddingResponse
  279. if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
  280. return nil, err
  281. }
  282. return &resp, nil
  283. }
  284. func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
  285. return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
  286. }
  287. func (c *Client) Version(ctx context.Context) (string, error) {
  288. var version struct {
  289. Version string `json:"version"`
  290. }
  291. if err := c.do(ctx, http.MethodGet, "/api/version", nil, &version); err != nil {
  292. return "", err
  293. }
  294. return version.Version, nil
  295. }