|
@@ -26,12 +26,9 @@ var templatesFS embed.FS
|
|
|
var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
|
|
|
|
|
|
func generate(c *gin.Context) {
|
|
|
- // TODO: these should be request parameters
|
|
|
- gpulayers := 1
|
|
|
- tokens := 512
|
|
|
- threads := runtime.NumCPU()
|
|
|
-
|
|
|
var req api.GenerateRequest
|
|
|
+ req.ModelOptions = api.DefaultModelOptions
|
|
|
+ req.PredictOptions = api.DefaultPredictOptions
|
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
|
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
|
|
return
|
|
@@ -41,7 +38,10 @@ func generate(c *gin.Context) {
|
|
|
req.Model = remoteModel.FullName()
|
|
|
}
|
|
|
|
|
|
- model, err := llama.New(req.Model, llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(gpulayers))
|
|
|
+ modelOpts := getModelOpts(req)
|
|
|
+ modelOpts.NGPULayers = 1 // hard-code this for now
|
|
|
+
|
|
|
+ model, err := llama.New(req.Model, modelOpts)
|
|
|
if err != nil {
|
|
|
fmt.Println("Loading the model failed:", err.Error())
|
|
|
return
|
|
@@ -65,13 +65,16 @@ func generate(c *gin.Context) {
|
|
|
}
|
|
|
|
|
|
ch := make(chan string)
|
|
|
+ model.SetTokenCallback(func(token string) bool {
|
|
|
+ ch <- token
|
|
|
+ return true
|
|
|
+ })
|
|
|
+
|
|
|
+ predictOpts := getPredictOpts(req)
|
|
|
|
|
|
go func() {
|
|
|
defer close(ch)
|
|
|
- _, err := model.Predict(req.Prompt, llama.Debug, llama.SetTokenCallback(func(token string) bool {
|
|
|
- ch <- token
|
|
|
- return true
|
|
|
- }), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"))
|
|
|
+ _, err := model.Predict(req.Prompt, predictOpts)
|
|
|
if err != nil {
|
|
|
panic(err)
|
|
|
}
|
|
@@ -161,3 +164,53 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i
|
|
|
|
|
|
return
|
|
|
}
|
|
|
+
|
|
|
+func getModelOpts(req api.GenerateRequest) llama.ModelOptions {
|
|
|
+ var opts llama.ModelOptions
|
|
|
+ opts.ContextSize = req.ModelOptions.ContextSize
|
|
|
+ opts.Seed = req.ModelOptions.Seed
|
|
|
+ opts.F16Memory = req.ModelOptions.F16Memory
|
|
|
+ opts.MLock = req.ModelOptions.MLock
|
|
|
+ opts.Embeddings = req.ModelOptions.Embeddings
|
|
|
+ opts.MMap = req.ModelOptions.MMap
|
|
|
+ opts.LowVRAM = req.ModelOptions.LowVRAM
|
|
|
+
|
|
|
+ opts.NBatch = req.ModelOptions.NBatch
|
|
|
+ opts.VocabOnly = req.ModelOptions.VocabOnly
|
|
|
+ opts.NUMA = req.ModelOptions.NUMA
|
|
|
+ opts.NGPULayers = req.ModelOptions.NGPULayers
|
|
|
+ opts.MainGPU = req.ModelOptions.MainGPU
|
|
|
+ opts.TensorSplit = req.ModelOptions.TensorSplit
|
|
|
+
|
|
|
+ return opts
|
|
|
+}
|
|
|
+
|
|
|
+func getPredictOpts(req api.GenerateRequest) llama.PredictOptions {
|
|
|
+ var opts llama.PredictOptions
|
|
|
+
|
|
|
+ if req.PredictOptions.Threads == -1 {
|
|
|
+ opts.Threads = runtime.NumCPU()
|
|
|
+ } else {
|
|
|
+ opts.Threads = req.PredictOptions.Threads
|
|
|
+ }
|
|
|
+
|
|
|
+ opts.Seed = req.PredictOptions.Seed
|
|
|
+ opts.Tokens = req.PredictOptions.Tokens
|
|
|
+ opts.Penalty = req.PredictOptions.Penalty
|
|
|
+ opts.Repeat = req.PredictOptions.Repeat
|
|
|
+ opts.Batch = req.PredictOptions.Batch
|
|
|
+ opts.NKeep = req.PredictOptions.NKeep
|
|
|
+ opts.TopK = req.PredictOptions.TopK
|
|
|
+ opts.TopP = req.PredictOptions.TopP
|
|
|
+ opts.TailFreeSamplingZ = req.PredictOptions.TailFreeSamplingZ
|
|
|
+ opts.TypicalP = req.PredictOptions.TypicalP
|
|
|
+ opts.Temperature = req.PredictOptions.Temperature
|
|
|
+ opts.FrequencyPenalty = req.PredictOptions.FrequencyPenalty
|
|
|
+ opts.PresencePenalty = req.PredictOptions.PresencePenalty
|
|
|
+ opts.Mirostat = req.PredictOptions.Mirostat
|
|
|
+ opts.MirostatTAU = req.PredictOptions.MirostatTAU
|
|
|
+ opts.MirostatETA = req.PredictOptions.MirostatETA
|
|
|
+ opts.MMap = req.PredictOptions.MMap
|
|
|
+
|
|
|
+ return opts
|
|
|
+}
|