|
@@ -69,23 +69,25 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
|
|
return opts, nil
|
|
return opts, nil
|
|
}
|
|
}
|
|
|
|
|
|
-func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) {
|
|
|
|
|
|
+// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
|
|
|
+// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
|
|
|
+func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
|
if name == "" {
|
|
if name == "" {
|
|
- return nil, fmt.Errorf("model %w", errRequired)
|
|
|
|
|
|
+ return nil, nil, nil, fmt.Errorf("model %w", errRequired)
|
|
}
|
|
}
|
|
|
|
|
|
model, err := GetModel(name)
|
|
model, err := GetModel(name)
|
|
if err != nil {
|
|
if err != nil {
|
|
- return nil, err
|
|
|
|
|
|
+ return nil, nil, nil, err
|
|
}
|
|
}
|
|
|
|
|
|
if err := model.CheckCapabilities(caps...); err != nil {
|
|
if err := model.CheckCapabilities(caps...); err != nil {
|
|
- return nil, fmt.Errorf("%s %w", name, err)
|
|
|
|
|
|
+ return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
|
}
|
|
}
|
|
|
|
|
|
opts, err := modelOptions(model, requestOpts)
|
|
opts, err := modelOptions(model, requestOpts)
|
|
if err != nil {
|
|
if err != nil {
|
|
- return nil, err
|
|
|
|
|
|
+ return nil, nil, nil, err
|
|
}
|
|
}
|
|
|
|
|
|
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
|
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
|
@@ -93,10 +95,10 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
|
|
select {
|
|
select {
|
|
case runner = <-runnerCh:
|
|
case runner = <-runnerCh:
|
|
case err = <-errCh:
|
|
case err = <-errCh:
|
|
- return nil, err
|
|
|
|
|
|
+ return nil, nil, nil, err
|
|
}
|
|
}
|
|
|
|
|
|
- return runner, nil
|
|
|
|
|
|
+ return runner.llama, model, &opts, nil
|
|
}
|
|
}
|
|
|
|
|
|
func (s *Server) GenerateHandler(c *gin.Context) {
|
|
func (s *Server) GenerateHandler(c *gin.Context) {
|
|
@@ -118,7 +120,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|
}
|
|
}
|
|
|
|
|
|
caps := []Capability{CapabilityCompletion}
|
|
caps := []Capability{CapabilityCompletion}
|
|
- r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
|
|
|
|
|
+ r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
|
if errors.Is(err, errCapabilityCompletion) {
|
|
if errors.Is(err, errCapabilityCompletion) {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
|
return
|
|
return
|
|
@@ -147,8 +149,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|
var msgs []api.Message
|
|
var msgs []api.Message
|
|
if req.System != "" {
|
|
if req.System != "" {
|
|
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
|
|
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
|
|
- } else if r.model.System != "" {
|
|
|
|
- msgs = append(msgs, api.Message{Role: "system", Content: r.model.System})
|
|
|
|
|
|
+ } else if m.System != "" {
|
|
|
|
+ msgs = append(msgs, api.Message{Role: "system", Content: m.System})
|
|
}
|
|
}
|
|
|
|
|
|
for _, i := range images {
|
|
for _, i := range images {
|
|
@@ -157,7 +159,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
|
|
|
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
|
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
|
|
|
|
|
- tmpl := r.model.Template
|
|
|
|
|
|
+ tmpl := m.Template
|
|
if req.Template != "" {
|
|
if req.Template != "" {
|
|
tmpl, err = template.Parse(req.Template)
|
|
tmpl, err = template.Parse(req.Template)
|
|
if err != nil {
|
|
if err != nil {
|
|
@@ -168,7 +170,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
|
|
|
var b bytes.Buffer
|
|
var b bytes.Buffer
|
|
if req.Context != nil {
|
|
if req.Context != nil {
|
|
- s, err := r.llama.Detokenize(c.Request.Context(), req.Context)
|
|
|
|
|
|
+ s, err := r.Detokenize(c.Request.Context(), req.Context)
|
|
if err != nil {
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
return
|
|
@@ -190,11 +192,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|
ch := make(chan any)
|
|
ch := make(chan any)
|
|
go func() {
|
|
go func() {
|
|
defer close(ch)
|
|
defer close(ch)
|
|
- if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
|
|
|
|
|
+ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
|
Prompt: prompt,
|
|
Prompt: prompt,
|
|
Images: images,
|
|
Images: images,
|
|
Format: req.Format,
|
|
Format: req.Format,
|
|
- Options: *r.Options,
|
|
|
|
|
|
+ Options: opts,
|
|
}, func(r llm.CompletionResponse) {
|
|
}, func(r llm.CompletionResponse) {
|
|
ch <- api.GenerateResponse{
|
|
ch <- api.GenerateResponse{
|
|
Model: req.Model,
|
|
Model: req.Model,
|
|
@@ -254,7 +256,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
|
|
|
|
|
+ r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
|
if err != nil {
|
|
if err != nil {
|
|
handleScheduleError(c, req.Model, err)
|
|
handleScheduleError(c, req.Model, err)
|
|
return
|
|
return
|
|
@@ -266,7 +268,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- embedding, err := r.llama.Embedding(c.Request.Context(), req.Prompt)
|
|
|
|
|
|
+ embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
|
if err != nil {
|
|
if err != nil {
|
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
|
@@ -1130,7 +1132,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|
}
|
|
}
|
|
|
|
|
|
caps := []Capability{CapabilityCompletion}
|
|
caps := []Capability{CapabilityCompletion}
|
|
- r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
|
|
|
|
|
+ r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
|
if errors.Is(err, errCapabilityCompletion) {
|
|
if errors.Is(err, errCapabilityCompletion) {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
|
return
|
|
return
|
|
@@ -1150,7 +1152,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- prompt, images, err := chatPrompt(c.Request.Context(), r.model, r.llama.Tokenize, r.Options, req.Messages)
|
|
|
|
|
|
+ prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
|
|
if err != nil {
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
return
|
|
@@ -1161,11 +1163,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|
ch := make(chan any)
|
|
ch := make(chan any)
|
|
go func() {
|
|
go func() {
|
|
defer close(ch)
|
|
defer close(ch)
|
|
- if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
|
|
|
|
|
+ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
|
Prompt: prompt,
|
|
Prompt: prompt,
|
|
Images: images,
|
|
Images: images,
|
|
Format: req.Format,
|
|
Format: req.Format,
|
|
- Options: *r.Options,
|
|
|
|
|
|
+ Options: opts,
|
|
}, func(r llm.CompletionResponse) {
|
|
}, func(r llm.CompletionResponse) {
|
|
ch <- api.ChatResponse{
|
|
ch <- api.ChatResponse{
|
|
Model: req.Model,
|
|
Model: req.Model,
|