client.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. // Package api implements the client-side API for code wishing to interact
  2. // with the ollama service. The methods of the [Client] type correspond to
  3. // the ollama REST API as described in [the API documentation].
  4. // The ollama command-line client itself uses this package to interact with
  5. // the backend service.
  6. //
  7. // # Examples
  8. //
  9. // Several examples of using this package are available [in the GitHub
  10. // repository].
  11. //
  12. // [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
  13. // [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples
  14. package api
  15. import (
  16. "bufio"
  17. "bytes"
  18. "context"
  19. "encoding/base64"
  20. "encoding/json"
  21. "fmt"
  22. "io"
  23. "net"
  24. "net/http"
  25. "net/url"
  26. "os"
  27. "path/filepath"
  28. "runtime"
  29. "strings"
  30. "time"
  31. "github.com/ollama/ollama/auth"
  32. "github.com/ollama/ollama/envconfig"
  33. "github.com/ollama/ollama/format"
  34. "github.com/ollama/ollama/version"
  35. )
  36. // Client encapsulates client state for interacting with the ollama
  37. // service. Use [ClientFromEnvironment] to create new Clients.
  38. type Client struct {
  39. base *url.URL
  40. http *http.Client
  41. }
  42. func checkError(resp *http.Response, body []byte) error {
  43. if resp.StatusCode < http.StatusBadRequest {
  44. return nil
  45. }
  46. apiError := StatusError{StatusCode: resp.StatusCode}
  47. err := json.Unmarshal(body, &apiError)
  48. if err != nil {
  49. // Use the full body as the message if we fail to decode a response.
  50. apiError.ErrorMessage = string(body)
  51. }
  52. return apiError
  53. }
  54. // ClientFromEnvironment creates a new [Client] using configuration from the
  55. // environment variable OLLAMA_HOST, which points to the network host and
  56. // port on which the ollama service is listenting. The format of this variable
  57. // is:
  58. //
  59. // <scheme>://<host>:<port>
  60. //
  61. // If the variable is not specified, a default ollama host and port will be
  62. // used.
  63. func ClientFromEnvironment() (*Client, error) {
  64. ollamaHost := envconfig.Host
  65. return &Client{
  66. base: &url.URL{
  67. Scheme: ollamaHost.Scheme,
  68. Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
  69. },
  70. http: http.DefaultClient,
  71. }, nil
  72. }
  73. func NewClient(base *url.URL, http *http.Client) *Client {
  74. return &Client{
  75. base: base,
  76. http: http,
  77. }
  78. }
  79. func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
  80. var reqBody io.Reader
  81. var data []byte
  82. var err error
  83. switch reqData := reqData.(type) {
  84. case io.Reader:
  85. // reqData is already an io.Reader
  86. reqBody = reqData
  87. case nil:
  88. // noop
  89. default:
  90. data, err = json.Marshal(reqData)
  91. if err != nil {
  92. return err
  93. }
  94. reqBody = bytes.NewReader(data)
  95. }
  96. requestURL := c.base.JoinPath(path)
  97. request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
  98. if err != nil {
  99. return err
  100. }
  101. request.Header.Set("Content-Type", "application/json")
  102. request.Header.Set("Accept", "application/json")
  103. request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
  104. respObj, err := c.http.Do(request)
  105. if err != nil {
  106. return err
  107. }
  108. defer respObj.Body.Close()
  109. respBody, err := io.ReadAll(respObj.Body)
  110. if err != nil {
  111. return err
  112. }
  113. if err := checkError(respObj, respBody); err != nil {
  114. return err
  115. }
  116. if len(respBody) > 0 && respData != nil {
  117. if err := json.Unmarshal(respBody, respData); err != nil {
  118. return err
  119. }
  120. }
  121. return nil
  122. }
  123. const maxBufferSize = 512 * format.KiloByte
  124. func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
  125. var buf *bytes.Buffer
  126. if data != nil {
  127. bts, err := json.Marshal(data)
  128. if err != nil {
  129. return err
  130. }
  131. buf = bytes.NewBuffer(bts)
  132. }
  133. requestURL := c.base.JoinPath(path)
  134. request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
  135. if err != nil {
  136. return err
  137. }
  138. request.Header.Set("Content-Type", "application/json")
  139. request.Header.Set("Accept", "application/x-ndjson")
  140. request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
  141. response, err := c.http.Do(request)
  142. if err != nil {
  143. return err
  144. }
  145. defer response.Body.Close()
  146. scanner := bufio.NewScanner(response.Body)
  147. // increase the buffer size to avoid running out of space
  148. scanBuf := make([]byte, 0, maxBufferSize)
  149. scanner.Buffer(scanBuf, maxBufferSize)
  150. for scanner.Scan() {
  151. var errorResponse struct {
  152. Error string `json:"error,omitempty"`
  153. }
  154. bts := scanner.Bytes()
  155. if err := json.Unmarshal(bts, &errorResponse); err != nil {
  156. return fmt.Errorf("unmarshal: %w", err)
  157. }
  158. if errorResponse.Error != "" {
  159. return fmt.Errorf(errorResponse.Error)
  160. }
  161. if response.StatusCode >= http.StatusBadRequest {
  162. return StatusError{
  163. StatusCode: response.StatusCode,
  164. Status: response.Status,
  165. ErrorMessage: errorResponse.Error,
  166. }
  167. }
  168. if err := fn(bts); err != nil {
  169. return err
  170. }
  171. }
  172. return nil
  173. }
  174. // GenerateResponseFunc is a function that [Client.Generate] invokes every time
  175. // a response is received from the service. If this function returns an error,
  176. // [Client.Generate] will stop generating and return this error.
  177. type GenerateResponseFunc func(GenerateResponse) error
  178. // Generate generates a response for a given prompt. The req parameter should
  179. // be populated with prompt details. fn is called for each response (there may
  180. // be multiple responses, e.g. in case streaming is enabled).
  181. func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
  182. return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error {
  183. var resp GenerateResponse
  184. if err := json.Unmarshal(bts, &resp); err != nil {
  185. return err
  186. }
  187. return fn(resp)
  188. })
  189. }
  190. // ChatResponseFunc is a function that [Client.Chat] invokes every time
  191. // a response is received from the service. If this function returns an error,
  192. // [Client.Chat] will stop generating and return this error.
  193. type ChatResponseFunc func(ChatResponse) error
  194. // Chat generates the next message in a chat. [ChatRequest] may contain a
  195. // sequence of messages which can be used to maintain chat history with a model.
  196. // fn is called for each response (there may be multiple responses, e.g. if case
  197. // streaming is enabled).
  198. func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
  199. return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
  200. var resp ChatResponse
  201. if err := json.Unmarshal(bts, &resp); err != nil {
  202. return err
  203. }
  204. return fn(resp)
  205. })
  206. }
  207. // PullProgressFunc is a function that [Client.Pull] invokes every time there
  208. // is progress with a "pull" request sent to the service. If this function
  209. // returns an error, [Client.Pull] will stop the process and return this error.
  210. type PullProgressFunc func(ProgressResponse) error
  211. // Pull downloads a model from the ollama library. fn is called each time
  212. // progress is made on the request and can be used to display a progress bar,
  213. // etc.
  214. func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
  215. return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
  216. var resp ProgressResponse
  217. if err := json.Unmarshal(bts, &resp); err != nil {
  218. return err
  219. }
  220. return fn(resp)
  221. })
  222. }
  223. // PushProgressFunc is a function that [Client.Push] invokes when progress is
  224. // made.
  225. // It's similar to other progress function types like [PullProgressFunc].
  226. type PushProgressFunc func(ProgressResponse) error
  227. // Push uploads a model to the model library; requires registering for ollama.ai
  228. // and adding a public key first. fn is called each time progress is made on
  229. // the request and can be used to display a progress bar, etc.
  230. func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
  231. return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
  232. var resp ProgressResponse
  233. if err := json.Unmarshal(bts, &resp); err != nil {
  234. return err
  235. }
  236. return fn(resp)
  237. })
  238. }
  239. // CreateProgressFunc is a function that [Client.Create] invokes when progress
  240. // is made.
  241. // It's similar to other progress function types like [PullProgressFunc].
  242. type CreateProgressFunc func(ProgressResponse) error
  243. // Create creates a model from a [Modelfile]. fn is a progress function that
  244. // behaves similarly to other methods (see [Client.Pull]).
  245. //
  246. // [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
  247. func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
  248. return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
  249. var resp ProgressResponse
  250. if err := json.Unmarshal(bts, &resp); err != nil {
  251. return err
  252. }
  253. return fn(resp)
  254. })
  255. }
  256. // List lists models that are available locally.
  257. func (c *Client) List(ctx context.Context) (*ListResponse, error) {
  258. var lr ListResponse
  259. if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil {
  260. return nil, err
  261. }
  262. return &lr, nil
  263. }
  264. // List running models.
  265. func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
  266. var lr ProcessResponse
  267. if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil {
  268. return nil, err
  269. }
  270. return &lr, nil
  271. }
  272. // Copy copies a model - creating a model with another name from an existing
  273. // model.
  274. func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
  275. if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
  276. return err
  277. }
  278. return nil
  279. }
  280. // Delete deletes a model and its data.
  281. func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
  282. if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
  283. return err
  284. }
  285. return nil
  286. }
  287. // Show obtains model information, including details, modelfile, license etc.
  288. func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) {
  289. var resp ShowResponse
  290. if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil {
  291. return nil, err
  292. }
  293. return &resp, nil
  294. }
  295. // Hearbeat checks if the server has started and is responsive; if yes, it
  296. // returns nil, otherwise an error.
  297. func (c *Client) Heartbeat(ctx context.Context) error {
  298. if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
  299. return err
  300. }
  301. return nil
  302. }
  303. // Embeddings generates embeddings from a model.
  304. func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
  305. var resp EmbeddingResponse
  306. if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
  307. return nil, err
  308. }
  309. return &resp, nil
  310. }
  311. // CreateBlob creates a blob from a file on the server. digest is the
  312. // expected SHA256 digest of the file, and r represents the file.
  313. func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
  314. return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
  315. }
  316. // Version returns the Ollama server version as a string.
  317. func (c *Client) Version(ctx context.Context) (string, error) {
  318. var version struct {
  319. Version string `json:"version"`
  320. }
  321. if err := c.do(ctx, http.MethodGet, "/api/version", nil, &version); err != nil {
  322. return "", err
  323. }
  324. return version.Version, nil
  325. }
  326. func Authorization(ctx context.Context, request *http.Request) (string, error) {
  327. data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
  328. home, err := os.UserHomeDir()
  329. if err != nil {
  330. return "", err
  331. }
  332. knownHostsFile, err := os.OpenFile(filepath.Join(home, ".ollama", "known_hosts"), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0600)
  333. if err != nil {
  334. return "", err
  335. }
  336. defer knownHostsFile.Close()
  337. token, err := auth.Sign(ctx, data)
  338. if err != nil {
  339. return "", err
  340. }
  341. // interleave request data into the token
  342. key, sig, _ := strings.Cut(token, ":")
  343. return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
  344. }
  345. // EnvConfig returns the environment configuration for the server.
  346. func (c *Client) ServerConfig(ctx context.Context) (*ServerConfig, error) {
  347. var config ServerConfig
  348. if err := c.do(ctx, http.MethodGet, "/api/config", nil, &config); err != nil {
  349. return nil, err
  350. }
  351. return &config, nil
  352. }
  353. func Authorization(ctx context.Context, request *http.Request) (string, error) {
  354. data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
  355. home, err := os.UserHomeDir()
  356. if err != nil {
  357. return "", err
  358. }
  359. knownHostsFile, err := os.OpenFile(filepath.Join(home, ".ollama", "known_hosts"), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0600)
  360. if err != nil {
  361. return "", err
  362. }
  363. defer knownHostsFile.Close()
  364. token, err := auth.Sign(ctx, data)
  365. if err != nil {
  366. return "", err
  367. }
  368. // interleave request data into the token
  369. key, sig, _ := strings.Cut(token, ":")
  370. return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
  371. }