|
@@ -38,35 +38,17 @@ var loaded struct {
|
|
|
options api.Options
|
|
|
}
|
|
|
|
|
|
-func GenerateHandler(c *gin.Context) {
|
|
|
- loaded.mu.Lock()
|
|
|
- defer loaded.mu.Unlock()
|
|
|
-
|
|
|
- checkpointStart := time.Now()
|
|
|
-
|
|
|
- var req api.GenerateRequest
|
|
|
- if err := c.ShouldBindJSON(&req); err != nil {
|
|
|
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- model, err := GetModel(req.Model)
|
|
|
- if err != nil {
|
|
|
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
+// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
|
|
|
+func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
|
|
|
opts := api.DefaultOptions()
|
|
|
if err := opts.FromMap(model.Options); err != nil {
|
|
|
log.Printf("could not load model options: %v", err)
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
- if err := opts.FromMap(req.Options); err != nil {
|
|
|
+ if err := opts.FromMap(reqOpts); err != nil {
|
|
|
log.Printf("could not merge model options: %v", err)
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, opts) {
|
|
@@ -83,21 +65,18 @@ func GenerateHandler(c *gin.Context) {
|
|
|
|
|
|
llm, err := llama.New(model.ModelPath, opts)
|
|
|
if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
if opts.NumKeep < 0 {
|
|
|
promptWithSystem, err := model.Prompt(api.GenerateRequest{}, "")
|
|
|
if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}}, "")
|
|
|
if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
tokensWithSystem := llm.Encode(promptWithSystem)
|
|
@@ -110,9 +89,8 @@ func GenerateHandler(c *gin.Context) {
|
|
|
loaded.digest = model.Digest
|
|
|
loaded.options = opts
|
|
|
}
|
|
|
- sessionDuration := 5 * time.Minute
|
|
|
-
|
|
|
loaded.expireAt = time.Now().Add(sessionDuration)
|
|
|
+
|
|
|
if loaded.expireTimer == nil {
|
|
|
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
|
|
|
loaded.mu.Lock()
|
|
@@ -132,6 +110,32 @@ func GenerateHandler(c *gin.Context) {
|
|
|
})
|
|
|
}
|
|
|
loaded.expireTimer.Reset(sessionDuration)
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func GenerateHandler(c *gin.Context) {
|
|
|
+ loaded.mu.Lock()
|
|
|
+ defer loaded.mu.Unlock()
|
|
|
+
|
|
|
+ checkpointStart := time.Now()
|
|
|
+
|
|
|
+ var req api.GenerateRequest
|
|
|
+ if err := c.ShouldBindJSON(&req); err != nil {
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ model, err := GetModel(req.Model)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ sessionDuration := 5 * time.Minute
|
|
|
+ if err := load(model, req.Options, sessionDuration); err != nil {
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
checkpointLoaded := time.Now()
|
|
|
|
|
@@ -181,6 +185,44 @@ func GenerateHandler(c *gin.Context) {
|
|
|
streamResponse(c, ch)
|
|
|
}
|
|
|
|
|
|
+func EmbeddingHandler(c *gin.Context) {
|
|
|
+ loaded.mu.Lock()
|
|
|
+ defer loaded.mu.Unlock()
|
|
|
+
|
|
|
+ var req api.EmbeddingRequest
|
|
|
+ if err := c.ShouldBindJSON(&req); err != nil {
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ model, err := GetModel(req.Model)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if err := load(model, req.Options, 5*time.Minute); err != nil {
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if !loaded.options.EmbeddingOnly {
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ embedding, err := loaded.llm.Embedding(req.Prompt)
|
|
|
+ if err != nil {
|
|
|
+ log.Printf("embedding generation failed: %v", err)
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ resp := api.EmbeddingResponse{
|
|
|
+ Embedding: embedding,
|
|
|
+ }
|
|
|
+ c.JSON(http.StatusOK, resp)
|
|
|
+}
|
|
|
+
|
|
|
func PullModelHandler(c *gin.Context) {
|
|
|
var req api.PullRequest
|
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
@@ -381,6 +423,7 @@ func Serve(ln net.Listener, extraOrigins []string) error {
|
|
|
|
|
|
r.POST("/api/pull", PullModelHandler)
|
|
|
r.POST("/api/generate", GenerateHandler)
|
|
|
+ r.POST("/api/embeddings", EmbeddingHandler)
|
|
|
r.POST("/api/create", CreateModelHandler)
|
|
|
r.POST("/api/push", PushModelHandler)
|
|
|
r.POST("/api/copy", CopyModelHandler)
|