client.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. // Package api implements the client-side API for code wishing to interact
  2. // with the ollama service. The methods of the [Client] type correspond to
  3. // the ollama REST API as described in [the API documentation].
  4. // The ollama command-line client itself uses this package to interact with
  5. // the backend service.
  6. //
  7. // # Examples
  8. //
  9. // Several examples of using this package are available [in the GitHub
  10. // repository].
  11. //
  12. // [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
  13. // [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples
  14. package api
  15. import (
  16. "bufio"
  17. "bytes"
  18. "context"
  19. "encoding/json"
  20. "fmt"
  21. "io"
  22. "net"
  23. "net/http"
  24. "net/url"
  25. "os"
  26. "runtime"
  27. "strconv"
  28. "strings"
  29. "github.com/ollama/ollama/format"
  30. "github.com/ollama/ollama/version"
  31. )
  32. // Client encapsulates client state for interacting with the ollama
  33. // service. Use [ClientFromEnvironment] to create new Clients.
  34. type Client struct {
  35. base *url.URL
  36. http *http.Client
  37. }
  38. func checkError(resp *http.Response, body []byte) error {
  39. if resp.StatusCode < http.StatusBadRequest {
  40. return nil
  41. }
  42. apiError := StatusError{StatusCode: resp.StatusCode}
  43. err := json.Unmarshal(body, &apiError)
  44. if err != nil {
  45. // Use the full body as the message if we fail to decode a response.
  46. apiError.ErrorMessage = string(body)
  47. }
  48. return apiError
  49. }
  50. // ClientFromEnvironment creates a new [Client] using configuration from the
  51. // environment variable OLLAMA_HOST, which points to the network host and
  52. // port on which the ollama service is listenting. The format of this variable
  53. // is:
  54. //
  55. // <scheme>://<host>:<port>
  56. //
  57. // If the variable is not specified, a default ollama host and port will be
  58. // used.
  59. func ClientFromEnvironment() (*Client, error) {
  60. ollamaHost, err := GetOllamaHost()
  61. if err != nil {
  62. return nil, err
  63. }
  64. return &Client{
  65. base: &url.URL{
  66. Scheme: ollamaHost.Scheme,
  67. Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
  68. },
  69. http: http.DefaultClient,
  70. }, nil
  71. }
  72. type OllamaHost struct {
  73. Scheme string
  74. Host string
  75. Port string
  76. }
  77. func GetOllamaHost() (OllamaHost, error) {
  78. defaultPort := "11434"
  79. hostVar := os.Getenv("OLLAMA_HOST")
  80. hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
  81. scheme, hostport, ok := strings.Cut(hostVar, "://")
  82. switch {
  83. case !ok:
  84. scheme, hostport = "http", hostVar
  85. case scheme == "http":
  86. defaultPort = "80"
  87. case scheme == "https":
  88. defaultPort = "443"
  89. }
  90. // trim trailing slashes
  91. hostport = strings.TrimRight(hostport, "/")
  92. host, port, err := net.SplitHostPort(hostport)
  93. if err != nil {
  94. host, port = "127.0.0.1", defaultPort
  95. if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
  96. host = ip.String()
  97. } else if hostport != "" {
  98. host = hostport
  99. }
  100. }
  101. if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
  102. return OllamaHost{}, ErrInvalidHostPort
  103. }
  104. return OllamaHost{
  105. Scheme: scheme,
  106. Host: host,
  107. Port: port,
  108. }, nil
  109. }
  110. func NewClient(base *url.URL, http *http.Client) *Client {
  111. return &Client{
  112. base: base,
  113. http: http,
  114. }
  115. }
  116. func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
  117. var reqBody io.Reader
  118. var data []byte
  119. var err error
  120. switch reqData := reqData.(type) {
  121. case io.Reader:
  122. // reqData is already an io.Reader
  123. reqBody = reqData
  124. case nil:
  125. // noop
  126. default:
  127. data, err = json.Marshal(reqData)
  128. if err != nil {
  129. return err
  130. }
  131. reqBody = bytes.NewReader(data)
  132. }
  133. requestURL := c.base.JoinPath(path)
  134. request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
  135. if err != nil {
  136. return err
  137. }
  138. request.Header.Set("Content-Type", "application/json")
  139. request.Header.Set("Accept", "application/json")
  140. request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
  141. respObj, err := c.http.Do(request)
  142. if err != nil {
  143. return err
  144. }
  145. defer respObj.Body.Close()
  146. respBody, err := io.ReadAll(respObj.Body)
  147. if err != nil {
  148. return err
  149. }
  150. if err := checkError(respObj, respBody); err != nil {
  151. return err
  152. }
  153. if len(respBody) > 0 && respData != nil {
  154. if err := json.Unmarshal(respBody, respData); err != nil {
  155. return err
  156. }
  157. }
  158. return nil
  159. }
  160. const maxBufferSize = 512 * format.KiloByte
  161. func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
  162. var buf *bytes.Buffer
  163. if data != nil {
  164. bts, err := json.Marshal(data)
  165. if err != nil {
  166. return err
  167. }
  168. buf = bytes.NewBuffer(bts)
  169. }
  170. requestURL := c.base.JoinPath(path)
  171. request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
  172. if err != nil {
  173. return err
  174. }
  175. request.Header.Set("Content-Type", "application/json")
  176. request.Header.Set("Accept", "application/x-ndjson")
  177. request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
  178. response, err := c.http.Do(request)
  179. if err != nil {
  180. return err
  181. }
  182. defer response.Body.Close()
  183. scanner := bufio.NewScanner(response.Body)
  184. // increase the buffer size to avoid running out of space
  185. scanBuf := make([]byte, 0, maxBufferSize)
  186. scanner.Buffer(scanBuf, maxBufferSize)
  187. for scanner.Scan() {
  188. var errorResponse struct {
  189. Error string `json:"error,omitempty"`
  190. }
  191. bts := scanner.Bytes()
  192. if err := json.Unmarshal(bts, &errorResponse); err != nil {
  193. return fmt.Errorf("unmarshal: %w", err)
  194. }
  195. if errorResponse.Error != "" {
  196. return fmt.Errorf(errorResponse.Error)
  197. }
  198. if response.StatusCode >= http.StatusBadRequest {
  199. return StatusError{
  200. StatusCode: response.StatusCode,
  201. Status: response.Status,
  202. ErrorMessage: errorResponse.Error,
  203. }
  204. }
  205. if err := fn(bts); err != nil {
  206. return err
  207. }
  208. }
  209. return nil
  210. }
  211. // GenerateResponseFunc is a function that [Client.Generate] invokes every time
  212. // a response is received from the service. If this function returns an error,
  213. // [Client.Generate] will stop generating and return this error.
  214. type GenerateResponseFunc func(GenerateResponse) error
  215. // Generate generates a response for a given prompt. The req parameter should
  216. // be populated with prompt details. fn is called for each response (there may
  217. // be multiple responses, e.g. in case streaming is enabled).
  218. func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
  219. return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error {
  220. var resp GenerateResponse
  221. if err := json.Unmarshal(bts, &resp); err != nil {
  222. return err
  223. }
  224. return fn(resp)
  225. })
  226. }
  227. // ChatResponseFunc is a function that [Client.Chat] invokes every time
  228. // a response is received from the service. If this function returns an error,
  229. // [Client.Chat] will stop generating and return this error.
  230. type ChatResponseFunc func(ChatResponse) error
  231. // Chat generates the next message in a chat. [ChatRequest] may contain a
  232. // sequence of messages which can be used to maintain chat history with a model.
  233. // fn is called for each response (there may be multiple responses, e.g. if case
  234. // streaming is enabled).
  235. func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
  236. return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
  237. var resp ChatResponse
  238. if err := json.Unmarshal(bts, &resp); err != nil {
  239. return err
  240. }
  241. return fn(resp)
  242. })
  243. }
  244. // PullProgressFunc is a function that [Client.Pull] invokes every time there
  245. // is progress with a "pull" request sent to the service. If this function
  246. // returns an error, [Client.Pull] will stop the process and return this error.
  247. type PullProgressFunc func(ProgressResponse) error
  248. // Pull downloads a model from the ollama library. fn is called each time
  249. // progress is made on the request and can be used to display a progress bar,
  250. // etc.
  251. func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
  252. return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
  253. var resp ProgressResponse
  254. if err := json.Unmarshal(bts, &resp); err != nil {
  255. return err
  256. }
  257. return fn(resp)
  258. })
  259. }
  260. // PushProgressFunc is a function that [Client.Push] invokes when progress is
  261. // made.
  262. // It's similar to other progress function types like [PullProgressFunc].
  263. type PushProgressFunc func(ProgressResponse) error
  264. // Push uploads a model to the model library; requires registering for ollama.ai
  265. // and adding a public key first. fn is called each time progress is made on
  266. // the request and can be used to display a progress bar, etc.
  267. func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
  268. return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
  269. var resp ProgressResponse
  270. if err := json.Unmarshal(bts, &resp); err != nil {
  271. return err
  272. }
  273. return fn(resp)
  274. })
  275. }
  276. // CreateProgressFunc is a function that [Client.Create] invokes when progress
  277. // is made.
  278. // It's similar to other progress function types like [PullProgressFunc].
  279. type CreateProgressFunc func(ProgressResponse) error
  280. // Create creates a model from a [Modelfile]. fn is a progress function that
  281. // behaves similarly to other methods (see [Client.Pull]).
  282. //
  283. // [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
  284. func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
  285. return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
  286. var resp ProgressResponse
  287. if err := json.Unmarshal(bts, &resp); err != nil {
  288. return err
  289. }
  290. return fn(resp)
  291. })
  292. }
  293. // List lists models that are available locally.
  294. func (c *Client) List(ctx context.Context) (*ListResponse, error) {
  295. var lr ListResponse
  296. if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil {
  297. return nil, err
  298. }
  299. return &lr, nil
  300. }
  301. // List running models.
  302. func (c *Client) ListRunning(ctx context.Context) (*ListResponse, error) {
  303. var lr ListResponse
  304. if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil {
  305. return nil, err
  306. }
  307. return &lr, nil
  308. }
  309. // Copy copies a model - creating a model with another name from an existing
  310. // model.
  311. func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
  312. if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
  313. return err
  314. }
  315. return nil
  316. }
  317. // Delete deletes a model and its data.
  318. func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
  319. if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
  320. return err
  321. }
  322. return nil
  323. }
  324. // Show obtains model information, including details, modelfile, license etc.
  325. func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) {
  326. var resp ShowResponse
  327. if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil {
  328. return nil, err
  329. }
  330. return &resp, nil
  331. }
  332. // Hearbeat checks if the server has started and is responsive; if yes, it
  333. // returns nil, otherwise an error.
  334. func (c *Client) Heartbeat(ctx context.Context) error {
  335. if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
  336. return err
  337. }
  338. return nil
  339. }
  340. // Embeddings generates embeddings from a model.
  341. func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
  342. var resp EmbeddingResponse
  343. if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
  344. return nil, err
  345. }
  346. return &resp, nil
  347. }
  348. // CreateBlob creates a blob from a file on the server. digest is the
  349. // expected SHA256 digest of the file, and r represents the file.
  350. func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
  351. return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
  352. }
  353. // Version returns the Ollama server version as a string.
  354. func (c *Client) Version(ctx context.Context) (string, error) {
  355. var version struct {
  356. Version string `json:"version"`
  357. }
  358. if err := c.do(ctx, http.MethodGet, "/api/version", nil, &version); err != nil {
  359. return "", err
  360. }
  361. return version.Version, nil
  362. }