Browse Source

api: structured outputs - chat endpoint (#7900)

Adds structured outputs to chat endpoint
---------

Co-authored-by: Michael Yang <mxyng@pm.me>
Co-authored-by: Hieu Nguyen <hieunguyen1053@outlook.com>
Parth Sareen 4 months ago
parent
commit
630e7dc6ff
10 changed files with 180 additions and 25 deletions
  1. 1 1
      api/types.go
  2. 2 1
      cmd/cmd.go
  3. 33 0
      llama/llama.go
  4. 69 0
      llama/llama_test.go
  5. 27 2
      llama/sampling_ext.cpp
  6. 2 0
      llama/sampling_ext.h
  7. 17 10
      llm/server.go
  8. 21 4
      openai/openai.go
  9. 7 6
      openai/openai_test.go
  10. 1 1
      server/routes.go

+ 1 - 1
api/types.go

@@ -94,7 +94,7 @@ type ChatRequest struct {
 	Stream *bool `json:"stream,omitempty"`
 
 	// Format is the format to return the response in (e.g. "json").
-	Format string `json:"format"`
+	Format json.RawMessage `json:"format,omitempty"`
 
 	// KeepAlive controls how long the model will stay loaded into memory
 	// following the request.

+ 2 - 1
cmd/cmd.go

@@ -8,6 +8,7 @@ import (
 	"crypto/ed25519"
 	"crypto/rand"
 	"crypto/sha256"
+	"encoding/json"
 	"encoding/pem"
 	"errors"
 	"fmt"
@@ -1038,7 +1039,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
 	req := &api.ChatRequest{
 		Model:    opts.Model,
 		Messages: opts.Messages,
-		Format:   opts.Format,
+		Format:   json.RawMessage(opts.Format),
 		Options:  opts.Options,
 	}
 

+ 33 - 0
llama/llama.go

@@ -85,9 +85,12 @@ COMPILER inline get_compiler() {
 import "C"
 
 import (
+	"bytes"
 	_ "embed"
+	"encoding/json"
 	"errors"
 	"fmt"
+	"log/slog"
 	"runtime"
 	"runtime/cgo"
 	"slices"
@@ -699,3 +702,33 @@ func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
 func (s *SamplingContext) Accept(id int, applyGrammar bool) {
 	C.gpt_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
 }
+
+type JsonSchema struct {
+	Defs       map[string]any `json:"$defs,omitempty"`
+	Properties map[string]any `json:"properties,omitempty"`
+	Required   []string       `json:"required,omitempty"`
+	Title      string         `json:"title,omitempty"`
+	Type       string         `json:"type,omitempty"`
+}
+
+func (js JsonSchema) AsGrammar() string {
+	var b bytes.Buffer
+	if err := json.NewEncoder(&b).Encode(js); err != nil {
+		return ""
+	}
+
+	cStr := C.CString(b.String())
+	defer C.free(unsafe.Pointer(cStr))
+
+	// Allocate buffer for grammar output with reasonable size
+	const maxLen = 32768 // 32KB
+	buf := make([]byte, maxLen)
+
+	// Call C function to convert schema to grammar
+	length := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
+	if length == 0 {
+		slog.Warn("unable to convert schema to grammar")
+	}
+
+	return string(buf[:length])
+}

+ 69 - 0
llama/llama_test.go

@@ -1 +1,70 @@
 package llama
+
+import (
+	"strings"
+	"testing"
+
+	"github.com/google/go-cmp/cmp"
+)
+
+func TestJsonSchema(t *testing.T) {
+	testCases := []struct {
+		name     string
+		schema   JsonSchema
+		expected string
+	}{
+		{
+			name: "empty schema",
+			schema: JsonSchema{
+				Type: "object",
+			},
+			expected: `array ::= "[" space ( value ("," space value)* )? "]" space
+boolean ::= ("true" | "false") space
+char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+decimal-part ::= [0-9]{1,16}
+integral-part ::= [0] | [1-9] [0-9]{0,15}
+null ::= "null" space
+number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
+object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
+root ::= object
+space ::= | " " | "\n" [ \t]{0,20}
+string ::= "\"" char* "\"" space
+value ::= object | array | string | number | boolean | null`,
+		},
+		{
+			name: "invalid schema with circular reference",
+			schema: JsonSchema{
+				Type: "object",
+				Properties: map[string]any{
+					"self": map[string]any{
+						"$ref": "#", // Self reference
+					},
+				},
+			},
+			expected: "", // Should return empty string for invalid schema
+		},
+		{
+			name: "schema with invalid type",
+			schema: JsonSchema{
+				Type: "invalid_type", // Invalid type
+				Properties: map[string]any{
+					"foo": map[string]any{
+						"type": "string",
+					},
+				},
+			},
+			expected: "", // Should return empty string for invalid schema
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			result := tc.schema.AsGrammar()
+			if !strings.EqualFold(strings.TrimSpace(result), strings.TrimSpace(tc.expected)) {
+				if diff := cmp.Diff(tc.expected, result); diff != "" {
+					t.Fatalf("grammar mismatch (-want +got):\n%s", diff)
+				}
+			}
+		})
+	}
+}

+ 27 - 2
llama/sampling_ext.cpp

@@ -1,11 +1,13 @@
 // TODO: this is a temporary wrapper to allow calling C++ code from CGo
 #include "sampling.h"
 #include "sampling_ext.h"
+#include "json-schema-to-grammar.h"
 
 struct gpt_sampler *gpt_sampler_cinit(
     const struct llama_model *model, struct gpt_sampler_cparams *params)
 {
-    try {
+    try
+    {
         gpt_sampler_params sparams;
         sparams.top_k = params->top_k;
         sparams.top_p = params->top_p;
@@ -24,7 +26,9 @@ struct gpt_sampler *gpt_sampler_cinit(
         sparams.seed = params->seed;
         sparams.grammar = params->grammar;
         return gpt_sampler_init(model, sparams);
-    } catch (const std::exception & err) {
+    }
+    catch (const std::exception &err)
+    {
         return nullptr;
     }
 }
@@ -54,3 +58,24 @@ void gpt_sampler_caccept(
 {
     gpt_sampler_accept(sampler, id, apply_grammar);
 }
+
+int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
+{
+    try
+    {
+        nlohmann::json schema = nlohmann::json::parse(json_schema);
+        std::string grammar_str = json_schema_to_grammar(schema);
+        size_t len = grammar_str.length();
+        if (len >= max_len)
+        {
+            len = max_len - 1;
+        }
+        strncpy(grammar, grammar_str.c_str(), len);
+        return len;
+    }
+    catch (const std::exception &e)
+    {
+        strncpy(grammar, "", max_len - 1);
+        return 0;
+    }
+}

+ 2 - 0
llama/sampling_ext.h

@@ -47,6 +47,8 @@ extern "C"
         llama_token id,
         bool apply_grammar);
 
+    int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
+
 #ifdef __cplusplus
 }
 #endif

+ 17 - 10
llm/server.go

@@ -634,27 +634,22 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
 const jsonGrammar = `
 root   ::= object
 value  ::= object | array | string | number | ("true" | "false" | "null") ws
-
 object ::=
   "{" ws (
             string ":" ws value
     ("," ws string ":" ws value)*
   )? "}" ws
-
 array  ::=
   "[" ws (
             value
     ("," ws value)*
   )? "]" ws
-
 string ::=
   "\"" (
     [^"\\\x7F\x00-\x1F] |
     "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
   )* "\"" ws
-
 number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
-
 # Optional space: by convention, applied in this grammar after literal chars when allowed
 ws ::= ([ \t\n] ws)?
 `
@@ -684,7 +679,7 @@ type completion struct {
 
 type CompletionRequest struct {
 	Prompt  string
-	Format  string
+	Format  json.RawMessage
 	Images  []ImageData
 	Options *api.Options
 }
@@ -749,10 +744,22 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 		return fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 
-	if req.Format == "json" {
-		request["grammar"] = jsonGrammar
-		if !strings.Contains(strings.ToLower(req.Prompt), "json") {
-			slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
+	// TODO (parthsareen): Move conversion to grammar with sampling logic
+	// API should do error handling for invalid formats
+	if req.Format != nil {
+		if strings.ToLower(strings.TrimSpace(string(req.Format))) == `"json"` {
+			request["grammar"] = jsonGrammar
+			if !strings.Contains(strings.ToLower(req.Prompt), "json") {
+				slog.Warn("prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
+			}
+		} else if schema, err := func() (llama.JsonSchema, error) {
+			var schema llama.JsonSchema
+			err := json.Unmarshal(req.Format, &schema)
+			return schema, err
+		}(); err == nil {
+			request["grammar"] = schema.AsGrammar()
+		} else {
+			slog.Warn(`format is neither a schema or "json"`, "format", req.Format)
 		}
 	}
 

+ 21 - 4
openai/openai.go

@@ -62,7 +62,12 @@ type Usage struct {
 }
 
 type ResponseFormat struct {
-	Type string `json:"type"`
+	Type       string      `json:"type"`
+	JsonSchema *JsonSchema `json:"json_schema,omitempty"`
+}
+
+type JsonSchema struct {
+	Schema map[string]any `json:"schema"`
 }
 
 type EmbedRequest struct {
@@ -482,9 +487,21 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
 		options["top_p"] = 1.0
 	}
 
-	var format string
-	if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
-		format = "json"
+	var format json.RawMessage
+	if r.ResponseFormat != nil {
+		switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
+		// Support the old "json_object" type for OpenAI compatibility
+		case "json_object":
+			format = json.RawMessage(`"json"`)
+		case "json_schema":
+			if r.ResponseFormat.JsonSchema != nil {
+				schema, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema)
+				if err != nil {
+					return nil, fmt.Errorf("failed to marshal json schema: %w", err)
+				}
+				format = schema
+			}
+		}
 	}
 
 	return &api.ChatRequest{

+ 7 - 6
openai/openai_test.go

@@ -13,6 +13,7 @@ import (
 	"time"
 
 	"github.com/gin-gonic/gin"
+	"github.com/google/go-cmp/cmp"
 
 	"github.com/ollama/ollama/api"
 )
@@ -107,7 +108,7 @@ func TestChatMiddleware(t *testing.T) {
 					"presence_penalty":  5.0,
 					"top_p":             6.0,
 				},
-				Format: "json",
+				Format: json.RawMessage(`"json"`),
 				Stream: &True,
 			},
 		},
@@ -316,13 +317,13 @@ func TestChatMiddleware(t *testing.T) {
 				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
 					t.Fatal(err)
 				}
+				return
 			}
-			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
-				t.Fatal("requests did not match")
+			if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
+				t.Fatalf("requests did not match: %+v", diff)
 			}
-
-			if !reflect.DeepEqual(tc.err, errResp) {
-				t.Fatal("errors did not match")
+			if diff := cmp.Diff(tc.err, errResp); diff != "" {
+				t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
 			}
 		})
 	}

+ 1 - 1
server/routes.go

@@ -278,7 +278,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
 			Prompt:  prompt,
 			Images:  images,
-			Format:  req.Format,
+			Format:  json.RawMessage(req.Format),
 			Options: opts,
 		}, func(cr llm.CompletionResponse) {
 			res := api.GenerateResponse{