Browse Source

server: group routes by category and purpose (#9270)

The route assembly in Handler lacked clear organization making it
difficult scan for routes and their relationships to each other. This
commit aims to fix that by reordering the assembly of routes to group
them by category and purpose.

Also, be more specific about what "config" refers to (it is about CORS
if you were wondering... I was.)
Blake Mizerany 2 months ago
parent
commit
68bac1e0a6
3 changed files with 50 additions and 32 deletions
  1. 3 3
      envconfig/config.go
  2. 1 1
      envconfig/config_test.go
  3. 46 28
      server/routes.go

+ 3 - 3
envconfig/config.go

@@ -53,8 +53,8 @@ func Host() *url.URL {
 	}
 }
 
-// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
-func Origins() (origins []string) {
+// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable.
+func AllowedOrigins() (origins []string) {
 	if s := Var("OLLAMA_ORIGINS"); s != "" {
 		origins = strings.Split(s, ",")
 	}
@@ -249,7 +249,7 @@ func AsMap() map[string]EnvVar {
 		"OLLAMA_NOHISTORY":         {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
 		"OLLAMA_NOPRUNE":           {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
 		"OLLAMA_NUM_PARALLEL":      {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
-		"OLLAMA_ORIGINS":           {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
+		"OLLAMA_ORIGINS":           {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
 		"OLLAMA_SCHED_SPREAD":      {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
 		"OLLAMA_MULTIUSER_CACHE":   {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
 		"OLLAMA_NEW_ENGINE":        {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},

+ 1 - 1
envconfig/config_test.go

@@ -134,7 +134,7 @@ func TestOrigins(t *testing.T) {
 		t.Run(tt.value, func(t *testing.T) {
 			t.Setenv("OLLAMA_ORIGINS", tt.value)
 
-			if diff := cmp.Diff(Origins(), tt.expect); diff != "" {
+			if diff := cmp.Diff(AllowedOrigins(), tt.expect); diff != "" {
 				t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff)
 			}
 		})

+ 46 - 28
server/routes.go

@@ -1127,54 +1127,72 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
 }
 
 func (s *Server) GenerateRoutes() http.Handler {
-	config := cors.DefaultConfig()
-	config.AllowWildcard = true
-	config.AllowBrowserExtensions = true
-	config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
-	openAIProperties := []string{"lang", "package-version", "os", "arch", "retry-count", "runtime", "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout"}
-	for _, prop := range openAIProperties {
-		config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
-	}
-	config.AllowOrigins = envconfig.Origins()
+	corsConfig := cors.DefaultConfig()
+	corsConfig.AllowWildcard = true
+	corsConfig.AllowBrowserExtensions = true
+	corsConfig.AllowHeaders = []string{
+		"Authorization",
+		"Content-Type",
+		"User-Agent",
+		"Accept",
+		"X-Requested-With",
+
+		// OpenAI compatibility headers
+		"x-stainless-lang",
+		"x-stainless-package-version",
+		"x-stainless-os",
+		"x-stainless-arch",
+		"x-stainless-retry-count",
+		"x-stainless-runtime",
+		"x-stainless-runtime-version",
+		"x-stainless-async",
+		"x-stainless-helper-method",
+		"x-stainless-poll-helper",
+		"x-stainless-custom-poll-interval",
+		"x-stainless-timeout",
+	}
+	corsConfig.AllowOrigins = envconfig.AllowedOrigins()
 
 	r := gin.Default()
 	r.Use(
-		cors.New(config),
+		cors.New(corsConfig),
 		allowedHostsMiddleware(s.addr),
 	)
 
+	// General
+	r.HEAD("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
+	r.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
+	r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
+	r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
+
+	// Local model cache management
 	r.POST("/api/pull", s.PullHandler)
-	r.POST("/api/generate", s.GenerateHandler)
-	r.POST("/api/chat", s.ChatHandler)
-	r.POST("/api/embed", s.EmbedHandler)
-	r.POST("/api/embeddings", s.EmbeddingsHandler)
-	r.POST("/api/create", s.CreateHandler)
 	r.POST("/api/push", s.PushHandler)
-	r.POST("/api/copy", s.CopyHandler)
 	r.DELETE("/api/delete", s.DeleteHandler)
+	r.HEAD("/api/tags", s.ListHandler)
+	r.GET("/api/tags", s.ListHandler)
 	r.POST("/api/show", s.ShowHandler)
+
+	// Create
+	r.POST("/api/create", s.CreateHandler)
 	r.POST("/api/blobs/:digest", s.CreateBlobHandler)
 	r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
+	r.POST("/api/copy", s.CopyHandler)
+
+	// Inference
 	r.GET("/api/ps", s.PsHandler)
+	r.POST("/api/generate", s.GenerateHandler)
+	r.POST("/api/chat", s.ChatHandler)
+	r.POST("/api/embed", s.EmbedHandler)
+	r.POST("/api/embeddings", s.EmbeddingsHandler)
 
-	// Compatibility endpoints
+	// Inference (OpenAI compatibility)
 	r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
 	r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
 	r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
 	r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
 	r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
 
-	for _, method := range []string{http.MethodGet, http.MethodHead} {
-		r.Handle(method, "/", func(c *gin.Context) {
-			c.String(http.StatusOK, "Ollama is running")
-		})
-
-		r.Handle(method, "/api/tags", s.ListHandler)
-		r.Handle(method, "/api/version", func(c *gin.Context) {
-			c.JSON(http.StatusOK, gin.H{"version": version.Version})
-		})
-	}
-
 	return r
 }