1 Star 3 Fork 2

gin-ecosystem/gin-middleware

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
cors.go 6.20 KB
一键复制 编辑 原始数据 按行查看 历史
aesoper 提交于 2020-05-20 15:06 . 添加日志中间件
/**
* @Author: aesoper
* @Description:
* @File: cors
* @Version: 1.0.0
* @Date: 2020/5/19 19:23
*/
package gin_middleware
import (
"errors"
"gitee.com/gin-ecosystem/gin-middleware/consts"
"github.com/gin-gonic/gin"
"log"
"net/http"
"strconv"
"strings"
)
type CORSConfig struct {
Skipper Skipper
AllowAllOrigins bool `mapstructure:"allowAllOrigins"`
AllowOrigins []string `mapstructure:"allowOrigins"`
AllowOriginFunc func(origin string) bool
AllowMethods []string `mapstructure:"allowMethods"`
AllowHeaders []string `mapstructure:"allowHeaders"`
AllowCredentials bool `mapstructure:"allowCredentials"`
ExposeHeaders []string `mapstructure:"exposeHeaders"`
MaxAge int64 `mapstructure:"maxAge"`
}
type cors struct {
allowAllOrigins bool
allowCredentials bool
allowOriginFunc func(string) bool
allowOrigins []string
normalHeaders http.Header
preflightHeaders http.Header
}
type converter func(string) string
// 验证是用户定义的检查配置。
func (c *CORSConfig) Validate() error {
if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) {
return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed")
}
if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 {
return errors.New("conflict settings: all origins disabled")
}
for _, origin := range c.AllowOrigins {
if origin != "*" && !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
return errors.New("bad origin: origins must either be '*' or include http:// or https://")
}
}
return nil
}
var DefaultCORSConfig = CORSConfig{
Skipper: DefaultSkipper,
AllowAllOrigins: true,
AllowOrigins: []string{"*"},
AllowMethods: []string{
http.MethodGet,
http.MethodHead,
http.MethodPut,
http.MethodPatch,
http.MethodPost,
http.MethodDelete,
},
}
func DefaultCORS() gin.HandlerFunc {
return NewCORS(&DefaultCORSConfig)
}
func NewCORS(config *CORSConfig) gin.HandlerFunc {
if config.Skipper == nil {
config.Skipper = DefaultCORSConfig.Skipper
}
if len(config.AllowOrigins) == 0 {
config.AllowOrigins = DefaultCORSConfig.AllowOrigins
}
if len(config.AllowMethods) == 0 {
config.AllowMethods = DefaultCORSConfig.AllowMethods
}
cors := &cors{
allowOriginFunc: config.AllowOriginFunc,
allowAllOrigins: config.AllowAllOrigins,
allowCredentials: config.AllowCredentials,
allowOrigins: normalize(config.AllowOrigins),
normalHeaders: generateNormalHeaders(config),
preflightHeaders: generatePreflightHeaders(config),
}
return func(ctx *gin.Context) {
if config.Skipper(ctx) {
return
}
cors.applyCORS(ctx)
}
}
func (cors *cors) applyCORS(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
if len(origin) == 0 {
// request is not a CORS request
return
}
if !cors.validateOrigin(origin) {
log.Printf("The request's Origin header `%s` does not match any of allowed origins.", origin)
c.AbortWithStatus(http.StatusForbidden)
return
}
if c.Request.Method == "OPTIONS" {
cors.handlePreflight(c)
defer c.AbortWithStatus(200)
} else {
cors.handleNormal(c)
}
if !cors.allowAllOrigins {
header := c.Writer.Header()
header.Set(consts.HeaderAccessControlAllowOrigin, origin)
}
}
func (cors *cors) validateOrigin(origin string) bool {
if cors.allowAllOrigins {
return true
}
for _, value := range cors.allowOrigins {
if value == origin {
return true
}
}
if cors.allowOriginFunc != nil {
return cors.allowOriginFunc(origin)
}
return false
}
func (cors *cors) handlePreflight(c *gin.Context) {
header := c.Writer.Header()
for key, value := range cors.preflightHeaders {
header[key] = value
}
}
func (cors *cors) handleNormal(c *gin.Context) {
header := c.Writer.Header()
for key, value := range cors.normalHeaders {
header[key] = value
}
}
func generateNormalHeaders(c *CORSConfig) http.Header {
headers := make(http.Header)
if c.AllowCredentials {
headers.Set(consts.HeaderAccessControlAllowCredentials, "true")
}
// backport support for early browsers
if len(c.AllowMethods) > 0 {
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper)
value := strings.Join(allowMethods, ",")
headers.Set(consts.HeaderAccessControlAllowMethods, value)
}
if len(c.ExposeHeaders) > 0 {
exposeHeaders := convert(normalize(c.ExposeHeaders), http.CanonicalHeaderKey)
headers.Set(consts.HeaderAccessControlExposeHeaders, strings.Join(exposeHeaders, ","))
}
if c.AllowAllOrigins {
headers.Set(consts.HeaderAccessControlAllowOrigin, "*")
} else {
headers.Set(consts.HeaderVary, consts.HeaderOrigin)
}
return headers
}
func generatePreflightHeaders(c *CORSConfig) http.Header {
headers := make(http.Header)
if c.AllowCredentials {
headers.Set(consts.HeaderAccessControlAllowCredentials, "true")
}
if len(c.AllowMethods) > 0 {
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper)
value := strings.Join(allowMethods, ",")
headers.Set(consts.HeaderAccessControlAllowMethods, value)
}
if len(c.AllowHeaders) > 0 {
allowHeaders := convert(normalize(c.AllowHeaders), http.CanonicalHeaderKey)
value := strings.Join(allowHeaders, ",")
headers.Set(consts.HeaderAccessControlAllowHeaders, value)
}
if c.MaxAge > 0 {
value := strconv.FormatInt(c.MaxAge, 10)
headers.Set(consts.HeaderAccessControlMaxAge, value)
}
if c.AllowAllOrigins {
headers.Set(consts.HeaderAccessControlAllowOrigin, "*")
} else {
headers.Add(consts.HeaderVary, consts.HeaderOrigin)
headers.Add(consts.HeaderVary, consts.HeaderAccessControlRequestMethod)
headers.Add(consts.HeaderVary, consts.HeaderAccessControlRequestHeaders)
}
return headers
}
func normalize(values []string) []string {
if values == nil {
return nil
}
distinctMap := make(map[string]bool, len(values))
normalized := make([]string, 0, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
value = strings.ToLower(value)
if _, seen := distinctMap[value]; !seen {
normalized = append(normalized, value)
distinctMap[value] = true
}
}
return normalized
}
func convert(s []string, c converter) []string {
var out []string
for _, i := range s {
out = append(out, c(i))
}
return out
}
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/gin-ecosystem/gin-middleware.git
git@gitee.com:gin-ecosystem/gin-middleware.git
gin-ecosystem
gin-middleware
gin-middleware
master

搜索帮助