Browse Source

add capabilities

Michael Yang 10 months ago
parent
commit
a30915bde1
3 changed files with 26 additions and 10 deletions
  1. 18 2
      server/images.go
  2. 4 4
      server/routes.go
  3. 4 4
      template/template_test.go

+ 18 - 2
server/images.go

@@ -34,6 +34,10 @@ import (
 	"github.com/ollama/ollama/version"
 )
 
+type Capability string
+
+const CapabilityCompletion = Capability("completion")
+
 type registryOptions struct {
 	Insecure bool
 	Username string
@@ -58,8 +62,20 @@ type Model struct {
 	Template *template.Template
 }
 
-func (m *Model) IsEmbedding() bool {
-	return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
+func (m *Model) Has(caps ...Capability) bool {
+	for _, cap := range caps {
+		switch cap {
+		case CapabilityCompletion:
+			if slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") {
+				return false
+			}
+		default:
+			slog.Error("unknown capability", "capability", cap)
+			return false
+		}
+	}
+
+	return true
 }
 
 func (m *Model) String() string {

+ 4 - 4
server/routes.go

@@ -122,8 +122,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		return
 	}
 
-	if model.IsEmbedding() {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"})
+	if !model.Has(CapabilityCompletion) {
+		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
 		return
 	}
 
@@ -1308,8 +1308,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		return
 	}
 
-	if model.IsEmbedding() {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"})
+	if !model.Has(CapabilityCompletion) {
+		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
 		return
 	}
 

+ 4 - 4
template/template_test.go

@@ -61,8 +61,8 @@ func TestNamed(t *testing.T) {
 
 func TestParse(t *testing.T) {
 	cases := []struct {
-		template     string
-		capabilities []string
+		template string
+		vars     []string
 	}{
 		{"{{ .Prompt }}", []string{"prompt"}},
 		{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
@@ -81,8 +81,8 @@ func TestParse(t *testing.T) {
 			}
 
 			vars := tmpl.Vars()
-			if !slices.Equal(tt.capabilities, vars) {
-				t.Errorf("expected %v, got %v", tt.capabilities, vars)
+			if !slices.Equal(tt.vars, vars) {
+				t.Errorf("expected %v, got %v", tt.vars, vars)
 			}
 		})
 	}