瀏覽代碼

allow ollama.com to call inference and info endpoints

- By default allow ollama.com to call inference and info endpoints this can be overridden by setting an OLLAMA_HOSTS env var
Bruce MacDonald 8 月之前
父節點
當前提交
f84cc9939c
共有 2 個文件被更改,包括 65 次插入39 次删除
  1. 5 0
      envconfig/config.go
  2. 60 39
      server/routes.go

+ 5 - 0
envconfig/config.go

@@ -57,6 +57,11 @@ func Host() *url.URL {
 	}
 }
 
+// HasCustomOrigins returns true if custom origins are configured. Origins can be configured via the OLLAMA_ORIGINS environment variable.
+func HasCustomOrigins() bool {
+	return Var("OLLAMA_ORIGINS") != ""
+}
+
 // Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
 func Origins() (origins []string) {
 	if s := Var("OLLAMA_ORIGINS"); s != "" {

+ 60 - 39
server/routes.go

@@ -1051,52 +1051,73 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
 }
 
 func (s *Server) GenerateRoutes() http.Handler {
-	config := cors.DefaultConfig()
-	config.AllowWildcard = true
-	config.AllowBrowserExtensions = true
-	config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
+	baseConfig := cors.DefaultConfig()
+	baseConfig.AllowWildcard = true
+	baseConfig.AllowBrowserExtensions = true
+	baseConfig.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
 	openAIProperties := []string{"lang", "package-version", "os", "arch", "runtime", "runtime-version", "async"}
 	for _, prop := range openAIProperties {
-		config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
+		baseConfig.AllowHeaders = append(baseConfig.AllowHeaders, "x-stainless-"+prop)
 	}
-	config.AllowOrigins = envconfig.Origins()
 
 	r := gin.Default()
-	r.Use(
-		cors.New(config),
-		allowedHostsMiddleware(s.addr),
-	)
-
-	r.POST("/api/pull", s.PullModelHandler)
-	r.POST("/api/generate", s.GenerateHandler)
-	r.POST("/api/chat", s.ChatHandler)
-	r.POST("/api/embed", s.EmbedHandler)
-	r.POST("/api/embeddings", s.EmbeddingsHandler)
-	r.POST("/api/create", s.CreateModelHandler)
-	r.POST("/api/push", s.PushModelHandler)
-	r.POST("/api/copy", s.CopyModelHandler)
-	r.DELETE("/api/delete", s.DeleteModelHandler)
-	r.POST("/api/show", s.ShowModelHandler)
-	r.POST("/api/blobs/:digest", s.CreateBlobHandler)
-	r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
-	r.GET("/api/ps", s.ProcessHandler)
-
-	// Compatibility endpoints
-	r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
-	r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
-	r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
-	r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
-	r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
-
-	for _, method := range []string{http.MethodGet, http.MethodHead} {
-		r.Handle(method, "/", func(c *gin.Context) {
-			c.String(http.StatusOK, "Ollama is running")
-		})
 
-		r.Handle(method, "/api/tags", s.ListModelsHandler)
-		r.Handle(method, "/api/version", func(c *gin.Context) {
-			c.JSON(http.StatusOK, gin.H{"version": version.Version})
+	openConfig := baseConfig
+	openConfig.AllowOrigins = envconfig.Origins()
+	if !envconfig.HasCustomOrigins() {
+		openConfig.AllowOrigins = append(openConfig.AllowOrigins, "https://ollama.com")
+		openConfig.AllowOrigins = append(openConfig.AllowOrigins, "https://www.ollama.com")
+	}
+
+	openBaseGroup := r.Group("/")
+	openBaseGroup.Use(cors.New(openConfig), allowedHostsMiddleware(s.addr))
+	{
+		openBaseGroup.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
+		openBaseGroup.HEAD("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
+	}
+
+	openAPIGroup := r.Group("/api")
+	openAPIGroup.Use(cors.New(openConfig), allowedHostsMiddleware(s.addr))
+	{
+		openAPIGroup.OPTIONS("/*path", func(c *gin.Context) {
+			c.Status(http.StatusOK)
 		})
+		openAPIGroup.POST("/pull", s.PullModelHandler)
+		openAPIGroup.POST("/generate", s.GenerateHandler)
+		openAPIGroup.POST("/chat", s.ChatHandler)
+		openAPIGroup.POST("/embed", s.EmbedHandler)
+		openAPIGroup.POST("/embeddings", s.EmbeddingsHandler)
+		openAPIGroup.POST("/show", s.ShowModelHandler)
+		openAPIGroup.GET("/tags", s.ListModelsHandler)
+		openAPIGroup.HEAD("/tags", s.ListModelsHandler)
+		openAPIGroup.GET("/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
+		openAPIGroup.HEAD("/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
+	}
+
+	restrictedConfig := baseConfig
+	restrictedConfig.AllowOrigins = envconfig.Origins()
+	restrictedAPIGroup := r.Group("/api")
+	restrictedAPIGroup.Use(cors.New(restrictedConfig), allowedHostsMiddleware(s.addr))
+	{
+		restrictedAPIGroup.POST("/create", s.CreateModelHandler)
+		restrictedAPIGroup.POST("/push", s.PushModelHandler)
+		restrictedAPIGroup.POST("/copy", s.CopyModelHandler)
+		restrictedAPIGroup.DELETE("/delete", s.DeleteModelHandler)
+		restrictedAPIGroup.POST("/blobs/:digest", s.CreateBlobHandler)
+		restrictedAPIGroup.HEAD("/blobs/:digest", s.HeadBlobHandler)
+		restrictedAPIGroup.GET("/ps", s.ProcessHandler)
+	}
+
+	openAIConfig := baseConfig
+	openAIConfig.AllowOrigins = envconfig.Origins()
+	openAIGroup := r.Group("/v1")
+	openAIGroup.Use(cors.New(openAIConfig), allowedHostsMiddleware(s.addr))
+	{
+		openAIGroup.POST("/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
+		openAIGroup.POST("/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
+		openAIGroup.POST("/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
+		openAIGroup.GET("/models", openai.ListMiddleware(), s.ListModelsHandler)
+		openAIGroup.GET("/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
 	}
 
 	return r