|
@@ -3,12 +3,14 @@ package server
|
|
|
import (
|
|
|
"embed"
|
|
|
"encoding/json"
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"log"
|
|
|
"math"
|
|
|
"net"
|
|
|
"net/http"
|
|
|
+ "os"
|
|
|
"path"
|
|
|
"runtime"
|
|
|
"strings"
|
|
@@ -25,6 +27,15 @@ import (
|
|
|
var templatesFS embed.FS
|
|
|
var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
|
|
|
|
|
|
+func cacheDir() string {
|
|
|
+ home, err := os.UserHomeDir()
|
|
|
+ if err != nil {
|
|
|
+ panic(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ return path.Join(home, ".ollama")
|
|
|
+}
|
|
|
+
|
|
|
func generate(c *gin.Context) {
|
|
|
var req api.GenerateRequest
|
|
|
req.ModelOptions = api.DefaultModelOptions
|
|
@@ -37,9 +48,16 @@ func generate(c *gin.Context) {
|
|
|
if remoteModel, _ := getRemote(req.Model); remoteModel != nil {
|
|
|
req.Model = remoteModel.FullName()
|
|
|
}
|
|
|
+ if _, err := os.Stat(req.Model); err != nil {
|
|
|
+ if !errors.Is(err, os.ErrNotExist) {
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+ req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
|
|
|
+ }
|
|
|
|
|
|
modelOpts := getModelOpts(req)
|
|
|
- modelOpts.NGPULayers = 1 // hard-code this for now
|
|
|
+ modelOpts.NGPULayers = 1 // hard-code this for now
|
|
|
|
|
|
model, err := llama.New(req.Model, modelOpts)
|
|
|
if err != nil {
|
|
@@ -118,6 +136,17 @@ func Serve(ln net.Listener) error {
|
|
|
go func() {
|
|
|
defer close(progressCh)
|
|
|
if err := pull(req.Model, progressCh); err != nil {
|
|
|
+ var opError *net.OpError
|
|
|
+ if errors.As(err, &opError) {
|
|
|
+ result := api.PullProgress{
|
|
|
+ Error: api.Error{
|
|
|
+ Code: http.StatusBadGateway,
|
|
|
+ Message: "failed to get models from directory",
|
|
|
+ },
|
|
|
+ }
|
|
|
+ c.JSON(http.StatusBadGateway, result)
|
|
|
+ return
|
|
|
+ }
|
|
|
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
|
|
return
|
|
|
}
|