123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223 |
- package server
- import (
- "encoding/json"
- "io"
- "log"
- "net"
- "net/http"
- "os"
- "path/filepath"
- "strings"
- "text/template"
- "time"
- "dario.cat/mergo"
- "github.com/gin-gonic/gin"
- "github.com/jmorganca/ollama/api"
- "github.com/jmorganca/ollama/llama"
- )
- func cacheDir() string {
- home, err := os.UserHomeDir()
- if err != nil {
- panic(err)
- }
- return filepath.Join(home, ".ollama")
- }
- func generate(c *gin.Context) {
- start := time.Now()
- var req api.GenerateRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
- return
- }
- model, err := GetModel(req.Model)
- if err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
- return
- }
- opts := api.DefaultOptions()
- if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
- if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
- templ, err := template.New("").Parse(model.Prompt)
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
- var sb strings.Builder
- if err = templ.Execute(&sb, req); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
- req.Prompt = sb.String()
- llm, err := llama.New(model.ModelPath, opts)
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
- defer llm.Close()
- ch := make(chan any)
- go func() {
- defer close(ch)
- llm.Predict(req.Context, req.Prompt, func(r api.GenerateResponse) {
- r.Model = req.Model
- r.CreatedAt = time.Now().UTC()
- if r.Done {
- r.TotalDuration = time.Since(start)
- }
- ch <- r
- })
- }()
- streamResponse(c, ch)
- }
- func pull(c *gin.Context) {
- var req api.PullRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
- return
- }
- ch := make(chan any)
- go func() {
- defer close(ch)
- fn := func(status, digest string, total, completed int, percent float64) {
- ch <- api.PullProgress{
- Status: status,
- Digest: digest,
- Total: total,
- Completed: completed,
- Percent: percent,
- }
- }
- if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
- }()
- streamResponse(c, ch)
- }
- func push(c *gin.Context) {
- var req api.PushRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
- return
- }
- ch := make(chan any)
- go func() {
- defer close(ch)
- fn := func(status, digest string, total, completed int, percent float64) {
- ch <- api.PushProgress{
- Status: status,
- Digest: digest,
- Total: total,
- Completed: completed,
- Percent: percent,
- }
- }
- if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- return
- }
- }()
- streamResponse(c, ch)
- }
- func create(c *gin.Context) {
- var req api.CreateRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
- return
- }
- // NOTE consider passing the entire Modelfile in the json instead of the path to it
- file, err := os.Open(req.Path)
- if err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
- return
- }
- defer file.Close()
- ch := make(chan any)
- go func() {
- defer close(ch)
- fn := func(status string) {
- ch <- api.CreateProgress{
- Status: status,
- }
- }
- if err := CreateModel(req.Name, file, fn); err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
- return
- }
- }()
- streamResponse(c, ch)
- }
- func Serve(ln net.Listener) error {
- r := gin.Default()
- r.GET("/", func(c *gin.Context) {
- c.String(http.StatusOK, "Ollama is running")
- })
- r.POST("/api/pull", pull)
- r.POST("/api/generate", generate)
- r.POST("/api/create", create)
- r.POST("/api/push", push)
- log.Printf("Listening on %s", ln.Addr())
- s := &http.Server{
- Handler: r,
- }
- return s.Serve(ln)
- }
- func streamResponse(c *gin.Context, ch chan any) {
- c.Stream(func(w io.Writer) bool {
- val, ok := <-ch
- if !ok {
- return false
- }
- bts, err := json.Marshal(val)
- if err != nil {
- return false
- }
- bts = append(bts, '\n')
- if _, err := w.Write(bts); err != nil {
- return false
- }
- return true
- })
- }
|