|
@@ -674,21 +674,6 @@ type CompletionResponse struct {
|
|
|
}
|
|
|
|
|
|
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
|
|
- if err := s.sem.Acquire(ctx, 1); err != nil {
|
|
|
- if errors.Is(err, context.Canceled) {
|
|
|
- slog.Info("aborting completion request due to client closing the connection")
|
|
|
- } else {
|
|
|
- slog.Error("Failed to acquire semaphore", "error", err)
|
|
|
- }
|
|
|
- return err
|
|
|
- }
|
|
|
- defer s.sem.Release(1)
|
|
|
-
|
|
|
- // put an upper limit on num_predict to avoid the model running on forever
|
|
|
- if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
|
|
|
- req.Options.NumPredict = 10 * s.options.NumCtx
|
|
|
- }
|
|
|
-
|
|
|
request := map[string]any{
|
|
|
"prompt": req.Prompt,
|
|
|
"stream": true,
|
|
@@ -714,6 +699,39 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|
|
"cache_prompt": true,
|
|
|
}
|
|
|
|
|
|
+ if len(req.Format) > 0 {
|
|
|
+ switch {
|
|
|
+ case bytes.Equal(req.Format, []byte(`""`)):
|
|
|
+ // fallthrough
|
|
|
+ case bytes.Equal(req.Format, []byte(`"json"`)):
|
|
|
+ request["grammar"] = grammarJSON
|
|
|
+ case bytes.HasPrefix(req.Format, []byte("{")):
|
|
|
+ // User provided a JSON schema
|
|
|
+ g := llama.SchemaToGrammar(req.Format)
|
|
|
+ if g == nil {
|
|
|
+ return fmt.Errorf("invalid JSON schema in format")
|
|
|
+ }
|
|
|
+ request["grammar"] = string(g)
|
|
|
+ default:
|
|
|
+ return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema", req.Format)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := s.sem.Acquire(ctx, 1); err != nil {
|
|
|
+ if errors.Is(err, context.Canceled) {
|
|
|
+ slog.Info("aborting completion request due to client closing the connection")
|
|
|
+ } else {
|
|
|
+ slog.Error("Failed to acquire semaphore", "error", err)
|
|
|
+ }
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ defer s.sem.Release(1)
|
|
|
+
|
|
|
+ // put an upper limit on num_predict to avoid the model running on forever
|
|
|
+ if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
|
|
|
+ req.Options.NumPredict = 10 * s.options.NumCtx
|
|
|
+ }
|
|
|
+
|
|
|
// Make sure the server is ready
|
|
|
status, err := s.getServerStatusRetry(ctx)
|
|
|
if err != nil {
|
|
@@ -722,16 +740,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|
|
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
|
|
}
|
|
|
|
|
|
- if bytes.Equal(req.Format, []byte(`"json"`)) {
|
|
|
- request["grammar"] = grammarJSON
|
|
|
- } else if bytes.HasPrefix(req.Format, []byte("{")) {
|
|
|
- g := llama.SchemaToGrammar(req.Format)
|
|
|
- if g == nil {
|
|
|
- return fmt.Errorf("invalid JSON schema in format")
|
|
|
- }
|
|
|
- request["grammar"] = string(g)
|
|
|
- }
|
|
|
-
|
|
|
// Handling JSON marshaling with special characters unescaped.
|
|
|
buffer := &bytes.Buffer{}
|
|
|
enc := json.NewEncoder(buffer)
|