client.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. // Package api implements the client-side API for code wishing to interact
  2. // with the ollama service. The methods of the [Client] type correspond to
  3. // the ollama REST API as described in https://github.com/ollama/ollama/blob/main/docs/api.md
  4. //
  5. // The ollama command-line client itself uses this package to interact with
  6. // the backend service.
  7. package api
  8. import (
  9. "bufio"
  10. "bytes"
  11. "context"
  12. "encoding/json"
  13. "fmt"
  14. "io"
  15. "net"
  16. "net/http"
  17. "net/url"
  18. "os"
  19. "runtime"
  20. "strconv"
  21. "strings"
  22. "github.com/ollama/ollama/format"
  23. "github.com/ollama/ollama/version"
  24. )
  25. // Client encapsulates client state for interacting with the ollama
  26. // service. Use [ClientFromEnvironment] to create new Clients.
  27. type Client struct {
  28. base *url.URL
  29. http *http.Client
  30. }
  31. func checkError(resp *http.Response, body []byte) error {
  32. if resp.StatusCode < http.StatusBadRequest {
  33. return nil
  34. }
  35. apiError := StatusError{StatusCode: resp.StatusCode}
  36. err := json.Unmarshal(body, &apiError)
  37. if err != nil {
  38. // Use the full body as the message if we fail to decode a response.
  39. apiError.ErrorMessage = string(body)
  40. }
  41. return apiError
  42. }
  43. // ClientFromEnvironment creates a new [Client] using configuration from the
  44. // environment variable OLLAMA_HOST, which points to the network host and
  45. // port on which the ollama service is listenting. The format of this variable
  46. // is:
  47. //
  48. // <scheme>://<host>:<port>
  49. //
  50. // If the variable is not specified, a default ollama host and port will be
  51. // used.
  52. func ClientFromEnvironment() (*Client, error) {
  53. ollamaHost, err := GetOllamaHost()
  54. if err != nil {
  55. return nil, err
  56. }
  57. return &Client{
  58. base: &url.URL{
  59. Scheme: ollamaHost.Scheme,
  60. Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
  61. },
  62. http: http.DefaultClient,
  63. }, nil
  64. }
  65. type OllamaHost struct {
  66. Scheme string
  67. Host string
  68. Port string
  69. }
  70. func GetOllamaHost() (OllamaHost, error) {
  71. defaultPort := "11434"
  72. hostVar := os.Getenv("OLLAMA_HOST")
  73. hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
  74. scheme, hostport, ok := strings.Cut(hostVar, "://")
  75. switch {
  76. case !ok:
  77. scheme, hostport = "http", hostVar
  78. case scheme == "http":
  79. defaultPort = "80"
  80. case scheme == "https":
  81. defaultPort = "443"
  82. }
  83. // trim trailing slashes
  84. hostport = strings.TrimRight(hostport, "/")
  85. host, port, err := net.SplitHostPort(hostport)
  86. if err != nil {
  87. host, port = "127.0.0.1", defaultPort
  88. if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
  89. host = ip.String()
  90. } else if hostport != "" {
  91. host = hostport
  92. }
  93. }
  94. if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
  95. return OllamaHost{}, ErrInvalidHostPort
  96. }
  97. return OllamaHost{
  98. Scheme: scheme,
  99. Host: host,
  100. Port: port,
  101. }, nil
  102. }
  103. func NewClient(base *url.URL, http *http.Client) *Client {
  104. return &Client{
  105. base: base,
  106. http: http,
  107. }
  108. }
  109. func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
  110. var reqBody io.Reader
  111. var data []byte
  112. var err error
  113. switch reqData := reqData.(type) {
  114. case io.Reader:
  115. // reqData is already an io.Reader
  116. reqBody = reqData
  117. case nil:
  118. // noop
  119. default:
  120. data, err = json.Marshal(reqData)
  121. if err != nil {
  122. return err
  123. }
  124. reqBody = bytes.NewReader(data)
  125. }
  126. requestURL := c.base.JoinPath(path)
  127. request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
  128. if err != nil {
  129. return err
  130. }
  131. request.Header.Set("Content-Type", "application/json")
  132. request.Header.Set("Accept", "application/json")
  133. request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
  134. respObj, err := c.http.Do(request)
  135. if err != nil {
  136. return err
  137. }
  138. defer respObj.Body.Close()
  139. respBody, err := io.ReadAll(respObj.Body)
  140. if err != nil {
  141. return err
  142. }
  143. if err := checkError(respObj, respBody); err != nil {
  144. return err
  145. }
  146. if len(respBody) > 0 && respData != nil {
  147. if err := json.Unmarshal(respBody, respData); err != nil {
  148. return err
  149. }
  150. }
  151. return nil
  152. }
  153. const maxBufferSize = 512 * format.KiloByte
  154. func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
  155. var buf *bytes.Buffer
  156. if data != nil {
  157. bts, err := json.Marshal(data)
  158. if err != nil {
  159. return err
  160. }
  161. buf = bytes.NewBuffer(bts)
  162. }
  163. requestURL := c.base.JoinPath(path)
  164. request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
  165. if err != nil {
  166. return err
  167. }
  168. request.Header.Set("Content-Type", "application/json")
  169. request.Header.Set("Accept", "application/x-ndjson")
  170. request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
  171. response, err := c.http.Do(request)
  172. if err != nil {
  173. return err
  174. }
  175. defer response.Body.Close()
  176. scanner := bufio.NewScanner(response.Body)
  177. // increase the buffer size to avoid running out of space
  178. scanBuf := make([]byte, 0, maxBufferSize)
  179. scanner.Buffer(scanBuf, maxBufferSize)
  180. for scanner.Scan() {
  181. var errorResponse struct {
  182. Error string `json:"error,omitempty"`
  183. }
  184. bts := scanner.Bytes()
  185. if err := json.Unmarshal(bts, &errorResponse); err != nil {
  186. return fmt.Errorf("unmarshal: %w", err)
  187. }
  188. if errorResponse.Error != "" {
  189. return fmt.Errorf(errorResponse.Error)
  190. }
  191. if response.StatusCode >= http.StatusBadRequest {
  192. return StatusError{
  193. StatusCode: response.StatusCode,
  194. Status: response.Status,
  195. ErrorMessage: errorResponse.Error,
  196. }
  197. }
  198. if err := fn(bts); err != nil {
  199. return err
  200. }
  201. }
  202. return nil
  203. }
  204. // GenerateResponseFunc is a function that [Client.Generate] invokes every time
  205. // a response is received from the service. If this function returns an error,
  206. // [Client.Generate] will stop generating and return this error.
  207. type GenerateResponseFunc func(GenerateResponse) error
  208. // Generate generates a response for a given prompt. The req parameter should
  209. // be populated with prompt details. fn is called for each response (there may
  210. // be multiple responses, e.g. in case streaming is enabled).
  211. func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
  212. return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error {
  213. var resp GenerateResponse
  214. if err := json.Unmarshal(bts, &resp); err != nil {
  215. return err
  216. }
  217. return fn(resp)
  218. })
  219. }
  220. // ChatResponseFunc is a function that [Client.Chat] invokes every time
  221. // a response is received from the service. If this function returns an error,
  222. // [Client.Chat] will stop generating and return this error.
  223. type ChatResponseFunc func(ChatResponse) error
  224. // Chat generates the next message in a chat. [ChatRequest] may contain a
  225. // sequence of messages which can be used to maintain chat history with a model.
  226. // fn is called for each response (there may be multiple responses, e.g. if case
  227. // streaming is enabled).
  228. func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
  229. return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
  230. var resp ChatResponse
  231. if err := json.Unmarshal(bts, &resp); err != nil {
  232. return err
  233. }
  234. return fn(resp)
  235. })
  236. }
  237. // PullProgressFunc is a function that [Client.Pull] invokes every time there
  238. // is progress with a "pull" request sent to the service. If this function
  239. // returns an error, [Client.Pull] will stop the process and return this error.
  240. type PullProgressFunc func(ProgressResponse) error
  241. // Pull downloads a model from the ollama library. fn is called each time
  242. // progress is made on the request and can be used to display a progress bar,
  243. // etc.
  244. func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
  245. return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
  246. var resp ProgressResponse
  247. if err := json.Unmarshal(bts, &resp); err != nil {
  248. return err
  249. }
  250. return fn(resp)
  251. })
  252. }
  253. type PushProgressFunc func(ProgressResponse) error
  254. func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
  255. return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
  256. var resp ProgressResponse
  257. if err := json.Unmarshal(bts, &resp); err != nil {
  258. return err
  259. }
  260. return fn(resp)
  261. })
  262. }
  263. type CreateProgressFunc func(ProgressResponse) error
  264. func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
  265. return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
  266. var resp ProgressResponse
  267. if err := json.Unmarshal(bts, &resp); err != nil {
  268. return err
  269. }
  270. return fn(resp)
  271. })
  272. }
  273. func (c *Client) List(ctx context.Context) (*ListResponse, error) {
  274. var lr ListResponse
  275. if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil {
  276. return nil, err
  277. }
  278. return &lr, nil
  279. }
  280. func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
  281. if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
  282. return err
  283. }
  284. return nil
  285. }
  286. func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
  287. if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
  288. return err
  289. }
  290. return nil
  291. }
  292. func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) {
  293. var resp ShowResponse
  294. if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil {
  295. return nil, err
  296. }
  297. return &resp, nil
  298. }
  299. func (c *Client) Heartbeat(ctx context.Context) error {
  300. if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
  301. return err
  302. }
  303. return nil
  304. }
  305. func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
  306. var resp EmbeddingResponse
  307. if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
  308. return nil, err
  309. }
  310. return &resp, nil
  311. }
  312. func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
  313. return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
  314. }
  315. func (c *Client) Version(ctx context.Context) (string, error) {
  316. var version struct {
  317. Version string `json:"version"`
  318. }
  319. if err := c.do(ctx, http.MethodGet, "/api/version", nil, &version); err != nil {
  320. return "", err
  321. }
  322. return version.Version, nil
  323. }