diff --git a/.gitattributes b/.gitattributes index 28981b84a..49b63e526 100644 --- a/.gitattributes +++ b/.gitattributes @@ -13,3 +13,8 @@ *.js text eol=lf *.json text eol=lf LICENSE text eol=lf + +# Exclude `website` and `cookbook` from GitHub's language statistics +# https://github.com/github/linguist#using-gitattributes +cookbook/* linguist-documentation +website/* linguist-documentation diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index 1a76adca7..82220c0a1 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -6,7 +6,7 @@ package main import ( - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" "net/http" "net/http/httptest" "testing" @@ -15,7 +15,7 @@ import ( func TestExample(t *testing.T) { e := echo.New() - e.GET("/", func(c *echo.Context) error { + e.GET("/", func(c echo.Context) error { return c.String(http.StatusOK, "Hello, World!") }) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index cfe4ca563..89fb3ce85 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -4,9 +4,11 @@ on: push: branches: - master + - v4 pull_request: branches: - master + - v4 workflow_dispatch: permissions: @@ -21,10 +23,10 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Code - uses: actions/checkout@v6 + uses: actions/checkout@v5 - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v6 + uses: actions/setup-go@v5 with: go-version: ${{ env.LATEST_GO_VERSION }} check-latest: true @@ -44,4 +46,3 @@ jobs: go version go install golang.org/x/vuln/cmd/govulncheck@latest govulncheck ./... - diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index e11a029fc..5f4d53f2c 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -4,9 +4,11 @@ on: push: branches: - master + - v4 pull_request: branches: - master + - v4 workflow_dispatch: permissions: @@ -23,17 +25,17 @@ jobs: os: [ubuntu-latest, macos-latest, windows-latest] # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy # Echo tests with last four major releases (unless there are pressing vulnerabilities) - # As we depend on `golang.org/x/` libraries which only support the last 2 Go releases, we could have situations when - # we derive from the last four major releases promise. + # As we depend on `golang.org/x/` libraries which only support last 2 Go releases we could have situations when + # we derive from last four major releases promise. go: ["1.25", "1.26"] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: - name: Checkout Code - uses: actions/checkout@v6 + uses: actions/checkout@v5 - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v6 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} @@ -42,7 +44,7 @@ jobs: - name: Upload coverage to Codecov if: success() && matrix.go == env.LATEST_GO_VERSION && matrix.os == 'ubuntu-latest' - uses: codecov/codecov-action@v6 + uses: codecov/codecov-action@v5 with: token: fail_ci_if_error: false @@ -53,18 +55,18 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Code (Previous) - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: ref: ${{ github.base_ref }} path: previous - name: Checkout Code (New) - uses: actions/checkout@v6 + uses: actions/checkout@v5 with: path: new - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v6 + uses: actions/setup-go@v5 with: go-version: ${{ env.LATEST_GO_VERSION }} diff --git a/.golangci.yaml b/.golangci.yaml deleted file mode 100644 index eab6aa2d3..000000000 --- a/.golangci.yaml +++ /dev/null @@ -1,24 +0,0 @@ -version: "2" -linters: -# default: none - enable: - - revive - disable: - - errcheck - settings: - revive: - rules: - - name: exported - exclusions: - generated: lax - presets: - - common-false-positives - - legacy - - std-error-handling - paths: - - _test\.go$ -formatters: - exclusions: - generated: lax - paths: - - _test\.go$ diff --git a/API_CHANGES_V5.md b/API_CHANGES_V5.md deleted file mode 100644 index d3ca81560..000000000 --- a/API_CHANGES_V5.md +++ /dev/null @@ -1,1178 +0,0 @@ -# Echo v5 Public API Changes - -**Comparison between `master` (v4.15.0) and `v5` (v5.0.0-alpha) branches** - -Generated: 2026-01-01 - ---- - -## Executive Summary (by authors) - -Echo `v5` is maintenance release with **major breaking changes** -- `Context` is now struct instead of interface and we can add method to it in the future in minor versions. -- Adds new `Router` interface for possible new routing implementations. -- Drops old logging interface and uses moderm `log/slog` instead. -- Rearranges alot of methods/function signatures to make them more consistent. - -## Executive Summary (by LLMs) - -Echo v5 represents a **major breaking release** with significant architectural changes focused on: -- **Updated generic helpers** to take `*Context` and rename form helpers to `FormValue*` -- **Simplified API surface** by moving Context from interface to concrete struct -- **Modern Go patterns** including slog.Logger integration -- **Enhanced routing** with explicit RouteInfo and Routes types -- **Better error handling** with simplified HTTPError -- **New test helpers** via the `echotest` package - -### Change Statistics - -- **Major Breaking Changes**: 15+ -- **New Functions Added**: 30+ -- **Type Signature Changes**: 20+ -- **Removed APIs**: 10+ -- **New Packages Added**: 1 (`echotest`) -- **Version Change**: `4.15.0` β†’ `5.0.0-alpha` - ---- - -## Critical Breaking Changes - -### 1. **Context: Interface β†’ Concrete Struct** - -**v4 (master):** -```go -type Context interface { - Request() *http.Request - // ... many methods -} - -// Handler signature -func handler(c echo.Context) error -``` - -**v5:** -```go -type Context struct { - // Has unexported fields -} - -// Handler signature - NOW USES POINTER! -func handler(c *echo.Context) error -``` - -**Impact:** πŸ”΄ **CRITICAL BREAKING CHANGE** -- ALL handlers must change from `echo.Context` to `*echo.Context` -- Context is now a concrete struct, not an interface -- This affects every single handler function in user code - -**Migration:** -```go -// Before (v4) -func MyHandler(c echo.Context) error { - return c.JSON(200, map[string]string{"hello": "world"}) -} - -// After (v5) -func MyHandler(c *echo.Context) error { - return c.JSON(200, map[string]string{"hello": "world"}) -} -``` - ---- - -### 2. **Logger: Custom Interface β†’ slog.Logger** - -**v4:** -```go -type Echo struct { - Logger Logger // Custom interface with Print, Debug, Info, etc. -} - -type Logger interface { - Output() io.Writer - SetOutput(w io.Writer) - Prefix() string - // ... many custom methods -} - -// Context returns Logger interface -func (c Context) Logger() Logger -``` - -**v5:** -```go -type Echo struct { - Logger *slog.Logger // Standard library structured logger -} - -// Context returns slog.Logger -func (c *Context) Logger() *slog.Logger -func (c *Context) SetLogger(logger *slog.Logger) -``` - -**Impact:** πŸ”΄ **BREAKING CHANGE** -- Must use Go's standard `log/slog` package -- Logger interface completely removed -- All logging code needs updating - ---- - -### 3. **Router: From Router to DefaultRouter** - -**v4:** -```go -type Router struct { ... } - -func NewRouter(e *Echo) *Router -func (e *Echo) Router() *Router -``` - -**v5:** -```go -type DefaultRouter struct { ... } - -func NewRouter(config RouterConfig) *DefaultRouter -func (e *Echo) Router() Router // Returns interface -``` - -**Changes:** -- New `Router` interface introduced -- `DefaultRouter` is the concrete implementation -- `NewRouter()` now takes `RouterConfig` instead of `*Echo` -- Added `NewConcurrentRouter(r Router) Router` for thread-safe routing - ---- - -### 4. **Route Return Types Changed** - -**v4:** -```go -func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route -func (e *Echo) Any(path string, h HandlerFunc, m ...MiddlewareFunc) []*Route -func (e *Echo) Routes() []*Route -``` - -**v5:** -```go -func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo -func (e *Echo) Any(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo -func (e *Echo) Match(...) Routes // Returns Routes type -func (e *Echo) Router() Router // Returns interface -``` - -**New Types:** -```go -type RouteInfo struct { - Name string - Method string - Path string - Parameters []string -} - -type Routes []RouteInfo // Collection with helper methods -``` - -**Impact:** πŸ”΄ **BREAKING CHANGE** -- Route registration methods return `RouteInfo` instead of `*Route` -- New `Routes` collection type with filtering methods -- `Route` struct still exists but used differently - ---- - -### 5. **Response Type Changed** - -**v4:** -```go -func (c Context) Response() *Response -type Response struct { - Writer http.ResponseWriter - Status int - Size int64 - Committed bool -} -func NewResponse(w http.ResponseWriter, e *Echo) *Response -``` - -**v5:** -```go -func (c *Context) Response() http.ResponseWriter -type Response struct { - http.ResponseWriter // Embedded - Status int - Size int64 - Committed bool -} -func NewResponse(w http.ResponseWriter, logger *slog.Logger) *Response -func UnwrapResponse(rw http.ResponseWriter) (*Response, error) -``` - -**Changes:** -- Context.Response() returns `http.ResponseWriter` instead of `*Response` -- Response now embeds `http.ResponseWriter` -- NewResponse takes `*slog.Logger` instead of `*Echo` -- New `UnwrapResponse()` helper function - ---- - -### 6. **HTTPError Simplified** - -**v4:** -```go -type HTTPError struct { - Internal error - Message interface{} // Can be any type - Code int -} - -func NewHTTPError(code int, message ...interface{}) *HTTPError -``` - -**v5:** -```go -type HTTPError struct { - Code int - Message string // Now string only - // Has unexported fields (Internal moved) -} - -func NewHTTPError(code int, message string) *HTTPError -func (he HTTPError) Wrap(err error) error // New method -func (he *HTTPError) StatusCode() int // Implements HTTPStatusCoder -``` - -**Changes:** -- `Message` field changed from `interface{}` to `string` -- `NewHTTPError()` now takes `string` instead of `...interface{}` -- Added `HTTPStatusCoder` interface and `StatusCode()` method -- Added `Wrap(err error)` method for error wrapping - ---- - -### 7. **HTTPErrorHandler Signature Changed** - -**v4:** -```go -type HTTPErrorHandler func(err error, c Context) - -func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) -``` - -**v5:** -```go -type HTTPErrorHandler func(c *Context, err error) // Parameters swapped! - -func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler // Now a factory -``` - -**Impact:** πŸ”΄ **BREAKING CHANGE** -- Parameter order reversed: `(c *Context, err error)` instead of `(err error, c Context)` -- DefaultHTTPErrorHandler is now a factory function that returns HTTPErrorHandler -- Takes `exposeError` bool to control error message exposure - ---- - -## Notable API Changes in v5 - -### 1. **Generic Parameter Extraction Functions (Updated Signatures)** - -These helpers keep the same generic API but now accept `*Context`, and the -form helpers are renamed from `FormParam*` to `FormValue*`: - -```go -// Query Parameters -func QueryParam[T any](c *Context, key string, opts ...any) (T, error) -func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) -func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error) -func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) - -// Path Parameters -func PathParam[T any](c *Context, paramName string, opts ...any) (T, error) -func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...any) (T, error) - -// Form Values -func FormValue[T any](c *Context, key string, opts ...any) (T, error) -func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) -func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) -func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) - -// Generic Parsing -func ParseValue[T any](value string, opts ...any) (T, error) -func ParseValueOr[T any](value string, defaultValue T, opts ...any) (T, error) -func ParseValues[T any](values []string, opts ...any) ([]T, error) -func ParseValuesOr[T any](values []string, defaultValue []T, opts ...any) ([]T, error) -``` - -`FormParam*` was renamed to `FormValue*`; the rest keep names but now take `*Context`. - -**Supported Types:** -- bool, string -- int, int8, int16, int32, int64 -- uint, uint8, uint16, uint32, uint64 -- float32, float64 -- time.Time, time.Duration -- BindUnmarshaler, encoding.TextUnmarshaler, json.Unmarshaler - -**Example Usage:** -```go -// v5 - Type-safe parameter binding -id, err := echo.PathParam[int](c, "id") -page, err := echo.QueryParamOr[int](c, "page", 1) -tags, err := echo.QueryParams[string](c, "tags") -``` - ---- - -### 2. **Context Store Helpers Now Use `*Context`** - -```go -// Type-safe context value retrieval -func ContextGet[T any](c *Context, key string) (T, error) -func ContextGetOr[T any](c *Context, key string, defaultValue T) (T, error) - -// Error types -var ErrNonExistentKey = errors.New("non existent key") -var ErrInvalidKeyType = errors.New("invalid key type") -``` - -These helpers existed in v4 with `Context` and now accept `*Context`. - -**Example:** -```go -// v5 -user, err := echo.ContextGet[*User](c, "user") -count, err := echo.ContextGetOr[int](c, "count", 0) -``` - ---- - -### 3. **PathValues Type** - -New structured path parameter handling: - -```go -type PathValue struct { - Name string - Value string -} - -type PathValues []PathValue - -func (p PathValues) Get(name string) (string, bool) -func (p PathValues) GetOr(name string, defaultValue string) string - -// Context methods -func (c *Context) PathValues() PathValues -func (c *Context) SetPathValues(pathValues PathValues) -``` - ---- - -### 4. **Time Parsing Options** - -```go -type TimeLayout string - -const ( - TimeLayoutUnixTime = TimeLayout("UnixTime") - TimeLayoutUnixTimeMilli = TimeLayout("UnixTimeMilli") - TimeLayoutUnixTimeNano = TimeLayout("UnixTimeNano") -) - -type TimeOpts struct { - Layout TimeLayout - ParseInLocation *time.Location - ToInLocation *time.Location -} -``` - ---- - -### 5. **StartConfig for Server Configuration** - -```go -type StartConfig struct { - Address string - HideBanner bool - HidePort bool - CertFilesystem fs.FS - TLSConfig *tls.Config - ListenerNetwork string - ListenerAddrFunc func(addr net.Addr) - GracefulTimeout time.Duration - OnShutdownError func(err error) - BeforeServeFunc func(s *http.Server) error -} - -func (sc StartConfig) Start(ctx context.Context, h http.Handler) error -func (sc StartConfig) StartTLS(ctx context.Context, h http.Handler, certFile, keyFile any) error -``` - -**Example:** -```go -// v5 - More control over server startup -ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) -defer cancel() - -sc := echo.StartConfig{ - Address: ":8080", - GracefulTimeout: 10 * time.Second, -} -if err := sc.Start(ctx, e); err != nil { - log.Fatal(err) -} -``` - ---- - -### 6. **Echo Config and Constructors** - -```go -type Config struct { - // Configuration for Echo (logger, binder, renderer, etc.) -} - -func NewWithConfig(config Config) *Echo -``` - -This adds a configuration struct for creating an `Echo` instance without -mutating fields after `New()`. - ---- - -### 7. **Enhanced Routing Features** - -```go -// New route methods -func (e *Echo) AddRoute(route Route) (RouteInfo, error) -func (e *Echo) Middlewares() []MiddlewareFunc -func (e *Echo) PreMiddlewares() []MiddlewareFunc -type AddRouteError struct{ ... } - -// Routes collection with filters -type Routes []RouteInfo - -func (r Routes) Clone() Routes -func (r Routes) FilterByMethod(method string) (Routes, error) -func (r Routes) FilterByName(name string) (Routes, error) -func (r Routes) FilterByPath(path string) (Routes, error) -func (r Routes) FindByMethodPath(method string, path string) (RouteInfo, error) -func (r Routes) Reverse(routeName string, pathValues ...any) (string, error) - -// RouteInfo operations -func (r RouteInfo) Clone() RouteInfo -func (r RouteInfo) Reverse(pathValues ...any) string -``` - ---- - -### 8. **Middleware Configuration Interface** - -```go -type MiddlewareConfigurator interface { - ToMiddleware() (MiddlewareFunc, error) -} -``` - -Allows middleware configs to be converted to middleware without panicking. - ---- - -### 9. **New Context Methods** - -```go -// v5 additions -func (c *Context) FileFS(file string, filesystem fs.FS) error -func (c *Context) FormValueOr(name, defaultValue string) string -func (c *Context) InitializeRoute(ri *RouteInfo, pathValues *PathValues) -func (c *Context) ParamOr(name, defaultValue string) string -func (c *Context) QueryParamOr(name, defaultValue string) string -func (c *Context) RouteInfo() RouteInfo -``` - ---- - -### 10. **Virtual Host Support** - -```go -func NewVirtualHostHandler(vhosts map[string]*Echo) *Echo -``` - -Creates an Echo instance that routes requests to different Echo instances based on host. - ---- - -### 11. **New Binder Functions** - -```go -func BindBody(c *Context, target any) error -func BindHeaders(c *Context, target any) error -func BindPathValues(c *Context, target any) error // Renamed from BindPathParams -func BindQueryParams(c *Context, target any) error -``` - -Top-level binding functions that work with `*Context`. - ---- - -### 12. **New echotest Package** - -```go -package echotest // import "github.com/labstack/echo/v5/echotest" - -func LoadBytes(t *testing.T, name string, opts ...loadBytesOpts) []byte -func TrimNewlineEnd(bytes []byte) []byte -type ContextConfig struct{ ... } -type MultipartForm struct{ ... } -type MultipartFormFile struct{ ... } -``` - -Helpers for loading fixtures and constructing test contexts. - ---- - -## Removed APIs in v5 - -### Constants - -```go -// v4 - Removed in v5 -const CONNECT = http.MethodConnect // Use http.MethodConnect directly -``` - -**Reason:** Deprecated in v4, use stdlib `http.Method*` constants instead. - ---- - -### Constants Added in v5 - -```go -// v5 additions -const ( - NotFoundRouteName = "echo_route_not_found_name" -) -``` - ---- - -### Error Variable Changes - -**v4 exports:** -```go -ErrBadRequest -ErrInvalidKeyType -ErrNonExistentKey -``` - -**v5 exports:** -```go -ErrBadRequest // Now backed by unexported httpError type -ErrValidatorNotRegistered // New -ErrInvalidKeyType -ErrNonExistentKey -``` - -**Reason:** v5 centralizes on `NewHTTPError(code, message)` rather than a broad set -of predefined HTTP error variables. - ---- - -### Functions Removed - -```go -// v4 - Removed in v5 -func GetPath(r *http.Request) string // Use r.URL.Path or r.URL.RawPath -``` - -### Variables Removed - -```go -// v4 - Removed in v5 -var MethodNotAllowedHandler = func(c Context) error { ... } -var NotFoundHandler = func(c Context) error { ... } -``` - -### Functions Renamed - -```go -// v4 -func FormParam[T any](c Context, key string, opts ...any) (T, error) -func FormParamOr[T any](c Context, key string, defaultValue T, opts ...any) (T, error) -func FormParams[T any](c Context, key string, opts ...any) ([]T, error) -func FormParamsOr[T any](c Context, key string, defaultValue []T, opts ...any) ([]T, error) - -// v5 -func FormValue[T any](c *Context, key string, opts ...any) (T, error) -func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) -func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) -func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) -``` - ---- - -### Type Methods Removed/Changed - -**Echo struct changes:** -```go -// v4 fields removed in v5 -type Echo struct { - StdLogger *stdLog.Logger // Removed - Server *http.Server // Removed (use StartConfig) - TLSServer *http.Server // Removed (use StartConfig) - Listener net.Listener // Removed (use StartConfig) - TLSListener net.Listener // Removed (use StartConfig) - AutoTLSManager autocert.Manager // Removed - ListenerNetwork string // Removed - OnAddRouteHandler func(...) // Changed to OnAddRoute - DisableHTTP2 bool // Removed (use StartConfig) - Debug bool // Removed - HideBanner bool // Removed (use StartConfig) - HidePort bool // Removed (use StartConfig) -} - -// v5 Echo struct (simplified) -type Echo struct { - Binder Binder - Filesystem fs.FS // NEW - Renderer Renderer - Validator Validator - JSONSerializer JSONSerializer - IPExtractor IPExtractor - OnAddRoute func(route Route) error // Simplified - HTTPErrorHandler HTTPErrorHandler - Logger *slog.Logger // Changed from Logger interface -} -``` - ---- - -**Context interface β†’ struct:** -```go -// v4 -type Context interface { - // Had: SetResponse(*Response) - Response() *Response - - // Had: ParamNames(), SetParamNames(), ParamValues(), SetParamValues() - // These are removed in v5 (use PathValues() instead) -} - -// v5 -type Context struct { - // Concrete struct with unexported fields -} - -func (c *Context) Response() http.ResponseWriter // Changed return type -func (c *Context) PathValues() PathValues // Replaces ParamNames/Values -``` - ---- - -**Types removed:** -```go -// v4 -type Map map[string]interface{} -``` - -**Group changes:** -```go -// v4 -func (g *Group) File(path, file string) // No return value -func (g *Group) Static(pathPrefix, fsRoot string) // No return value -func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) // No return value - -// v5 -func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo -func (g *Group) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo -func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo -``` - -Now return `RouteInfo` and accept middleware. - ---- - -### Value Binder Factory Name Changes - -```go -// v4 -func PathParamsBinder(c Context) *ValueBinder -func QueryParamsBinder(c Context) *ValueBinder -func FormFieldBinder(c Context) *ValueBinder - -// v5 -func PathValuesBinder(c *Context) *ValueBinder // Renamed -func QueryParamsBinder(c *Context) *ValueBinder -func FormFieldBinder(c *Context) *ValueBinder -``` - ---- - -## Type Signature Changes - -### Binder Interface - -```go -// v4 -type Binder interface { - Bind(i interface{}, c Context) error -} - -// v5 -type Binder interface { - Bind(c *Context, target any) error // Parameters swapped! -} -``` - ---- - -### DefaultBinder Methods - -```go -// v4 -func (b *DefaultBinder) Bind(i interface{}, c Context) error -func (b *DefaultBinder) BindBody(c Context, i interface{}) error -func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error - -// v5 -func (b *DefaultBinder) Bind(c *Context, target any) error // Swapped params -// BindBody, BindPathParams, etc. are now top-level functions -``` - ---- - -### JSONSerializer Interface - -```go -// v4 -type JSONSerializer interface { - Serialize(c Context, i interface{}, indent string) error - Deserialize(c Context, i interface{}) error -} - -// v5 -type JSONSerializer interface { - Serialize(c *Context, target any, indent string) error - Deserialize(c *Context, target any) error -} -``` - ---- - -### Renderer Interface - -```go -// v4 -type Renderer interface { - Render(io.Writer, string, interface{}, Context) error -} - -// v5 -type Renderer interface { - Render(c *Context, w io.Writer, templateName string, data any) error -} -``` - -Parameters reordered with Context first. - ---- - -### NewBindingError - -```go -// v4 -func NewBindingError(sourceParam string, values []string, message interface{}, internalError error) error - -// v5 -func NewBindingError(sourceParam string, values []string, message string, err error) error -``` - -Message parameter changed from `interface{}` to `string`. - ---- - -### HandlerName - -```go -// v5 only -func HandlerName(h HandlerFunc) string -``` - -New utility function to get handler function name. - ---- - -## Middleware Package Changes - -### Signature and Type Updates - -```go -// CORS now accepts optional allow-origins -func CORS(allowOrigins ...string) echo.MiddlewareFunc - -// BodyLimit now accepts bytes -func BodyLimit(limitBytes int64) echo.MiddlewareFunc - -// DefaultSkipper now uses *echo.Context -func DefaultSkipper(c *echo.Context) bool - -// Trailing slash configs renamed/split -func AddTrailingSlashWithConfig(config AddTrailingSlashConfig) echo.MiddlewareFunc -func RemoveTrailingSlashWithConfig(config RemoveTrailingSlashConfig) echo.MiddlewareFunc -type AddTrailingSlashConfig struct{ ... } -type RemoveTrailingSlashConfig struct{ ... } - -// Auth + extractor signatures now use *echo.Context and add ExtractorSource -type BasicAuthValidator func(c *echo.Context, user string, password string) (bool, error) -type Extractor func(c *echo.Context) (string, error) -type ExtractorSource string -type KeyAuthValidator func(c *echo.Context, key string, source ExtractorSource) (bool, error) -type KeyAuthErrorHandler func(c *echo.Context, err error) error - -// BodyDump handler now includes err -type BodyDumpHandler func(c *echo.Context, reqBody []byte, resBody []byte, err error) - -// ValuesExtractor now returns extractor source and CreateExtractors takes a limit -type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error) -func CreateExtractors(lookups string, limit uint) ([]ValuesExtractor, error) -type ValueExtractorError struct{ ... } - -// New constants -const KB = 1024 - -// Rate limiter store now takes a float64 limit -func NewRateLimiterMemoryStore(rateLimit float64) (store *RateLimiterMemoryStore) -``` - -### Added Middleware Exports - -```go -var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key") -var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key") -var RedirectHTTPSConfig = RedirectConfig{ ... } -var RedirectHTTPSWWWConfig = RedirectConfig{ ... } -var RedirectNonHTTPSWWWConfig = RedirectConfig{ ... } -var RedirectNonWWWConfig = RedirectConfig{ ... } -var RedirectWWWConfig = RedirectConfig{ ... } -``` - -### Removed/Consolidated Middleware Exports - -```go -// Removed in v5 -func Logger() echo.MiddlewareFunc -func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc -func Timeout() echo.MiddlewareFunc -func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc -type ErrKeyAuthMissing struct{ ... } -type CSRFErrorHandler func(err error, c echo.Context) error -type LoggerConfig struct{ ... } -type LogErrorFunc func(c echo.Context, err error, stack []byte) error -type TargetProvider interface{ ... } -type TrailingSlashConfig struct{ ... } -type TimeoutConfig struct{ ... } -``` - -Also removed defaults: `DefaultBasicAuthConfig`, `DefaultBodyDumpConfig`, `DefaultBodyLimitConfig`, -`DefaultCORSConfig`, `DefaultDecompressConfig`, `DefaultGzipConfig`, `DefaultLoggerConfig`, -`DefaultRedirectConfig`, `DefaultRequestIDConfig`, `DefaultRewriteConfig`, `DefaultTimeoutConfig`, -`DefaultTrailingSlashConfig`. - ---- - -## Router Interface Changes - -### v4 Router (Concrete Struct) - -```go -type Router struct { ... } - -func NewRouter(e *Echo) *Router -func (r *Router) Add(method, path string, h HandlerFunc) -func (r *Router) Find(method, path string, c Context) -func (r *Router) Reverse(name string, params ...interface{}) string -func (r *Router) Routes() []*Route -``` - -### v5 Router (Interface + DefaultRouter) - -```go -type Router interface { - Add(routable Route) (RouteInfo, error) - Remove(method string, path string) error - Routes() Routes - Route(c *Context) HandlerFunc -} - -type DefaultRouter struct { ... } - -func NewRouter(config RouterConfig) *DefaultRouter -func NewConcurrentRouter(r Router) Router // NEW - -type RouterConfig struct { - NotFoundHandler HandlerFunc - MethodNotAllowedHandler HandlerFunc - OptionsMethodHandler HandlerFunc - AllowOverwritingRoute bool - UnescapePathParamValues bool - UseEscapedPathForMatching bool -} -``` - -**Key Changes:** -- Router is now an interface -- DefaultRouter is the concrete implementation -- Add() returns `(RouteInfo, error)` instead of being void -- New `Remove()` method -- New `Route()` method replaces `Find()` -- Configuration through `RouterConfig` - ---- - -## Echo Instance Method Changes - -### Route Registration - -```go -// v4 -func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route - -// v5 -func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo -func (e *Echo) AddRoute(route Route) (RouteInfo, error) // NEW -``` - -### Static File Serving - -```go -// v4 -func (e *Echo) Static(pathPrefix, fsRoot string) *Route -func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route -func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route -func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route - -// v5 -func (e *Echo) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo -func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo -func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo -func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo -``` - -Return type changed from `*Route` to `RouteInfo`. - -### Server Management - -```go -// v4 -func (e *Echo) Start(address string) error -func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) error -func (e *Echo) StartAutoTLS(address string) error -func (e *Echo) StartH2CServer(address string, h2s *http2.Server) error -func (e *Echo) StartServer(s *http.Server) error -func (e *Echo) Shutdown(ctx context.Context) error -func (e *Echo) Close() error -func (e *Echo) ListenerAddr() net.Addr -func (e *Echo) TLSListenerAddr() net.Addr -func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) - -// v5 -func (e *Echo) Start(address string) error // Simplified -func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) - -// Removed: StartTLS, StartAutoTLS, StartH2CServer, StartServer -// Use StartConfig instead for advanced server configuration -// Removed: Shutdown, Close, ListenerAddr, TLSListenerAddr -// Removed: DefaultHTTPErrorHandler (now a top-level factory function) -``` - -**v5 provides** `StartConfig` type for all advanced server configuration. - -### Router Access - -```go -// v4 -func (e *Echo) Router() *Router -func (e *Echo) Routers() map[string]*Router // For multi-host -func (e *Echo) Routes() []*Route -func (e *Echo) Reverse(name string, params ...interface{}) string -func (e *Echo) URI(handler HandlerFunc, params ...interface{}) string -func (e *Echo) URL(h HandlerFunc, params ...interface{}) string -func (e *Echo) Host(name string, m ...MiddlewareFunc) *Group - -// v5 -func (e *Echo) Router() Router // Returns interface -// Removed: Routers(), Reverse(), URI(), URL(), Host() -// Use router.Routes() and Routes.Reverse() instead -``` - ---- - -## NewContext Changes - -```go -// v4 -func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context -func NewResponse(w http.ResponseWriter, e *Echo) *Response - -// v5 -func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) *Context -func NewContext(r *http.Request, w http.ResponseWriter, opts ...any) *Context // Standalone -func NewResponse(w http.ResponseWriter, logger *slog.Logger) *Response -``` - ---- - -## Migration Guide Summary - -If you are using Linux you can migrate easier parts like that: -```bash -find . -type f -name "*.go" -exec sed -i 's/ echo.Context/ *echo.Context/g' {} + -find . -type f -name "*.go" -exec sed -i 's/echo\/v4/echo\/v5/g' {} + -``` -or in your favorite IDE - -Replace all: -1. ` echo.Context` -> ` *echo.Context` -2. `echo/v4` -> `echo/v5` - - -### 1. Update All Handler Signatures - -```go -// Before -func MyHandler(c echo.Context) error { ... } - -// After -func MyHandler(c *echo.Context) error { ... } -``` - -### 2. Update Logger Usage - -```go -// Before -e.Logger.Info("Server started") -c.Logger().Error("Something went wrong") - -// After -e.Logger.Info("Server started") -c.Logger().Error("Something went wrong") // Same API, different logger -``` - -### 3. Use Type-Safe Parameter Extraction - -```go -// Before -idStr := c.Param("id") -id, err := strconv.Atoi(idStr) - -// After -id, err := echo.PathParam[int](c, "id") -``` - -### 4. Update Error Handler - -```go -// Before -e.HTTPErrorHandler = func(err error, c echo.Context) { - // handle error -} - -// After -e.HTTPErrorHandler = func(c *echo.Context, err error) { // Swapped! - // handle error -} - -// Or use factory -e.HTTPErrorHandler = echo.DefaultHTTPErrorHandler(true) // exposeError=true -``` - -### 5. Update Server Startup - -```go -// Before -e.Start(":8080") -e.StartTLS(":443", "cert.pem", "key.pem") - -// After -// Simple -e.Start(":8080") - -// Advanced with graceful shutdown -ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) -defer cancel() -sc := echo.StartConfig{Address: ":8080"} -sc.Start(ctx, e) -``` - -### 6. Update Route Info Access - -```go -// Before -routes := e.Routes() -for _, r := range routes { - fmt.Println(r.Method, r.Path) -} - -// After -routes := e.Router().Routes() -for _, r := range routes { - fmt.Println(r.Method, r.Path) -} -``` - -### 7. Update HTTPError Creation - -```go -// Before -return echo.NewHTTPError(400, "invalid request", someDetail) - -// After -return echo.NewHTTPError(400, "invalid request") -``` - -### 8. Update Custom Binder - -```go -// Before -type MyBinder struct{} -func (b *MyBinder) Bind(i interface{}, c echo.Context) error { ... } - -// After -type MyBinder struct{} -func (b *MyBinder) Bind(c *echo.Context, target any) error { ... } // Swapped! -``` - -### 9. Path Parameters - -```go -// Before -names := c.ParamNames() -values := c.ParamValues() - -// After -pathValues := c.PathValues() -for _, pv := range pathValues { - fmt.Println(pv.Name, pv.Value) -} -``` - -### 10. Response Access - -```go -// Before -resp := c.Response() -resp.Header().Set("X-Custom", "value") - -// After -c.Response().Header().Set("X-Custom", "value") // Returns http.ResponseWriter - -// To get *echo.Response -resp, err := echo.UnwrapResponse(c.Response()) -``` - -### Go Version Requirements - -- **v4**: Go 1.24.0 (per `go.mod`) -- **v5**: Go 1.25.0 (per `go.mod`) - ---- - -**Generated by comparing `go doc` output from master (v4.15.0) and v5 (v5.0.0-alpha) branches** diff --git a/CHANGELOG.md b/CHANGELOG.md index e129befa2..86feea16b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,169 +1,28 @@ # Changelog -## v5.2.0 - 2026-06-14 +## v4.15.3 - 2026-06-14 **Security** -* fix(static): reject encoded path separators that bypass route-level middleware by @vishr in https://github.com/labstack/echo/pull/3009 -* fix(middleware/static): don't double-unescape request path (#2599) by @vishr in https://github.com/labstack/echo/pull/3006 +* fix(static): reject encoded path separators that bypass route-level middleware by @vishr in https://github.com/labstack/echo/pull/3011 -Fixes [GHSA-vfp3-v2gw-7wfq](https://github.com/labstack/echo/security/advisories/GHSA-vfp3-v2gw-7wfq): an encoded path separator (`%2F` or `%5C`) in a static file URL could bypass route-level middleware (e.g. authentication on a sibling route) and disclose static files. Both `StaticDirectoryHandler`/`StaticFS` and the `Static` middleware are affected. Thanks to @a-tt-om and @oran-gugu for reporting. +Fixes [GHSA-vfp3-v2gw-7wfq](https://github.com/labstack/echo/security/advisories/GHSA-vfp3-v2gw-7wfq): an encoded path separator (`%2F` or `%5C`) in a static file URL could bypass route-level middleware (e.g. authentication on a sibling route) and disclose static files. Both `StaticDirectoryHandler` (used by `Static`/`StaticFS`) and the `Static` middleware are affected. Backport of the v5 fix (#3009). Thanks to @a-tt-om and @oran-gugu for reporting. -**Enhancements** -* feat(middleware): optional RateLimiterStoreContext for response headers (#2961) by @vishr in https://github.com/labstack/echo/pull/3007 -* perf: optimize core hot paths (chain, context, binding, responses) by @vishr in https://github.com/labstack/echo/pull/3008 -* fix(binder): include field name in bind conversion errors (#2629) by @vishr in https://github.com/labstack/echo/pull/3005 -* fix(binder): serialize BindingError to structured JSON (#2771) by @vishr in https://github.com/labstack/echo/pull/3004 -* fix(binder): MustUnixTime docs say time.Time, not time.Duration by @c-tonneslan in https://github.com/labstack/echo/pull/2988 -* fix(middleware): reset ContentLength after gzip decompression by @shblue21 in https://github.com/labstack/echo/pull/3000 -* fix(middleware/proxy): append RealIP to X-Forwarded-For for WebSocket requests by @kawaway in https://github.com/labstack/echo/pull/2994 -* Fix proxy panic when balancer has no targets by @shblue21 in https://github.com/labstack/echo/pull/2977 -* fix(middleware): correct documented KeyAuth KeyLookup default by @leestana01 in https://github.com/labstack/echo/pull/2992 -* test: lock in v5 group route method-handling (405 + OPTIONS) by @vishr in https://github.com/labstack/echo/pull/3003 -* docs: liveness signals in README + public ROADMAP by @vishr in https://github.com/labstack/echo/pull/3002 -* Fix typos in CSRFConfig comments by @shblue21 in https://github.com/labstack/echo/pull/2979 -* refactor: modernize code usage using gofix by @kumapower17 in https://github.com/labstack/echo/pull/2970 -* refactor: replace Split in loops with more efficient SplitSeq by @box4wangjing in https://github.com/labstack/echo/pull/2969 -* refactor: use the built-in max/min to simplify the code by @criciss in https://github.com/labstack/echo/pull/2966 -* Update GitHub actions deps versions by @aldas in https://github.com/labstack/echo/pull/2971 - -**New Contributors** - -* @criciss made their first contribution in https://github.com/labstack/echo/pull/2966 -* @box4wangjing made their first contribution in https://github.com/labstack/echo/pull/2969 -* @shblue21 made their first contribution in https://github.com/labstack/echo/pull/2977 -* @c-tonneslan made their first contribution in https://github.com/labstack/echo/pull/2988 -* @leestana01 made their first contribution in https://github.com/labstack/echo/pull/2992 -* @kawaway made their first contribution in https://github.com/labstack/echo/pull/2994 - -**Full Changelog**: https://github.com/labstack/echo/compare/v5.1.1...v5.2.0 - - -## v5.1.1 - 2026-05-01 +## v4.15.2 - 2026-05-01 **Security** -* `Context.Scheme()` should validate values taken from header by @aldas in https://github.com/labstack/echo/pull/2953 +* `Context.Scheme()` should validate values taken from header by @aldas in https://github.com/labstack/echo/pull/2962 Thanks to @shblue21 for reporting this [issue](https://github.com/labstack/echo/issues/2952). -**Enhancements** - -* Add golangci linter configuration by @aldas in https://github.com/labstack/echo/pull/2930 -* Make StartConfig listener creation context-aware by @EricGusmao in https://github.com/labstack/echo/pull/2936 -* fix(lint): resolve staticcheck issues and improve code quality by @itsllyaz in https://github.com/labstack/echo/pull/2941 -* Context.Scheme should validate values taken from header by @aldas in https://github.com/labstack/echo/pull/2953 -* chore: fix typos in httperror.go by @tisonkun in https://github.com/labstack/echo/pull/2958 -* Context.Json should not unwrap response by @aldas in https://github.com/labstack/echo/pull/2964 - - -## v5.1.0 - 2026-03-31 - -**Security** - -This change does not break the API contract, but it does introduce breaking changes in logic/behavior. -If your application is using `c.RealIP()` beware and read https://echo.labstack.com/docs/ip-address - -`v4` behavior can be restored with: -```go -e := echo.New() -e.IPExtractor = echo.LegacyIPExtractor() -``` - -* Remove legacy IP extraction logic from context.RealIP method by @aldas in https://github.com/labstack/echo/pull/2933 +## v4.15.1 - 2026-02-22 **Enhancements** -* Add echo-opentelemetry to the README.md by @aldas in https://github.com/labstack/echo/pull/2908 -* fix: correct spelling mistakes in comments and field name by @crawfordxx in https://github.com/labstack/echo/pull/2916 -* Add https://github.com/labstack/echo-prometheus to the middleware list in README.md by @aldas in https://github.com/labstack/echo/pull/2919 -* Add StartConfig.Listener so server with custom Listener is easier to create by @aldas in https://github.com/labstack/echo/pull/2920 -* Fix rate limiter documentation for default burst value by @karesansui-u in https://github.com/labstack/echo/pull/2925 -* Add doc comments to clarify usage of File related methods and leading slash handling by @aldas in https://github.com/labstack/echo/pull/2928 -* Add NewDefaultFS function to help create filesystem that allows absolute paths by @aldas in https://github.com/labstack/echo/pull/2931 -* Do not set http.Server.WriteTimeout in StartConfig by @aldas in https://github.com/labstack/echo/pull/2932 - - - -## v5.0.4 - 2026-02-15 - -**Enhancements** - -* Remove unused import 'errors' from README example by @kumapower17 in https://github.com/labstack/echo/pull/2889 -* Fix Graceful shutdown: after `http.Server.Serve` returns we need to wait for graceful shutdown goroutine to finish by @aldas in https://github.com/labstack/echo/pull/2898 -* Update location of oapi-codegen in README by @mromaszewicz in https://github.com/labstack/echo/pull/2896 -* Add Go 1.26 to CI flow by @aldas in https://github.com/labstack/echo/pull/2899 -* Add new function `echo.StatusCode` by @suwakei in https://github.com/labstack/echo/pull/2892 -* CSRF: support older token-based CSRF protection handler that want to render token into template by @aldas in https://github.com/labstack/echo/pull/2894 -* Add `echo.ResolveResponseStatus` function to help middleware/handlers determine HTTP status code and echo.Response by @aldas in https://github.com/labstack/echo/pull/2900 - - -## v5.0.3 - 2026-02-06 - -**Security** - -* Fix directory traversal vulnerability under Windows in Static middleware when default Echo filesystem is used. Reported by @shblue21. - -This applies to cases when: -- Windows is used as OS -- `middleware.StaticConfig.Filesystem` is `nil` (default) -- `echo.Filesystem` is has not been set explicitly (default) - -Exposure is restricted to the active process working directory and its subfolders. - - -## v5.0.2 - 2026-02-02 - -**Security** - -* Fix Static middleware with `config.Browse=true` lists all files/subfolders from `config.Filesystem` root and not starting from `config.Root` in https://github.com/labstack/echo/pull/2887 - - -## v5.0.1 - 2026-01-28 - -* Panic MW: will now return a custom PanicStackError with stack trace by @aldas in https://github.com/labstack/echo/pull/2871 -* Docs: add missing err parameter to DenyHandler example by @cgalibern in https://github.com/labstack/echo/pull/2878 -* improve: improve websocket checks in IsWebSocket() [per RFC 6455] by @raju-mechatronics in https://github.com/labstack/echo/pull/2875 -* fix: Context.Json() should not send status code before serialization is complete by @aldas in https://github.com/labstack/echo/pull/2877 - - -## v5.0.0 - 2026-01-18 - -Echo `v5` is maintenance release with **major breaking changes** -- `Context` is now struct instead of interface and we can add method to it in the future in minor versions. -- Adds new `Router` interface for possible new routing implementations. -- Drops old logging interface and uses moderm `log/slog` instead. -- Rearranges alot of methods/function signatures to make them more consistent. - -Upgrade notes and `v4` support: -- Echo `v4` is supported with **security*** updates and **bug** fixes until **2026-12-31** -- If you are using Echo in a production environment, it is recommended to wait until after 2026-03-31 before upgrading. -- Until 2026-03-31, any critical issues requiring breaking `v5` API changes will be addressed, even if this violates semantic versioning. - -See [API_CHANGES_V5.md](./API_CHANGES_V5.md) for public API changes between `v4` and `v5`, notes on **upgrading**. - -Upgrading TLDR: - -If you are using Linux you can migrate easier parts like that: -```bash -find . -type f -name "*.go" -exec sed -i 's/ echo.Context/ *echo.Context/g' {} + -find . -type f -name "*.go" -exec sed -i 's/echo\/v4/echo\/v5/g' {} + -``` -macOS -```bash -find . -type f -name "*.go" -exec sed -i '' 's/ echo.Context/ *echo.Context/g' {} + -find . -type f -name "*.go" -exec sed -i '' 's/echo\/v4/echo\/v5/g' {} + -``` - -or in your favorite IDE - -Replace all: -1. ` echo.Context` -> ` *echo.Context` -2. `echo/v4` -> `echo/v5` - -This should solve most of the issues. Probably the hardest part is updating all the tests. +* CSRF: support older token-based CSRF protection handler that want to render token into template by @aldas in https://github.com/labstack/echo/pull/2905 ## v4.15.0 - 2026-01-01 diff --git a/LICENSE b/LICENSE index 2f18411bd..c46d0105f 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2022 LabStack +Copyright (c) 2021 LabStack Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/Makefile b/Makefile index bd075bbae..cbd78f1bf 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,10 @@ PKG := "github.com/labstack/echo" PKG_LIST := $(shell go list ${PKG}/...) +tag: + @git tag `grep -P '^\tversion = ' echo.go|cut -f2 -d'"'` + @git tag|grep -v ^v + .DEFAULT_GOAL := check check: lint vet race ## Check project @@ -22,11 +26,12 @@ race: ## Run tests with data race detector @go test -race ${PKG_LIST} benchmark: ## Run benchmarks - @go test -run="-" -benchmem -bench=".*" ${PKG_LIST} + @go test -run="-" -bench=".*" ${PKG_LIST} help: ## Display this help screen @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' -goversion ?= "1.25" -test_version: ## Run tests inside Docker with given version (defaults to 1.25 oldest supported). Example: make test_version goversion=1.25 - @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" +goversion ?= "1.22" +docker_user ?= "1000" +test_version: ## Run tests inside Docker with given version (defaults to 1.22 oldest supported). Example: make test_version goversion=1.22 + @docker run --rm -it --user $(docker_user) -e HOME=/tmp -e GOCACHE=/tmp/go-cache -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "mkdir -p /tmp/go-cache /tmp/.cache && cd /project && make init check" diff --git a/README.md b/README.md index cd6abcbf6..5e52d1d4e 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,5 @@ -[![Latest release](https://img.shields.io/github/v/release/labstack/echo?style=flat-square&label=release&color=00afd1)](https://github.com/labstack/echo/releases) -[![Last commit](https://img.shields.io/github/last-commit/labstack/echo/master?style=flat-square)](https://github.com/labstack/echo/commits/master) [![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge) -[![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v5) +[![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4) [![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo) [![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/labstack/echo/echo.yml?style=flat-square)](https://github.com/labstack/echo/actions) [![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo) @@ -13,14 +11,13 @@ High performance, extensible, minimalist Go web framework. -Echo is built on Go's standard `net/http` β€” and interoperates with it via `echo.WrapHandler` / `echo.WrapMiddleware` β€” adding the parts the standard library leaves to you: a fast radix-tree router, request binding (with a pluggable validator), a deep middleware ecosystem, and centralized error handling. Actively maintained, with `v5` as the current release line (see badges above for the latest version and most recent commit). - * [Official website](https://echo.labstack.com) * [Quick start](https://echo.labstack.com/docs/quick-start) * [Middlewares](https://echo.labstack.com/docs/category/middleware) Help and questions: [Github Discussions](https://github.com/labstack/echo/discussions) + ### Feature Overview - Optimized HTTP router which smartly prioritize routes @@ -51,23 +48,13 @@ Click [here](https://github.com/sponsors/labstack) for more information on spons ## [Guide](https://echo.labstack.com/guide) -### Supported Echo versions - -- Latest major version of Echo is `v5` as of 2026-01-18. - - See [API_CHANGES_V5.md](./API_CHANGES_V5.md) for public API changes between `v4` and `v5`, notes on upgrading. -- Echo `v4` is supported with **security*** updates and **bug** fixes until **2026-12-31** - -See [ROADMAP.md](./ROADMAP.md) for where Echo is heading and the version support policy. - ### Installation ```sh // go get github.com/labstack/echo/{version} -go get github.com/labstack/echo/v5 +go get github.com/labstack/echo/v4 ``` - -Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with -older versions. +Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions. ### Example @@ -75,8 +62,8 @@ older versions. package main import ( - "github.com/labstack/echo/v5" - "github.com/labstack/echo/v5/middleware" + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" "log/slog" "net/http" ) @@ -86,20 +73,20 @@ func main() { e := echo.New() // Middleware - e.Use(middleware.RequestLogger()) // use the RequestLogger middleware with slog logger - e.Use(middleware.Recover()) // recover panics as errors for proper error handling + e.Use(middleware.RequestLogger()) // use the default RequestLogger middleware with slog logger + e.Use(middleware.Recover()) // recover panics as errors for proper error handling // Routes e.GET("/", hello) // Start server - if err := e.Start(":8080"); err != nil { + if err := e.Start(":8080"); err != nil && !errors.Is(err, http.ErrServerClosed) { slog.Error("failed to start server", "error", err) } } // Handler -func hello(c *echo.Context) error { +func hello(c echo.Context) error { return c.String(http.StatusOK, "Hello, World!") } ``` @@ -108,12 +95,10 @@ func hello(c *echo.Context) error { Following list of middleware is maintained by Echo team. -| Repository | Description | -|------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware | -| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [pprof](https://pkg.go.dev/net/http/pprof)) middlewares | -| [github.com/labstack/echo-opentelemetry](https://github.com/labstack/echo-opentelemetry) | [OpenTelemetry](https://opentelemetry.io/) middleware for tracing and metrics | -| [github.com/labstack/echo-prometheus](https://github.com/labstack/echo-prometheus) | [Prometheus](https://github.com/prometheus/client_golang/) middleware for Echo | +| Repository | Description | +|------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware | +| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](https://github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares | # Third-party middleware repositories @@ -122,11 +107,11 @@ of middlewares in this list. | Repository | Description | |------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [oapi-codegen/oapi-codegen](https://github.com/oapi-codegen/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator | +| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator | | [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | | [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | | [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | UberΒ΄s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | -| [github.com/samber/slog-echo](https://github.com/samber/slog-echo) | Go [slog](https://pkg.go.dev/golang.org/x/exp/slog) logging library wrapper for Echo logger interface. | +| [github.com/samber/slog-echo](https://github.com/samber/slog-echo) | Go [slog](https://pkg.go.dev/golang.org/x/exp/slog) logging library wrapper for Echo logger interface. | | [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | | [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. | | [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code | diff --git a/ROADMAP.md b/ROADMAP.md deleted file mode 100644 index 874d50b0c..000000000 --- a/ROADMAP.md +++ /dev/null @@ -1,50 +0,0 @@ -# Echo Roadmap - -> **DRAFT** β€” this is a starting point for maintainers to edit, not a commitment. -> Dates and priorities are owned by the Echo team. Open a discussion to propose changes. - -This document exists so the community can see where Echo is heading. Echo is -**actively maintained**. We publish releases regularly across two supported -lines β€” see [README](./README.md) badges for the latest version and most recent commit. - -## Version policy - -| Line | Status | Support | -|------|--------|---------| -| `v5` | **Current** (since 2026-01-18) | New features, fixes, and improvements. | -| `v4` | Maintenance / LTS | **Security and bug fixes until 2026-12-31.** No new features. | - -Upgrading from v4? See [API_CHANGES_V5.md](./API_CHANGES_V5.md). - -Echo supports the **latest four Go major releases** and may work with older versions. - -## Now (in progress) - -- Stabilizing the `v5` API surface through point releases. -- Documentation catch-up for v5 behavior changes (e.g. CORS / `RouteNotFound` - behavior on groups β€” see #2950). -- Triaging and reducing the open issue / PR backlog. - -## Next (under consideration) - -These are frequently-requested items being discussed. Inclusion here is **not** a -commitment β€” each still needs design agreement before implementation: - -- **Automatic `HEAD` for `GET` routes** (#2895; see #2949) β€” opt-in, likely via an - `OnAddRoute` hook so users keep control. -- **Rate limiter response metadata** β€” expose `Retry-After` / remaining quota - through the store interface (#2961). -- **Real-IP / `Forwarded` header handling** improvements (#2744). -- **Proxy middleware** authorization-header handling (#2787). - -## Later / exploratory - -- Continued alignment with the Go standard library (`net/http`, `slog`). -- Reducing third-party surface where the stdlib now covers the need. - -## How to influence the roadmap - -- **Discuss before large PRs** β€” open a [Discussion](https://github.com/labstack/echo/discussions) - or issue so we can agree on the design first. -- πŸ‘ reactions on issues help us gauge demand. -- See [README β†’ Contribute](./README.md#contribute) for contribution guidelines. diff --git a/SECURITY.md b/SECURITY.md deleted file mode 100644 index efb618697..000000000 --- a/SECURITY.md +++ /dev/null @@ -1,15 +0,0 @@ -# Security Policy - -## Supported Versions - -| Version | Supported | -|-----------|-------------------------------------| -| 5.x.x | :white_check_mark: | -| >= 4.15.x | :white_check_mark: until 2026.12.31 | -| < 4.15 | :x: | - -## Reporting a Vulnerability - -https://github.com/labstack/echo/security/advisories/new - -or look for maintainers email(s) in commits and email them. diff --git a/_fixture/dist/private.txt b/_fixture/dist/private.txt deleted file mode 100644 index 0f9d2435b..000000000 --- a/_fixture/dist/private.txt +++ /dev/null @@ -1 +0,0 @@ -private file diff --git a/_fixture/dist/public/assets/readme.md b/_fixture/dist/public/assets/readme.md deleted file mode 100644 index 50590f554..000000000 --- a/_fixture/dist/public/assets/readme.md +++ /dev/null @@ -1 +0,0 @@ -readme in assets diff --git a/_fixture/dist/public/assets/subfolder/subfolder.md b/_fixture/dist/public/assets/subfolder/subfolder.md deleted file mode 100644 index 74c928b2f..000000000 --- a/_fixture/dist/public/assets/subfolder/subfolder.md +++ /dev/null @@ -1 +0,0 @@ -file inside subfolder diff --git a/_fixture/dist/public/index.html b/_fixture/dist/public/index.html deleted file mode 100644 index df6d9015a..000000000 --- a/_fixture/dist/public/index.html +++ /dev/null @@ -1 +0,0 @@ -

Hello from index

diff --git a/_fixture/dist/public/test.txt b/_fixture/dist/public/test.txt deleted file mode 100644 index dd937160d..000000000 --- a/_fixture/dist/public/test.txt +++ /dev/null @@ -1 +0,0 @@ -test.txt contents diff --git a/bind.go b/bind.go index 4713afde2..1e1043c07 100644 --- a/bind.go +++ b/bind.go @@ -13,13 +13,12 @@ import ( "reflect" "strconv" "strings" - "sync" "time" ) // Binder is the interface that wraps the Bind method. type Binder interface { - Bind(c *Context, target any) error + Bind(i any, c Context) error } // DefaultBinder is the default implementation of the Binder interface. @@ -40,22 +39,31 @@ type bindMultipleUnmarshaler interface { UnmarshalParams(params []string) error } -// BindPathValues binds path parameter values to bindable object -func BindPathValues(c *Context, target any) error { +// BindPathParams binds path params to bindable object +// +// Time format support: time.Time fields can use `format` tags to specify custom parsing layouts. +// Example: `param:"created" format:"2006-01-02T15:04"` for datetime-local format +// Example: `param:"date" format:"2006-01-02"` for date format +// Uses Go's standard time format reference time: Mon Jan 2 15:04:05 MST 2006 +// Works with form data, query parameters, and path parameters (not JSON body) +// Falls back to default time.Time parsing if no format tag is specified +func (b *DefaultBinder) BindPathParams(c Context, i any) error { + names := c.ParamNames() + values := c.ParamValues() params := map[string][]string{} - for _, param := range c.PathValues() { - params[param.Name] = []string{param.Value} + for i, name := range names { + params[name] = []string{values[i]} } - if err := bindData(target, params, "param", nil); err != nil { - return ErrBadRequest.Wrap(err) + if err := b.bindData(i, params, "param", nil); err != nil { + return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } return nil } // BindQueryParams binds query params to bindable object -func BindQueryParams(c *Context, target any) error { - if err := bindData(target, c.QueryParams(), "query", nil); err != nil { - return ErrBadRequest.Wrap(err) +func (b *DefaultBinder) BindQueryParams(c Context, i any) error { + if err := b.bindData(i, c.QueryParams(), "query", nil); err != nil { + return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } return nil } @@ -65,7 +73,7 @@ func BindQueryParams(c *Context, target any) error { // which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm // See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm // See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm -func BindBody(c *Context, target any) (err error) { +func (b *DefaultBinder) BindBody(c Context, i any) (err error) { req := c.Request() if req.ContentLength == 0 { return @@ -77,52 +85,58 @@ func BindBody(c *Context, target any) (err error) { switch mediatype { case MIMEApplicationJSON: - if err = c.Echo().JSONSerializer.Deserialize(c, target); err != nil { - var hErr *HTTPError - if errors.As(err, &hErr) { + if err = c.Echo().JSONSerializer.Deserialize(c, i); err != nil { + switch err.(type) { + case *HTTPError: return err + default: + return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } - return ErrBadRequest.Wrap(err) } case MIMEApplicationXML, MIMETextXML: - if err = xml.NewDecoder(req.Body).Decode(target); err != nil { - return ErrBadRequest.Wrap(err) + if err = xml.NewDecoder(req.Body).Decode(i); err != nil { + if ute, ok := err.(*xml.UnsupportedTypeError); ok { + return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err) + } else if se, ok := err.(*xml.SyntaxError); ok { + return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())).SetInternal(err) + } + return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } case MIMEApplicationForm: - params, err := c.FormValues() + params, err := c.FormParams() if err != nil { - return ErrBadRequest.Wrap(err) + return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } - if err = bindData(target, params, "form", nil); err != nil { - return ErrBadRequest.Wrap(err) + if err = b.bindData(i, params, "form", nil); err != nil { + return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } case MIMEMultipartForm: params, err := c.MultipartForm() if err != nil { - return ErrBadRequest.Wrap(err) + return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } - if err = bindData(target, params.Value, "form", params.File); err != nil { - return ErrBadRequest.Wrap(err) + if err = b.bindData(i, params.Value, "form", params.File); err != nil { + return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } default: - return &HTTPError{Code: http.StatusUnsupportedMediaType} + return ErrUnsupportedMediaType } return nil } // BindHeaders binds HTTP headers to a bindable object -func BindHeaders(c *Context, target any) error { - if err := bindData(target, c.Request().Header, "header", nil); err != nil { - return ErrBadRequest.Wrap(err) +func (b *DefaultBinder) BindHeaders(c Context, i any) error { + if err := b.bindData(i, c.Request().Header, "header", nil); err != nil { + return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } return nil } // Bind implements the `Binder#Bind` function. // Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous -// step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathValues. -func (b *DefaultBinder) Bind(c *Context, target any) error { - if err := BindPathValues(c, target); err != nil { +// step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams. +func (b *DefaultBinder) Bind(i any, c Context) (err error) { + if err := b.BindPathParams(c, i); err != nil { return err } // Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body. @@ -130,82 +144,15 @@ func (b *DefaultBinder) Bind(c *Context, target any) error { // The HTTP method check restores pre-v4.1.11 behavior to avoid these problems (see issue #1670) method := c.Request().Method if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead { - if err := BindQueryParams(c, target); err != nil { + if err = b.BindQueryParams(c, i); err != nil { return err } } - return BindBody(c, target) -} - -// bindFieldMeta is the cached, type-level reflection metadata for a single struct field. Reading struct -// tags (reflect.StructTag.Get) parses the tag string on every call, so for binding-heavy endpoints we -// compute it once per struct type and reuse it across requests (see bindStructMeta). Only type-level data -// is cached here; per-request, per-instance reflect.Value operations still happen in bindData. -type bindFieldMeta struct { - index int // field index within the struct - // fieldKind is the DECLARED field kind (typeField.Type.Kind()), used only for unmarshal dispatch. - // It is intentionally not the post-anonymous-pointer-deref live kind; bindData computes that - // separately as structFieldKind where needed. - fieldKind reflect.Kind - anonymous bool // reflect.StructField.Anonymous - formatTag string // value of the `format` struct tag - // binding-source tag values. bindData is only ever called with one of these four tags (see the - // callers BindPathValues/BindQueryParams/BindBody/BindHeaders). Keep these fields, the four - // f.Tag.Get(...) lines in bindMetaFor, and the tagName switch in sync if a source is ever added. - param, query, form, header string -} - -// tagName returns the field's tag value for the given binding source tag. -// Keep in sync with the tag fields above and the f.Tag.Get calls in bindMetaFor. -func (m *bindFieldMeta) tagName(tag string) string { - switch tag { - case "param": - return m.param - case "query": - return m.query - case "form": - return m.form - case "header": - return m.header - default: - return "" - } -} - -// bindStructMeta is the cached field metadata for a whole struct type, in declaration order. -type bindStructMeta struct { - fields []bindFieldMeta -} - -// bindStructCache memoizes bindStructMeta keyed by struct reflect.Type. Concurrent double-computation is -// harmless because the result is deterministic and idempotent. -var bindStructCache sync.Map // map[reflect.Type]*bindStructMeta - -func bindMetaFor(typ reflect.Type) *bindStructMeta { - if cached, ok := bindStructCache.Load(typ); ok { - return cached.(*bindStructMeta) - } - n := typ.NumField() - meta := &bindStructMeta{fields: make([]bindFieldMeta, n)} - for i := 0; i < n; i++ { - f := typ.Field(i) - meta.fields[i] = bindFieldMeta{ - index: i, - anonymous: f.Anonymous, - fieldKind: f.Type.Kind(), - formatTag: f.Tag.Get("format"), - param: f.Tag.Get("param"), - query: f.Tag.Get("query"), - form: f.Tag.Get("form"), - header: f.Tag.Get("header"), - } - } - bindStructCache.Store(typ, meta) - return meta + return b.BindBody(c, i) } // bindData will bind data ONLY fields in destination struct that have EXPLICIT tag -func bindData(destination any, data map[string][]string, tag string, dataFiles map[string][]*multipart.FileHeader) error { +func (b *DefaultBinder) bindData(destination any, data map[string][]string, tag string, dataFiles map[string][]*multipart.FileHeader) error { if destination == nil || (len(data) == 0 && len(dataFiles) == 0) { return nil } @@ -224,7 +171,7 @@ func bindData(destination any, data map[string][]string, tag string, dataFiles m isElemInterface := k == reflect.Interface isElemString := k == reflect.String isElemSliceOfStrings := k == reflect.Slice && typ.Elem().Elem().Kind() == reflect.String - if !isElemSliceOfStrings && !isElemString && !isElemInterface { + if !(isElemSliceOfStrings || isElemString || isElemInterface) { return nil } if val.IsNil() { @@ -253,12 +200,11 @@ func bindData(destination any, data map[string][]string, tag string, dataFiles m return errors.New("binding element must be a struct") } - meta := bindMetaFor(typ) - for fi := range meta.fields { // iterate over all destination fields - fm := &meta.fields[fi] - structField := val.Field(fm.index) - if fm.anonymous { - if structField.Kind() == reflect.Pointer { + for i := 0; i < typ.NumField(); i++ { // iterate over all destination fields + typeField := typ.Field(i) + structField := val.Field(i) + if typeField.Anonymous { + if structField.Kind() == reflect.Ptr { structField = structField.Elem() } } @@ -266,8 +212,8 @@ func bindData(destination any, data map[string][]string, tag string, dataFiles m continue } structFieldKind := structField.Kind() - inputFieldName := fm.tagName(tag) - if fm.anonymous && structFieldKind == reflect.Struct && inputFieldName != "" { + inputFieldName := typeField.Tag.Get(tag) + if typeField.Anonymous && structFieldKind == reflect.Struct && inputFieldName != "" { // if anonymous struct with query/param/form tags, report an error return errors.New("query/param/form tags are not allowed with anonymous struct field") } @@ -276,7 +222,7 @@ func bindData(destination any, data map[string][]string, tag string, dataFiles m // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contain fields with tags). // structs that implement BindUnmarshaler are bound only when they have explicit tag if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct { - if err := bindData(structField.Addr().Interface(), data, tag, dataFiles); err != nil { + if err := b.bindData(structField.Addr().Interface(), data, tag, dataFiles); err != nil { return err } } @@ -317,16 +263,17 @@ func bindData(destination any, data map[string][]string, tag string, dataFiles m // but it is smart enough to handle niche cases like `*int`,`*[]string`,`[]*int` . // try unmarshalling first, in case we're dealing with an alias to an array type - if ok, err := unmarshalInputsToField(fm.fieldKind, inputValue, structField); ok { + if ok, err := unmarshalInputsToField(typeField.Type.Kind(), inputValue, structField); ok { if err != nil { - return fmt.Errorf("%s: %w", inputFieldName, err) + return err } continue } - if ok, err := unmarshalInputToField(fm.fieldKind, inputValue[0], structField, fm.formatTag); ok { + formatTag := typeField.Tag.Get("format") + if ok, err := unmarshalInputToField(typeField.Type.Kind(), inputValue[0], structField, formatTag); ok { if err != nil { - return fmt.Errorf("%s: %w", inputFieldName, err) + return err } continue } @@ -342,9 +289,9 @@ func bindData(destination any, data map[string][]string, tag string, dataFiles m sliceOf := structField.Type().Elem().Kind() numElems := len(inputValue) slice := reflect.MakeSlice(structField.Type(), numElems, numElems) - for j := range numElems { + for j := 0; j < numElems; j++ { if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil { - return fmt.Errorf("%s: %w", inputFieldName, err) + return err } } structField.Set(slice) @@ -352,7 +299,7 @@ func bindData(destination any, data map[string][]string, tag string, dataFiles m } if err := setWithProperType(structFieldKind, inputValue[0], structField); err != nil { - return fmt.Errorf("%s: %w", inputFieldName, err) + return err } } return nil @@ -366,7 +313,7 @@ func setWithProperType(valueKind reflect.Kind, val string, structField reflect.V } switch valueKind { - case reflect.Pointer: + case reflect.Ptr: return setWithProperType(structField.Elem().Kind(), val, structField.Elem()) case reflect.Int: return setIntField(val, 0, structField) @@ -403,7 +350,7 @@ func setWithProperType(valueKind reflect.Kind, val string, structField reflect.V } func unmarshalInputsToField(valueKind reflect.Kind, values []string, field reflect.Value) (bool, error) { - if valueKind == reflect.Pointer { + if valueKind == reflect.Ptr { if field.IsNil() { field.Set(reflect.New(field.Type().Elem())) } @@ -419,7 +366,7 @@ func unmarshalInputsToField(valueKind reflect.Kind, values []string, field refle } func unmarshalInputToField(valueKind reflect.Kind, val string, field reflect.Value, formatTag string) (bool, error) { - if valueKind == reflect.Pointer { + if valueKind == reflect.Ptr { if field.IsNil() { field.Set(reflect.New(field.Type().Elem())) } @@ -427,6 +374,7 @@ func unmarshalInputToField(valueKind reflect.Kind, val string, field reflect.Val } fieldIValue := field.Addr().Interface() + // Handle time.Time with custom format tag if formatTag != "" { if _, isTime := fieldIValue.(*time.Time); isTime { diff --git a/bind_cache_test.go b/bind_cache_test.go deleted file mode 100644 index f3da8a13e..000000000 --- a/bind_cache_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors - -package echo - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" -) - -// TestBindCachedMetaPreservesFieldNameError ensures the per-type bind metadata cache preserves the -// field-name prefix in conversion errors on BOTH the cold (first) and warm (cached) bind of a type. -// DTO is declared locally so its reflect.Type is independent of suite ordering, making the second -// bind a deterministic cache hit (the bindMetaFor Load branch). -func TestBindCachedMetaPreservesFieldNameError(t *testing.T) { - type DTO struct { - Number int `query:"number"` - } - bind := func() error { - e := New() - req := httptest.NewRequest(http.MethodGet, "/?number=10a", nil) - var dto DTO - return e.NewContext(req, httptest.NewRecorder()).Bind(&dto) - } - - assert.ErrorContains(t, bind(), "number", "cold cache: error must carry field name") - assert.ErrorContains(t, bind(), "number", "warm cache: error must still carry field name") -} diff --git a/bind_field_error_test.go b/bind_field_error_test.go deleted file mode 100644 index 8bf63b58d..000000000 --- a/bind_field_error_test.go +++ /dev/null @@ -1,32 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors - -package echo - -import ( - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -// Regression test for #2629: when binding form data fails a type conversion, the -// returned error must identify which field failed (so applications can render a -// useful message), instead of a bare strconv error with no field context. -func TestBind_formConversionErrorIncludesFieldName(t *testing.T) { - e := New() - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("number=10a")) - req.Header.Set(HeaderContentType, MIMEApplicationForm) - c := e.NewContext(req, httptest.NewRecorder()) - - type DTO struct { - Number int `form:"number"` - } - var dto DTO - err := c.Bind(&dto) - - assert.Error(t, err) - assert.ErrorContains(t, err, "number", "bind error must identify the failing field") -} diff --git a/bind_test.go b/bind_test.go index e33298b39..6cc597b33 100644 --- a/bind_test.go +++ b/bind_test.go @@ -25,79 +25,79 @@ import ( ) type bindTestStruct struct { - T Timestamp - GoT time.Time + I int + PtrI *int + I8 int8 + PtrI8 *int8 + I16 int16 PtrI16 *int16 - PtrUI *uint - Tptr *Timestamp - PtrF32 *float32 - PtrB *bool + I32 int32 PtrI32 *int32 - GoTptr *time.Time + I64 int64 PtrI64 *int64 - PtrI *int - PtrI8 *int8 - PtrF64 *float64 + UI uint + PtrUI *uint + UI8 uint8 PtrUI8 *uint8 - PtrUI64 *uint64 + UI16 uint16 PtrUI16 *uint16 - PtrS *string + UI32 uint32 PtrUI32 *uint32 + UI64 uint64 + PtrUI64 *uint64 + B bool + PtrB *bool + F32 float32 + PtrF32 *float32 + F64 float64 + PtrF64 *float64 S string + PtrS *string cantSet string DoesntExist string + GoT time.Time + GoTptr *time.Time + T Timestamp + Tptr *Timestamp SA StringArray - F64 float64 - I int - UI64 uint64 - UI uint - I64 int64 - F32 float32 - UI32 uint32 - I32 int32 - UI16 uint16 - I16 int16 - B bool - UI8 uint8 - I8 int8 } type bindTestStructWithTags struct { - T Timestamp `json:"T" form:"T"` - GoT time.Time `json:"GoT" form:"GoT"` - PtrI16 *int16 `json:"PtrI16" form:"PtrI16"` - PtrUI *uint `json:"PtrUI" form:"PtrUI"` - Tptr *Timestamp `json:"Tptr" form:"Tptr"` - PtrF32 *float32 `json:"PtrF32" form:"PtrF32"` - PtrB *bool `json:"PtrB" form:"PtrB"` - PtrI32 *int32 `json:"PtrI32" form:"PtrI32"` - GoTptr *time.Time `json:"GoTptr" form:"GoTptr"` - PtrI64 *int64 `json:"PtrI64" form:"PtrI64"` - PtrI *int `json:"PtrI" form:"PtrI"` - PtrI8 *int8 `json:"PtrI8" form:"PtrI8"` - PtrF64 *float64 `json:"PtrF64" form:"PtrF64"` - PtrUI8 *uint8 `json:"PtrUI8" form:"PtrUI8"` - PtrUI64 *uint64 `json:"PtrUI64" form:"PtrUI64"` - PtrUI16 *uint16 `json:"PtrUI16" form:"PtrUI16"` - PtrS *string `json:"PtrS" form:"PtrS"` - PtrUI32 *uint32 `json:"PtrUI32" form:"PtrUI32"` - S string `json:"S" form:"S"` + I int `json:"I" form:"I"` + PtrI *int `json:"PtrI" form:"PtrI"` + I8 int8 `json:"I8" form:"I8"` + PtrI8 *int8 `json:"PtrI8" form:"PtrI8"` + I16 int16 `json:"I16" form:"I16"` + PtrI16 *int16 `json:"PtrI16" form:"PtrI16"` + I32 int32 `json:"I32" form:"I32"` + PtrI32 *int32 `json:"PtrI32" form:"PtrI32"` + I64 int64 `json:"I64" form:"I64"` + PtrI64 *int64 `json:"PtrI64" form:"PtrI64"` + UI uint `json:"UI" form:"UI"` + PtrUI *uint `json:"PtrUI" form:"PtrUI"` + UI8 uint8 `json:"UI8" form:"UI8"` + PtrUI8 *uint8 `json:"PtrUI8" form:"PtrUI8"` + UI16 uint16 `json:"UI16" form:"UI16"` + PtrUI16 *uint16 `json:"PtrUI16" form:"PtrUI16"` + UI32 uint32 `json:"UI32" form:"UI32"` + PtrUI32 *uint32 `json:"PtrUI32" form:"PtrUI32"` + UI64 uint64 `json:"UI64" form:"UI64"` + PtrUI64 *uint64 `json:"PtrUI64" form:"PtrUI64"` + B bool `json:"B" form:"B"` + PtrB *bool `json:"PtrB" form:"PtrB"` + F32 float32 `json:"F32" form:"F32"` + PtrF32 *float32 `json:"PtrF32" form:"PtrF32"` + F64 float64 `json:"F64" form:"F64"` + PtrF64 *float64 `json:"PtrF64" form:"PtrF64"` + S string `json:"S" form:"S"` + PtrS *string `json:"PtrS" form:"PtrS"` cantSet string DoesntExist string `json:"DoesntExist" form:"DoesntExist"` + GoT time.Time `json:"GoT" form:"GoT"` + GoTptr *time.Time `json:"GoTptr" form:"GoTptr"` + T Timestamp `json:"T" form:"T"` + Tptr *Timestamp `json:"Tptr" form:"Tptr"` SA StringArray `json:"SA" form:"SA"` - F64 float64 `json:"F64" form:"F64"` - I int `json:"I" form:"I"` - UI64 uint64 `json:"UI64" form:"UI64"` - UI uint `json:"UI" form:"UI"` - I64 int64 `json:"I64" form:"I64"` - F32 float32 `json:"F32" form:"F32"` - UI32 uint32 `json:"UI32" form:"UI32"` - I32 int32 `json:"I32" form:"I32"` - UI16 uint16 `json:"UI16" form:"UI16"` - I16 int16 `json:"I16" form:"I16"` - B bool `json:"B" form:"B"` - UI8 uint8 `json:"UI8" form:"UI8"` - I8 int8 `json:"I8" form:"I8"` } type Timestamp time.Time @@ -283,7 +283,7 @@ func TestBindHeaderParam(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) u := new(user) - err := BindHeaders(c, u) + err := (&DefaultBinder{}).BindHeaders(c, u) if assert.NoError(t, err) { assert.Equal(t, 2, u.ID) assert.Equal(t, "Jon Doe", u.Name) @@ -297,7 +297,7 @@ func TestBindHeaderParamBadType(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) u := new(user) - err := BindHeaders(c, u) + err := (&DefaultBinder{}).BindHeaders(c, u) assert.Error(t, err) httpErr, ok := err.(*HTTPError) @@ -312,13 +312,13 @@ func TestBindUnmarshalParam(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { - T Timestamp `query:"ts"` + T Timestamp `query:"ts"` + TA []Timestamp `query:"ta"` + SA StringArray `query:"sa"` ST Struct StWithTag struct { Foo string `query:"st"` } - TA []Timestamp `query:"ta"` - SA StringArray `query:"sa"` }{} err := c.Bind(&result) ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)) @@ -339,10 +339,10 @@ func TestBindUnmarshalText(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { - T time.Time `query:"ts"` - ST Struct + T time.Time `query:"ts"` TA []time.Time `query:"ta"` SA StringArray `query:"sa"` + ST Struct }{} err := c.Bind(&result) ts := time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC) @@ -447,7 +447,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]string", func(t *testing.T) { dest := map[string]string{} - assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]string{ "multiple": "1", @@ -459,7 +459,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]string with nil map", func(t *testing.T) { var dest map[string]string - assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]string{ "multiple": "1", @@ -471,7 +471,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string][]string", func(t *testing.T) { dest := map[string][]string{} - assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]string{ "multiple": {"1", "2"}, @@ -483,7 +483,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string][]string with nil map", func(t *testing.T) { var dest map[string][]string - assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]string{ "multiple": {"1", "2"}, @@ -495,7 +495,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]interface", func(t *testing.T) { dest := map[string]any{} - assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]any{ "multiple": "1", @@ -507,7 +507,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]interface with nil map", func(t *testing.T) { var dest map[string]any - assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]any{ "multiple": "1", @@ -519,32 +519,33 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]int skips", func(t *testing.T) { dest := map[string]int{} - assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]int{}, dest) }) t.Run("ok, bind to map[string]int skips with nil map", func(t *testing.T) { var dest map[string]int - assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]int(nil), dest) }) t.Run("ok, bind to map[string][]int skips", func(t *testing.T) { dest := map[string][]int{} - assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]int{}, dest) }) t.Run("ok, bind to map[string][]int skips with nil map", func(t *testing.T) { var dest map[string][]int - assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]int(nil), dest) }) } func TestBindbindData(t *testing.T) { ts := new(bindTestStruct) - err := bindData(ts, values, "form", nil) + b := new(DefaultBinder) + err := b.bindData(ts, values, "form", nil) assert.NoError(t, err) assert.Equal(t, 0, ts.I) @@ -569,13 +570,9 @@ func TestBindParam(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - c.InitializeRoute( - &RouteInfo{Path: "/users/:id/:name"}, - &PathValues{ - {Name: "id", Value: "1"}, - {Name: "name", Value: "Jon Snow"}, - }, - ) + c.SetPath("/users/:id/:name") + c.SetParamNames("id", "name") + c.SetParamValues("1", "Jon Snow") u := new(user) err := c.Bind(u) @@ -586,12 +583,9 @@ func TestBindParam(t *testing.T) { // Second test for the absence of a param c2 := e.NewContext(req, rec) - c2.InitializeRoute( - &RouteInfo{Path: "/users/:id"}, - &PathValues{ - {Name: "id", Value: "1"}, - }, - ) + c2.SetPath("/users/:id") + c2.SetParamNames("id") + c2.SetParamValues("1") u = new(user) err = c2.Bind(u) @@ -609,12 +603,9 @@ func TestBindParam(t *testing.T) { rec2 := httptest.NewRecorder() c3 := e2.NewContext(req2, rec2) - c3.InitializeRoute( - &RouteInfo{Path: "/users/:id"}, - &PathValues{ - {Name: "id", Value: "1"}, - }, - ) + c3.SetPath("/users/:id") + c3.SetParamNames("id") + c3.SetParamValues("1") u = new(user) err = c3.Bind(u) @@ -636,12 +627,14 @@ func TestBindUnmarshalTypeError(t *testing.T) { err := c.Bind(u) - assert.EqualError(t, err, `code=400, message=Bad Request, err=json: cannot unmarshal string into Go struct field user.id of type int`) + he := &HTTPError{Code: http.StatusBadRequest, Message: "Unmarshal type error: expected=int, got=string, field=id, offset=14", Internal: err.(*HTTPError).Internal} + + assert.Equal(t, he, err) } func TestBindSetWithProperType(t *testing.T) { ts := new(bindTestStruct) - typ := reflect.TypeFor[bindTestStruct]() + typ := reflect.TypeOf(ts).Elem() val := reflect.ValueOf(ts).Elem() for i := 0; i < typ.NumField(); i++ { typeField := typ.Field(i) @@ -662,7 +655,7 @@ func TestBindSetWithProperType(t *testing.T) { Bar bytes.Buffer } v := &foo{} - typ = reflect.TypeFor[foo]() + typ = reflect.TypeOf(v).Elem() val = reflect.ValueOf(v).Elem() assert.Error(t, setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) } @@ -670,10 +663,11 @@ func TestBindSetWithProperType(t *testing.T) { func BenchmarkBindbindDataWithTags(b *testing.B) { b.ReportAllocs() ts := new(bindTestStructWithTags) + binder := new(DefaultBinder) var err error b.ResetTimer() for i := 0; i < b.N; i++ { - err = bindData(ts, values, "form", nil) + err = binder.bindData(ts, values, "form", nil) } assert.NoError(b, err) assertBindTestStruct(b, (*bindTestStruct)(ts)) @@ -748,36 +742,36 @@ func testBindError(t *testing.T, r io.Reader, ctype string, expectedInternal err strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): if assert.IsType(t, new(HTTPError), err) { assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).Code) - assert.IsType(t, expectedInternal, err.(*HTTPError).Unwrap()) + assert.IsType(t, expectedInternal, err.(*HTTPError).Internal) } default: if assert.IsType(t, new(HTTPError), err) { assert.Equal(t, ErrUnsupportedMediaType, err) - assert.IsType(t, expectedInternal, err.(*HTTPError).Unwrap()) + assert.IsType(t, expectedInternal, err.(*HTTPError).Internal) } } } func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { // tests to check binding behaviour when multiple sources (path params, query params and request body) are in use - // binding is done in steps and one source could overwrite previous source bound data + // binding is done in steps and one source could overwrite previous source binded data // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed type Opts struct { + ID int `json:"id" form:"id" query:"id"` Node string `json:"node" form:"node" query:"node" param:"node"` Lang string - ID int `json:"id" form:"id" query:"id"` } var testCases = []struct { - givenContent io.Reader - whenBindTarget any - expect any name string givenURL string + givenContent io.Reader givenMethod string + whenBindTarget any + whenNoPathParams bool + expect any expectError string - whenNoPathValues bool }{ { name: "ok, POST bind to struct with: path param + query param + body", @@ -805,14 +799,14 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), - expect: &Opts{ID: 1, Node: "zzz"}, // body is bound last and overwrites previous (path,query) values + expect: &Opts{ID: 1, Node: "zzz"}, // body is binded last and overwrites previous (path,query) values }, { name: "ok, DELETE bind to struct with: path param + query param + body", givenMethod: http.MethodDelete, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), - expect: &Opts{ID: 1, Node: "zzz"}, // for DELETE body is bound after query params + expect: &Opts{ID: 1, Node: "zzz"}, // for DELETE body is binded after query params }, { name: "ok, POST bind to struct with: path param + body", @@ -834,7 +828,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{`), expect: &Opts{ID: 0, Node: "node_from_path"}, // query binding has already modified bind target - expectError: "code=400, message=Bad Request, err=unexpected EOF", + expectError: "code=400, message=unexpected EOF, internal=unexpected EOF", }, { name: "nok, GET with body bind failure when types are not convertible", @@ -842,7 +836,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { givenURL: "/api/real_node/endpoint?id=nope", givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), expect: &Opts{ID: 0, Node: "node_from_path"}, // path params binding has already modified bind target - expectError: `code=400, message=Bad Request, err=id: strconv.ParseInt: parsing "nope": invalid syntax`, + expectError: "code=400, message=strconv.ParseInt: parsing \"nope\": invalid syntax, internal=strconv.ParseInt: parsing \"nope\": invalid syntax", }, { name: "nok, GET body bind failure - trying to bind json array to struct", @@ -850,14 +844,14 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), expect: &Opts{ID: 0, Node: "xxx"}, // query binding has already modified bind target - expectError: `code=400, message=Bad Request, err=json: cannot unmarshal array into Go value of type echo.Opts`, + expectError: "code=400, message=Unmarshal type error: expected=echo.Opts, got=array, field=, offset=1, internal=json: cannot unmarshal array into Go value of type echo.Opts", }, { // query param is ignored as we do not know where exactly to bind it in slice name: "ok, GET bind to struct slice, ignore query param", givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), - whenNoPathValues: true, + whenNoPathParams: true, whenBindTarget: &[]Opts{}, expect: &[]Opts{ {ID: 1, Node: ""}, @@ -868,7 +862,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { givenMethod: http.MethodPost, givenURL: "/api/real_node/endpoint?id=nope&node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), - whenNoPathValues: true, + whenNoPathParams: true, whenBindTarget: &[]Opts{}, expect: &[]Opts{{ID: 1}}, expectError: "", @@ -888,7 +882,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint", givenContent: strings.NewReader(`[{"id": 1}]`), - whenNoPathValues: true, + whenNoPathParams: true, whenBindTarget: &[]Opts{}, expect: &[]Opts{{ID: 1, Node: ""}}, expectError: "", @@ -904,10 +898,9 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - if !tc.whenNoPathValues { - c.SetPathValues(PathValues{ - {Name: "node", Value: "node_from_path"}, - }) + if !tc.whenNoPathParams { + c.SetParamNames("node") + c.SetParamValues("node_from_path") } var bindTarget any @@ -918,7 +911,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { } b := new(DefaultBinder) - err := b.Bind(c, bindTarget) + err := b.Bind(bindTarget, c) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -931,28 +924,28 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { func TestDefaultBinder_BindBody(t *testing.T) { // tests to check binding behaviour when multiple sources (path params, query params and request body) are in use - // generally when binding from request body - URL and path params are ignored - unless form is being bound. + // generally when binding from request body - URL and path params are ignored - unless form is being binded. // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed type Node struct { - Node string `json:"node" xml:"node" form:"node" query:"node" param:"node"` ID int `json:"id" xml:"id" form:"id" query:"id"` + Node string `json:"node" xml:"node" form:"node" query:"node" param:"node"` } type Nodes struct { Nodes []Node `xml:"node" form:"node"` } var testCases = []struct { - givenContent io.Reader - whenBindTarget any - expect any name string givenURL string + givenContent io.Reader givenMethod string givenContentType string - expectError string - whenNoPathValues bool + whenNoPathParams bool whenChunkedBody bool + whenBindTarget any + expect any + expectError string }{ { name: "ok, JSON POST bind to struct with: path + query + empty field in body", @@ -976,7 +969,7 @@ func TestDefaultBinder_BindBody(t *testing.T) { givenMethod: http.MethodPost, givenContentType: MIMEApplicationJSON, givenContent: strings.NewReader(`[{"id": 1}]`), - whenNoPathValues: true, + whenNoPathParams: true, whenBindTarget: &[]Node{}, expect: &[]Node{{ID: 1, Node: ""}}, expectError: "", @@ -1004,7 +997,7 @@ func TestDefaultBinder_BindBody(t *testing.T) { givenContentType: MIMEApplicationJSON, givenContent: strings.NewReader(`{`), expect: &Node{ID: 0, Node: ""}, - expectError: "code=400, message=Bad Request, err=unexpected EOF", + expectError: "code=400, message=unexpected EOF, internal=unexpected EOF", }, { name: "ok, XML POST bind to struct with: path + query + empty body", @@ -1030,7 +1023,7 @@ func TestDefaultBinder_BindBody(t *testing.T) { givenContentType: MIMEApplicationXML, givenContent: strings.NewReader(`<`), expect: &Node{ID: 0, Node: ""}, - expectError: "code=400, message=Bad Request, err=XML syntax error on line 1: unexpected EOF", + expectError: "code=400, message=Syntax error: line=1, error=XML syntax error on line 1: unexpected EOF, internal=XML syntax error on line 1: unexpected EOF", }, { name: "ok, FORM POST bind to struct with: path + query + body", @@ -1120,10 +1113,9 @@ func TestDefaultBinder_BindBody(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - if !tc.whenNoPathValues { - c.SetPathValues(PathValues{ - {Name: "node", Value: "real_node"}, - }) + if !tc.whenNoPathParams { + c.SetParamNames("node") + c.SetParamValues("real_node") } var bindTarget any @@ -1132,8 +1124,9 @@ func TestDefaultBinder_BindBody(t *testing.T) { } else { bindTarget = &Node{} } + b := new(DefaultBinder) - err := BindBody(c, bindTarget) + err := b.BindBody(c, bindTarget) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -1196,7 +1189,7 @@ func TestBindUnmarshalParamExtras(t *testing.T) { }{} err := testBindURL("/?t=xxxx", &result) - assert.EqualError(t, err, `code=400, message=Bad Request, err=t: 'xxxx' is not an integer`) + assert.EqualError(t, err, "code=400, message='xxxx' is not an integer, internal='xxxx' is not an integer") }) t.Run("ok, target is struct", func(t *testing.T) { @@ -1301,7 +1294,7 @@ func TestBindUnmarshalParams(t *testing.T) { }{} err := testBindURL("/?t=xxxx", &result) - assert.EqualError(t, err, "code=400, message=Bad Request, err=t: 'xxxx' is not an integer") + assert.EqualError(t, err, "code=400, message='xxxx' is not an integer, internal='xxxx' is not an integer") }) t.Run("ok, target is struct", func(t *testing.T) { @@ -1368,7 +1361,7 @@ func TestBindInt8(t *testing.T) { } p := target{} err := testBindURL("/?v=x&v=2", &p) - assert.EqualError(t, err, `code=400, message=Bad Request, err=v: strconv.ParseInt: parsing "x": invalid syntax`) + assert.EqualError(t, err, "code=400, message=strconv.ParseInt: parsing \"x\": invalid syntax, internal=strconv.ParseInt: parsing \"x\": invalid syntax") }) t.Run("nok, int8 embedded in struct", func(t *testing.T) { @@ -1476,7 +1469,7 @@ func TestBindMultipartFormFiles(t *testing.T) { } err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored - assert.EqualError(t, err, `code=400, message=Bad Request, err=binding to multipart.FileHeader struct is not supported, use pointer to struct`) + assert.EqualError(t, err, "code=400, message=binding to multipart.FileHeader struct is not supported, use pointer to struct, internal=binding to multipart.FileHeader struct is not supported, use pointer to struct") }) t.Run("ok, bind single multipart file to pointer to multipart file", func(t *testing.T) { @@ -1584,7 +1577,7 @@ func TestTimeFormatBinding(t *testing.T) { DateTimeLocal time.Time `form:"datetime_local" format:"2006-01-02T15:04"` Date time.Time `query:"date" format:"2006-01-02"` CustomFormat time.Time `form:"custom" format:"01/02/2006 15:04:05"` - DefaultTime time.Time `form:"default_time"` // No format tag - should use default parsing + DefaultTime time.Time `form:"default_time"` // No format tag - should use default parsing PtrTime *time.Time `query:"ptr_time" format:"2006-01-02"` } @@ -1630,7 +1623,7 @@ func TestTimeFormatBinding(t *testing.T) { { name: "nok, wrong format should fail", contentType: MIMEApplicationForm, - data: "datetime_local=2023-12-25", // Missing time part + data: "datetime_local=2023-12-25", // Missing time part expectError: true, }, } diff --git a/binder.go b/binder.go index 7ae3709fa..300c00934 100644 --- a/binder.go +++ b/binder.go @@ -16,7 +16,7 @@ import ( /** Following functions provide handful of methods for binding to Go native types from request query or path parameters. * QueryParamsBinder(c) - binds query parameters (source URL) - * PathValuesBinder(c) - binds path parameters (source URL) + * PathParamsBinder(c) - binds path parameters (source URL) * FormFieldBinder(c) - binds form fields (source URL + body) Example: @@ -66,9 +66,6 @@ import ( */ // BindingError represents an error that occurred while binding request data. -// -// Note: JSON serialization is handled by the MarshalJSON method below, not by the -// struct tags (which are kept for documentation). MarshalJSON emits {"field","message"}. type BindingError struct { // Field is the field name where value binding failed Field string `json:"field"` @@ -78,11 +75,15 @@ type BindingError struct { } // NewBindingError creates new instance of binding error -func NewBindingError(sourceParam string, values []string, message string, err error) error { +func NewBindingError(sourceParam string, values []string, message any, internalError error) error { return &BindingError{ - Field: sourceParam, - Values: values, - HTTPError: &HTTPError{Code: http.StatusBadRequest, Message: message, err: err}, + Field: sourceParam, + Values: values, + HTTPError: &HTTPError{ + Code: http.StatusBadRequest, + Message: message, + Internal: internalError, + }, } } @@ -91,24 +92,6 @@ func (be *BindingError) Error() string { return fmt.Sprintf("%s, field=%s", be.HTTPError.Error(), be.Field) } -// MarshalJSON implements json.Marshaler so that binding errors are serialized into -// a structured response (e.g. {"field":"id","message":"..."}) rather than being -// flattened to a generic message. DefaultHTTPErrorHandler routes errors that -// implement json.Marshaler through their own encoding. -func (be *BindingError) MarshalJSON() ([]byte, error) { - message := be.Message - if message == "" { - message = http.StatusText(be.Code) - } - return json.Marshal(struct { - Field string `json:"field"` - Message string `json:"message"` - }{ - Field: be.Field, - Message: message, - }) -} - // ValueBinder provides utility methods for binding query or path parameter to various Go built-in types type ValueBinder struct { // ValueFunc is used to get single parameter (first) value from request @@ -116,14 +99,14 @@ type ValueBinder struct { // ValuesFunc is used to get all values for parameter from request. i.e. `/api/search?ids=1&ids=2` ValuesFunc func(sourceParam string) []string // ErrorFunc is used to create errors. Allows you to use your own error type, that for example marshals to your specific json response - ErrorFunc func(sourceParam string, values []string, message string, internalError error) error + ErrorFunc func(sourceParam string, values []string, message any, internalError error) error errors []error // failFast is flag for binding methods to return without attempting to bind when previous binding already failed failFast bool } // QueryParamsBinder creates query parameter value binder -func QueryParamsBinder(c *Context) *ValueBinder { +func QueryParamsBinder(c Context) *ValueBinder { return &ValueBinder{ failFast: true, ValueFunc: c.QueryParam, @@ -138,8 +121,8 @@ func QueryParamsBinder(c *Context) *ValueBinder { } } -// PathValuesBinder creates path parameter value binder -func PathValuesBinder(c *Context) *ValueBinder { +// PathParamsBinder creates path parameter value binder +func PathParamsBinder(c Context) *ValueBinder { return &ValueBinder{ failFast: true, ValueFunc: c.Param, @@ -165,7 +148,7 @@ func PathValuesBinder(c *Context) *ValueBinder { // NB: when binding forms take note that this implementation uses standard library form parsing // which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm // See https://golang.org/pkg/net/http/#Request.ParseForm -func FormFieldBinder(c *Context) *ValueBinder { +func FormFieldBinder(c Context) *ValueBinder { vb := &ValueBinder{ failFast: true, ValueFunc: func(sourceParam string) string { @@ -176,7 +159,7 @@ func FormFieldBinder(c *Context) *ValueBinder { vb.ValuesFunc = func(sourceParam string) []string { if c.Request().Form == nil { // this is same as `Request().FormValue()` does internally - _, _ = c.MultipartForm() // we want to trigger c.request.ParseMultipartForm(c.formParseMaxMemory) + _ = c.Request().ParseMultipartForm(32 << 20) } values, ok := c.Request().Form[sourceParam] if !ok { @@ -548,11 +531,11 @@ func (b *ValueBinder) int(sourceParam string, value string, dest any, bitSize in case *int64: *d = n case *int32: - *d = int32(n) // #nosec G115 + *d = int32(n) case *int16: - *d = int16(n) // #nosec G115 + *d = int16(n) case *int8: - *d = int8(n) // #nosec G115 + *d = int8(n) case *int: *d = int(n) } @@ -776,13 +759,13 @@ func (b *ValueBinder) uint(sourceParam string, value string, dest any, bitSize i case *uint64: *d = n case *uint32: - *d = uint32(n) // #nosec G115 + *d = uint32(n) case *uint16: - *d = uint16(n) // #nosec G115 + *d = uint16(n) case *uint8: // byte is alias to uint8 - *d = uint8(n) // #nosec G115 + *d = uint8(n) case *uint: - *d = uint(n) // #nosec G115 + *d = uint(n) } return b } @@ -1260,7 +1243,7 @@ func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder return b.unixTime(sourceParam, dest, false, time.Second) } -// MustUnixTime requires parameter value to exist to bind to time.Time variable (in local time corresponding +// MustUnixTime requires parameter value to exist to bind to time.Duration variable (in local time corresponding // to the given Unix time). Returns error when value does not exist. // // Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00 @@ -1281,7 +1264,7 @@ func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueB return b.unixTime(sourceParam, dest, false, time.Millisecond) } -// MustUnixTimeMilli requires parameter value to exist to bind to time.Time variable (in local time corresponding +// MustUnixTimeMilli requires parameter value to exist to bind to time.Duration variable (in local time corresponding // to the given Unix time in millisecond precision). Returns error when value does not exist. // // Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 @@ -1305,8 +1288,8 @@ func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBi return b.unixTime(sourceParam, dest, false, time.Nanosecond) } -// MustUnixTimeNano requires parameter value to exist to bind to time.Time variable (in local time corresponding -// to the given Unix time value in nanosecond precision). Returns error when value does not exist. +// MustUnixTimeNano requires parameter value to exist to bind to time.Duration variable (in local Time corresponding +// to the given Unix time value in nano second precision). Returns error when value does not exist. // // Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00 // Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00 diff --git a/binder_error_response_test.go b/binder_error_response_test.go deleted file mode 100644 index 0ed077684..000000000 --- a/binder_error_response_test.go +++ /dev/null @@ -1,46 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors - -package echo - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" -) - -// Regression test for #2771: a BindingError returned from a handler must be -// serialized by DefaultHTTPErrorHandler into a structured response that retains -// the field name (and the binder message), not flattened to {"message":"Bad Request"}. -func TestBindingError_serializesToStructuredJSON(t *testing.T) { - e := New() - e.GET("/doc", func(c *Context) error { - var docNum int - return QueryParamsBinder(c).MustInt("docNum", &docNum).BindError() - }) - - req := httptest.NewRequest(http.MethodGet, "/doc?docNum=abc", nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusBadRequest, rec.Code) - - var body map[string]any - assert.NoError(t, json.Unmarshal(rec.Body.Bytes(), &body)) - assert.Equal(t, "docNum", body["field"], "binding error response must retain the field name") - assert.Equal(t, "failed to bind field value to int", body["message"], "binding error response must retain the binder message") -} - -// When the binding error carries no message, MarshalJSON falls back to the -// status text (mirroring DefaultHTTPErrorHandler's *HTTPError branch). -func TestBindingError_marshalJSON_emptyMessageFallsBackToStatusText(t *testing.T) { - be := &BindingError{Field: "name", HTTPError: &HTTPError{Code: http.StatusBadRequest}} - - b, err := be.MarshalJSON() - - assert.NoError(t, err) - assert.JSONEq(t, `{"field":"name","message":"Bad Request"}`, string(b)) -} diff --git a/binder_external_test.go b/binder_external_test.go index d83c891b3..e44055a23 100644 --- a/binder_external_test.go +++ b/binder_external_test.go @@ -7,19 +7,18 @@ package echo_test import ( "encoding/base64" "fmt" + "github.com/labstack/echo/v4" "log" "net/http" "net/http/httptest" - - "github.com/labstack/echo/v5" ) func ExampleValueBinder_BindErrors() { // example route function that binds query params to different destinations and returns all bind errors in one go - routeFunc := func(c *echo.Context) error { + routeFunc := func(c echo.Context) error { var opts struct { - IDs []int64 Active bool + IDs []int64 } length := int64(50) // default length is 50 @@ -54,10 +53,10 @@ func ExampleValueBinder_BindErrors() { func ExampleValueBinder_BindError() { // example route function that binds query params to different destinations and stops binding on first bind error - failFastRouteFunc := func(c *echo.Context) error { + failFastRouteFunc := func(c echo.Context) error { var opts struct { - IDs []int64 Active bool + IDs []int64 } length := int64(50) // default length is 50 @@ -90,7 +89,7 @@ func ExampleValueBinder_BindError() { func ExampleValueBinder_CustomFunc() { // example route function that binds query params using custom function closure - routeFunc := func(c *echo.Context) error { + routeFunc := func(c echo.Context) error { length := int64(50) // default length is 50 var binary []byte diff --git a/binder_generic.go b/binder_generic.go index 62e1da512..f4d45af76 100644 --- a/binder_generic.go +++ b/binder_generic.go @@ -56,12 +56,13 @@ const ( // To treat empty values as errors, validate the result separately or check the raw value. // // See ParseValue for supported types and options -func PathParam[T any](c *Context, paramName string, opts ...any) (T, error) { - for _, pv := range c.PathValues() { - if pv.Name == paramName { - v, err := ParseValue[T](pv.Value, opts...) +func PathParam[T any](c Context, paramName string, opts ...any) (T, error) { + for i, name := range c.ParamNames() { + if name == paramName { + pValues := c.ParamValues() + v, err := ParseValue[T](pValues[i], opts...) if err != nil { - return v, NewBindingError(paramName, []string{pv.Value}, "path value", err) + return v, NewBindingError(paramName, []string{pValues[i]}, "path param", err) } return v, nil } @@ -82,12 +83,13 @@ func PathParam[T any](c *Context, paramName string, opts ...any) (T, error) { // // If "id" is "abc": returns (0, BindingError) // // See ParseValue for supported types and options -func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...any) (T, error) { - for _, pv := range c.PathValues() { - if pv.Name == paramName { - v, err := ParseValueOr[T](pv.Value, defaultValue, opts...) +func PathParamOr[T any](c Context, paramName string, defaultValue T, opts ...any) (T, error) { + for i, name := range c.ParamNames() { + if name == paramName { + pValues := c.ParamValues() + v, err := ParseValueOr[T](pValues[i], defaultValue, opts...) if err != nil { - return v, NewBindingError(paramName, []string{pv.Value}, "path value", err) + return v, NewBindingError(paramName, []string{pValues[i]}, "path param", err) } return v, nil } @@ -111,7 +113,7 @@ func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...an // - Invalid value (?key=abc for int): returns (zero, BindingError) // // See ParseValue for supported types and options -func QueryParam[T any](c *Context, key string, opts ...any) (T, error) { +func QueryParam[T any](c Context, key string, opts ...any) (T, error) { values, ok := c.QueryParams()[key] if !ok { var zero T @@ -141,7 +143,7 @@ func QueryParam[T any](c *Context, key string, opts ...any) (T, error) { // // If "page" is "abc": returns (1, BindingError) // // See ParseValue for supported types and options -func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) { +func QueryParamOr[T any](c Context, key string, defaultValue T, opts ...any) (T, error) { values, ok := c.QueryParams()[key] if !ok { return defaultValue, nil @@ -161,7 +163,7 @@ func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T // It returns the typed slice and an error if binding any value fails. Returns ErrNonExistentKey if parameter not found. // // See ParseValues for supported types and options -func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error) { +func QueryParams[T any](c Context, key string, opts ...any) ([]T, error) { values, ok := c.QueryParams()[key] if !ok { return nil, ErrNonExistentKey @@ -186,7 +188,7 @@ func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error) { // // If "ids" contains "abc": returns ([], BindingError) // // See ParseValues for supported types and options -func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) { +func QueryParamsOr[T any](c Context, key string, defaultValue []T, opts ...any) ([]T, error) { values, ok := c.QueryParams()[key] if !ok { return defaultValue, nil @@ -199,7 +201,7 @@ func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) return result, nil } -// FormValue extracts and parses a single form value from the request by key. +// FormParam extracts and parses a single form value from the request by key. // It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found. // // Empty String Handling: @@ -210,11 +212,11 @@ func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) // To treat empty values as errors, validate the result separately or check the raw value. // // See ParseValue for supported types and options -func FormValue[T any](c *Context, key string, opts ...any) (T, error) { - formValues, err := c.FormValues() +func FormParam[T any](c Context, key string, opts ...any) (T, error) { + formValues, err := c.FormParams() if err != nil { var zero T - return zero, fmt.Errorf("failed to parse form value, key: %s, err: %w", key, err) + return zero, fmt.Errorf("failed to parse form param, key: %s, err: %w", key, err) } values, ok := formValues[key] if !ok { @@ -228,12 +230,12 @@ func FormValue[T any](c *Context, key string, opts ...any) (T, error) { value := values[0] v, err := ParseValue[T](value, opts...) if err != nil { - return v, NewBindingError(key, []string{value}, "form value", err) + return v, NewBindingError(key, []string{value}, "form param", err) } return v, nil } -// FormValueOr extracts and parses a single form value from the request by key. +// FormParamOr extracts and parses a single form value from the request by key. // Returns defaultValue if the parameter is not found or has an empty value. // Returns an error only if parsing fails or form parsing errors occur. // @@ -245,11 +247,11 @@ func FormValue[T any](c *Context, key string, opts ...any) (T, error) { // // If "limit" is "abc": returns (100, BindingError) // // See ParseValue for supported types and options -func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) { - formValues, err := c.FormValues() +func FormParamOr[T any](c Context, key string, defaultValue T, opts ...any) (T, error) { + formValues, err := c.FormParams() if err != nil { var zero T - return zero, fmt.Errorf("failed to parse form value, key: %s, err: %w", key, err) + return zero, fmt.Errorf("failed to parse form param, key: %s, err: %w", key, err) } values, ok := formValues[key] if !ok { @@ -261,19 +263,19 @@ func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, value := values[0] v, err := ParseValueOr[T](value, defaultValue, opts...) if err != nil { - return v, NewBindingError(key, []string{value}, "form value", err) + return v, NewBindingError(key, []string{value}, "form param", err) } return v, nil } -// FormValues extracts and parses all values for a form values key as a slice. +// FormParams extracts and parses all values for a form values key as a slice. // It returns the typed slice and an error if binding any value fails. Returns ErrNonExistentKey if parameter not found. // // See ParseValues for supported types and options -func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) { - formValues, err := c.FormValues() +func FormParams[T any](c Context, key string, opts ...any) ([]T, error) { + formValues, err := c.FormParams() if err != nil { - return nil, fmt.Errorf("failed to parse form values, key: %s, err: %w", key, err) + return nil, fmt.Errorf("failed to parse form params, key: %s, err: %w", key, err) } values, ok := formValues[key] if !ok { @@ -281,26 +283,26 @@ func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) { } result, err := ParseValues[T](values, opts...) if err != nil { - return nil, NewBindingError(key, values, "form values", err) + return nil, NewBindingError(key, values, "form params", err) } return result, nil } -// FormValuesOr extracts and parses all values for a form values key as a slice. +// FormParamsOr extracts and parses all values for a form values key as a slice. // Returns defaultValue if the parameter is not found. // Returns an error only if parsing any value fails or form parsing errors occur. // // Example: // -// tags, err := echo.FormValuesOr[string](c, "tags", []string{}) +// tags, err := echo.FormParamsOr[string](c, "tags", []string{}) // // If "tags" is missing: returns ([], nil) // // If form parsing fails: returns (nil, error) // // See ParseValues for supported types and options -func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) { - formValues, err := c.FormValues() +func FormParamsOr[T any](c Context, key string, defaultValue []T, opts ...any) ([]T, error) { + formValues, err := c.FormParams() if err != nil { - return nil, fmt.Errorf("failed to parse form values, key: %s, err: %w", key, err) + return nil, fmt.Errorf("failed to parse form params, key: %s, err: %w", key, err) } values, ok := formValues[key] if !ok { @@ -308,7 +310,7 @@ func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) } result, err := ParseValuesOr[T](values, defaultValue, opts...) if err != nil { - return nil, NewBindingError(key, values, "form values", err) + return nil, NewBindingError(key, values, "form params", err) } return result, nil } diff --git a/binder_generic_test.go b/binder_generic_test.go index 849d75962..96dfc5ed8 100644 --- a/binder_generic_test.go +++ b/binder_generic_test.go @@ -64,16 +64,15 @@ func TestPathParam(t *testing.T) { name: "nok, invalid value", givenValue: "can_parse_me", expect: false, - expectErr: `code=400, message=path value, err=failed to parse value, err: strconv.ParseBool: parsing "can_parse_me": invalid syntax, field=key`, + expectErr: `code=400, message=path param, internal=failed to parse value, err: strconv.ParseBool: parsing "can_parse_me": invalid syntax, field=key`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - c := NewContext(nil, nil) - c.SetPathValues(PathValues{{ - Name: cmp.Or(tc.givenKey, "key"), - Value: tc.givenValue, - }}) + e := New() + c := e.NewContext(nil, nil) + c.SetParamNames(cmp.Or(tc.givenKey, "key")) + c.SetParamValues(tc.givenValue) v, err := PathParam[bool](c, "key") if tc.expectErr != "" { @@ -87,12 +86,14 @@ func TestPathParam(t *testing.T) { } func TestPathParam_UnsupportedType(t *testing.T) { - c := NewContext(nil, nil) - c.SetPathValues(PathValues{{Name: "key", Value: "true"}}) + e := New() + c := e.NewContext(nil, nil) + c.SetParamNames("key") + c.SetParamValues("true") v, err := PathParam[[]bool](c, "key") - expectErr := "code=400, message=path value, err=failed to parse value, err: unsupported value type: *[]bool, field=key" + expectErr := "code=400, message=path param, internal=failed to parse value, err: unsupported value type: *[]bool, field=key" assert.EqualError(t, err, expectErr) assert.Equal(t, []bool(nil), v) } @@ -119,13 +120,14 @@ func TestQueryParam(t *testing.T) { name: "nok, invalid value", givenURL: "/?key=invalidbool", expect: false, - expectErr: `code=400, message=query param, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, + expectErr: `code=400, message=query param, internal=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) - c := NewContext(req, nil) + e := New() + c := e.NewContext(req, nil) v, err := QueryParam[bool](c, "key") if tc.expectErr != "" { @@ -140,11 +142,12 @@ func TestQueryParam(t *testing.T) { func TestQueryParam_UnsupportedType(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil) - c := NewContext(req, nil) + e := New() + c := e.NewContext(req, nil) v, err := QueryParam[[]bool](c, "key") - expectErr := "code=400, message=query param, err=failed to parse value, err: unsupported value type: *[]bool, field=key" + expectErr := "code=400, message=query param, internal=failed to parse value, err: unsupported value type: *[]bool, field=key" assert.EqualError(t, err, expectErr) assert.Equal(t, []bool(nil), v) } @@ -171,13 +174,14 @@ func TestQueryParams(t *testing.T) { name: "nok, invalid value", givenURL: "/?key=true&key=invalidbool", expect: []bool(nil), - expectErr: `code=400, message=query params, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, + expectErr: `code=400, message=query params, internal=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) - c := NewContext(req, nil) + e := New() + c := e.NewContext(req, nil) v, err := QueryParams[bool](c, "key") if tc.expectErr != "" { @@ -192,11 +196,12 @@ func TestQueryParams(t *testing.T) { func TestQueryParams_UnsupportedType(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil) - c := NewContext(req, nil) + e := New() + c := e.NewContext(req, nil) v, err := QueryParams[[]bool](c, "key") - expectErr := "code=400, message=query params, err=failed to parse value, err: unsupported value type: *[]bool, field=key" + expectErr := "code=400, message=query params, internal=failed to parse value, err: unsupported value type: *[]bool, field=key" assert.EqualError(t, err, expectErr) assert.Equal(t, [][]bool(nil), v) } @@ -223,15 +228,16 @@ func TestFormValue(t *testing.T) { name: "nok, invalid value", givenURL: "/?key=invalidbool", expect: false, - expectErr: `code=400, message=form value, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, + expectErr: `code=400, message=form param, internal=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) - c := NewContext(req, nil) + e := New() + c := e.NewContext(req, nil) - v, err := FormValue[bool](c, "key") + v, err := FormParam[bool](c, "key") if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { @@ -244,11 +250,12 @@ func TestFormValue(t *testing.T) { func TestFormValue_UnsupportedType(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil) - c := NewContext(req, nil) + e := New() + c := e.NewContext(req, nil) - v, err := FormValue[[]bool](c, "key") + v, err := FormParam[[]bool](c, "key") - expectErr := "code=400, message=form value, err=failed to parse value, err: unsupported value type: *[]bool, field=key" + expectErr := "code=400, message=form param, internal=failed to parse value, err: unsupported value type: *[]bool, field=key" assert.EqualError(t, err, expectErr) assert.Equal(t, []bool(nil), v) } @@ -275,15 +282,16 @@ func TestFormValues(t *testing.T) { name: "nok, invalid value", givenURL: "/?key=true&key=invalidbool", expect: []bool(nil), - expectErr: `code=400, message=form values, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, + expectErr: `code=400, message=form params, internal=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) - c := NewContext(req, nil) + e := New() + c := e.NewContext(req, nil) - v, err := FormValues[bool](c, "key") + v, err := FormParams[bool](c, "key") if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { @@ -296,11 +304,12 @@ func TestFormValues(t *testing.T) { func TestFormValues_UnsupportedType(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil) - c := NewContext(req, nil) + e := New() + c := e.NewContext(req, nil) - v, err := FormValues[[]bool](c, "key") + v, err := FormParams[[]bool](c, "key") - expectErr := "code=400, message=form values, err=failed to parse value, err: unsupported value type: *[]bool, field=key" + expectErr := "code=400, message=form params, internal=failed to parse value, err: unsupported value type: *[]bool, field=key" assert.EqualError(t, err, expectErr) assert.Equal(t, [][]bool(nil), v) } @@ -1424,13 +1433,15 @@ func TestPathParamOr(t *testing.T) { givenKey: "id", givenValue: "invalid", defaultValue: 999, - expectErr: "code=400, message=path value, err=failed to parse value", + expectErr: "code=400, message=path param, internal=failed to parse value", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - c := NewContext(nil, nil) - c.SetPathValues(PathValues{{Name: tc.givenKey, Value: tc.givenValue}}) + e := New() + c := e.NewContext(nil, nil) + c.SetParamNames(tc.givenKey) + c.SetParamValues(tc.givenValue) v, err := PathParamOr[int](c, "id", tc.defaultValue) if tc.expectErr != "" { @@ -1479,7 +1490,8 @@ func TestQueryParamOr(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) - c := NewContext(req, nil) + e := New() + c := e.NewContext(req, nil) v, err := QueryParamOr[int](c, "key", tc.defaultValue) if tc.expectErr != "" { @@ -1522,7 +1534,8 @@ func TestQueryParamsOr(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) - c := NewContext(req, nil) + e := New() + c := e.NewContext(req, nil) v, err := QueryParamsOr[int](c, "key", tc.defaultValue) if tc.expectErr != "" { @@ -1565,9 +1578,10 @@ func TestFormValueOr(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) - c := NewContext(req, nil) + e := New() + c := e.NewContext(req, nil) - v, err := FormValueOr[string](c, "name", tc.defaultValue) + v, err := FormParamOr[string](c, "name", tc.defaultValue) if tc.expectErr != "" { assert.ErrorContains(t, err, tc.expectErr) } else { @@ -1602,9 +1616,10 @@ func TestFormValuesOr(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) - c := NewContext(req, nil) + e := New() + c := e.NewContext(req, nil) - v, err := FormValuesOr[string](c, "tags", tc.defaultValue) + v, err := FormParamsOr[string](c, "tags", tc.defaultValue) if tc.expectErr != "" { assert.ErrorContains(t, err, tc.expectErr) } else { diff --git a/binder_test.go b/binder_test.go index 8eced8208..31d13a02c 100644 --- a/binder_test.go +++ b/binder_test.go @@ -18,7 +18,7 @@ import ( "time" ) -func createTestContext(URL string, body io.Reader, pathValues map[string]string) *Context { +func createTestContext(URL string, body io.Reader, pathParams map[string]string) Context { e := New() req := httptest.NewRequest(http.MethodGet, URL, body) if body != nil { @@ -27,15 +27,15 @@ func createTestContext(URL string, body io.Reader, pathValues map[string]string) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - if len(pathValues) > 0 { - params := make(PathValues, 0) - for name, value := range pathValues { - params = append(params, PathValue{ - Name: name, - Value: value, - }) + if len(pathParams) > 0 { + names := make([]string, 0) + values := make([]string, 0) + for name, value := range pathParams { + names = append(names, name) + values = append(values, value) } - c.SetPathValues(params) + c.SetParamNames(names...) + c.SetParamValues(values...) } return c @@ -43,12 +43,12 @@ func createTestContext(URL string, body io.Reader, pathValues map[string]string) func TestBindingError_Error(t *testing.T) { err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error")) - assert.EqualError(t, err, `code=400, message=bind failed, err=internal error, field=id`) + assert.EqualError(t, err, `code=400, message=bind failed, internal=internal error, field=id`) bErr := err.(*BindingError) assert.Equal(t, 400, bErr.Code) assert.Equal(t, "bind failed", bErr.Message) - assert.Equal(t, errors.New("internal error"), bErr.err) + assert.Equal(t, errors.New("internal error"), bErr.Internal) assert.Equal(t, "id", bErr.Field) assert.Equal(t, []string{"1", "nope"}, bErr.Values) @@ -62,13 +62,13 @@ func TestBindingError_ErrorJSON(t *testing.T) { assert.Equal(t, `{"field":"id","message":"bind failed"}`, string(resp)) } -func TestPathValuesBinder(t *testing.T) { +func TestPathParamsBinder(t *testing.T) { c := createTestContext("/api/user/999", nil, map[string]string{ "id": "1", "nr": "2", "slice": "3", }) - b := PathValuesBinder(c) + b := PathParamsBinder(c) id := int64(99) nr := int64(88) @@ -91,15 +91,15 @@ func TestQueryParamsBinder_FailFast(t *testing.T) { var testCases = []struct { name string whenURL string - expectError []string givenFailFast bool + expectError []string }{ { name: "ok, FailFast=true stops at first error", whenURL: "/api/user/999?nr=en&id=nope", givenFailFast: true, expectError: []string{ - `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "nope": invalid syntax, field=id`, + `code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing "nope": invalid syntax, field=id`, }, }, { @@ -107,8 +107,8 @@ func TestQueryParamsBinder_FailFast(t *testing.T) { whenURL: "/api/user/999?nr=en&id=nope", givenFailFast: false, expectError: []string{ - `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "nope": invalid syntax, field=id`, - `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "en": invalid syntax, field=nr`, + `code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing "nope": invalid syntax, field=id`, + `code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing "en": invalid syntax, field=nr`, }, }, } @@ -165,7 +165,7 @@ func TestFormFieldBinder(t *testing.T) { } func TestValueBinder_errorStopsBinding(t *testing.T) { - // this test documents "feature" that binding multiple params can change destination if it was bound before + // this test documents "feature" that binding multiple params can change destination if it was binded before // failing parameter binding c := createTestContext("/api/user/999?id=1&nr=nope", nil, nil) @@ -177,7 +177,7 @@ func TestValueBinder_errorStopsBinding(t *testing.T) { Int64("nr", &nr). BindError() - assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=nr") + assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=nr") assert.Equal(t, int64(1), id) assert.Equal(t, int64(88), nr) } @@ -192,17 +192,17 @@ func TestValueBinder_BindError(t *testing.T) { Int64("nr", &nr). BindError() - assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=id") + assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=id") assert.Nil(t, b.errors) assert.Nil(t, b.BindError()) } func TestValueBinder_GetValues(t *testing.T) { var testCases = []struct { - whenValuesFunc func(sourceParam string) []string name string - expectError string + whenValuesFunc func(sourceParam string) []string expect []int64 + expectError string }{ { name: "ok, default implementation", @@ -266,13 +266,13 @@ func TestValueBinder_CustomFuncWithError(t *testing.T) { func TestValueBinder_CustomFunc(t *testing.T) { var testCases = []struct { - expectValue any name string - whenURL string + givenFailFast bool givenFuncErrors []error + whenURL string expectParamValues []string + expectValue any expectErrors []string - givenFailFast bool }{ { name: "ok, binds value", @@ -341,13 +341,13 @@ func TestValueBinder_CustomFunc(t *testing.T) { func TestValueBinder_MustCustomFunc(t *testing.T) { var testCases = []struct { - expectValue any name string - whenURL string + givenFailFast bool givenFuncErrors []error + whenURL string expectParamValues []string + expectValue any expectErrors []string - givenFailFast bool }{ { name: "ok, binds value", @@ -418,12 +418,12 @@ func TestValueBinder_MustCustomFunc(t *testing.T) { func TestValueBinder_String(t *testing.T) { var testCases = []struct { name string + givenFailFast bool + givenBindErrors []error whenURL string + whenMust bool expectValue string expectError string - givenBindErrors []error - givenFailFast bool - whenMust bool }{ { name: "ok, binds value", @@ -494,12 +494,12 @@ func TestValueBinder_String(t *testing.T) { func TestValueBinder_Strings(t *testing.T) { var testCases = []struct { name string - whenURL string - expectError string - givenBindErrors []error - expectValue []string givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue []string + expectError string }{ { name: "ok, binds value", @@ -570,12 +570,12 @@ func TestValueBinder_Strings(t *testing.T) { func TestValueBinder_Int64_intValue(t *testing.T) { var testCases = []struct { name string - whenURL string - expectError string - givenBindErrors []error - expectValue int64 givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue int64 + expectError string }{ { name: "ok, binds value", @@ -598,7 +598,7 @@ func TestValueBinder_Int64_intValue(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: 99, - expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -626,7 +626,7 @@ func TestValueBinder_Int64_intValue(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: 99, - expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } @@ -667,19 +667,19 @@ func TestValueBinder_Int_errorMessage(t *testing.T) { assert.Equal(t, 99, destInt) assert.Equal(t, uint(98), destUint) - assert.EqualError(t, errs[0], `code=400, message=failed to bind field value to int, err=strconv.ParseInt: parsing "nope": invalid syntax, field=param`) - assert.EqualError(t, errs[1], `code=400, message=failed to bind field value to uint, err=strconv.ParseUint: parsing "nope": invalid syntax, field=param`) + assert.EqualError(t, errs[0], `code=400, message=failed to bind field value to int, internal=strconv.ParseInt: parsing "nope": invalid syntax, field=param`) + assert.EqualError(t, errs[1], `code=400, message=failed to bind field value to uint, internal=strconv.ParseUint: parsing "nope": invalid syntax, field=param`) } func TestValueBinder_Uint64_uintValue(t *testing.T) { var testCases = []struct { name string - whenURL string - expectError string - givenBindErrors []error - expectValue uint64 givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue uint64 + expectError string }{ { name: "ok, binds value", @@ -702,7 +702,7 @@ func TestValueBinder_Uint64_uintValue(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: 99, - expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to uint64, internal=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -730,7 +730,7 @@ func TestValueBinder_Uint64_uintValue(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: 99, - expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to uint64, internal=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", }, } @@ -881,12 +881,12 @@ func TestValueBinder_Int_Types(t *testing.T) { func TestValueBinder_Int64s_intsValue(t *testing.T) { var testCases = []struct { name string - whenURL string - expectError string - givenBindErrors []error - expectValue []int64 givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue []int64 + expectError string }{ { name: "ok, binds value", @@ -909,7 +909,7 @@ func TestValueBinder_Int64s_intsValue(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []int64{99}, - expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -937,7 +937,7 @@ func TestValueBinder_Int64s_intsValue(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []int64{99}, - expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } @@ -970,12 +970,12 @@ func TestValueBinder_Int64s_intsValue(t *testing.T) { func TestValueBinder_Uint64s_uintsValue(t *testing.T) { var testCases = []struct { name string - whenURL string - expectError string - givenBindErrors []error - expectValue []uint64 givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue []uint64 + expectError string }{ { name: "ok, binds value", @@ -998,7 +998,7 @@ func TestValueBinder_Uint64s_uintsValue(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []uint64{99}, - expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to uint64, internal=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -1026,7 +1026,7 @@ func TestValueBinder_Uint64s_uintsValue(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []uint64{99}, - expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to uint64, internal=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", }, } @@ -1169,7 +1169,7 @@ func TestValueBinder_Ints_Types(t *testing.T) { func TestValueBinder_Ints_Types_FailFast(t *testing.T) { // FailFast() should stop parsing and return early - errTmpl := "code=400, message=failed to bind field value to %v, err=strconv.Parse%v: parsing \"nope\": invalid syntax, field=param" + errTmpl := "code=400, message=failed to bind field value to %v, internal=strconv.Parse%v: parsing \"nope\": invalid syntax, field=param" c := createTestContext("/search?param=1¶m=nope¶m=2", nil, nil) var dest64 []int64 @@ -1226,12 +1226,12 @@ func TestValueBinder_Ints_Types_FailFast(t *testing.T) { func TestValueBinder_Bool(t *testing.T) { var testCases = []struct { name string - whenURL string - expectError string - givenBindErrors []error givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool expectValue bool + expectError string }{ { name: "ok, binds value", @@ -1254,7 +1254,7 @@ func TestValueBinder_Bool(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: false, - expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -1282,7 +1282,7 @@ func TestValueBinder_Bool(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: false, - expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", }, } @@ -1315,12 +1315,12 @@ func TestValueBinder_Bool(t *testing.T) { func TestValueBinder_Bools(t *testing.T) { var testCases = []struct { name string - whenURL string - expectError string - givenBindErrors []error - expectValue []bool givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue []bool + expectError string }{ { name: "ok, binds value", @@ -1344,14 +1344,14 @@ func TestValueBinder_Bools(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=true¶m=nope¶m=100", expectValue: []bool(nil), - expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", }, { name: "nok, conversion fails fast, value is not changed", givenFailFast: true, whenURL: "/search?param=true¶m=nope¶m=100", expectValue: []bool(nil), - expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -1380,7 +1380,7 @@ func TestValueBinder_Bools(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []bool(nil), - expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", }, } @@ -1411,12 +1411,12 @@ func TestValueBinder_Bools(t *testing.T) { func TestValueBinder_Float64(t *testing.T) { var testCases = []struct { name string - whenURL string - expectError string - givenBindErrors []error - expectValue float64 givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue float64 + expectError string }{ { name: "ok, binds value", @@ -1439,7 +1439,7 @@ func TestValueBinder_Float64(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: 1.123, - expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -1467,7 +1467,7 @@ func TestValueBinder_Float64(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: 1.123, - expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, } @@ -1500,12 +1500,12 @@ func TestValueBinder_Float64(t *testing.T) { func TestValueBinder_Float64s(t *testing.T) { var testCases = []struct { name string - whenURL string - expectError string - givenBindErrors []error - expectValue []float64 givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue []float64 + expectError string }{ { name: "ok, binds value", @@ -1529,14 +1529,14 @@ func TestValueBinder_Float64s(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []float64(nil), - expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "nok, conversion fails fast, value is not changed", givenFailFast: true, whenURL: "/search?param=0¶m=nope¶m=100", expectValue: []float64(nil), - expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -1565,7 +1565,7 @@ func TestValueBinder_Float64s(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []float64(nil), - expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, } @@ -1596,12 +1596,12 @@ func TestValueBinder_Float64s(t *testing.T) { func TestValueBinder_Float32(t *testing.T) { var testCases = []struct { name string - whenURL string - expectError string - givenBindErrors []error - expectValue float32 givenNoFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue float32 + expectError string }{ { name: "ok, binds value", @@ -1624,7 +1624,7 @@ func TestValueBinder_Float32(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: 1.123, - expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -1652,7 +1652,7 @@ func TestValueBinder_Float32(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: 1.123, - expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, } @@ -1685,12 +1685,12 @@ func TestValueBinder_Float32(t *testing.T) { func TestValueBinder_Float32s(t *testing.T) { var testCases = []struct { name string - whenURL string - expectError string - givenBindErrors []error - expectValue []float32 givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue []float32 + expectError string }{ { name: "ok, binds value", @@ -1714,14 +1714,14 @@ func TestValueBinder_Float32s(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []float32(nil), - expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "nok, conversion fails fast, value is not changed", givenFailFast: true, whenURL: "/search?param=0¶m=nope¶m=100", expectValue: []float32(nil), - expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -1750,7 +1750,7 @@ func TestValueBinder_Float32s(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []float32(nil), - expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, } @@ -1781,14 +1781,14 @@ func TestValueBinder_Float32s(t *testing.T) { func TestValueBinder_Time(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00") var testCases = []struct { - expectValue time.Time name string + givenFailFast bool + givenBindErrors []error whenURL string + whenMust bool whenLayout string + expectValue time.Time expectError string - givenBindErrors []error - givenFailFast bool - whenMust bool }{ { name: "ok, binds value", @@ -1863,13 +1863,13 @@ func TestValueBinder_Times(t *testing.T) { exampleTime2, _ := time.Parse(time.RFC3339, "2000-01-02T09:45:31+00:00") var testCases = []struct { name string + givenFailFast bool + givenBindErrors []error whenURL string + whenMust bool whenLayout string - expectError string - givenBindErrors []error expectValue []time.Time - givenFailFast bool - whenMust bool + expectError string }{ { name: "ok, binds value", @@ -1948,12 +1948,12 @@ func TestValueBinder_Duration(t *testing.T) { example := 42 * time.Second var testCases = []struct { name string - whenURL string - expectError string - givenBindErrors []error - expectValue time.Duration givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue time.Duration + expectError string }{ { name: "ok, binds value", @@ -2026,12 +2026,12 @@ func TestValueBinder_Durations(t *testing.T) { exampleDuration2 := 1 * time.Millisecond var testCases = []struct { name string - whenURL string - expectError string - givenBindErrors []error - expectValue []time.Duration givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue []time.Duration + expectError string }{ { name: "ok, binds value", @@ -2103,13 +2103,13 @@ func TestValueBinder_BindUnmarshaler(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00") var testCases = []struct { - expectValue Timestamp name string - whenURL string - expectError string - givenBindErrors []error givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue Timestamp + expectError string }{ { name: "ok, binds value", @@ -2132,7 +2132,7 @@ func TestValueBinder_BindUnmarshaler(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: Timestamp{}, - expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", }, { name: "ok (must), binds value", @@ -2160,7 +2160,7 @@ func TestValueBinder_BindUnmarshaler(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: Timestamp{}, - expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", }, } @@ -2195,12 +2195,12 @@ func TestValueBinder_JSONUnmarshaler(t *testing.T) { var testCases = []struct { name string - whenURL string - expectError string - expectValue big.Int - givenBindErrors []error givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue big.Int + expectError string }{ { name: "ok, binds value", @@ -2223,7 +2223,7 @@ func TestValueBinder_JSONUnmarshaler(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=xxx", expectValue: big.Int{}, - expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", }, { name: "ok (must), binds value", @@ -2251,7 +2251,7 @@ func TestValueBinder_JSONUnmarshaler(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=xxx", expectValue: big.Int{}, - expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", }, } @@ -2286,12 +2286,12 @@ func TestValueBinder_TextUnmarshaler(t *testing.T) { var testCases = []struct { name string - whenURL string - expectError string - expectValue big.Int - givenBindErrors []error givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue big.Int + expectError string }{ { name: "ok, binds value", @@ -2314,7 +2314,7 @@ func TestValueBinder_TextUnmarshaler(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=xxx", expectValue: big.Int{}, - expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", }, { name: "ok (must), binds value", @@ -2342,7 +2342,7 @@ func TestValueBinder_TextUnmarshaler(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=xxx", expectValue: big.Int{}, - expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", }, } @@ -2374,9 +2374,9 @@ func TestValueBinder_TextUnmarshaler(t *testing.T) { func TestValueBinder_BindWithDelimiter_types(t *testing.T) { var testCases = []struct { - expect any name string whenURL string + expect any }{ { name: "ok, strings", @@ -2522,12 +2522,12 @@ func TestValueBinder_BindWithDelimiter_types(t *testing.T) { func TestValueBinder_BindWithDelimiter(t *testing.T) { var testCases = []struct { name string - whenURL string - expectError string - givenBindErrors []error - expectValue []int64 givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue []int64 + expectError string }{ { name: "ok, binds value", @@ -2550,7 +2550,7 @@ func TestValueBinder_BindWithDelimiter(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []int64(nil), - expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -2578,7 +2578,7 @@ func TestValueBinder_BindWithDelimiter(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []int64(nil), - expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } @@ -2621,13 +2621,13 @@ func TestBindWithDelimiter_invalidType(t *testing.T) { func TestValueBinder_UnixTime(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43+00:00") // => 1609180603 var testCases = []struct { - expectValue time.Time name string - whenURL string - expectError string - givenBindErrors []error givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue time.Time + expectError string }{ { name: "ok, binds value, unix time in seconds", @@ -2655,7 +2655,7 @@ func TestValueBinder_UnixTime(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -2683,7 +2683,7 @@ func TestValueBinder_UnixTime(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } @@ -2717,13 +2717,13 @@ func TestValueBinder_UnixTime(t *testing.T) { func TestValueBinder_UnixTimeMilli(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339Nano, "2022-03-13T15:13:30.140000000+00:00") // => 1647184410140 var testCases = []struct { - expectValue time.Time name string - whenURL string - expectError string - givenBindErrors []error givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue time.Time + expectError string }{ { name: "ok, binds value, unix time in milliseconds", @@ -2746,7 +2746,7 @@ func TestValueBinder_UnixTimeMilli(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -2774,7 +2774,7 @@ func TestValueBinder_UnixTimeMilli(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } @@ -2810,13 +2810,13 @@ func TestValueBinder_UnixTimeNano(t *testing.T) { exampleTimeNano, _ := time.Parse(time.RFC3339Nano, "2020-12-28T18:36:43.123456789+00:00") // => 1609180603123456789 exampleTimeNanoBelowSec, _ := time.Parse(time.RFC3339Nano, "1970-01-01T00:00:00.999999999+00:00") var testCases = []struct { - expectValue time.Time name string - whenURL string - expectError string - givenBindErrors []error givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue time.Time + expectError string }{ { name: "ok, binds value, unix time in nano seconds (sec precision)", @@ -2849,7 +2849,7 @@ func TestValueBinder_UnixTimeNano(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -2877,7 +2877,7 @@ func TestValueBinder_UnixTimeNano(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } @@ -2919,7 +2919,7 @@ func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) { binder := new(DefaultBinder) for i := 0; i < b.N; i++ { var dest Opts - _ = binder.Bind(c, &dest) + _ = binder.Bind(&dest, c) } } @@ -2967,16 +2967,17 @@ func BenchmarkRawFunc_Int64_single(b *testing.B) { func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) { type Opts struct { - String string `query:"string"` - Strings []string `query:"strings"` - Int64 int64 `query:"int64"` + Int64 int64 `query:"int64"` + Int32 int32 `query:"int32"` + Int16 int16 `query:"int16"` + Int8 int8 `query:"int8"` + String string `query:"string"` + Uint64 uint64 `query:"uint64"` - Int32 int32 `query:"int32"` Uint32 uint32 `query:"uint32"` - Int16 int16 `query:"int16"` Uint16 uint16 `query:"uint16"` - Int8 int8 `query:"int8"` Uint8 uint8 `query:"uint8"` + Strings []string `query:"strings"` } c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil) @@ -2985,7 +2986,7 @@ func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) { binder := new(DefaultBinder) for i := 0; i < b.N; i++ { var dest Opts - _ = binder.Bind(c, &dest) + _ = binder.Bind(&dest, c) if dest.Int64 != 1 { b.Fatalf("int64!=1") } @@ -2994,16 +2995,17 @@ func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) { func BenchmarkValueBinder_BindInt64_10_fields(b *testing.B) { type Opts struct { - String string `query:"string"` - Strings []string `query:"strings"` - Int64 int64 `query:"int64"` + Int64 int64 `query:"int64"` + Int32 int32 `query:"int32"` + Int16 int16 `query:"int16"` + Int8 int8 `query:"int8"` + String string `query:"string"` + Uint64 uint64 `query:"uint64"` - Int32 int32 `query:"int32"` Uint32 uint32 `query:"uint32"` - Int16 int16 `query:"int16"` Uint16 uint16 `query:"uint16"` - Int8 int8 `query:"int8"` Uint8 uint8 `query:"uint8"` + Strings []string `query:"strings"` } c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil) @@ -3032,27 +3034,27 @@ func BenchmarkValueBinder_BindInt64_10_fields(b *testing.B) { func TestValueBinder_TimeError(t *testing.T) { var testCases = []struct { - expectValue time.Time name string + givenFailFast bool + givenBindErrors []error whenURL string + whenMust bool whenLayout string + expectValue time.Time expectError string - givenBindErrors []error - givenFailFast bool - whenMust bool }{ { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\": extra text: \"nope\", field=param", + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\": extra text: \"nope\", field=param", + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", }, } @@ -3085,33 +3087,33 @@ func TestValueBinder_TimeError(t *testing.T) { func TestValueBinder_TimesError(t *testing.T) { var testCases = []struct { name string + givenFailFast bool + givenBindErrors []error whenURL string + whenMust bool whenLayout string - expectError string - givenBindErrors []error expectValue []time.Time - givenFailFast bool - whenMust bool + expectError string }{ { name: "nok, fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param", + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", }, } @@ -3147,25 +3149,25 @@ func TestValueBinder_TimesError(t *testing.T) { func TestValueBinder_DurationError(t *testing.T) { var testCases = []struct { name string - whenURL string - expectError string - givenBindErrors []error - expectValue time.Duration givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue time.Duration + expectError string }{ { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: 0, - expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param", + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: 0, - expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param", + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", }, } @@ -3198,32 +3200,32 @@ func TestValueBinder_DurationError(t *testing.T) { func TestValueBinder_DurationsError(t *testing.T) { var testCases = []struct { name string - whenURL string - expectError string - givenBindErrors []error - expectValue []time.Duration givenFailFast bool + givenBindErrors []error + whenURL string whenMust bool + expectValue []time.Duration + expectError string }{ { name: "nok, fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, err=time: missing unit in duration \"1\", field=param", + expectError: "code=400, message=failed to bind field value to Duration, internal=time: missing unit in duration \"1\", field=param", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param", + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param", + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", }, } diff --git a/context.go b/context.go index 9a7429f39..6500e5eef 100644 --- a/context.go +++ b/context.go @@ -6,194 +6,270 @@ package echo import ( "bytes" "encoding/xml" - "errors" "fmt" "io" - "io/fs" - "log/slog" "mime/multipart" "net" "net/http" "net/url" - "path" - "path/filepath" "strings" "sync" - "unsafe" ) -// stringToBytes returns a []byte view over s without copying, avoiding the allocation+copy of []byte(s) -// on the response write path (the zero-copy technique used by fasthttp/fiber). -// -// Contract β€” all of the following must hold at every call site: -// - The result is read-only: writing through it is undefined behaviour. -// - The callee must NOT retain the slice beyond the call. It is only passed to the response -// Writer's Write, whose io.Writer contract forbids retaining/mutating the argument. Note the -// concrete writer may be a wrapping ResponseWriter (e.g. gzip); such writers must copy, not alias. -// - s must stay reachable for as long as the slice is used: the slice aliases s's backing array. -func stringToBytes(s string) []byte { - if s == "" { - return nil - } - return unsafe.Slice(unsafe.StringData(s), len(s)) -} - -// jsonpOpen and jsonpClose are the constant byte wrappers for JSONP payloads, kept as package-level -// slices to avoid allocating them on every JSONP response. -var ( - jsonpOpen = []byte("(") - jsonpClose = []byte(");") -) +// Context represents the context of the current HTTP request. It holds request and +// response objects, path, path parameters, data and registered handler. +type Context interface { + // Request returns `*http.Request`. + Request() *http.Request -const ( - // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain. - // Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests. - // It is added to context only when Router does not find matching method handler for request. - ContextKeyHeaderAllow = "echo_header_allow" -) + // SetRequest sets `*http.Request`. + SetRequest(r *http.Request) -const ( - // defaultMemory is default value for memory limit that is used when - // parsing multipart forms (See (*http.Request).ParseMultipartForm) - defaultMemory int64 = 32 << 20 // 32 MB - indexPage = "index.html" -) + // SetResponse sets `*Response`. + SetResponse(r *Response) -// Context represents the context of the current HTTP request. It holds request and -// response objects, path, path parameters, data and registered handler. -type Context struct { - request *http.Request - orgResponse *Response - response http.ResponseWriter - query url.Values + // Response returns `*Response`. + Response() *Response - // formParseMaxMemory is used for http.Request.ParseMultipartForm - formParseMaxMemory int64 + // IsTLS returns true if HTTP connection is TLS otherwise false. + IsTLS() bool - route *RouteInfo - pathValues *PathValues + // IsWebSocket returns true if HTTP connection is WebSocket otherwise false. + IsWebSocket() bool - // handler is the route handler resolved during routing. It is invoked by the terminal of the global - // middleware chain (see Echo.buildRouterChains) so that the chain can be compiled once and reused. - handler HandlerFunc + // Scheme returns the HTTP protocol scheme, `http` or `https`. + Scheme() string - // dsw is reused by json() so that each JSON response does not heap-allocate a delayedStatusWriter. - // It lives on the pooled Context; &c.dsw is a stable, allocation-free pointer. Only json() may point - // the response at &c.dsw, and only via the nested-call guard there β€” aliasing it to itself (wrapping - // &c.dsw around &c.dsw) would make the response writer reference itself. - dsw delayedStatusWriter + // RealIP returns the client's network address based on `X-Forwarded-For` + // or `X-Real-IP` request header. + // The behavior can be configured using `Echo#IPExtractor`. + RealIP() string - store map[string]any - echo *Echo - logger *slog.Logger + // Path returns the registered path for the handler. + Path() string - path string - lock sync.RWMutex -} - -// NewContext returns a new Context instance. -// -// Note: request,response and e can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway -// these arguments are useful when creating context for tests and cases like that. -func NewContext(r *http.Request, w http.ResponseWriter, opts ...any) *Context { - var e *Echo - for _, opt := range opts { - switch v := opt.(type) { - case *Echo: - e = v - } - } - return newContext(r, w, e) -} + // SetPath sets the registered path for the handler. + SetPath(p string) -func newContext(r *http.Request, w http.ResponseWriter, e *Echo) *Context { - // store is created lazily by Set and cleared (not freed) by Reset, so we deliberately do not allocate a map here. - c := &Context{ - pathValues: nil, - echo: e, - logger: nil, - } - var logger *slog.Logger - paramLen := int32(0) - formParseMaxMemory := defaultMemory - if e != nil { - paramLen = e.contextPathParamAllocSize.Load() - logger = e.Logger - formParseMaxMemory = e.formParseMaxMemory - } - if logger == nil { - logger = slog.Default() - } - c.logger = logger - p := make(PathValues, 0, paramLen) - c.pathValues = &p + // Param returns path parameter by name. + Param(name string) string + + // ParamNames returns path parameter names. + ParamNames() []string + + // SetParamNames sets path parameter names. + SetParamNames(names ...string) + + // ParamValues returns path parameter values. + ParamValues() []string + + // SetParamValues sets path parameter values. + SetParamValues(values ...string) + + // QueryParam returns the query param for the provided name. + QueryParam(name string) string + + // QueryParams returns the query parameters as `url.Values`. + QueryParams() url.Values + + // QueryString returns the URL query string. + QueryString() string + + // FormValue returns the form field value for the provided name. + FormValue(name string) string + + // FormParams returns the form parameters as `url.Values`. + FormParams() (url.Values, error) + + // FormFile returns the multipart form file for the provided name. + FormFile(name string) (*multipart.FileHeader, error) + + // MultipartForm returns the multipart form. + MultipartForm() (*multipart.Form, error) + + // Cookie returns the named cookie provided in the request. + Cookie(name string) (*http.Cookie, error) + + // SetCookie adds a `Set-Cookie` header in HTTP response. + SetCookie(cookie *http.Cookie) + + // Cookies returns the HTTP cookies sent with the request. + Cookies() []*http.Cookie + + // Get retrieves data from the context. + Get(key string) any + + // Set saves data in the context. + Set(key string, val any) + + // Bind binds path params, query params and the request body into provided type `i`. The default binder + // binds body based on Content-Type header. + Bind(i any) error - c.SetRequest(r) - c.orgResponse = NewResponse(w, logger) - c.response = c.orgResponse - c.formParseMaxMemory = formParseMaxMemory - return c + // Validate validates provided `i`. It is usually called after `Context#Bind()`. + // Validator must be registered using `Echo#Validator`. + Validate(i any) error + + // Render renders a template with data and sends a text/html response with status + // code. Renderer must be registered using `Echo.Renderer`. + Render(code int, name string, data any) error + + // HTML sends an HTTP response with status code. + HTML(code int, html string) error + + // HTMLBlob sends an HTTP blob response with status code. + HTMLBlob(code int, b []byte) error + + // String sends a string response with status code. + String(code int, s string) error + + // JSON sends a JSON response with status code. + JSON(code int, i any) error + + // JSONPretty sends a pretty-print JSON with status code. + JSONPretty(code int, i any, indent string) error + + // JSONBlob sends a JSON blob response with status code. + JSONBlob(code int, b []byte) error + + // JSONP sends a JSONP response with status code. It uses `callback` to construct + // the JSONP payload. + JSONP(code int, callback string, i any) error + + // JSONPBlob sends a JSONP blob response with status code. It uses `callback` + // to construct the JSONP payload. + JSONPBlob(code int, callback string, b []byte) error + + // XML sends an XML response with status code. + XML(code int, i any) error + + // XMLPretty sends a pretty-print XML with status code. + XMLPretty(code int, i any, indent string) error + + // XMLBlob sends an XML blob response with status code. + XMLBlob(code int, b []byte) error + + // Blob sends a blob response with status code and content type. + Blob(code int, contentType string, b []byte) error + + // Stream sends a streaming response with status code and content type. + Stream(code int, contentType string, r io.Reader) error + + // File sends a response with the content of the file. + File(file string) error + + // Attachment sends a response as attachment, prompting client to save the + // file. + Attachment(file string, name string) error + + // Inline sends a response as inline, opening the file in the browser. + Inline(file string, name string) error + + // NoContent sends a response with no body and a status code. + NoContent(code int) error + + // Redirect redirects the request to a provided URL with status code. + Redirect(code int, url string) error + + // Error invokes the registered global HTTP error handler. Generally used by middleware. + // A side-effect of calling global error handler is that now Response has been committed (sent to the client) and + // middlewares up in chain can not change Response status code or Response body anymore. + // + // Avoid using this method in handlers as no middleware will be able to effectively handle errors after that. + Error(err error) + + // Handler returns the matched handler by router. + Handler() HandlerFunc + + // SetHandler sets the matched handler by router. + SetHandler(h HandlerFunc) + + // Logger returns the `Logger` instance. + Logger() Logger + + // SetLogger Set the logger + SetLogger(l Logger) + + // Echo returns the `Echo` instance. + Echo() *Echo + + // Reset resets the context after request completes. It must be called along + // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. + // See `Echo#ServeHTTP()` + Reset(r *http.Request, w http.ResponseWriter) } -// Reset resets the context after request completes. It must be called along -// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. -// See `Echo#ServeHTTP()` -func (c *Context) Reset(r *http.Request, w http.ResponseWriter) { - c.request = r - c.orgResponse.reset(w) - c.response = c.orgResponse - c.query = nil - // clear (rather than nil) keeps the map allocated on the pooled Context so that requests using Set - // do not allocate a fresh map each time. clear(nil) is a no-op. - clear(c.store) - c.logger = c.echo.Logger - - c.route = nil - c.handler = nil - c.dsw = delayedStatusWriter{} - c.path = "" - // NOTE: empty by setting length to 0. PathValues has to have capacity of c.echo.contextPathParamAllocSize at all times - *c.pathValues = (*c.pathValues)[:0] +type context struct { + logger Logger + request *http.Request + response *Response + query url.Values + echo *Echo + + store Map + lock sync.RWMutex + + // following fields are set by Router + handler HandlerFunc + + // path is route path that Router matched. It is empty string where there is no route match. + // Route registered with RouteNotFound is considered as a match and path therefore is not empty. + path string + + // Usually echo.Echo is sizing pvalues but there could be user created middlewares that decide to + // overwrite parameter by calling SetParamNames + SetParamValues. + // When echo.Echo allocated that slice it length/capacity is tied to echo.Echo.maxParam value. + // + // It is important that pvalues size is always equal or bigger to pnames length. + pvalues []string + + // pnames length is tied to param count for the matched route + pnames []string } -func (c *Context) writeContentType(value string) { - header := c.response.Header() +const ( + // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain. + // Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests. + // It is added to context only when Router does not find matching method handler for request. + ContextKeyHeaderAllow = "echo_header_allow" +) + +const ( + defaultMemory = 32 << 20 // 32 MB + indexPage = "index.html" + defaultIndent = " " +) + +func (c *context) writeContentType(value string) { + header := c.Response().Header() if header.Get(HeaderContentType) == "" { header.Set(HeaderContentType, value) } } -// Request returns `*http.Request`. -func (c *Context) Request() *http.Request { +func (c *context) Request() *http.Request { return c.request } -// SetRequest sets `*http.Request`. -func (c *Context) SetRequest(r *http.Request) { +func (c *context) SetRequest(r *http.Request) { c.request = r } -// Response returns `*Response`. -func (c *Context) Response() http.ResponseWriter { +func (c *context) Response() *Response { return c.response } -// SetResponse sets `*http.ResponseWriter`. Some context methods and/or middleware require that given ResponseWriter implements following -// method `Unwrap() http.ResponseWriter` which eventually should return *echo.Response instance. -func (c *Context) SetResponse(r http.ResponseWriter) { +func (c *context) SetResponse(r *Response) { c.response = r } -// IsTLS returns true if HTTP connection is TLS otherwise false. -func (c *Context) IsTLS() bool { +func (c *context) IsTLS() bool { return c.request.TLS != nil } -// IsWebSocket returns true if HTTP connection is WebSocket otherwise false. -func (c *Context) IsWebSocket() bool { +func (c *context) IsWebSocket() bool { upgrade := c.request.Header.Get(HeaderUpgrade) - connection := c.request.Header.Get(HeaderConnection) - return strings.EqualFold(upgrade, "websocket") && strings.Contains(strings.ToLower(connection), "upgrade") + return strings.EqualFold(upgrade, "websocket") } func isValidProto(proto string) bool { @@ -209,7 +285,7 @@ func isValidProto(proto string) bool { } // Scheme returns the HTTP protocol scheme, `http` or `https`. -func (c *Context) Scheme() string { +func (c *context) Scheme() string { // Can't use `r.Request.URL.Scheme` // See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0 if c.IsTLS() { @@ -230,188 +306,107 @@ func (c *Context) Scheme() string { return "http" } -// RealIP returns the client IP address using the configured extraction strategy. -// -// If Echo#IPExtractor is set, it is used to resolve the client IP from the incoming request (typically via proxy -// headers such as X-Forwarded-For or X-Real-IP). -// Look into the `ip.go` file for comments and examples. -// -// See: -// - Echo#ExtractIPFromXFFHeader for `X-Forwarded-For` handling with trust checks -// - Echo#ExtractIPFromRealIPHeader for `X-Real-IP` handling with trust checks -// - Echo#LegacyIPExtractor for `v4` compatibility (spoofable, no trust checks built in) -// -// If no extractor is configured, RealIP falls back to the request RemoteAddr, returning only the host portion. -// -// Notes: -// - No validation or trust enforcement is performed unless implemented by the configured IPExtractor. -// - When relying on proxy headers, ensure the application is deployed behind trusted intermediaries to avoid spoofing. -func (c *Context) RealIP() string { +func (c *context) RealIP() string { if c.echo != nil && c.echo.IPExtractor != nil { return c.echo.IPExtractor(c.request) } - // req.RemoteAddr is the IP address of the remote end of the connection, which may be a proxy. It is populated by the - // http.conn.readRequest() method and uses net.Conn.RemoteAddr().String() which we trust. + // Fall back to legacy behavior + if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" { + i := strings.IndexAny(ip, ",") + if i > 0 { + xffip := strings.TrimSpace(ip[:i]) + xffip = strings.TrimPrefix(xffip, "[") + xffip = strings.TrimSuffix(xffip, "]") + return xffip + } + return ip + } + if ip := c.request.Header.Get(HeaderXRealIP); ip != "" { + ip = strings.TrimPrefix(ip, "[") + ip = strings.TrimSuffix(ip, "]") + return ip + } ra, _, _ := net.SplitHostPort(c.request.RemoteAddr) return ra } -// Path returns the registered path for the handler. -func (c *Context) Path() string { +func (c *context) Path() string { return c.path } -// SetPath sets the registered path for the handler. -func (c *Context) SetPath(p string) { +func (c *context) SetPath(p string) { c.path = p } -// RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route. -// -// RouteInfo returns generic "empty" struct for these cases: -// * Context is accessed before Routing is done. For example inside Pre middlewares (`e.Pre()`) -// * Router did not find matching route - 404 (route not found) -// * Router did not find matching route with same method - 405 (method not allowed) -func (c *Context) RouteInfo() RouteInfo { - if c.route != nil { - return c.route.Clone() +func (c *context) Param(name string) string { + for i, n := range c.pnames { + if i < len(c.pvalues) { + if n == name { + return c.pvalues[i] + } + } } - return RouteInfo{} -} - -// Param returns path parameter by name. -func (c *Context) Param(name string) string { - return c.pathValues.GetOr(name, "") + return "" } -// ParamOr returns the path parameter or default value for the provided name. -// -// Notes for DefaultRouter implementation: -// Path parameter could be empty for cases like that: -// * route `/release-:version/bin` and request URL is `/release-/bin` -// * route `/api/:version/image.jpg` and request URL is `/api//image.jpg` -// but not when path parameter is last part of route path -// * route `/download/file.:ext` will not match request `/download/file.` -func (c *Context) ParamOr(name, defaultValue string) string { - return c.pathValues.GetOr(name, defaultValue) +func (c *context) ParamNames() []string { + return c.pnames } -// PathValues returns path parameter values. -func (c *Context) PathValues() PathValues { - return *c.pathValues -} +func (c *context) SetParamNames(names ...string) { + c.pnames = names -// SetPathValues sets path parameters for current request. -func (c *Context) SetPathValues(pathValues PathValues) { - if pathValues == nil { - panic("context SetPathValues called with nil PathValues") + l := len(names) + if len(c.pvalues) < l { + // Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them, + // probably those values will be overridden in a Context#SetParamValues + newPvalues := make([]string, l) + copy(newPvalues, c.pvalues) + c.pvalues = newPvalues } - c.setPathValues(&pathValues) -} - -// InitializeRoute sets the route related variables of this request to the context. -func (c *Context) InitializeRoute(ri *RouteInfo, pathValues *PathValues) { - c.route = ri - c.path = ri.Path - c.setPathValues(pathValues) } -func (c *Context) setPathValues(pv *PathValues) { - // Router accesses c.pathValues by index and may resize it to full capacity during routing - // for that to work without going out-of-bounds we must make sure that c.pathValues slice is not replaced with smaller - // slice than Router can set when routing Route with maximum amount of parameters. - pathValues := c.pathValues - if cap(*c.pathValues) < len(*pv) { - // normally we should not end up here. pathValues is normally sized to Echo.contextPathParamAllocSize which should not - // be smaller than anything router knows as maximum path parameter count to be. - tmp := make(PathValues, len(*pv)) - c.pathValues = &tmp - pathValues = c.pathValues - } else if len(*c.pathValues) != len(*pv) { - *pathValues = (*pathValues)[0:len(*pv)] // resize slice to given params length for copy to work - } - copy(*pathValues, *pv) +func (c *context) ParamValues() []string { + return c.pvalues[:len(c.pnames)] } -// QueryParam returns the query param for the provided name. -func (c *Context) QueryParam(name string) string { - // If the full query map was already built (e.g. by QueryParams), use it. Otherwise look the single - // key up directly from the raw query, avoiding the url.Values map allocation for the common case of - // reading only a few params. The result is identical to url.Values.Get on the parsed query. - if c.query != nil { - return c.query.Get(name) +func (c *context) SetParamValues(values ...string) { + // NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam (or bigger) at all times + // It will break the Router#Find code + limit := len(values) + if limit > len(c.pvalues) { + c.pvalues = make([]string, limit) } - return getRawQueryParam(c.request.URL.RawQuery, name) -} - -// getRawQueryParam returns the first value for name parsed directly from a raw URL query string. It -// matches url.Values.Get over url.ParseQuery output: first match wins, '+' decodes to space, percent -// escapes are decoded, segments containing ';' are skipped, and pairs whose key or value fail to -// unescape are skipped. It avoids allocating the full url.Values map for single-key lookups. -func getRawQueryParam(query, name string) string { - for query != "" { - var seg string - seg, query, _ = strings.Cut(query, "&") - if seg == "" || strings.Contains(seg, ";") { - continue - } - key, value, _ := strings.Cut(seg, "=") - k, err := url.QueryUnescape(key) - if err != nil || k != name { - continue - } - v, err := url.QueryUnescape(value) - if err != nil { - continue - } - return v + for i := 0; i < limit; i++ { + c.pvalues[i] = values[i] } - return "" } -// QueryParamOr returns the query param or default value for the provided name. -// Note: QueryParamOr does not distinguish if query had no value by that name or value was empty string -// This means URLs `/test?search=` and `/test` would both return `1` for `c.QueryParamOr("search", "1")` -func (c *Context) QueryParamOr(name, defaultValue string) string { - value := c.QueryParam(name) - if value == "" { - value = defaultValue +func (c *context) QueryParam(name string) string { + if c.query == nil { + c.query = c.request.URL.Query() } - return value + return c.query.Get(name) } -// QueryParams returns the query parameters as `url.Values`. -func (c *Context) QueryParams() url.Values { +func (c *context) QueryParams() url.Values { if c.query == nil { c.query = c.request.URL.Query() } return c.query } -// QueryString returns the URL query string. -func (c *Context) QueryString() string { +func (c *context) QueryString() string { return c.request.URL.RawQuery } -// FormValue returns the form field value for the provided name. -func (c *Context) FormValue(name string) string { +func (c *context) FormValue(name string) string { return c.request.FormValue(name) } -// FormValueOr returns the form field value or default value for the provided name. -// Note: FormValueOr does not distinguish if form had no value by that name or value was empty string -func (c *Context) FormValueOr(name, defaultValue string) string { - value := c.FormValue(name) - if value == "" { - value = defaultValue - } - return value -} - -// FormValues returns the form field values as `url.Values`. -func (c *Context) FormValues() (url.Values, error) { +func (c *context) FormParams() (url.Values, error) { if strings.HasPrefix(c.request.Header.Get(HeaderContentType), MIMEMultipartForm) { - if err := c.request.ParseMultipartForm(c.formParseMaxMemory); err != nil { + if err := c.request.ParseMultipartForm(defaultMemory); err != nil { return nil, err } } else { @@ -422,189 +417,141 @@ func (c *Context) FormValues() (url.Values, error) { return c.request.Form, nil } -// FormFile returns the multipart form file for the provided name. -func (c *Context) FormFile(name string) (*multipart.FileHeader, error) { +func (c *context) FormFile(name string) (*multipart.FileHeader, error) { f, fh, err := c.request.FormFile(name) if err != nil { return nil, err } - _ = f.Close() + f.Close() return fh, nil } -// MultipartForm returns the multipart form. -func (c *Context) MultipartForm() (*multipart.Form, error) { - err := c.request.ParseMultipartForm(c.formParseMaxMemory) +func (c *context) MultipartForm() (*multipart.Form, error) { + err := c.request.ParseMultipartForm(defaultMemory) return c.request.MultipartForm, err } -// Cookie returns the named cookie provided in the request. -func (c *Context) Cookie(name string) (*http.Cookie, error) { +func (c *context) Cookie(name string) (*http.Cookie, error) { return c.request.Cookie(name) } -// SetCookie adds a `Set-Cookie` header in HTTP response. -func (c *Context) SetCookie(cookie *http.Cookie) { +func (c *context) SetCookie(cookie *http.Cookie) { http.SetCookie(c.Response(), cookie) } -// Cookies returns the HTTP cookies sent with the request. -func (c *Context) Cookies() []*http.Cookie { +func (c *context) Cookies() []*http.Cookie { return c.request.Cookies() } -// Get retrieves data from the context. -// Method returns any(nil) when key does not exist which is different from typed nil (eg. []byte(nil)). -func (c *Context) Get(key string) any { - // Unlock without defer to avoid the deferred-call overhead on this hot path. +func (c *context) Get(key string) any { c.lock.RLock() - v := c.store[key] - c.lock.RUnlock() - return v + defer c.lock.RUnlock() + return c.store[key] } -// Set saves data in the context. -func (c *Context) Set(key string, val any) { +func (c *context) Set(key string, val any) { c.lock.Lock() + defer c.lock.Unlock() + if c.store == nil { - c.store = make(map[string]any) + c.store = make(Map) } c.store[key] = val - c.lock.Unlock() } -// Bind binds path params, query params and the request body into provided type `i`. The default binder -// binds body based on Content-Type header. -func (c *Context) Bind(i any) error { - return c.echo.Binder.Bind(c, i) +func (c *context) Bind(i any) error { + return c.echo.Binder.Bind(i, c) } -// Validate validates provided `i`. It is usually called after `Context#Bind()`. -// Validator must be registered using `Echo#Validator`. -func (c *Context) Validate(i any) error { +func (c *context) Validate(i any) error { if c.echo.Validator == nil { return ErrValidatorNotRegistered } return c.echo.Validator.Validate(i) } -// Render renders a template with data and sends a text/html response with status -// code. Renderer must be registered using `Echo.Renderer`. -func (c *Context) Render(code int, name string, data any) (err error) { +func (c *context) Render(code int, name string, data any) (err error) { if c.echo.Renderer == nil { return ErrRendererNotRegistered } - // as Renderer.Render can fail, and in that case we need to delay sending status code to the client until - // (global) error handler decides the correct status code for the error to be sent to the client, so we need to write - // the rendered template to the buffer first. - // - // html.Template.ExecuteTemplate() documentations writes: - // > If an error occurs executing the template or writing its output, - // > execution stops, but partial results may already have been written to - // > the output writer. - buf := new(bytes.Buffer) - if err = c.echo.Renderer.Render(c, buf, name, data); err != nil { + if err = c.echo.Renderer.Render(buf, name, data, c); err != nil { return } return c.HTMLBlob(code, buf.Bytes()) } -// HTML sends an HTTP response with status code. -func (c *Context) HTML(code int, html string) (err error) { - return c.HTMLBlob(code, stringToBytes(html)) +func (c *context) HTML(code int, html string) (err error) { + return c.HTMLBlob(code, []byte(html)) } -// HTMLBlob sends an HTTP blob response with status code. -func (c *Context) HTMLBlob(code int, b []byte) (err error) { +func (c *context) HTMLBlob(code int, b []byte) (err error) { return c.Blob(code, MIMETextHTMLCharsetUTF8, b) } -// String sends a string response with status code. -func (c *Context) String(code int, s string) (err error) { - return c.Blob(code, MIMETextPlainCharsetUTF8, stringToBytes(s)) +func (c *context) String(code int, s string) (err error) { + return c.Blob(code, MIMETextPlainCharsetUTF8, []byte(s)) } -func (c *Context) jsonPBlob(code int, callback string, i any) (err error) { +func (c *context) jsonPBlob(code int, callback string, i any) (err error) { + indent := "" + if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { + indent = defaultIndent + } c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8) c.response.WriteHeader(code) - if _, err = c.response.Write(stringToBytes(callback)); err != nil { - return - } - if _, err = c.response.Write(jsonpOpen); err != nil { + if _, err = c.response.Write([]byte(callback + "(")); err != nil { return } - if err = c.echo.JSONSerializer.Serialize(c, i, ""); err != nil { + if err = c.echo.JSONSerializer.Serialize(c, i, indent); err != nil { return } - if _, err = c.response.Write(jsonpClose); err != nil { + if _, err = c.response.Write([]byte(");")); err != nil { return } return } -func (c *Context) json(code int, i any, indent string) error { +func (c *context) json(code int, i any, indent string) error { c.writeContentType(MIMEApplicationJSON) - - // as JSONSerializer.Serialize can fail, and in that case we need to delay sending status code to the client until - // (global) error handler decides correct status code for the error to be sent to the client. - // For that we need to use writer that can store the proposed status code until the first Write is called. - resp := c.Response() - // Reuse the Context-owned delayedStatusWriter to avoid heap-allocating one per JSON response. - // If we are already nested inside a delayed write (rare: a serializer or handler calling c.JSON - // re-entrantly), allocate a fresh writer so the outer call's writer (which is &c.dsw) is not - // clobbered β€” reusing c.dsw here would make it reference itself. - if _, nested := resp.(*delayedStatusWriter); nested { - c.SetResponse(&delayedStatusWriter{ResponseWriter: resp, status: code}) - } else { - c.dsw = delayedStatusWriter{ResponseWriter: resp, status: code} - c.SetResponse(&c.dsw) - } - defer c.SetResponse(resp) - + c.response.Status = code return c.echo.JSONSerializer.Serialize(c, i, indent) } -// JSON sends a JSON response with status code. -func (c *Context) JSON(code int, i any) (err error) { - return c.json(code, i, "") +func (c *context) JSON(code int, i any) (err error) { + indent := "" + if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { + indent = defaultIndent + } + return c.json(code, i, indent) } -// JSONPretty sends a pretty-print JSON with status code. -func (c *Context) JSONPretty(code int, i any, indent string) (err error) { +func (c *context) JSONPretty(code int, i any, indent string) (err error) { return c.json(code, i, indent) } -// JSONBlob sends a JSON blob response with status code. -func (c *Context) JSONBlob(code int, b []byte) (err error) { +func (c *context) JSONBlob(code int, b []byte) (err error) { return c.Blob(code, MIMEApplicationJSON, b) } -// JSONP sends a JSONP response with status code. It uses `callback` to construct -// the JSONP payload. -func (c *Context) JSONP(code int, callback string, i any) (err error) { +func (c *context) JSONP(code int, callback string, i any) (err error) { return c.jsonPBlob(code, callback, i) } -// JSONPBlob sends a JSONP blob response with status code. It uses `callback` -// to construct the JSONP payload. -func (c *Context) JSONPBlob(code int, callback string, b []byte) (err error) { +func (c *context) JSONPBlob(code int, callback string, b []byte) (err error) { c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8) c.response.WriteHeader(code) - if _, err = c.response.Write(stringToBytes(callback)); err != nil { - return - } - if _, err = c.response.Write(jsonpOpen); err != nil { + if _, err = c.response.Write([]byte(callback + "(")); err != nil { return } if _, err = c.response.Write(b); err != nil { return } - _, err = c.response.Write(jsonpClose) + _, err = c.response.Write([]byte(");")) return } -func (c *Context) xml(code int, i any, indent string) (err error) { +func (c *context) xml(code int, i any, indent string) (err error) { c.writeContentType(MIMEApplicationXMLCharsetUTF8) c.response.WriteHeader(code) enc := xml.NewEncoder(c.response) @@ -617,18 +564,19 @@ func (c *Context) xml(code int, i any, indent string) (err error) { return enc.Encode(i) } -// XML sends an XML response with status code. -func (c *Context) XML(code int, i any) (err error) { - return c.xml(code, i, "") +func (c *context) XML(code int, i any) (err error) { + indent := "" + if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { + indent = defaultIndent + } + return c.xml(code, i, indent) } -// XMLPretty sends a pretty-print XML with status code. -func (c *Context) XMLPretty(code int, i any, indent string) (err error) { +func (c *context) XMLPretty(code int, i any, indent string) (err error) { return c.xml(code, i, indent) } -// XMLBlob sends an XML blob response with status code. -func (c *Context) XMLBlob(code int, b []byte) (err error) { +func (c *context) XMLBlob(code int, b []byte) (err error) { c.writeContentType(MIMEApplicationXMLCharsetUTF8) c.response.WriteHeader(code) if _, err = c.response.Write([]byte(xml.Header)); err != nil { @@ -638,98 +586,41 @@ func (c *Context) XMLBlob(code int, b []byte) (err error) { return } -// Blob sends a blob response with status code and content type. -func (c *Context) Blob(code int, contentType string, b []byte) (err error) { +func (c *context) Blob(code int, contentType string, b []byte) (err error) { c.writeContentType(contentType) c.response.WriteHeader(code) _, err = c.response.Write(b) return } -// Stream sends a streaming response with status code and content type. -func (c *Context) Stream(code int, contentType string, r io.Reader) (err error) { +func (c *context) Stream(code int, contentType string, r io.Reader) (err error) { c.writeContentType(contentType) c.response.WriteHeader(code) _, err = io.Copy(c.response, r) return } -// File sends a response with the content of the file. -// -// Avoid using the leading `/` slash as most of the Go standard library fs.FS implementations require relative paths for -// file operations. -func (c *Context) File(file string) error { - return fsFile(c, file, c.echo.Filesystem) -} - -// FileFS serves file from given file system. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (c *Context) FileFS(file string, filesystem fs.FS) error { - return fsFile(c, file, filesystem) -} - -func fsFile(c *Context, file string, filesystem fs.FS) error { - file = path.Clean(file) // `os.Open` and `os.DirFs.Open()` behave differently, later does not like ``, `.`, `..` at all, but we allowed those now need to clean - f, err := filesystem.Open(file) - if err != nil { - return ErrNotFound - } - defer f.Close() - - fi, _ := f.Stat() - if fi.IsDir() { - file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect. - f, err = filesystem.Open(file) - if err != nil { - return ErrNotFound - } - defer f.Close() - if fi, err = f.Stat(); err != nil { - return err - } - } - ff, ok := f.(io.ReadSeeker) - if !ok { - return errors.New("file does not implement io.ReadSeeker") - } - http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff) - return nil -} - -// Attachment sends a response as attachment, prompting client to save the file. -// -// Avoid using the leading `/` slash as most of the Go standard library fs.FS implementations require relative paths for -// file operations. -func (c *Context) Attachment(file, name string) error { +func (c *context) Attachment(file, name string) error { return c.contentDisposition(file, name, "attachment") } -// Inline sends a response as inline, opening the file in the browser. -// -// Avoid using the leading `/` slash as most of the Go standard library fs.FS implementations require relative paths for -// file operations. -func (c *Context) Inline(file, name string) error { +func (c *context) Inline(file, name string) error { return c.contentDisposition(file, name, "inline") } var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") -func (c *Context) contentDisposition(file, name, dispositionType string) error { +func (c *context) contentDisposition(file, name, dispositionType string) error { c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf(`%s; filename="%s"`, dispositionType, quoteEscaper.Replace(name))) return c.File(file) } -// NoContent sends a response with no body and a status code. -func (c *Context) NoContent(code int) error { +func (c *context) NoContent(code int) error { c.response.WriteHeader(code) return nil } -// Redirect redirects the request to a provided URL with status code. -func (c *Context) Redirect(code int, url string) error { +func (c *context) Redirect(code int, url string) error { if code < 300 || code > 308 { return ErrInvalidRedirectCode } @@ -738,20 +629,45 @@ func (c *Context) Redirect(code int, url string) error { return nil } -// Logger returns logger in Context -func (c *Context) Logger() *slog.Logger { - if c.logger != nil { - return c.logger +func (c *context) Error(err error) { + c.echo.HTTPErrorHandler(err, c) +} + +func (c *context) Echo() *Echo { + return c.echo +} + +func (c *context) Handler() HandlerFunc { + return c.handler +} + +func (c *context) SetHandler(h HandlerFunc) { + c.handler = h +} + +func (c *context) Logger() Logger { + res := c.logger + if res != nil { + return res } return c.echo.Logger } -// SetLogger sets logger in Context -func (c *Context) SetLogger(logger *slog.Logger) { - c.logger = logger +func (c *context) SetLogger(l Logger) { + c.logger = l } -// Echo returns the `Echo` instance. -func (c *Context) Echo() *Echo { - return c.echo +func (c *context) Reset(r *http.Request, w http.ResponseWriter) { + c.request = r + c.response.reset(w) + c.query = nil + c.handler = NotFoundHandler + c.store = nil + c.path = "" + c.pnames = nil + c.logger = nil + // NOTE: Don't reset because it has to have length c.echo.maxParam (or bigger) at all times + for i := 0; i < len(c.pvalues); i++ { + c.pvalues[i] = "" + } } diff --git a/context_fs.go b/context_fs.go new file mode 100644 index 000000000..1c25baf12 --- /dev/null +++ b/context_fs.go @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "errors" + "io" + "io/fs" + "net/http" + "path/filepath" +) + +func (c *context) File(file string) error { + return fsFile(c, file, c.echo.Filesystem) +} + +// FileFS serves file from given file system. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (c *context) FileFS(file string, filesystem fs.FS) error { + return fsFile(c, file, filesystem) +} + +func fsFile(c Context, file string, filesystem fs.FS) error { + f, err := filesystem.Open(file) + if err != nil { + return ErrNotFound + } + defer f.Close() + + fi, _ := f.Stat() + if fi.IsDir() { + file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect. + f, err = filesystem.Open(file) + if err != nil { + return ErrNotFound + } + defer f.Close() + if fi, err = f.Stat(); err != nil { + return err + } + } + ff, ok := f.(io.ReadSeeker) + if !ok { + return errors.New("file does not implement io.ReadSeeker") + } + http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff) + return nil +} diff --git a/context_fs_test.go b/context_fs_test.go new file mode 100644 index 000000000..83232ea45 --- /dev/null +++ b/context_fs_test.go @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "github.com/stretchr/testify/assert" + "io/fs" + "net/http" + "net/http/httptest" + "os" + "testing" +) + +func TestContext_File(t *testing.T) { + var testCases = []struct { + name string + whenFile string + whenFS fs.FS + expectStatus int + expectStartsWith []byte + expectError string + }{ + { + name: "ok, from default file system", + whenFile: "_fixture/images/walle.png", + whenFS: nil, + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "ok, from custom file system", + whenFile: "walle.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, not existent file", + whenFile: "not.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: nil, + expectError: "code=404, message=Not Found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + if tc.whenFS != nil { + e.Filesystem = tc.whenFS + } + + handler := func(ec Context) error { + return ec.(*context).File(tc.whenFile) + } + + req := httptest.NewRequest(http.MethodGet, "/match.png", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + + assert.Equal(t, tc.expectStatus, rec.Code) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestContext_FileFS(t *testing.T) { + var testCases = []struct { + name string + whenFile string + whenFS fs.FS + expectStatus int + expectStartsWith []byte + expectError string + }{ + { + name: "ok", + whenFile: "walle.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, not existent file", + whenFile: "not.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: nil, + expectError: "code=404, message=Not Found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + handler := func(ec Context) error { + return ec.(*context).FileFS(tc.whenFile, tc.whenFS) + } + + req := httptest.NewRequest(http.MethodGet, "/match.png", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + + assert.Equal(t, tc.expectStatus, rec.Code) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} diff --git a/context_generic.go b/context_generic.go index 7cf8b296c..f06041bbf 100644 --- a/context_generic.go +++ b/context_generic.go @@ -13,12 +13,9 @@ var ErrInvalidKeyType = errors.New("invalid key type") // ContextGet retrieves a value from the context store or ErrNonExistentKey error the key is missing. // Returns ErrInvalidKeyType error if the value is not castable to type T. -func ContextGet[T any](c *Context, key string) (T, error) { - c.lock.RLock() - defer c.lock.RUnlock() - - val, ok := c.store[key] - if !ok { +func ContextGet[T any](c Context, key string) (T, error) { + val := c.Get(key) + if val == any(nil) { var zero T return zero, ErrNonExistentKey } @@ -34,7 +31,7 @@ func ContextGet[T any](c *Context, key string) (T, error) { // ContextGetOr retrieves a value from the context store or returns a default value when the key // is missing. Returns ErrInvalidKeyType error if the value is not castable to type T. -func ContextGetOr[T any](c *Context, key string, defaultValue T) (T, error) { +func ContextGetOr[T any](c Context, key string, defaultValue T) (T, error) { typed, err := ContextGet[T](c, key) if err == ErrNonExistentKey { return defaultValue, nil diff --git a/context_generic_test.go b/context_generic_test.go index ce468ac3e..9b6d2d04e 100644 --- a/context_generic_test.go +++ b/context_generic_test.go @@ -10,7 +10,8 @@ import ( ) func TestContextGetOK(t *testing.T) { - c := NewContext(nil, nil) + e := New() + c := e.NewContext(nil, nil) c.Set("key", int64(123)) @@ -20,7 +21,8 @@ func TestContextGetOK(t *testing.T) { } func TestContextGetNonExistentKey(t *testing.T) { - c := NewContext(nil, nil) + e := New() + c := e.NewContext(nil, nil) c.Set("key", int64(123)) @@ -30,7 +32,8 @@ func TestContextGetNonExistentKey(t *testing.T) { } func TestContextGetInvalidCast(t *testing.T) { - c := NewContext(nil, nil) + e := New() + c := e.NewContext(nil, nil) c.Set("key", int64(123)) @@ -40,7 +43,8 @@ func TestContextGetInvalidCast(t *testing.T) { } func TestContextGetOrOK(t *testing.T) { - c := NewContext(nil, nil) + e := New() + c := e.NewContext(nil, nil) c.Set("key", int64(123)) @@ -50,7 +54,8 @@ func TestContextGetOrOK(t *testing.T) { } func TestContextGetOrNonExistentKey(t *testing.T) { - c := NewContext(nil, nil) + e := New() + c := e.NewContext(nil, nil) c.Set("key", int64(123)) @@ -60,7 +65,8 @@ func TestContextGetOrNonExistentKey(t *testing.T) { } func TestContextGetOrInvalidCast(t *testing.T) { - c := NewContext(nil, nil) + e := New() + c := e.NewContext(nil, nil) c.Set("key", int64(123)) diff --git a/context_test.go b/context_test.go index 9b820c6e3..c848d9044 100644 --- a/context_test.go +++ b/context_test.go @@ -8,22 +8,20 @@ import ( "crypto/tls" "encoding/json" "encoding/xml" + "errors" "fmt" "io" - "io/fs" - "log/slog" "math" "mime/multipart" - "net" "net/http" "net/http/httptest" "net/url" - "os" "strings" "testing" "text/template" "time" + "github.com/labstack/gommon/log" "github.com/stretchr/testify/assert" ) @@ -31,14 +29,13 @@ type Template struct { templates *template.Template } -var testUser = user{ID: 1, Name: "Jon Snow"} +var testUser = user{1, "Jon Snow"} func BenchmarkAllocJSONP(b *testing.B) { e := New() - e.Logger = slog.New(slog.DiscardHandler) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) b.ResetTimer() b.ReportAllocs() @@ -50,10 +47,9 @@ func BenchmarkAllocJSONP(b *testing.B) { func BenchmarkAllocJSON(b *testing.B) { e := New() - e.Logger = slog.New(slog.DiscardHandler) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) b.ResetTimer() b.ReportAllocs() @@ -65,10 +61,9 @@ func BenchmarkAllocJSON(b *testing.B) { func BenchmarkAllocXML(b *testing.B) { e := New() - e.Logger = slog.New(slog.DiscardHandler) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) b.ResetTimer() b.ReportAllocs() @@ -79,7 +74,7 @@ func BenchmarkAllocXML(b *testing.B) { } func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) { - c := Context{request: &http.Request{ + c := context{request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, }} for i := 0; i < b.N; i++ { @@ -87,7 +82,7 @@ func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) { } } -func (t *Template) Render(c *Context, w io.Writer, name string, data any) error { +func (t *Template) Render(w io.Writer, name string, data any, c Context) error { return t.templates.ExecuteTemplate(w, name, data) } @@ -96,7 +91,7 @@ func TestContextEcho(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) assert.Equal(t, e, c.Echo()) } @@ -106,7 +101,7 @@ func TestContextRequest(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) assert.NotNil(t, c.Request()) assert.Equal(t, req, c.Request()) @@ -117,7 +112,7 @@ func TestContextResponse(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) assert.NotNil(t, c.Response()) } @@ -127,12 +122,12 @@ func TestContextRenderTemplate(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) tmpl := &Template{ templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")), } - c.Echo().Renderer = tmpl + c.echo.Renderer = tmpl err := c.Render(http.StatusOK, "hello", "Jon Snow") if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) @@ -140,91 +135,24 @@ func TestContextRenderTemplate(t *testing.T) { } } -func TestContextRenderTemplateError(t *testing.T) { - // we test that when template rendering fails, no response is sent to the client yet, so the global error handler can decide what to do - e := New() - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - tmpl := &Template{ - templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")), - } - c.Echo().Renderer = tmpl - err := c.Render(http.StatusOK, "not_existing", "Jon Snow") - - assert.EqualError(t, err, `template: no template "not_existing" associated with template "hello"`) - assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client - assert.Empty(t, rec.Body.String()) // body must not be sent to the client -} - func TestContextRenderErrorsOnNoRenderer(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) - c.Echo().Renderer = nil + c.echo.Renderer = nil assert.Error(t, c.Render(http.StatusOK, "hello", "Jon Snow")) } -func TestContextStream(t *testing.T) { - e := New() - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec) - - r, w := io.Pipe() - go func() { - defer w.Close() - for i := range 3 { - fmt.Fprintf(w, "data: index %v\n\n", i) - time.Sleep(5 * time.Millisecond) - } - }() - - err := c.Stream(http.StatusOK, "text/event-stream", r) - if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "text/event-stream", rec.Header().Get(HeaderContentType)) - assert.Equal(t, "data: index 0\n\ndata: index 1\n\ndata: index 2\n\n", rec.Body.String()) - } -} - -func TestContextHTML(t *testing.T) { - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - c := NewContext(req, rec) - - err := c.HTML(http.StatusOK, "Hi, Jon Snow") - if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(t, "Hi, Jon Snow", rec.Body.String()) - } -} - -func TestContextHTMLBlob(t *testing.T) { - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - c := NewContext(req, rec) - - err := c.HTMLBlob(http.StatusOK, []byte("Hi, Jon Snow")) - if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(t, "Hi, Jon Snow", rec.Body.String()) - } -} - func TestContextJSON(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) - err := c.JSON(http.StatusOK, user{ID: 1, Name: "Jon Snow"}) + err := c.JSON(http.StatusOK, user{1, "Jon Snow"}) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) @@ -236,37 +164,33 @@ func TestContextJSONErrorsOut(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) err := c.JSON(http.StatusOK, make(chan bool)) assert.EqualError(t, err, "json: unsupported type: chan bool") - - assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client - assert.Empty(t, rec.Body.String()) // body must not be sent to the client } -func TestContextJSONWithNotEchoResponse(t *testing.T) { +func TestContextJSONPrettyURL(t *testing.T) { e := New() rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) - c := e.NewContext(req, rec) - - c.SetResponse(rec) - - err := c.JSON(http.StatusCreated, map[string]float64{"foo": math.NaN()}) - assert.EqualError(t, err, "json: unsupported value: NaN") + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) - assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client - assert.Empty(t, rec.Body.String()) // body must not be sent to the client + err := c.JSON(http.StatusOK, user{1, "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) + } } func TestContextJSONPretty(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) - err := c.JSONPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, " ") + err := c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ") if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) @@ -278,16 +202,16 @@ func TestContextJSONWithEmptyIntent(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) - u := user{ID: 1, Name: "Jon Snow"} + u := user{1, "Jon Snow"} emptyIndent := "" buf := new(bytes.Buffer) enc := json.NewEncoder(buf) enc.SetIndent(emptyIndent, emptyIndent) _ = enc.Encode(u) - err := c.JSONPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, emptyIndent) + err := c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) @@ -299,10 +223,10 @@ func TestContextJSONP(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) callback := "callback" - err := c.JSONP(http.StatusOK, callback, user{ID: 1, Name: "Jon Snow"}) + err := c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"}) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) @@ -314,9 +238,9 @@ func TestContextJSONBlob(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) - data, err := json.Marshal(user{ID: 1, Name: "Jon Snow"}) + data, err := json.Marshal(user{1, "Jon Snow"}) assert.NoError(t, err) err = c.JSONBlob(http.StatusOK, data) if assert.NoError(t, err) { @@ -330,10 +254,10 @@ func TestContextJSONPBlob(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) callback := "callback" - data, err := json.Marshal(user{ID: 1, Name: "Jon Snow"}) + data, err := json.Marshal(user{1, "Jon Snow"}) assert.NoError(t, err) err = c.JSONPBlob(http.StatusOK, callback, data) if assert.NoError(t, err) { @@ -347,9 +271,9 @@ func TestContextXML(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) - err := c.XML(http.StatusOK, user{ID: 1, Name: "Jon Snow"}) + err := c.XML(http.StatusOK, user{1, "Jon Snow"}) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) @@ -357,13 +281,27 @@ func TestContextXML(t *testing.T) { } } +func TestContextXMLPrettyURL(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + + err := c.XML(http.StatusOK, user{1, "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String()) + } +} + func TestContextXMLPretty(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) - err := c.XMLPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, " ") + err := c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ") if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) @@ -375,9 +313,9 @@ func TestContextXMLBlob(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) - data, err := xml.Marshal(user{ID: 1, Name: "Jon Snow"}) + data, err := xml.Marshal(user{1, "Jon Snow"}) assert.NoError(t, err) err = c.XMLBlob(http.StatusOK, data) if assert.NoError(t, err) { @@ -391,16 +329,16 @@ func TestContextXMLWithEmptyIntent(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) - u := user{ID: 1, Name: "Jon Snow"} + u := user{1, "Jon Snow"} emptyIndent := "" buf := new(bytes.Buffer) enc := xml.NewEncoder(buf) enc.Indent(emptyIndent, emptyIndent) _ = enc.Encode(u) - err := c.XMLPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, emptyIndent) + err := c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) @@ -408,17 +346,71 @@ func TestContextXMLWithEmptyIntent(t *testing.T) { } } -func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { +type responseWriterErr struct { +} + +func (responseWriterErr) Header() http.Header { + return http.Header{} +} + +func (responseWriterErr) Write([]byte) (int, error) { + return 0, errors.New("responseWriterErr") +} + +func (responseWriterErr) WriteHeader(statusCode int) { +} + +func TestContextXMLError(t *testing.T) { e := New() - req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - err := c.JSON(http.StatusCreated, user{ID: 1, Name: "Jon Snow"}) + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + c.response.Writer = responseWriterErr{} + + err := c.XML(http.StatusOK, make(chan bool)) + assert.EqualError(t, err, "responseWriterErr") +} + +func TestContextString(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + err := c.String(http.StatusOK, "Hello, World!") if assert.NoError(t, err) { - assert.Equal(t, http.StatusCreated, rec.Code) - assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) - assert.Equal(t, userJSON+"\n", rec.Body.String()) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hello, World!", rec.Body.String()) + } +} + +func TestContextHTML(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + + err := c.HTML(http.StatusOK, "Hello, World!") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hello, World!", rec.Body.String()) + } +} + +func TestContextStream(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + + r := strings.NewReader("response from a stream") + err := c.Stream(http.StatusOK, "application/octet-stream", r) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType)) + assert.Equal(t, "response from a stream", rec.Body.String()) } } @@ -444,7 +436,7 @@ func TestContextAttachment(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) err := c.Attachment("_fixture/images/walle.png", tc.whenName) if assert.NoError(t, err) { @@ -479,7 +471,7 @@ func TestContextInline(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) err := c.Inline("_fixture/images/walle.png", tc.whenName) if assert.NoError(t, err) { @@ -496,12 +488,69 @@ func TestContextNoContent(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) c.NoContent(http.StatusOK) assert.Equal(t, http.StatusOK, rec.Code) } +func TestContextError(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + + c.Error(errors.New("error")) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.True(t, c.Response().Committed) +} + +func TestContextReset(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + c.SetParamNames("foo") + c.SetParamValues("bar") + c.Set("foe", "ban") + c.query = url.Values(map[string][]string{"fon": {"baz"}}) + + c.Reset(req, httptest.NewRecorder()) + + assert.Len(t, c.ParamValues(), 0) + assert.Len(t, c.ParamNames(), 0) + assert.Len(t, c.Path(), 0) + assert.Len(t, c.QueryParams(), 0) + assert.Len(t, c.store, 0) +} + +func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec).(*context) + err := c.JSON(http.StatusCreated, user{1, "Jon Snow"}) + + if assert.NoError(t, err) { + assert.Equal(t, http.StatusCreated, rec.Code) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON+"\n", rec.Body.String()) + } +} + +func TestContext_JSON_DoesntCommitResponseCodePrematurely(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec).(*context) + err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()}) + + if assert.Error(t, err) { + assert.False(t, c.response.Committed) + } +} + func TestContextCookie(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -510,7 +559,7 @@ func TestContextCookie(t *testing.T) { req.Header.Add(HeaderCookie, theme) req.Header.Add(HeaderCookie, user) rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) // Read single cookie, err := c.Cookie("theme") @@ -547,237 +596,107 @@ func TestContextCookie(t *testing.T) { assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly") } -func TestContext_PathValues(t *testing.T) { - var testCases = []struct { - name string - given PathValues - expect PathValues - }{ - { - name: "param exists", - given: PathValues{ - {Name: "uid", Value: "101"}, - {Name: "fid", Value: "501"}, - }, - expect: PathValues{ - {Name: "uid", Value: "101"}, - {Name: "fid", Value: "501"}, - }, - }, - { - name: "params is empty", - given: PathValues{}, - expect: PathValues{}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, nil) - - c.SetPathValues(tc.given) +func TestContextPath(t *testing.T) { + e := New() + r := e.Router() - assert.EqualValues(t, tc.expect, c.PathValues()) - }) - } -} + handler := func(c Context) error { return c.String(http.StatusOK, "OK") } -func TestContext_PathParam(t *testing.T) { - var testCases = []struct { - name string - given PathValues - whenParamName string - expect string - }{ - { - name: "param exists", - given: PathValues{ - {Name: "uid", Value: "101"}, - {Name: "fid", Value: "501"}, - }, - whenParamName: "uid", - expect: "101", - }, - { - name: "multiple same param values exists - return first", - given: PathValues{ - {Name: "uid", Value: "101"}, - {Name: "uid", Value: "202"}, - {Name: "fid", Value: "501"}, - }, - whenParamName: "uid", - expect: "101", - }, - { - name: "param does not exists", - given: PathValues{ - {Name: "uid", Value: "101"}, - }, - whenParamName: "nope", - expect: "", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, nil) + r.Add(http.MethodGet, "/users/:id", handler) + c := e.NewContext(nil, nil) + r.Find(http.MethodGet, "/users/1", c) - c.SetPathValues(tc.given) + assert.Equal(t, "/users/:id", c.Path()) - assert.EqualValues(t, tc.expect, c.Param(tc.whenParamName)) - }) - } + r.Add(http.MethodGet, "/users/:uid/files/:fid", handler) + c = e.NewContext(nil, nil) + r.Find(http.MethodGet, "/users/1/files/1", c) + assert.Equal(t, "/users/:uid/files/:fid", c.Path()) } -func TestContext_PathParamDefault(t *testing.T) { - var testCases = []struct { - name string - given PathValues - whenParamName string - whenDefaultValue string - expect string - }{ - { - name: "param exists", - given: PathValues{ - {Name: "uid", Value: "101"}, - {Name: "fid", Value: "501"}, - }, - whenParamName: "uid", - whenDefaultValue: "999", - expect: "101", - }, - { - name: "param exists and is empty", - given: PathValues{ - {Name: "uid", Value: ""}, - {Name: "fid", Value: "501"}, - }, - whenParamName: "uid", - whenDefaultValue: "999", - expect: "", // <-- this is different from QueryParamOr behaviour - }, - { - name: "param does not exists", - given: PathValues{ - {Name: "uid", Value: "101"}, - }, - whenParamName: "nope", - whenDefaultValue: "999", - expect: "999", - }, - } +func TestContextPathParam(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, nil) - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, nil) + // ParamNames + c.SetParamNames("uid", "fid") + assert.EqualValues(t, []string{"uid", "fid"}, c.ParamNames()) - c.SetPathValues(tc.given) + // ParamValues + c.SetParamValues("101", "501") + assert.EqualValues(t, []string{"101", "501"}, c.ParamValues()) - assert.EqualValues(t, tc.expect, c.ParamOr(tc.whenParamName, tc.whenDefaultValue)) - }) - } + // Param + assert.Equal(t, "501", c.Param("fid")) + assert.Equal(t, "", c.Param("undefined")) } -func TestContextGetAndSetPathValuesMutability(t *testing.T) { - t.Run("c.PathValues() does not return copy and modifying raw slice mutates value in context", func(t *testing.T) { - e := New() - e.contextPathParamAllocSize.Store(1) - - req := httptest.NewRequest(http.MethodGet, "/:foo", nil) - c := e.NewContext(req, nil) - - params := PathValues{{Name: "foo", Value: "101"}} - c.SetPathValues(params) - - // round-trip param values with modification - paramVals := c.PathValues() - assert.Equal(t, params, c.PathValues()) - - // PathValues() does not return copy and modifying raw slice mutates value in context - paramVals[0] = PathValue{Name: "xxx", Value: "yyy"} - assert.Equal(t, PathValues{PathValue{Name: "xxx", Value: "yyy"}}, c.PathValues()) - }) - - t.Run("calling SetPathValues with bigger size changes capacity in context", func(t *testing.T) { - e := New() - e.contextPathParamAllocSize.Store(1) - - req := httptest.NewRequest(http.MethodGet, "/:foo", nil) - c := e.NewContext(req, nil) - // increase path param capacity in context - pathValues := PathValues{ - {Name: "aaa", Value: "bbb"}, - {Name: "ccc", Value: "ddd"}, - } - c.SetPathValues(pathValues) - assert.Equal(t, pathValues, c.PathValues()) - - // shouldn't explode during Reset() afterwards! - assert.NotPanics(t, func() { - c.Reset(nil, nil) - }) - assert.Equal(t, PathValues{}, c.PathValues()) - assert.Len(t, *c.pathValues, 0) - assert.Equal(t, 2, cap(*c.pathValues)) - }) - - t.Run("calling SetPathValues with smaller size slice does not change capacity in context", func(t *testing.T) { - e := New() - - req := httptest.NewRequest(http.MethodGet, "/:foo", nil) - c := e.NewContext(req, nil) - c.pathValues = &PathValues{ - {Name: "aaa", Value: "bbb"}, - {Name: "ccc", Value: "ddd"}, - } - - pathValues := PathValues{ - {Name: "aaa", Value: "bbb"}, - } - // given pathValues slice is smaller. this should not decrease c.pathValues capacity - c.SetPathValues(pathValues) - assert.Equal(t, pathValues, c.PathValues()) - - // shouldn't explode during Reset() afterwards! - assert.NotPanics(t, func() { - c.Reset(nil, nil) - }) - assert.Equal(t, PathValues{}, c.PathValues()) - assert.Len(t, *c.pathValues, 0) - assert.Equal(t, 2, cap(*c.pathValues)) +func TestContextGetAndSetParam(t *testing.T) { + e := New() + r := e.Router() + r.Add(http.MethodGet, "/:foo", func(Context) error { return nil }) + req := httptest.NewRequest(http.MethodGet, "/:foo", nil) + c := e.NewContext(req, nil) + c.SetParamNames("foo") + + // round-trip param values with modification + paramVals := c.ParamValues() + assert.EqualValues(t, []string{""}, c.ParamValues()) + paramVals[0] = "bar" + c.SetParamValues(paramVals...) + assert.EqualValues(t, []string{"bar"}, c.ParamValues()) + + // shouldn't explode during Reset() afterwards! + assert.NotPanics(t, func() { + c.Reset(nil, nil) }) - } -// Issue #1655 -func TestContext_SetParamNamesShouldNotModifyPathValuesCapacity(t *testing.T) { +func TestContextSetParamNamesEchoMaxParam(t *testing.T) { e := New() - c := e.NewContext(nil, nil) + assert.Equal(t, 0, *e.maxParam) + + expectedOneParam := []string{"one"} + expectedTwoParams := []string{"one", "two"} + expectedThreeParams := []string{"one", "two", ""} + + { + c := e.AcquireContext() + c.SetParamNames("1", "2") + c.SetParamValues(expectedTwoParams...) + assert.Equal(t, 0, *e.maxParam) // has not been changed + assert.EqualValues(t, expectedTwoParams, c.ParamValues()) + e.ReleaseContext(c) + } - assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load()) - expectedTwoParams := PathValues{ - {Name: "1", Value: "one"}, - {Name: "2", Value: "two"}, + { + c := e.AcquireContext() + c.SetParamNames("1", "2", "3") + c.SetParamValues(expectedThreeParams...) + assert.Equal(t, 0, *e.maxParam) // has not been changed + assert.EqualValues(t, expectedThreeParams, c.ParamValues()) + e.ReleaseContext(c) } - c.SetPathValues(expectedTwoParams) - assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load()) - assert.Equal(t, expectedTwoParams, c.PathValues()) - - expectedThreeParams := PathValues{ - {Name: "1", Value: "one"}, - {Name: "2", Value: "two"}, - {Name: "3", Value: "three"}, + + { // values is always same size as names length + c := e.NewContext(nil, nil) + c.SetParamValues([]string{"one", "two"}...) // more values than names should be ok + c.SetParamNames("1") + assert.Equal(t, 0, *e.maxParam) // has not been changed + assert.EqualValues(t, expectedOneParam, c.ParamValues()) + } + + e.GET("/:id", handlerFunc) + assert.Equal(t, 1, *e.maxParam) // has not been changed + + { + c := e.NewContext(nil, nil) + c.SetParamValues([]string{"one", "two"}...) + c.SetParamNames("1") + assert.Equal(t, 1, *e.maxParam) // has not been changed + assert.EqualValues(t, expectedOneParam, c.ParamValues()) } - c.SetPathValues(expectedThreeParams) - assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load()) - assert.Equal(t, expectedThreeParams, c.PathValues()) } func TestContextFormValue(t *testing.T) { @@ -794,151 +713,41 @@ func TestContextFormValue(t *testing.T) { assert.Equal(t, "Jon Snow", c.FormValue("name")) assert.Equal(t, "jon@labstack.com", c.FormValue("email")) - // FormValueOr - assert.Equal(t, "Jon Snow", c.FormValueOr("name", "nope")) - assert.Equal(t, "default", c.FormValueOr("missing", "default")) - - // FormValues - values, err := c.FormValues() + // FormParams + params, err := c.FormParams() if assert.NoError(t, err) { assert.Equal(t, url.Values{ "name": []string{"Jon Snow"}, "email": []string{"jon@labstack.com"}, - }, values) + }, params) } // Multipart FormParams error req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) req.Header.Add(HeaderContentType, MIMEMultipartForm) c = e.NewContext(req, nil) - values, err = c.FormValues() - assert.Nil(t, values) + params, err = c.FormParams() + assert.Nil(t, params) assert.Error(t, err) } -func TestContext_QueryParams(t *testing.T) { - var testCases = []struct { - expect url.Values - name string - givenURL string - }{ - { - name: "multiple values in url", - givenURL: "/?test=1&test=2&email=jon%40labstack.com", - expect: url.Values{ - "test": []string{"1", "2"}, - "email": []string{"jon@labstack.com"}, - }, - }, - { - name: "single value in url", - givenURL: "/?nope=1", - expect: url.Values{ - "nope": []string{"1"}, - }, - }, - { - name: "no query params in url", - givenURL: "/?", - expect: url.Values{}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) - e := New() - c := e.NewContext(req, nil) - - assert.Equal(t, tc.expect, c.QueryParams()) - }) - } -} - -func TestContext_QueryParam(t *testing.T) { - var testCases = []struct { - name string - givenURL string - whenParamName string - expect string - }{ - { - name: "value exists in url", - givenURL: "/?test=1", - whenParamName: "test", - expect: "1", - }, - { - name: "multiple values exists in url", - givenURL: "/?test=9&test=8", - whenParamName: "test", - expect: "9", // <-- first value in returned - }, - { - name: "value does not exists in url", - givenURL: "/?nope=1", - whenParamName: "test", - expect: "", - }, - { - name: "value is empty in url", - givenURL: "/?test=", - whenParamName: "test", - expect: "", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) - e := New() - c := e.NewContext(req, nil) - - assert.Equal(t, tc.expect, c.QueryParam(tc.whenParamName)) - }) - } -} - -func TestContext_QueryParamDefault(t *testing.T) { - var testCases = []struct { - name string - givenURL string - whenParamName string - whenDefaultValue string - expect string - }{ - { - name: "value exists in url", - givenURL: "/?test=1", - whenParamName: "test", - whenDefaultValue: "999", - expect: "1", - }, - { - name: "value does not exists in url", - givenURL: "/?nope=1", - whenParamName: "test", - whenDefaultValue: "999", - expect: "999", - }, - { - name: "value is empty in url", - givenURL: "/?test=", - whenParamName: "test", - whenDefaultValue: "999", - expect: "999", - }, - } +func TestContextQueryParam(t *testing.T) { + q := make(url.Values) + q.Set("name", "Jon Snow") + q.Set("email", "jon@labstack.com") + req := httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil) + e := New() + c := e.NewContext(req, nil) - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) - e := New() - c := e.NewContext(req, nil) + // QueryParam + assert.Equal(t, "Jon Snow", c.QueryParam("name")) + assert.Equal(t, "jon@labstack.com", c.QueryParam("email")) - assert.Equal(t, tc.expect, c.QueryParamOr(tc.whenParamName, tc.whenDefaultValue)) - }) - } + // QueryParams + assert.Equal(t, url.Values{ + "name": []string{"Jon Snow"}, + "email": []string{"jon@labstack.com"}, + }, c.QueryParams()) } func TestContextFormFile(t *testing.T) { @@ -999,47 +808,16 @@ func TestContextRedirect(t *testing.T) { assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) } -func TestContextGet(t *testing.T) { - var testCases = []struct { - name string - given any - whenKey string - expect any - }{ - { - name: "ok, value exist", - given: "Jon Snow", - whenKey: "key", - expect: "Jon Snow", - }, - { - name: "ok, value does not exist", - given: "Jon Snow", - whenKey: "nope", - expect: nil, - }, - { - name: "ok, value is nil value", - given: []byte(nil), - whenKey: "key", - expect: []byte(nil), - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var c = new(Context) - c.Set("key", tc.given) - - v := c.Get(tc.whenKey) - assert.Equal(t, tc.expect, v) - }) - } +func TestContextStore(t *testing.T) { + var c Context = new(context) + c.Set("name", "Jon Snow") + assert.Equal(t, "Jon Snow", c.Get("name")) } func BenchmarkContext_Store(b *testing.B) { e := &Echo{} - c := &Context{ + c := &context{ echo: e, } @@ -1051,6 +829,42 @@ func BenchmarkContext_Store(b *testing.B) { } } +func TestContextHandler(t *testing.T) { + e := New() + r := e.Router() + b := new(bytes.Buffer) + + r.Add(http.MethodGet, "/handler", func(Context) error { + _, err := b.Write([]byte("handler")) + return err + }) + c := e.NewContext(nil, nil) + r.Find(http.MethodGet, "/handler", c) + err := c.Handler()(c) + assert.Equal(t, "handler", b.String()) + assert.NoError(t, err) +} + +func TestContext_SetHandler(t *testing.T) { + var c Context = new(context) + + assert.Nil(t, c.Handler()) + + c.SetHandler(func(c Context) error { + return nil + }) + assert.NotNil(t, c.Handler()) +} + +func TestContext_Path(t *testing.T) { + path := "/pa/th" + + var c Context = new(context) + + c.SetPath(path) + assert.Equal(t, path, c.Path()) +} + type validator struct{} func (*validator) Validate(i any) error { @@ -1079,7 +893,7 @@ func TestContext_QueryString(t *testing.T) { } func TestContext_Request(t *testing.T) { - var c = new(Context) + var c Context = new(context) assert.Nil(t, c.Request()) @@ -1252,9 +1066,10 @@ func TestContext_Scheme(t *testing.T) { if tc.givenHeaders != nil { req.Header = tc.givenHeaders } - c := NewContext(req, nil) + e := New() + c := e.NewContext(req, nil) if tc.givenIsTLS { - c.request.TLS = &tls.ConnectionState{} + c.Request().TLS = &tls.ConnectionState{} } assert.Equal(t, tc.expect, c.Scheme()) @@ -1264,56 +1079,39 @@ func TestContext_Scheme(t *testing.T) { func TestContext_IsWebSocket(t *testing.T) { tests := []struct { - c *Context + c Context ws assert.BoolAssertionFunc }{ { - &Context{ + &context{ request: &http.Request{ - Header: http.Header{ - HeaderUpgrade: []string{"websocket"}, - HeaderConnection: []string{"upgrade"}, - }, + Header: http.Header{HeaderUpgrade: []string{"websocket"}}, }, }, assert.True, }, { - &Context{ + &context{ request: &http.Request{ - Header: http.Header{ - HeaderUpgrade: []string{"Websocket"}, - HeaderConnection: []string{"Upgrade"}, - }, + Header: http.Header{HeaderUpgrade: []string{"Websocket"}}, }, }, assert.True, }, { - &Context{ + &context{ request: &http.Request{}, }, assert.False, }, { - &Context{ + &context{ request: &http.Request{ Header: http.Header{HeaderUpgrade: []string{"other"}}, }, }, assert.False, }, - { - &Context{ - request: &http.Request{ - Header: http.Header{ - HeaderUpgrade: []string{"websocket"}, - HeaderConnection: []string{"close"}, - }, - }, - }, - assert.False, - }, } for i, tt := range tests { @@ -1332,212 +1130,110 @@ func TestContext_Bind(t *testing.T) { req.Header.Add(HeaderContentType, MIMEApplicationJSON) err := c.Bind(u) assert.NoError(t, err) - assert.Equal(t, &user{ID: 1, Name: "Jon Snow"}, u) + assert.Equal(t, &user{1, "Jon Snow"}, u) } -func TestContext_RealIP(t *testing.T) { - _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64") +func TestContext_Logger(t *testing.T) { + e := New() + c := e.NewContext(nil, nil) - var testCases = []struct { - name string - givenIPExtrator IPExtractor - whenReq *http.Request - expect string + log1 := c.Logger() + assert.NotNil(t, log1) + + log2 := log.New("echo2") + c.SetLogger(log2) + assert.Equal(t, log2, c.Logger()) + + // Resetting the context returns the initial logger + c.Reset(nil, nil) + assert.Equal(t, log1, c.Logger()) +} + +func TestContext_RealIP(t *testing.T) { + tests := []struct { + c Context + s string }{ { - name: "ip from remote addr", - givenIPExtrator: nil, - whenReq: &http.Request{RemoteAddr: "89.89.89.89:1654"}, - expect: "89.89.89.89", + &context{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, + }, + }, + "127.0.0.1", }, { - name: "ip from ip extractor", - givenIPExtrator: ExtractIPFromRealIPHeader(TrustIPRange(ipv6ForRemoteAddrExternalRange)), - whenReq: &http.Request{ - Header: http.Header{ - HeaderXRealIP: []string{"[2001:db8::113:199]"}, - HeaderXForwardedFor: []string{"[2001:db8::113:198], [2001:db8::113:197]"}, // <-- should not affect anything + &context{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1,127.0.1.1"}}, }, - RemoteAddr: "[2001:db8::113:1]:8080", }, - expect: "2001:db8::113:199", + "127.0.0.1", }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - c := e.NewContext(tc.whenReq, nil) - if tc.givenIPExtrator != nil { - e.IPExtractor = tc.givenIPExtrator - } - assert.Equal(t, tc.expect, c.RealIP()) - }) - } -} - -func TestContext_File(t *testing.T) { - var testCases = []struct { - whenFS fs.FS - name string - whenFile string - expectError string - expectStartsWith []byte - expectStatus int - }{ { - name: "ok, from default file system", - whenFile: "_fixture/images/walle.png", - whenFS: nil, - expectStatus: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, + &context{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1"}}, + }, + }, + "127.0.0.1", }, { - name: "ok, from custom file system", - whenFile: "walle.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, + &context{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348], 2001:db8::1, "}}, + }, + }, + "2001:db8:85a3:8d3:1319:8a2e:370:7348", }, { - name: "nok, not existent file", - whenFile: "not.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: nil, - expectError: "Not Found", + &context{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348],[2001:db8::1]"}}, + }, + }, + "2001:db8:85a3:8d3:1319:8a2e:370:7348", }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - if tc.whenFS != nil { - e.Filesystem = tc.whenFS - } - - handler := func(ec *Context) error { - return ec.File(tc.whenFile) - } - - req := httptest.NewRequest(http.MethodGet, "/match.png", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := handler(c) - - assert.Equal(t, tc.expectStatus, rec.Code) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} - -func TestContext_FileFS(t *testing.T) { - var testCases = []struct { - whenFS fs.FS - name string - whenFile string - expectError string - expectStartsWith []byte - expectStatus int - }{ { - name: "ok", - whenFile: "walle.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, + &context{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"2001:db8:85a3:8d3:1319:8a2e:370:7348"}}, + }, + }, + "2001:db8:85a3:8d3:1319:8a2e:370:7348", }, { - name: "nok, not existent file", - whenFile: "not.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: nil, - expectError: "Not Found", + &context{ + request: &http.Request{ + Header: http.Header{ + "X-Real-Ip": []string{"192.168.0.1"}, + }, + }, + }, + "192.168.0.1", + }, + { + &context{ + request: &http.Request{ + Header: http.Header{ + "X-Real-Ip": []string{"[2001:db8::1]"}, + }, + }, + }, + "2001:db8::1", }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - - handler := func(ec *Context) error { - return ec.FileFS(tc.whenFile, tc.whenFS) - } - - req := httptest.NewRequest(http.MethodGet, "/match.png", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := handler(c) - - assert.Equal(t, tc.expectStatus, rec.Code) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) + { + &context{ + request: &http.Request{ + RemoteAddr: "89.89.89.89:1654", + }, + }, + "89.89.89.89", + }, } -} - -func TestLogger(t *testing.T) { - e := New() - c := e.NewContext(nil, nil) - - log1 := c.Logger() - assert.NotNil(t, log1) - assert.Equal(t, e.Logger, log1) - - customLogger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - c.SetLogger(customLogger) - assert.Equal(t, customLogger, c.Logger()) - - // Resetting the context returns the initial Echo logger - c.Reset(nil, nil) - assert.Equal(t, e.Logger, c.Logger()) -} - -func TestRouteInfo(t *testing.T) { - e := New() - c := e.NewContext(nil, nil) - orgRI := RouteInfo{ - Name: "root", - Method: http.MethodGet, - Path: "/*", - Parameters: []string{"*"}, + for _, tt := range tests { + assert.Equal(t, tt.s, tt.c.RealIP()) } - c.route = &orgRI - ri := c.RouteInfo() - assert.Equal(t, orgRI, ri) - - // Test mutability when middlewares start to change things - - // RouteInfo inside context will not be affected when returned instance is changed - expect := orgRI.Clone() - ri.Path = "changed" - ri.Parameters[0] = "changed" - assert.Equal(t, expect, c.RouteInfo()) - - // RouteInfo inside context will not be affected when returned instance is changed - expect = c.RouteInfo() - orgRI.Name = "changed" - assert.NotEqual(t, expect, c.RouteInfo()) } diff --git a/dispatch_pool_test.go b/dispatch_pool_test.go deleted file mode 100644 index b2912139f..000000000 --- a/dispatch_pool_test.go +++ /dev/null @@ -1,101 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors - -package echo - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" -) - -// TestContextResetClearsStore guards the clear(c.store) reuse: a pooled Context must not leak store -// values from a previous request into the next one, and Set must still work after a clear-based Reset. -func TestContextResetClearsStore(t *testing.T) { - e := New() - c := e.NewContext(httptest.NewRequest(http.MethodGet, "/", nil), httptest.NewRecorder()) - c.Set("secret", "req1") - assert.Equal(t, "req1", c.Get("secret")) - - c.Reset(httptest.NewRequest(http.MethodGet, "/", nil), httptest.NewRecorder()) - assert.Nil(t, c.Get("secret"), "store must not leak across Reset") - - c.Set("k", "req2") // Set must still work after clear-based reset - assert.Equal(t, "req2", c.Get("k")) -} - -// TestContextJSONStatusAcrossReset guards the reused delayedStatusWriter (c.dsw): a second JSON -// response on a pooled+Reset Context must use the new status, not inherit the previous one. -func TestContextJSONStatusAcrossReset(t *testing.T) { - e := New() - c := e.NewContext(httptest.NewRequest(http.MethodGet, "/", nil), httptest.NewRecorder()) - assert.NoError(t, c.JSON(http.StatusTeapot, map[string]int{"a": 1})) - - rec2 := httptest.NewRecorder() - c.Reset(httptest.NewRequest(http.MethodGet, "/", nil), rec2) - assert.NoError(t, c.JSON(http.StatusCreated, map[string]int{"b": 2})) - assert.Equal(t, http.StatusCreated, rec2.Code) - assert.JSONEq(t, `{"b":2}`, rec2.Body.String()) -} - -// TestNestedJSONUsesFreshDelayedWriter guards the nested c.JSON case: a serializer that calls c.JSON -// re-entrantly must not corrupt the outer delayedStatusWriter (regression test for c.dsw self-reference). -func TestNestedJSONUsesFreshDelayedWriter(t *testing.T) { - e := New() - e.JSONSerializer = nestedJSONSerializer{} - rec := httptest.NewRecorder() - c := e.NewContext(httptest.NewRequest(http.MethodGet, "/", nil), rec) - assert.NoError(t, c.JSON(http.StatusOK, map[string]int{"outer": 1})) - assert.Equal(t, http.StatusOK, rec.Code) -} - -type nestedJSONSerializer struct{} - -func (nestedJSONSerializer) Serialize(c *Context, i any, indent string) error { - if m, ok := i.(map[string]int); ok && m["outer"] == 1 { - // re-enter c.JSON once before encoding the outer payload - if err := c.JSON(http.StatusOK, map[string]int{"inner": 2}); err != nil { - return err - } - } - return (DefaultJSONSerializer{}).Serialize(c, i, indent) -} - -func (nestedJSONSerializer) Deserialize(c *Context, i any) error { - return (DefaultJSONSerializer{}).Deserialize(c, i) -} - -// TestGlobalMiddlewareRunsOnNotFoundAndMethodNotAllowed pins the dispatch contract: global (Use) and -// pre (Pre) middleware must execute even when the router returns 404 / 405 / OPTIONS handlers. -func TestGlobalMiddlewareRunsOnNotFoundAndMethodNotAllowed(t *testing.T) { - cases := []struct { - name, method, path string - code int - }{ - {"404", http.MethodGet, "/missing", http.StatusNotFound}, - {"405", http.MethodPost, "/", http.StatusMethodNotAllowed}, - {"OPTIONS", http.MethodOptions, "/", http.StatusNoContent}, - } - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - e := New() - var pre, use bool - e.Pre(func(n HandlerFunc) HandlerFunc { - return func(c *Context) error { pre = true; return n(c) } - }) - e.Use(func(n HandlerFunc) HandlerFunc { - return func(c *Context) error { use = true; return n(c) } - }) - e.GET("/", func(c *Context) error { return c.String(http.StatusOK, "ok") }) - - rec := httptest.NewRecorder() - e.ServeHTTP(rec, httptest.NewRequest(tc.method, tc.path, nil)) - - assert.True(t, pre, "pre-middleware must run on %s", tc.name) - assert.True(t, use, "global middleware must run on %s", tc.name) - assert.Equal(t, tc.code, rec.Code) - }) - } -} diff --git a/echo.go b/echo.go index ee3a53cd0..489e16c23 100644 --- a/echo.go +++ b/echo.go @@ -9,33 +9,30 @@ Example: package main import ( - "log/slog" - "net/http" + "net/http" - "github.com/labstack/echo/v5" - "github.com/labstack/echo/v5/middleware" + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" ) // Handler - func hello(c *echo.Context) error { - return c.String(http.StatusOK, "Hello, World!") + func hello(c echo.Context) error { + return c.String(http.StatusOK, "Hello, World!") } func main() { - // Echo instance - e := echo.New() + // Echo instance + e := echo.New() - // Middleware - e.Use(middleware.RequestLogger()) - e.Use(middleware.Recover()) + // Middleware + e.Use(middleware.Logger()) + e.Use(middleware.Recover()) - // Routes - e.GET("/", hello) + // Routes + e.GET("/", hello) - // Start server - if err := e.Start(":8080"); err != nil { - slog.Error("failed to start server", "error", err) - } + // Start server + e.Logger.Fatal(e.Start(":1323")) } Learn more at https://echo.labstack.com @@ -44,99 +41,126 @@ package echo import ( stdContext "context" + "crypto/tls" "encoding/json" "errors" "fmt" - "io/fs" - "log/slog" + stdLog "log" + "net" "net/http" - "net/url" "os" - "os/signal" - "path" - "path/filepath" - "strings" + "reflect" + "runtime" "sync" - "sync/atomic" - "syscall" - - "github.com/labstack/echo/v5/internal/pathutil" + "time" + + "github.com/labstack/gommon/color" + "github.com/labstack/gommon/log" + "golang.org/x/crypto/acme" + "golang.org/x/crypto/acme/autocert" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) // Echo is the top-level framework instance. // // Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these // fields from handlers/middlewares and changing field values at the same time leads to data-races. -// Same rule applies to adding new routes after server has been started - Adding a route is not Goroutine safe action. +// Adding new routes after the server has been started is also not safe! type Echo struct { - serveHTTPFunc func(http.ResponseWriter, *http.Request) - - Binder Binder - - // Filesystem is the file system used for serving static files. Defaults to the current working directory (os.Getwd()). - // - // Note: fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid - // so if you have `fs := os.DirFS("/tmp")` and you try to `fs.Open("/tmp/file.txt")` it will fail, but "file.txt" - // would succeed. `echo.NewDefaultFS("/tmp")` overwrites this behavior and allows you to use Open with a matching - // absolute path prefix. - Filesystem fs.FS - - Renderer Renderer - Validator Validator + filesystem + common + // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get + // listener address info (on which interface/port was listener bound) without having data races. + startupMutex sync.RWMutex + colorer *color.Color + + // premiddleware are middlewares that are run before routing is done. In case a pre-middleware returns + // an error the router is not executed and the request will end up in the global error handler. + premiddleware []MiddlewareFunc + middleware []MiddlewareFunc + maxParam *int + router *Router + routers map[string]*Router + pool sync.Pool + + StdLogger *stdLog.Logger + Server *http.Server + TLSServer *http.Server + Listener net.Listener + TLSListener net.Listener + AutoTLSManager autocert.Manager + HTTPErrorHandler HTTPErrorHandler + Binder Binder JSONSerializer JSONSerializer + Validator Validator + Renderer Renderer + Logger Logger IPExtractor IPExtractor - OnAddRoute func(route Route) error - HTTPErrorHandler HTTPErrorHandler - Logger *slog.Logger + ListenerNetwork string - contextPool sync.Pool + // OnAddRouteHandler is called when Echo adds new route to specific host router. + OnAddRouteHandler func(host string, route Route, handler HandlerFunc, middleware []MiddlewareFunc) + DisableHTTP2 bool + Debug bool + HideBanner bool + HidePort bool +} - router Router +// Route contains a handler and information for matching against requests. +type Route struct { + Method string `json:"method"` + Path string `json:"path"` + Name string `json:"name"` +} - // premiddleware are middlewares that are called before routing is done - premiddleware []MiddlewareFunc +// HTTPError represents an error that occurred while handling a request. +type HTTPError struct { + Internal error `json:"-"` // Stores the error returned by an external dependency + Message any `json:"message"` + Code int `json:"-"` +} - // middleware are middlewares that are called after routing is done and before handler is called - middleware []MiddlewareFunc +// MiddlewareFunc defines a function to process middleware. +type MiddlewareFunc func(next HandlerFunc) HandlerFunc - // chain is the global middleware chain (e.middleware) compiled once and reused for every request. - // It terminates in a dispatcher that invokes the route handler stored on the Context during routing. - // Rebuilt by Use(). See buildRouterChains. - chain HandlerFunc - // preChain is the pre-middleware chain (e.premiddleware) compiled once. It performs routing and then - // invokes chain. Rebuilt by Pre()/Use(). Only used when premiddleware is registered. - preChain HandlerFunc +// HandlerFunc defines a function to serve HTTP requests. +type HandlerFunc func(c Context) error - contextPathParamAllocSize atomic.Int32 +// HTTPErrorHandler is a centralized HTTP error handler. +type HTTPErrorHandler func(err error, c Context) - // formParseMaxMemory is passed to Context for multipart form parsing (See http.Request.ParseMultipartForm) - formParseMaxMemory int64 +// Validator is the interface that wraps the Validate function. +type Validator interface { + Validate(i any) error } // JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. type JSONSerializer interface { - Serialize(c *Context, target any, indent string) error - Deserialize(c *Context, target any) error + Serialize(c Context, i any, indent string) error + Deserialize(c Context, i any) error } -// HTTPErrorHandler is a centralized HTTP error handler. -type HTTPErrorHandler func(c *Context, err error) +// Map defines a generic map of type `map[string]any`. +type Map map[string]any -// HandlerFunc defines a function to serve HTTP requests. -type HandlerFunc func(c *Context) error - -// MiddlewareFunc defines a function to process middleware. -type MiddlewareFunc func(next HandlerFunc) HandlerFunc +// Common struct for Echo & Group. +type common struct{} -// MiddlewareConfigurator defines interface for creating middleware handlers with possibility to return configuration errors instead of panicking. -type MiddlewareConfigurator interface { - ToMiddleware() (MiddlewareFunc, error) -} - -// Validator is the interface that wraps the Validate function. -type Validator interface { - Validate(i any) error -} +// HTTP methods +// NOTE: Deprecated, please use the stdlib constants directly instead. +const ( + CONNECT = http.MethodConnect + DELETE = http.MethodDelete + GET = http.MethodGet + HEAD = http.MethodHead + OPTIONS = http.MethodOptions + PATCH = http.MethodPatch + POST = http.MethodPost + // PROPFIND = "PROPFIND" + PUT = http.MethodPut + TRACE = http.MethodTrace +) // MIME types const ( @@ -145,7 +169,7 @@ const ( // Deprecated: Please use MIMEApplicationJSON instead. JSON should be encoded using UTF-8 by default. // No "charset" parameter is defined for this registration. // Adding one really has no effect on compliant recipients. - // See RFC 8259, section 8.1. https://datatracker.ietf.org/doc/html/rfc8259#section-8.1n" + // See RFC 8259, section 8.1. https://datatracker.ietf.org/doc/html/rfc8259#section-8.1 MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8 MIMEApplicationJavaScript = "application/javascript" MIMEApplicationJavaScriptCharsetUTF8 = MIMEApplicationJavaScript + "; " + charsetUTF8 @@ -172,9 +196,6 @@ const ( REPORT = "REPORT" // RouteNotFound is special method type for routes handling "route not found" (404) cases RouteNotFound = "echo_route_not_found" - // RouteAny is special method type that matches any HTTP method in request. Any has lower - // priority that other methods that have been registered with Router to that path. - RouteAny = "echo_route_any" ) // Headers @@ -235,7 +256,7 @@ const ( HeaderXFrameOptions = "X-Frame-Options" HeaderContentSecurityPolicy = "Content-Security-Policy" HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only" - HeaderXCSRFToken = "X-CSRF-Token" // #nosec G101 + HeaderXCSRFToken = "X-CSRF-Token" HeaderReferrerPolicy = "Referrer-Policy" // HeaderSecFetchSite fetch metadata request header indicates the relationship between a request initiator's @@ -244,277 +265,273 @@ const ( HeaderSecFetchSite = "Sec-Fetch-Site" ) -// Config is configuration for NewWithConfig function -type Config struct { - // Logger is the slog logger instance used for application-wide structured logging. - // If not set, a default TextHandler writing to stdout is created. - Logger *slog.Logger - - // HTTPErrorHandler is the centralized error handler that processes errors returned - // by handlers and middleware, converting them to appropriate HTTP responses. - // If not set, DefaultHTTPErrorHandler(false) is used. - HTTPErrorHandler HTTPErrorHandler - - // Router is the HTTP request router responsible for matching URLs to handlers - // using a radix tree-based algorithm. - // If not set, NewRouter(RouterConfig{}) is used. - Router Router - - // OnAddRoute is an optional callback hook executed when routes are registered. - // Useful for route validation, logging, or custom route processing. - // If not set, no callback is executed. - OnAddRoute func(route Route) error - - // Filesystem is the fs.FS implementation used for serving static files. - // Supports os.DirFS, embed.FS, and custom implementations. - // If not set, defaults to current working directory. - Filesystem fs.FS - - // Binder handles automatic data binding from HTTP requests to Go structs. - // Supports JSON, XML, form data, query parameters, and path parameters. - // If not set, DefaultBinder is used. - Binder Binder - - // Validator provides optional struct validation after data binding. - // Commonly used with third-party validation libraries. - // If not set, Context.Validate() returns ErrValidatorNotRegistered. - Validator Validator - - // Renderer provides template rendering for generating HTML responses. - // Requires integration with a template engine like html/template. - // If not set, Context.Render() returns ErrRendererNotRegistered. - Renderer Renderer - - // JSONSerializer handles JSON encoding and decoding for HTTP requests/responses. - // Can be replaced with faster alternatives like jsoniter or sonic. - // If not set, DefaultJSONSerializer using encoding/json is used. - JSONSerializer JSONSerializer +const ( + // Version of Echo + Version = "4.15.3" + website = "https://echo.labstack.com" + // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo + banner = ` + ____ __ + / __/___/ / ___ + / _// __/ _ \/ _ \ +/___/\__/_//_/\___/ %s +High performance, minimalist Go web framework +%s +____________________________________O/_______ + O\ +` +) - // IPExtractor defines the strategy for extracting the real client IP address - // from requests, particularly important when behind proxies or load balancers. - // Used for rate limiting, access control, and logging. - // If not set, falls back to checking X-Forwarded-For and X-Real-IP headers. - IPExtractor IPExtractor +var methods = [...]string{ + http.MethodConnect, + http.MethodDelete, + http.MethodGet, + http.MethodHead, + http.MethodOptions, + http.MethodPatch, + http.MethodPost, + PROPFIND, + http.MethodPut, + http.MethodTrace, + REPORT, +} + +// Errors +var ( + ErrBadRequest = NewHTTPError(http.StatusBadRequest) // HTTP 400 Bad Request + ErrUnauthorized = NewHTTPError(http.StatusUnauthorized) // HTTP 401 Unauthorized + ErrPaymentRequired = NewHTTPError(http.StatusPaymentRequired) // HTTP 402 Payment Required + ErrForbidden = NewHTTPError(http.StatusForbidden) // HTTP 403 Forbidden + ErrNotFound = NewHTTPError(http.StatusNotFound) // HTTP 404 Not Found + ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) // HTTP 405 Method Not Allowed + ErrNotAcceptable = NewHTTPError(http.StatusNotAcceptable) // HTTP 406 Not Acceptable + ErrProxyAuthRequired = NewHTTPError(http.StatusProxyAuthRequired) // HTTP 407 Proxy AuthRequired + ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout) // HTTP 408 Request Timeout + ErrConflict = NewHTTPError(http.StatusConflict) // HTTP 409 Conflict + ErrGone = NewHTTPError(http.StatusGone) // HTTP 410 Gone + ErrLengthRequired = NewHTTPError(http.StatusLengthRequired) // HTTP 411 Length Required + ErrPreconditionFailed = NewHTTPError(http.StatusPreconditionFailed) // HTTP 412 Precondition Failed + ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) // HTTP 413 Payload Too Large + ErrRequestURITooLong = NewHTTPError(http.StatusRequestURITooLong) // HTTP 414 URI Too Long + ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) // HTTP 415 Unsupported Media Type + ErrRequestedRangeNotSatisfiable = NewHTTPError(http.StatusRequestedRangeNotSatisfiable) // HTTP 416 Range Not Satisfiable + ErrExpectationFailed = NewHTTPError(http.StatusExpectationFailed) // HTTP 417 Expectation Failed + ErrTeapot = NewHTTPError(http.StatusTeapot) // HTTP 418 I'm a teapot + ErrMisdirectedRequest = NewHTTPError(http.StatusMisdirectedRequest) // HTTP 421 Misdirected Request + ErrUnprocessableEntity = NewHTTPError(http.StatusUnprocessableEntity) // HTTP 422 Unprocessable Entity + ErrLocked = NewHTTPError(http.StatusLocked) // HTTP 423 Locked + ErrFailedDependency = NewHTTPError(http.StatusFailedDependency) // HTTP 424 Failed Dependency + ErrTooEarly = NewHTTPError(http.StatusTooEarly) // HTTP 425 Too Early + ErrUpgradeRequired = NewHTTPError(http.StatusUpgradeRequired) // HTTP 426 Upgrade Required + ErrPreconditionRequired = NewHTTPError(http.StatusPreconditionRequired) // HTTP 428 Precondition Required + ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests) // HTTP 429 Too Many Requests + ErrRequestHeaderFieldsTooLarge = NewHTTPError(http.StatusRequestHeaderFieldsTooLarge) // HTTP 431 Request Header Fields Too Large + ErrUnavailableForLegalReasons = NewHTTPError(http.StatusUnavailableForLegalReasons) // HTTP 451 Unavailable For Legal Reasons + ErrInternalServerError = NewHTTPError(http.StatusInternalServerError) // HTTP 500 Internal Server Error + ErrNotImplemented = NewHTTPError(http.StatusNotImplemented) // HTTP 501 Not Implemented + ErrBadGateway = NewHTTPError(http.StatusBadGateway) // HTTP 502 Bad Gateway + ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable) // HTTP 503 Service Unavailable + ErrGatewayTimeout = NewHTTPError(http.StatusGatewayTimeout) // HTTP 504 Gateway Timeout + ErrHTTPVersionNotSupported = NewHTTPError(http.StatusHTTPVersionNotSupported) // HTTP 505 HTTP Version Not Supported + ErrVariantAlsoNegotiates = NewHTTPError(http.StatusVariantAlsoNegotiates) // HTTP 506 Variant Also Negotiates + ErrInsufficientStorage = NewHTTPError(http.StatusInsufficientStorage) // HTTP 507 Insufficient Storage + ErrLoopDetected = NewHTTPError(http.StatusLoopDetected) // HTTP 508 Loop Detected + ErrNotExtended = NewHTTPError(http.StatusNotExtended) // HTTP 510 Not Extended + ErrNetworkAuthenticationRequired = NewHTTPError(http.StatusNetworkAuthenticationRequired) // HTTP 511 Network Authentication Required + + ErrValidatorNotRegistered = errors.New("validator not registered") + ErrRendererNotRegistered = errors.New("renderer not registered") + ErrInvalidRedirectCode = errors.New("invalid redirect status code") + ErrCookieNotFound = errors.New("cookie not found") + ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") + ErrInvalidListenerNetwork = errors.New("invalid listener network") +) - // FormParseMaxMemory is default value for memory limit that is used - // when parsing multipart forms (See (*http.Request).ParseMultipartForm) - FormParseMaxMemory int64 +// NotFoundHandler is the handler that router uses in case there was no matching route found. Returns an error that results +// HTTP 404 status code. +var NotFoundHandler = func(c Context) error { + return ErrNotFound } -// NewWithConfig creates an instance of Echo with given configuration. -func NewWithConfig(config Config) *Echo { - e := New() - if config.Logger != nil { - e.Logger = config.Logger - } - if config.HTTPErrorHandler != nil { - e.HTTPErrorHandler = config.HTTPErrorHandler - } - if config.Router != nil { - e.router = config.Router - } - if config.OnAddRoute != nil { - e.OnAddRoute = config.OnAddRoute - } - if config.Filesystem != nil { - e.Filesystem = config.Filesystem - } - if config.Binder != nil { - e.Binder = config.Binder - } - if config.Validator != nil { - e.Validator = config.Validator +// MethodNotAllowedHandler is the handler thar router uses in case there was no matching route found but there was +// another matching routes for that requested URL. Returns an error that results HTTP 405 Method Not Allowed status code. +var MethodNotAllowedHandler = func(c Context) error { + // See RFC 7231 section 7.4.1: An origin server MUST generate an Allow field in a 405 (Method Not Allowed) + // response and MAY do so in any other response. For disabled resources an empty Allow header may be returned + routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string) + if ok && routerAllowMethods != "" { + c.Response().Header().Set(HeaderAllow, routerAllowMethods) } - if config.Renderer != nil { - e.Renderer = config.Renderer - } - if config.JSONSerializer != nil { - e.JSONSerializer = config.JSONSerializer - } - if config.IPExtractor != nil { - e.IPExtractor = config.IPExtractor - } - if config.FormParseMaxMemory > 0 { - e.formParseMaxMemory = config.FormParseMaxMemory - } - return e + return ErrMethodNotAllowed } // New creates an instance of Echo. -func New() *Echo { - dir, _ := os.Getwd() - logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) - e := &Echo{ - Logger: logger, - Filesystem: NewDefaultFS(dir), - Binder: &DefaultBinder{}, - JSONSerializer: &DefaultJSONSerializer{}, - formParseMaxMemory: defaultMemory, - } - - e.serveHTTPFunc = e.serveHTTP - e.router = NewRouter(RouterConfig{}) - e.HTTPErrorHandler = DefaultHTTPErrorHandler(false) - e.contextPool.New = func() any { - return newContext(nil, nil, e) - } - e.buildRouterChains() - return e +func New() (e *Echo) { + e = &Echo{ + filesystem: createFilesystem(), + Server: new(http.Server), + TLSServer: new(http.Server), + AutoTLSManager: autocert.Manager{ + Prompt: autocert.AcceptTOS, + }, + Logger: log.New("echo"), + colorer: color.New(), + maxParam: new(int), + ListenerNetwork: "tcp", + } + e.Server.Handler = e + e.TLSServer.Handler = e + e.HTTPErrorHandler = e.DefaultHTTPErrorHandler + e.Binder = &DefaultBinder{} + e.JSONSerializer = &DefaultJSONSerializer{} + e.Logger.SetLevel(log.ERROR) + e.StdLogger = stdLog.New(e.Logger.Output(), e.Logger.Prefix()+": ", 0) + e.pool.New = func() any { + return e.NewContext(nil, nil) + } + e.router = NewRouter(e) + e.routers = map[string]*Router{} + return } -// buildRouterChains compiles the global and pre-middleware chains once so that ServeHTTP does not have to -// re-wrap middleware closures on every request. It must be called whenever e.middleware or e.premiddleware -// changes (i.e. from Use/Pre). This is safe because middleware must not be mutated after the server starts. -func (e *Echo) buildRouterChains() { - // dispatch is the terminal of the global chain: it invokes the handler resolved during routing. - dispatch := func(c *Context) error { - return c.handler(c) - } - e.chain = applyMiddleware(dispatch, e.middleware...) - - // route performs routing (storing the matched handler on the Context) and then runs the global chain. - route := func(c *Context) error { - c.handler = e.router.Route(c) - return e.chain(c) +// NewContext returns a Context instance. +func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context { + return &context{ + request: r, + response: NewResponse(w, e), + store: make(Map), + echo: e, + pvalues: make([]string, *e.maxParam), + handler: NotFoundHandler, } - e.preChain = applyMiddleware(route, e.premiddleware...) -} - -// NewContext returns a new Context instance. -// -// Note: both request and response can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway -// these arguments are useful when creating context for tests and cases like that. -func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) *Context { - return newContext(r, w, e) } // Router returns the default router. -func (e *Echo) Router() Router { +func (e *Echo) Router() *Router { return e.router } -// DefaultHTTPErrorHandler creates new default HTTP error handler implementation. It sends a JSON response -// with status code. `exposeError` parameter decides if returned message will contain also error message or not +// Routers returns the map of host => router. +func (e *Echo) Routers() map[string]*Router { + return e.routers +} + +// DefaultHTTPErrorHandler is the default HTTP error handler. It sends a JSON response +// with status code. // -// Note: DefaultHTTPErrorHandler does not log errors. Use middleware for it if errors need to be logged (separately) -// Note: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). +// NOTE: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). // When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from // handler. Then the error that global error handler received will be ignored because we have already "committed" the // response and status code header has been sent to the client. -func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler { - return func(c *Context, err error) { - if r, _ := UnwrapResponse(c.response); r != nil && r.Committed { - return - } +func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { - code := http.StatusInternalServerError - var sc HTTPStatusCoder - if errors.As(err, &sc) { - if tmp := sc.StatusCode(); tmp != 0 { - code = tmp - } - } + if c.Response().Committed { + return + } - var result any - switch m := sc.(type) { - case json.Marshaler: // this type knows how to format itself to JSON - result = m - case *HTTPError: - sText := m.Message - if sText == "" { - sText = http.StatusText(code) - } - msg := map[string]any{"message": sText} - if exposeError { - if wrappedErr := m.Unwrap(); wrappedErr != nil { - msg["error"] = wrappedErr.Error() - } - } - result = msg - default: - msg := map[string]any{"message": http.StatusText(code)} - if exposeError { - msg["error"] = err.Error() + he, ok := err.(*HTTPError) + if ok { + if he.Internal != nil { + if herr, ok := he.Internal.(*HTTPError); ok { + he = herr } - result = msg } + } else { + he = &HTTPError{ + Code: http.StatusInternalServerError, + Message: http.StatusText(http.StatusInternalServerError), + } + } + + // Issue #1426 + code := he.Code + message := he.Message - var cErr error - if c.Request().Method == http.MethodHead { // Issue #608 - cErr = c.NoContent(code) + switch m := he.Message.(type) { + case string: + if e.Debug { + message = Map{"message": m, "error": err.Error()} } else { - cErr = c.JSON(code, result) - } - if cErr != nil { - c.Logger().Error("echo default error handler failed to send error to client", "error", cErr) // truly rare case. ala client already disconnected + message = Map{"message": m} } + case json.Marshaler: + // do nothing - this type knows how to format itself to JSON + case error: + message = Map{"message": m.Error()} + } + + // Send response + if c.Request().Method == http.MethodHead { // Issue #608 + err = c.NoContent(he.Code) + } else { + err = c.JSON(code, message) + } + if err != nil { + e.Logger.Error(err) } } -// Pre adds middleware to the chain which is run before router tries to find matching route. -// Meaning middleware is executed even for 404 (not found) cases. +// Pre adds middleware to the chain which is run before router. func (e *Echo) Pre(middleware ...MiddlewareFunc) { e.premiddleware = append(e.premiddleware, middleware...) - e.buildRouterChains() } -// Use adds middleware to the chain which is run after router has found matching route and before route/request handler method is executed. +// Use adds middleware to the chain which is run after router. func (e *Echo) Use(middleware ...MiddlewareFunc) { e.middleware = append(e.middleware, middleware...) - e.buildRouterChains() } // CONNECT registers a new CONNECT route for a path with matching handler in the -// router with optional route-level middleware. Panics on error. -func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// router with optional route-level middleware. +func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return e.Add(http.MethodConnect, path, h, m...) } // DELETE registers a new DELETE route for a path with matching handler in the router -// with optional route-level middleware. Panics on error. -func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// with optional route-level middleware. +func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return e.Add(http.MethodDelete, path, h, m...) } // GET registers a new GET route for a path with matching handler in the router -// with optional route-level middleware. Panics on error. -func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// with optional route-level middleware. +func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return e.Add(http.MethodGet, path, h, m...) } // HEAD registers a new HEAD route for a path with matching handler in the -// router with optional route-level middleware. Panics on error. -func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// router with optional route-level middleware. +func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return e.Add(http.MethodHead, path, h, m...) } // OPTIONS registers a new OPTIONS route for a path with matching handler in the -// router with optional route-level middleware. Panics on error. -func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// router with optional route-level middleware. +func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return e.Add(http.MethodOptions, path, h, m...) } // PATCH registers a new PATCH route for a path with matching handler in the -// router with optional route-level middleware. Panics on error. -func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// router with optional route-level middleware. +func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return e.Add(http.MethodPatch, path, h, m...) } // POST registers a new POST route for a path with matching handler in the -// router with optional route-level middleware. Panics on error. -func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// router with optional route-level middleware. +func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return e.Add(http.MethodPost, path, h, m...) } // PUT registers a new PUT route for a path with matching handler in the -// router with optional route-level middleware. Panics on error. -func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// router with optional route-level middleware. +func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return e.Add(http.MethodPut, path, h, m...) } // TRACE registers a new TRACE route for a path with matching handler in the -// router with optional route-level middleware. Panics on error. -func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// router with optional route-level middleware. +func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return e.Add(http.MethodTrace, path, h, m...) } @@ -523,8 +540,8 @@ func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo // Path supports static and named/any parameters just like other http method is defined. Generally path is ended with // wildcard/match-any character (`/*`, `/download/*` etc). // -// Example: `e.RouteNotFound("/*", func(c *echo.Context) error { return c.NoContent(http.StatusNotFound) })` -func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// Example: `e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })` +func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return e.Add(RouteNotFound, path, h, m...) } @@ -533,274 +550,388 @@ func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) Ro // // Note: this method only adds specific set of supported HTTP methods as handler and is not true // "catch-any-arbitrary-method" way of matching requests. -func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { - return e.Add(RouteAny, path, handler, middleware...) +func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { + routes := make([]*Route, len(methods)) + for i, m := range methods { + routes[i] = e.Add(m, path, handler, middleware...) + } + return routes } // Match registers a new route for multiple HTTP methods and path with matching -// handler in the router with optional route-level middleware. Panics on error. -func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { - errs := make([]error, 0) - ris := make(Routes, 0) - for _, m := range methods { - ri, err := e.AddRoute(Route{ - Method: m, - Path: path, - Handler: handler, - Middlewares: middleware, - }) - if err != nil { - errs = append(errs, err) - continue - } - ris = append(ris, ri) +// handler in the router with optional route-level middleware. +func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { + routes := make([]*Route, len(methods)) + for i, m := range methods { + routes[i] = e.Add(m, path, handler, middleware...) } - if len(errs) > 0 { - panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + return routes +} + +func (common) file(path, file string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route, + m ...MiddlewareFunc) *Route { + return get(path, func(c Context) error { + return c.File(file) + }, m...) +} + +// File registers a new route with path to serve a static file with optional route-level middleware. +func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route { + return e.file(path, file, e.GET, m...) +} + +func (e *Echo) add(host, method, path string, handler HandlerFunc, middlewares ...MiddlewareFunc) *Route { + router := e.findRouter(host) + //FIXME: when handler+middleware are both nil ... make it behave like handler removal + name := handlerName(handler) + route := router.add(method, path, name, func(c Context) error { + h := applyMiddleware(handler, middlewares...) + return h(c) + }) + + if e.OnAddRouteHandler != nil { + e.OnAddRouteHandler(host, *route, handler, middlewares) } - return ris + + return route } -// Static registers a new route with path prefix to serve static files from the provided root directory. -func (e *Echo) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo { - subFs := MustSubFS(e.Filesystem, fsRoot) - return e.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(subFs, false), - middleware..., - ) +// Add registers a new route for an HTTP method and path with matching handler +// in the router with optional route-level middleware. +func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { + return e.add("", method, path, handler, middleware...) } -// StaticFS registers a new route with path prefix to serve static files from the provided file system. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo { - return e.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(filesystem, false), - middleware..., - ) +// Host creates a new router group for the provided host and optional host-level middleware. +func (e *Echo) Host(name string, m ...MiddlewareFunc) (g *Group) { + e.routers[name] = NewRouter(e) + g = &Group{host: name, echo: e} + g.Use(m...) + return } -// StaticDirectoryHandler creates handler function to serve files from provided file system -// When disablePathUnescaping is set then file name from path is not unescaped and is served as is. -func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc { - return func(c *Context) error { - p := c.Param("*") - if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice - // By default the router matches routes against the raw, still-encoded request path - // (unless UseEscapedPathForMatching is enabled), so an encoded path separator is not - // treated as a segment boundary during routing. Unescaping it here would let it act as - // a separator and resolve a file outside the path the router authorized, bypassing - // route-level middleware (e.g. auth on a sibling route). No real filename contains a - // separator, so reject it as not found, carrying the reason internally for operators. - if pathutil.HasEncodedPathSeparator(p) { - return NewHTTPError(http.StatusNotFound, http.StatusText(http.StatusNotFound)). - Wrap(fmt.Errorf("rejected encoded path separator in static path %q", p)) - } - tmpPath, err := url.PathUnescape(p) - if err != nil { - return fmt.Errorf("failed to unescape path variable: %w", err) - } - p = tmpPath - } +// Group creates a new router group with prefix and optional group-level middleware. +func (e *Echo) Group(prefix string, m ...MiddlewareFunc) (g *Group) { + g = &Group{prefix: prefix, echo: e} + g.Use(m...) + return +} - // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid. - // Use path.Clean (not filepath.Clean): fs.FS paths are always forward-slash, so a backslash must stay a literal - // character rather than being interpreted as a separator on Windows (which would resolve a file across a boundary - // the router never matched on, the same Windows backslash traversal class as GHSA-pgvm-wxw2-hrv9). - name := path.Clean(strings.TrimPrefix(p, "/")) - fi, err := fs.Stat(fileSystem, name) - if err != nil { - return ErrNotFound - } +// URI generates an URI from handler. +func (e *Echo) URI(handler HandlerFunc, params ...any) string { + name := handlerName(handler) + return e.Reverse(name, params...) +} - // If the request is for a directory and does not end with "/" - p = c.Request().URL.Path // path must not be empty. - if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { - // Redirect to ends with "/" - return c.Redirect(http.StatusMovedPermanently, sanitizeURI(p+"/")) - } - return fsFile(c, name, fileSystem) - } +// URL is an alias for `URI` function. +func (e *Echo) URL(h HandlerFunc, params ...any) string { + return e.URI(h, params...) } -// FileFS registers a new route with path to serve file from the provided file system. -// -// Avoid using the leading `/` slash as most of the Go standard library fs.FS implementations require relative paths for -// file operations. -func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo { - return e.GET(path, StaticFileHandler(file, filesystem), m...) +// Reverse generates a URL from route name and provided parameters. +func (e *Echo) Reverse(name string, params ...any) string { + return e.router.Reverse(name, params...) } -// StaticFileHandler creates handler function to serve file from provided file system. -// -// Avoid using the leading `/` slash as most of the Go standard library fs.FS implementations require relative paths for -// file operations. -func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { - return func(c *Context) error { - return fsFile(c, file, filesystem) - } +// Routes returns the registered routes for default router. +// In case when Echo serves multiple hosts/domains use `e.Routers()["domain2.site"].Routes()` to get specific host routes. +func (e *Echo) Routes() []*Route { + return e.router.Routes() } -// File registers a new route with path to serve a static file with optional route-level middleware. Panics on error. -// -// Avoid using the leading `/` slash as most of the Go standard library fs.FS implementations require relative paths for -// file operations. -func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo { - handler := func(c *Context) error { - return c.File(file) - } - return e.Add(http.MethodGet, path, handler, middleware...) +// AcquireContext returns an empty `Context` instance from the pool. +// You must return the context by calling `ReleaseContext()`. +func (e *Echo) AcquireContext() Context { + return e.pool.Get().(Context) } -// AddRoute registers a new Route with default host Router -func (e *Echo) AddRoute(route Route) (RouteInfo, error) { - return e.add(route) +// ReleaseContext returns the `Context` instance back to the pool. +// You must call it after `AcquireContext()`. +func (e *Echo) ReleaseContext(c Context) { + e.pool.Put(c) } -func (e *Echo) add(route Route) (RouteInfo, error) { - if e.OnAddRoute != nil { - if err := e.OnAddRoute(route); err != nil { - return RouteInfo{}, err +// ServeHTTP implements `http.Handler` interface, which serves HTTP requests. +func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Acquire context + c := e.pool.Get().(*context) + c.Reset(r, w) + var h HandlerFunc + + if e.premiddleware == nil { + e.findRouter(r.Host).Find(r.Method, GetPath(r), c) + h = c.Handler() + h = applyMiddleware(h, e.middleware...) + } else { + h = func(c Context) error { + e.findRouter(r.Host).Find(r.Method, GetPath(r), c) + h := c.Handler() + h = applyMiddleware(h, e.middleware...) + return h(c) } + h = applyMiddleware(h, e.premiddleware...) } - ri, err := e.router.Add(route) - if err != nil { - return RouteInfo{}, err + // Execute chain + if err := h(c); err != nil { + e.HTTPErrorHandler(err, c) } - paramsCount := int32(len(ri.Parameters)) // #nosec G115 - if paramsCount > e.contextPathParamAllocSize.Load() { - e.contextPathParamAllocSize.Store(paramsCount) + // Release context + e.pool.Put(c) +} + +// Start starts an HTTP server. +func (e *Echo) Start(address string) error { + e.startupMutex.Lock() + e.Server.Addr = address + if err := e.configureServer(e.Server); err != nil { + e.startupMutex.Unlock() + return err } - return ri, nil + e.startupMutex.Unlock() + return e.Server.Serve(e.Listener) } -// Add registers a new route for an HTTP method and path with matching handler -// in the router with optional route-level middleware. -func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { - ri, err := e.add( - Route{ - Method: method, - Path: path, - Handler: handler, - Middlewares: middleware, - Name: "", - }, - ) - if err != nil { - panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage +// StartTLS starts an HTTPS server. +// If `certFile` or `keyFile` is `string` the values are treated as file paths. +// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is. +func (e *Echo) StartTLS(address string, certFile, keyFile any) (err error) { + e.startupMutex.Lock() + var cert []byte + if cert, err = filepathOrContent(certFile); err != nil { + e.startupMutex.Unlock() + return } - return ri + + var key []byte + if key, err = filepathOrContent(keyFile); err != nil { + e.startupMutex.Unlock() + return + } + + s := e.TLSServer + s.TLSConfig = new(tls.Config) + s.TLSConfig.Certificates = make([]tls.Certificate, 1) + if s.TLSConfig.Certificates[0], err = tls.X509KeyPair(cert, key); err != nil { + e.startupMutex.Unlock() + return + } + + e.configureTLS(address) + if err := e.configureServer(s); err != nil { + e.startupMutex.Unlock() + return err + } + e.startupMutex.Unlock() + return s.Serve(e.TLSListener) } -// Group creates a new router group with prefix and optional group-level middleware. -func (e *Echo) Group(prefix string, m ...MiddlewareFunc) (g *Group) { - g = &Group{prefix: prefix, echo: e} - g.Use(m...) - return +func filepathOrContent(fileOrContent any) (content []byte, err error) { + switch v := fileOrContent.(type) { + case string: + return os.ReadFile(v) + case []byte: + return v, nil + default: + return nil, ErrInvalidCertOrKeyType + } } -// PreMiddlewares returns registered pre middlewares. These are middleware to the chain -// which are run before router tries to find matching route. -// Use this method to build your own ServeHTTP method. -// -// NOTE: returned slice is not a copy. Do not mutate. -func (e *Echo) PreMiddlewares() []MiddlewareFunc { - return e.premiddleware +// StartAutoTLS starts an HTTPS server using certificates automatically installed from https://letsencrypt.org. +func (e *Echo) StartAutoTLS(address string) error { + e.startupMutex.Lock() + s := e.TLSServer + s.TLSConfig = new(tls.Config) + s.TLSConfig.GetCertificate = e.AutoTLSManager.GetCertificate + s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, acme.ALPNProto) + + e.configureTLS(address) + if err := e.configureServer(s); err != nil { + e.startupMutex.Unlock() + return err + } + e.startupMutex.Unlock() + return s.Serve(e.TLSListener) } -// Middlewares returns registered route level middlewares. Does not contain any group level -// middlewares. Use this method to build your own ServeHTTP method. -// -// NOTE: returned slice is not a copy. Do not mutate. -func (e *Echo) Middlewares() []MiddlewareFunc { - return e.middleware +func (e *Echo) configureTLS(address string) { + s := e.TLSServer + s.Addr = address + if !e.DisableHTTP2 { + s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2") + } } -// AcquireContext returns an empty `Context` instance from the pool. -// You must return the context by calling `ReleaseContext()`. -func (e *Echo) AcquireContext() *Context { - return e.contextPool.Get().(*Context) +// StartServer starts a custom http server. +func (e *Echo) StartServer(s *http.Server) (err error) { + e.startupMutex.Lock() + if err := e.configureServer(s); err != nil { + e.startupMutex.Unlock() + return err + } + if s.TLSConfig != nil { + e.startupMutex.Unlock() + return s.Serve(e.TLSListener) + } + e.startupMutex.Unlock() + return s.Serve(e.Listener) } -// ReleaseContext returns the `Context` instance back to the pool. -// You must call it after `AcquireContext()`. -func (e *Echo) ReleaseContext(c *Context) { - e.contextPool.Put(c) +func (e *Echo) configureServer(s *http.Server) error { + // Setup + e.colorer.SetOutput(e.Logger.Output()) + s.ErrorLog = e.StdLogger + s.Handler = e + if e.Debug { + e.Logger.SetLevel(log.DEBUG) + } + + if !e.HideBanner { + e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website)) + } + + if s.TLSConfig == nil { + if e.Listener == nil { + l, err := newListener(s.Addr, e.ListenerNetwork) + if err != nil { + return err + } + e.Listener = l + } + if !e.HidePort { + e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) + } + return nil + } + if e.TLSListener == nil { + l, err := newListener(s.Addr, e.ListenerNetwork) + if err != nil { + return err + } + e.TLSListener = tls.NewListener(l, s.TLSConfig) + } + if !e.HidePort { + e.colorer.Printf("⇨ https server started on %s\n", e.colorer.Green(e.TLSListener.Addr())) + } + return nil } -// ServeHTTP implements `http.Handler` interface, which serves HTTP requests. -func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { - e.serveHTTPFunc(w, r) +// ListenerAddr returns net.Addr for Listener +func (e *Echo) ListenerAddr() net.Addr { + e.startupMutex.RLock() + defer e.startupMutex.RUnlock() + if e.Listener == nil { + return nil + } + return e.Listener.Addr() } -// serveHTTP implements `http.Handler` interface, which serves HTTP requests. -func (e *Echo) serveHTTP(w http.ResponseWriter, r *http.Request) { - c := e.contextPool.Get().(*Context) - defer e.contextPool.Put(c) +// TLSListenerAddr returns net.Addr for TLSListener +func (e *Echo) TLSListenerAddr() net.Addr { + e.startupMutex.RLock() + defer e.startupMutex.RUnlock() + if e.TLSListener == nil { + return nil + } + return e.TLSListener.Addr() +} - c.Reset(r, w) +// StartH2CServer starts a custom http/2 server with h2c (HTTP/2 Cleartext). +func (e *Echo) StartH2CServer(address string, h2s *http2.Server) error { + e.startupMutex.Lock() + // Setup + s := e.Server + s.Addr = address + e.colorer.SetOutput(e.Logger.Output()) + s.ErrorLog = e.StdLogger + s.Handler = h2c.NewHandler(e, h2s) + if e.Debug { + e.Logger.SetLevel(log.DEBUG) + } - // The global (e.chain) and pre-middleware (e.preChain) chains are compiled once in buildRouterChains and - // reused here, so no middleware closures are allocated per request. - var err error - if e.premiddleware == nil { - c.handler = e.router.Route(c) - err = e.chain(c) - } else { - err = e.preChain(c) + if !e.HideBanner { + e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website)) } - if err != nil { - e.HTTPErrorHandler(c, err) + if e.Listener == nil { + l, err := newListener(s.Addr, e.ListenerNetwork) + if err != nil { + e.startupMutex.Unlock() + return err + } + e.Listener = l } + if !e.HidePort { + e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) + } + e.startupMutex.Unlock() + return s.Serve(e.Listener) } -// Start stars HTTP server on given address with Echo as a handler serving requests. The server can be shutdown by -// sending os.Interrupt signal with `ctrl+c`. Method returns only errors that are not http.ErrServerClosed. -// -// Note: this method is created for use in examples/demos and is deliberately simple without providing configuration -// options. -// -// In need of customization use: -// -// ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) -// defer cancel() -// sc := echo.StartConfig{Address: ":8080"} -// if err := sc.Start(ctx, e); err != nil && !errors.Is(err, http.ErrServerClosed) { -// slog.Error(err.Error()) -// } -// -// // or standard library `http.Server` -// -// s := http.Server{Addr: ":8080", Handler: e} -// if err := s.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { -// slog.Error(err.Error()) -// } -func (e *Echo) Start(address string) error { - sc := StartConfig{Address: address} - ctx, cancel := signal.NotifyContext(stdContext.Background(), os.Interrupt, syscall.SIGTERM) // start shutdown process on ctrl+c - defer cancel() - return sc.Start(ctx, e) +// Close immediately stops the server. +// It internally calls `http.Server#Close()`. +func (e *Echo) Close() error { + e.startupMutex.Lock() + defer e.startupMutex.Unlock() + if err := e.TLSServer.Close(); err != nil { + return err + } + return e.Server.Close() +} + +// Shutdown stops the server gracefully. +// It internally calls `http.Server#Shutdown()`. +func (e *Echo) Shutdown(ctx stdContext.Context) error { + e.startupMutex.Lock() + defer e.startupMutex.Unlock() + if err := e.TLSServer.Shutdown(ctx); err != nil { + return err + } + return e.Server.Shutdown(ctx) +} + +// NewHTTPError creates a new HTTPError instance. +func NewHTTPError(code int, message ...any) *HTTPError { + he := &HTTPError{Code: code, Message: http.StatusText(code)} + if len(message) > 0 { + he.Message = message[0] + } + return he +} + +// Error makes it compatible with `error` interface. +func (he *HTTPError) Error() string { + if he.Internal == nil { + return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message) + } + return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal) +} + +// SetInternal sets error to HTTPError.Internal +func (he *HTTPError) SetInternal(err error) *HTTPError { + he.Internal = err + return he +} + +// WithInternal returns clone of HTTPError with err set to HTTPError.Internal field +func (he *HTTPError) WithInternal(err error) *HTTPError { + return &HTTPError{ + Code: he.Code, + Message: he.Message, + Internal: err, + } +} + +// Unwrap satisfies the Go 1.13 error wrapper interface. +func (he *HTTPError) Unwrap() error { + return he.Internal } // WrapHandler wraps `http.Handler` into `echo.HandlerFunc`. func WrapHandler(h http.Handler) HandlerFunc { - return func(c *Context) error { - req := c.Request() - req.Pattern = c.Path() - for _, p := range c.PathValues() { - req.SetPathValue(p.Name, p.Value) - } - - h.ServeHTTP(c.Response(), req) + return func(c Context) error { + h.ServeHTTP(c.Response(), c.Request()) return nil } } @@ -808,101 +939,85 @@ func WrapHandler(h http.Handler) HandlerFunc { // WrapMiddleware wraps `func(http.Handler) http.Handler` into `echo.MiddlewareFunc` func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc { return func(next HandlerFunc) HandlerFunc { - return func(c *Context) (err error) { - req := c.Request() - req.Pattern = c.Path() - for _, p := range c.PathValues() { - req.SetPathValue(p.Name, p.Value) - } - + return func(c Context) (err error) { m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c.SetRequest(r) - c.SetResponse(NewResponse(w, c.echo.Logger)) + c.SetResponse(NewResponse(w, c.Echo())) err = next(c) - })).ServeHTTP(c.Response(), req) + })).ServeHTTP(c.Response(), c.Request()) return } } } -func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc { - for i := len(middleware) - 1; i >= 0; i-- { - h = middleware[i](h) +// GetPath returns RawPath, if it's empty returns Path from URL +// Difference between RawPath and Path is: +// - Path is where request path is stored. Value is stored in decoded form: /%47%6f%2f becomes /Go/. +// - RawPath is an optional field which only gets set if the default encoding is different from Path. +func GetPath(r *http.Request) string { + path := r.URL.RawPath + if path == "" { + path = r.URL.Path } - return h + return path } -// defaultFS emulates os.Open behavior with filesystem opened by `os.DirFs`. Difference between `os.Open` and `fs.Open` -// is that FS does not allow to open path that start with `..` or `/` etc. For example previously you could have `../images` -// in your application but `fs := os.DirFS("./")` would not allow you to use `fs.Open("../images")` and this would break -// all old applications that rely on being able to traverse up from current executable run path. -// NB: private because you really should use fs.FS implementation instances -type defaultFS struct { - fs fs.FS - prefix string +func (e *Echo) findRouter(host string) *Router { + if len(e.routers) > 0 { + if r, ok := e.routers[host]; ok { + return r + } + } + return e.router } -// NewDefaultFS returns a new defaultFS instance which allows `fs.FS.Open` to have absolute paths as input if it matches -// then given dir as prefix. -func NewDefaultFS(dir string) fs.FS { - return &defaultFS{ - prefix: dir, - fs: os.DirFS(dir), +func handlerName(h HandlerFunc) string { + t := reflect.ValueOf(h).Type() + if t.Kind() == reflect.Func { + return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name() } + return t.String() } -func (fs defaultFS) Open(name string) (fs.File, error) { - // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid - // For example `f.Name()` returns file names as absolute paths (e.g. `/tmp/data.csv`) so in case user wants to open - // a file with an absolute path we need to remove prefix and then call fs.FS.Open(). - // not to force users to cut prefix from file name we do it here. - if filepath.IsAbs(name) { - if strings.HasPrefix(name, fs.prefix) { - name = name[len(fs.prefix):] - if len(name) > 1 && os.IsPathSeparator(name[0]) { - name = name[1:] - } - } - } - return fs.fs.Open(name) +// // PathUnescape is wraps `url.PathUnescape` +// func PathUnescape(s string) (string, error) { +// return url.PathUnescape(s) +// } + +// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted +// connections. It's used by ListenAndServe and ListenAndServeTLS so +// dead TCP connections (e.g. closing laptop mid-download) eventually +// go away. +type tcpKeepAliveListener struct { + *net.TCPListener } -func subFS(currentFs fs.FS, root string) (fs.FS, error) { - root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows - if dFS, ok := currentFs.(*defaultFS); ok { - // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. - // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we - // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs - if !filepath.IsAbs(root) { - root = filepath.Join(dFS.prefix, root) - } - return &defaultFS{ - prefix: root, - fs: os.DirFS(root), - }, nil +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + if c, err = ln.AcceptTCP(); err != nil { + return + } else if err = c.(*net.TCPConn).SetKeepAlive(true); err != nil { + return } - return fs.Sub(currentFs, root) + // Ignore error from setting the KeepAlivePeriod as some systems, such as + // OpenBSD, do not support setting TCP_USER_TIMEOUT on IPPROTO_TCP + _ = c.(*net.TCPConn).SetKeepAlivePeriod(3 * time.Minute) + return } -// MustSubFS creates sub FS from current filesystem or panic on failure. -// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. -// -// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with -// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to -// create sub fs which uses necessary prefix for directory path. -func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS { - subFs, err := subFS(currentFs, fsRoot) +func newListener(address, network string) (*tcpKeepAliveListener, error) { + if network != "tcp" && network != "tcp4" && network != "tcp6" { + return nil, ErrInvalidListenerNetwork + } + l, err := net.Listen(network, address) if err != nil { - panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err)) + return nil, err } - return subFs + return &tcpKeepAliveListener{l.(*net.TCPListener)}, nil } -func sanitizeURI(uri string) string { - // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri - // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash - if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') { - uri = "/" + strings.TrimLeft(uri, `/\`) +func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc { + for i := len(middleware) - 1; i >= 0; i-- { + h = middleware[i](h) } - return uri + return h } diff --git a/echo_fs.go b/echo_fs.go new file mode 100644 index 000000000..f111095f1 --- /dev/null +++ b/echo_fs.go @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "fmt" + "io/fs" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "strings" + + "github.com/labstack/echo/v4/internal/pathutil" +) + +type filesystem struct { + // Filesystem is file system used by Static and File handlers to access files. + // Defaults to os.DirFS(".") + // + // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary + // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths + // including `assets/images` as their prefix. + Filesystem fs.FS +} + +func createFilesystem() filesystem { + return filesystem{ + Filesystem: newDefaultFS(), + } +} + +// Static registers a new route with path prefix to serve static files from the provided root directory. +func (e *Echo) Static(pathPrefix, fsRoot string) *Route { + subFs := MustSubFS(e.Filesystem, fsRoot) + return e.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(subFs, false), + ) +} + +// StaticFS registers a new route with path prefix to serve static files from the provided file system. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route { + return e.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(filesystem, false), + ) +} + +// StaticDirectoryHandler creates handler function to serve files from provided file system +// When disablePathUnescaping is set then file name from path is not unescaped and is served as is. +func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc { + return func(c Context) error { + p := c.Param("*") + if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice + // The router matches routes against the raw, still-encoded request path, so an + // encoded path separator (%2F or %5C) is not treated as a segment boundary during + // routing. Unescaping it here would let it act as a separator and resolve a file + // outside the path the router authorized, bypassing route-level middleware (e.g. auth + // on a sibling route). No real filename contains a separator, so reject it as not found. + if pathutil.HasEncodedPathSeparator(p) { + return ErrNotFound + } + tmpPath, err := url.PathUnescape(p) + if err != nil { + return fmt.Errorf("failed to unescape path variable: %w", err) + } + p = tmpPath + } + + // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid. + // Use path.Clean (not filepath.Clean): fs.FS paths are always forward-slash, so a backslash must stay a literal + // character rather than being interpreted as a separator on Windows (which would resolve a file across a boundary + // the router never matched on). + name := path.Clean(strings.TrimPrefix(p, "/")) + fi, err := fs.Stat(fileSystem, name) + if err != nil { + return ErrNotFound + } + + // If the request is for a directory and does not end with "/" + p = c.Request().URL.Path // path must not be empty. + if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { + // Redirect to ends with "/" + return c.Redirect(http.StatusMovedPermanently, sanitizeURI(p+"/")) + } + return fsFile(c, name, fileSystem) + } +} + +// FileFS registers a new route with path to serve file from the provided file system. +func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { + return e.GET(path, StaticFileHandler(file, filesystem), m...) +} + +// StaticFileHandler creates handler function to serve file from provided file system +func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { + return func(c Context) error { + return fsFile(c, file, filesystem) + } +} + +// defaultFS exists to preserve pre v4.7.0 behaviour where files were open by `os.Open`. +// v4.7 introduced `echo.Filesystem` field which is Go1.16+ `fs.Fs` interface. +// Difference between `os.Open` and `fs.Open` is that FS does not allow opening path that start with `.`, `..` or `/` +// etc. For example previously you could have `../images` in your application but `fs := os.DirFS("./")` would not +// allow you to use `fs.Open("../images")` and this would break all old applications that rely on being able to +// traverse up from current executable run path. +// NB: private because you really should use fs.FS implementation instances +type defaultFS struct { + fs fs.FS + prefix string +} + +func newDefaultFS() *defaultFS { + dir, _ := os.Getwd() + return &defaultFS{ + prefix: dir, + fs: nil, + } +} + +func (fs defaultFS) Open(name string) (fs.File, error) { + if fs.fs == nil { + return os.Open(name) + } + return fs.fs.Open(name) +} + +func subFS(currentFs fs.FS, root string) (fs.FS, error) { + root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows + if dFS, ok := currentFs.(*defaultFS); ok { + // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. + // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we + // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs + if !filepath.IsAbs(root) { + root = filepath.Join(dFS.prefix, root) + } + return &defaultFS{ + prefix: root, + fs: os.DirFS(root), + }, nil + } + return fs.Sub(currentFs, root) +} + +// MustSubFS creates sub FS from current filesystem or panic on failure. +// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. +// +// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with +// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to +// create sub fs which uses necessary prefix for directory path. +func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS { + subFs, err := subFS(currentFs, fsRoot) + if err != nil { + panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err)) + } + return subFs +} + +func sanitizeURI(uri string) string { + // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri + // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash + if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') { + uri = "/" + strings.TrimLeft(uri, `/\`) + } + return uri +} diff --git a/static_encoded_separator_test.go b/echo_fs_encoded_separator_test.go similarity index 81% rename from static_encoded_separator_test.go rename to echo_fs_encoded_separator_test.go index a10eafabc..16752e41d 100644 --- a/static_encoded_separator_test.go +++ b/echo_fs_encoded_separator_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors + package echo import ( @@ -9,8 +12,8 @@ import ( "github.com/stretchr/testify/assert" ) -// Regression for GHSA-vfp3-v2gw-7wfq: an encoded slash (%2F) must not let a static -// file request resolve across a path separator and bypass route-level middleware. +// Regression for GHSA-vfp3-v2gw-7wfq (v4 backport): an encoded path separator (%2F or %5C) +// must not let a static file request resolve across a separator and bypass route-level middleware. func TestStaticDirectoryHandler_EncodedSeparatorDoesNotBypassRoute(t *testing.T) { fsys := fstest.MapFS{ "admin/secret.txt": {Data: []byte("TOP-SECRET")}, @@ -18,9 +21,9 @@ func TestStaticDirectoryHandler_EncodedSeparatorDoesNotBypassRoute(t *testing.T) } e := New() g := e.Group("/admin", func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { return c.String(http.StatusForbidden, "denied") } + return func(c Context) error { return c.String(http.StatusForbidden, "denied") } }) - g.GET("/*", func(c *Context) error { return c.String(http.StatusOK, "reached-protected-handler") }) + g.GET("/*", func(c Context) error { return c.String(http.StatusOK, "reached-protected-handler") }) e.StaticFS("/", fsys) cases := []struct { @@ -31,7 +34,7 @@ func TestStaticDirectoryHandler_EncodedSeparatorDoesNotBypassRoute(t *testing.T) {"/admin/secret.txt", http.StatusForbidden, "denied"}, // protected route fires {"/admin%2Fsecret.txt", http.StatusNotFound, ""}, // encoded slash rejected, no disclosure {"/admin%2fsecret.txt", http.StatusNotFound, ""}, // lower-case hex variant - {"/admin%5Csecret.txt", http.StatusNotFound, ""}, // encoded backslash variant + {"/admin%5Csecret.txt", http.StatusNotFound, ""}, // encoded backslash (Windows separator) neutralized by path.Clean {"/admin%252Fsecret.txt", http.StatusNotFound, ""}, // double-encoded: single unescape -> literal filename, not a separator {"/index.html", http.StatusOK, "public"}, // legitimate static file still served } diff --git a/echo_fs_test.go b/echo_fs_test.go new file mode 100644 index 000000000..75f32dfb0 --- /dev/null +++ b/echo_fs_test.go @@ -0,0 +1,274 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "github.com/stretchr/testify/assert" + "io/fs" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" +) + +func TestEcho_StaticFS(t *testing.T) { + var testCases = []struct { + name string + givenPrefix string + givenFs fs.FS + givenFsRoot string + whenURL string + expectStatus int + expectHeaderLocation string + expectBodyStartsWith string + }{ + { + name: "ok", + givenPrefix: "/images", + givenFs: os.DirFS("./_fixture/images"), + whenURL: "/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "ok, from sub fs", + givenPrefix: "/images", + givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"), + whenURL: "/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "No file", + givenPrefix: "/images", + givenFs: os.DirFS("_fixture/scripts"), + whenURL: "/images/bolt.png", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory", + givenPrefix: "/images", + givenFs: os.DirFS("_fixture/images"), + whenURL: "/images/", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory Redirect", + givenPrefix: "/", + givenFs: os.DirFS("_fixture/"), + whenURL: "/folder", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory Redirect with non-root path", + givenPrefix: "/static", + givenFs: os.DirFS("_fixture"), + whenURL: "/static", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/static/", + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory 404 (request URL without slash)", + givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" + givenFs: os.DirFS("_fixture"), + whenURL: "/folder", // no trailing slash + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Prefixed directory redirect (without slash redirect to slash)", + givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* + givenFs: os.DirFS("_fixture"), + whenURL: "/folder", // no trailing slash + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory with index.html", + givenPrefix: "/", + givenFs: os.DirFS("_fixture"), + whenURL: "/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending with slash)", + givenPrefix: "/assets/", + givenFs: os.DirFS("_fixture"), + whenURL: "/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending without slash)", + givenPrefix: "/assets", + givenFs: os.DirFS("_fixture"), + whenURL: "/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Sub-directory with index.html", + givenPrefix: "/", + givenFs: os.DirFS("_fixture"), + whenURL: "/folder/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "do not allow directory traversal (backslash - windows separator)", + givenPrefix: "/", + givenFs: os.DirFS("_fixture/"), + whenURL: `/..\\middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "do not allow directory traversal (slash - unix separator)", + givenPrefix: "/", + givenFs: os.DirFS("_fixture/"), + whenURL: `/../middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + // An encoded slash (%2f) is rejected outright (GHSA-vfp3-v2gw-7wfq): the router matches + // on the raw path so %2f is not a separator, and unescaping it here would let it act as + // one. No redirect is emitted, closing the open-redirect vector. + name: "encoded slash is rejected, not redirected", + givenPrefix: "/", + givenFs: os.DirFS("_fixture/"), + whenURL: "/open.redirect.hackercom%2f..", + expectStatus: http.StatusNotFound, + expectHeaderLocation: "", + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + tmpFs := tc.givenFs + if tc.givenFsRoot != "" { + tmpFs = MustSubFS(tmpFs, tc.givenFsRoot) + } + e.StaticFS(tc.givenPrefix, tmpFs) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectStatus, rec.Code) + body := rec.Body.String() + if tc.expectBodyStartsWith != "" { + assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) + } else { + assert.Equal(t, "", body) + } + + if tc.expectHeaderLocation != "" { + assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) + } else { + _, ok := rec.Result().Header["Location"] + assert.False(t, ok) + } + }) + } +} + +func TestEcho_FileFS(t *testing.T) { + var testCases = []struct { + name string + whenPath string + whenFile string + whenFS fs.FS + givenURL string + expectCode int + expectStartsWith []byte + }{ + { + name: "ok", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/walle", + expectCode: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, requesting invalid path", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/walle.png", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + { + name: "nok, serving not existent file from filesystem", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "not-existent.png", + givenURL: "/walle", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) + + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestEcho_StaticPanic(t *testing.T) { + var testCases = []struct { + name string + givenRoot string + }{ + { + name: "panics for ../", + givenRoot: "../assets", + }, + { + name: "panics for /", + givenRoot: "/assets", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Filesystem = os.DirFS("./") + + assert.Panics(t, func() { + e.Static("../assets", tc.givenRoot) + }) + }) + } +} diff --git a/echo_test.go b/echo_test.go index 0af0b96ec..08cc7162b 100644 --- a/echo_test.go +++ b/echo_test.go @@ -6,23 +6,23 @@ package echo import ( "bytes" stdContext "context" + "crypto/tls" "errors" "fmt" "io" - "io/fs" - "log/slog" "net" "net/http" "net/http/httptest" "net/url" "os" - "path/filepath" - "runtime" + "reflect" "strings" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/http2" ) type user struct { @@ -62,99 +62,33 @@ func TestEcho(t *testing.T) { // Router assert.NotNil(t, e.Router()) - e.HTTPErrorHandler(c, errors.New("error")) - + // DefaultHTTPErrorHandler + e.DefaultHTTPErrorHandler(errors.New("error"), c) assert.Equal(t, http.StatusInternalServerError, rec.Code) } -func TestNewWithConfig(t *testing.T) { - e := NewWithConfig(Config{}) - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - - e.GET("/", func(c *Context) error { - return c.String(http.StatusTeapot, "Hello, World!") - }) - e.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusTeapot, rec.Code) - assert.Equal(t, `Hello, World!`, rec.Body.String()) -} - -func TestNewDefaultFS(t *testing.T) { - tempDir := t.TempDir() - filename := filepath.Join(tempDir, "file.txt") - if err := os.WriteFile(filename, []byte("hello"), 0644); err != nil { - t.Fatalf("failed to write file: %v", err) - } - - var testCases = []struct { - name string - givenDir string - whenName string - expectedError string - }{ - { - name: "ok, can open absolute path", - givenDir: tempDir, - whenName: filename, - }, - { - name: "ok, can open path to fs", - givenDir: tempDir, - whenName: "file.txt", - }, - { - name: "nok, can not use ./ in path", - givenDir: tempDir, - whenName: "./file.txt", - expectedError: `open ./file.txt: invalid argument`, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - myFs := NewDefaultFS(tc.givenDir) - - f, err := myFs.Open(tc.whenName) - if tc.expectedError != "" { - assert.EqualError(t, err, tc.expectedError) - return - } - if err != nil { - t.Fatalf("failed to read file: %v", err) - } - defer f.Close() - - contents, err := io.ReadAll(f) - assert.NoError(t, err) - assert.Equal(t, []byte("hello"), contents) - }) - } -} - -func TestEcho_StaticFS(t *testing.T) { +func TestEchoStatic(t *testing.T) { var testCases = []struct { - givenFs fs.FS name string givenPrefix string - givenFsRoot string + givenRoot string whenURL string + expectStatus int expectHeaderLocation string expectBodyStartsWith string - expectStatus int }{ { name: "ok", givenPrefix: "/images", - givenFs: os.DirFS("./_fixture/images"), + givenRoot: "_fixture/images", whenURL: "/images/walle.png", expectStatus: http.StatusOK, expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), }, { - name: "ok, from sub fs", + name: "ok with relative path for root points to directory", givenPrefix: "/images", - givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"), + givenRoot: "./_fixture/images", whenURL: "/images/walle.png", expectStatus: http.StatusOK, expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), @@ -162,7 +96,7 @@ func TestEcho_StaticFS(t *testing.T) { { name: "No file", givenPrefix: "/images", - givenFs: os.DirFS("_fixture/scripts"), + givenRoot: "_fixture/scripts", whenURL: "/images/bolt.png", expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -170,7 +104,7 @@ func TestEcho_StaticFS(t *testing.T) { { name: "Directory", givenPrefix: "/images", - givenFs: os.DirFS("_fixture/images"), + givenRoot: "_fixture/images", whenURL: "/images/", expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -178,7 +112,7 @@ func TestEcho_StaticFS(t *testing.T) { { name: "Directory Redirect", givenPrefix: "/", - givenFs: os.DirFS("_fixture/"), + givenRoot: "_fixture", whenURL: "/folder", expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/folder/", @@ -187,7 +121,7 @@ func TestEcho_StaticFS(t *testing.T) { { name: "Directory Redirect with non-root path", givenPrefix: "/static", - givenFs: os.DirFS("_fixture"), + givenRoot: "_fixture", whenURL: "/static", expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/static/", @@ -196,7 +130,7 @@ func TestEcho_StaticFS(t *testing.T) { { name: "Prefixed directory 404 (request URL without slash)", givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" - givenFs: os.DirFS("_fixture"), + givenRoot: "_fixture", whenURL: "/folder", // no trailing slash expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -204,7 +138,7 @@ func TestEcho_StaticFS(t *testing.T) { { name: "Prefixed directory redirect (without slash redirect to slash)", givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* - givenFs: os.DirFS("_fixture"), + givenRoot: "_fixture", whenURL: "/folder", // no trailing slash expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/folder/", @@ -213,7 +147,7 @@ func TestEcho_StaticFS(t *testing.T) { { name: "Directory with index.html", givenPrefix: "/", - givenFs: os.DirFS("_fixture"), + givenRoot: "_fixture", whenURL: "/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -221,7 +155,7 @@ func TestEcho_StaticFS(t *testing.T) { { name: "Prefixed directory with index.html (prefix ending with slash)", givenPrefix: "/assets/", - givenFs: os.DirFS("_fixture"), + givenRoot: "_fixture", whenURL: "/assets/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -229,7 +163,7 @@ func TestEcho_StaticFS(t *testing.T) { { name: "Prefixed directory with index.html (prefix ending without slash)", givenPrefix: "/assets", - givenFs: os.DirFS("_fixture"), + givenRoot: "_fixture", whenURL: "/assets/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -237,7 +171,7 @@ func TestEcho_StaticFS(t *testing.T) { { name: "Sub-directory with index.html", givenPrefix: "/", - givenFs: os.DirFS("_fixture"), + givenRoot: "_fixture", whenURL: "/folder/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -245,7 +179,7 @@ func TestEcho_StaticFS(t *testing.T) { { name: "do not allow directory traversal (backslash - windows separator)", givenPrefix: "/", - givenFs: os.DirFS("_fixture/"), + givenRoot: "_fixture/", whenURL: `/..\\middleware/basic_auth.go`, expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -253,40 +187,20 @@ func TestEcho_StaticFS(t *testing.T) { { name: "do not allow directory traversal (slash - unix separator)", givenPrefix: "/", - givenFs: os.DirFS("_fixture/"), + givenRoot: "_fixture/", whenURL: `/../middleware/basic_auth.go`, expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", }, - { - // An encoded slash (%2f) is rejected outright (GHSA-vfp3-v2gw-7wfq): by default the - // router matches on the raw path so %2f is not a separator, and unescaping it here - // would let it act as one. No redirect is emitted, closing the open-redirect vector. - name: "encoded slash is rejected, not redirected", - givenPrefix: "/", - givenFs: os.DirFS("_fixture/"), - whenURL: "/open.redirect.hackercom%2f..", - expectStatus: http.StatusNotFound, - expectHeaderLocation: "", - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() - - tmpFs := tc.givenFs - if tc.givenFsRoot != "" { - tmpFs = MustSubFS(tmpFs, tc.givenFsRoot) - } - e.StaticFS(tc.givenPrefix, tmpFs) - + e.Static(tc.givenPrefix, tc.givenRoot) req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, tc.expectStatus, rec.Code) body := rec.Body.String() if tc.expectBodyStartsWith != "" { @@ -305,114 +219,44 @@ func TestEcho_StaticFS(t *testing.T) { } } -func TestEcho_FileFS(t *testing.T) { - var testCases = []struct { - whenFS fs.FS - name string - whenPath string - whenFile string - givenURL string - expectStartsWith []byte - expectCode int - }{ - { - name: "ok", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/walle", - expectCode: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, requesting invalid path", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/walle.png", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - { - name: "nok, serving not existent file from filesystem", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "not-existent.png", - givenURL: "/walle", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - } +func TestEchoStaticRedirectIndex(t *testing.T) { + e := New() - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) + // HandlerFunc + e.Static("/static", "_fixture") - req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) - rec := httptest.NewRecorder() + errCh := make(chan error) - e.ServeHTTP(rec, req) + go func() { + errCh <- e.Start(":0") + }() - assert.Equal(t, tc.expectCode, rec.Code) + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] + addr := e.ListenerAddr().String() + if resp, err := http.Get("http://" + addr + "/static"); err == nil { // http.Get follows redirects by default + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + assert.Fail(t, err.Error()) } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} + }(resp.Body) + assert.Equal(t, http.StatusOK, resp.StatusCode) -func TestEcho_StaticPanic(t *testing.T) { - var testCases = []struct { - name string - givenRoot string - }{ - { - name: "panics for ../", - givenRoot: "../assets", - }, - { - name: "panics for /", - givenRoot: "/assets", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.Filesystem = os.DirFS("./") + if body, err := io.ReadAll(resp.Body); err == nil { + assert.Equal(t, true, strings.HasPrefix(string(body), "")) + } else { + assert.Fail(t, err.Error()) + } - assert.Panics(t, func() { - e.Static("../assets", tc.givenRoot) - }) - }) + } else { + assert.NoError(t, err) } -} - -func TestEchoStaticRedirectIndex(t *testing.T) { - e := New() - - // HandlerFunc - ri := e.Static("/static", "_fixture") - assert.Equal(t, http.MethodGet, ri.Method) - assert.Equal(t, "/static*", ri.Path) - assert.Equal(t, "GET:/static*", ri.Name) - assert.Equal(t, []string{"*"}, ri.Parameters) - ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond) - defer cancel() - addr, err := startOnRandomPort(ctx, e) - if err != nil { - assert.Fail(t, err.Error()) + if err := e.Close(); err != nil { + t.Fatal(err) } - - code, body, err := doGet(fmt.Sprintf("http://%v/static", addr)) - assert.NoError(t, err) - assert.True(t, strings.HasPrefix(body, "")) - assert.Equal(t, http.StatusOK, code) } func TestEchoFile(t *testing.T) { @@ -421,8 +265,8 @@ func TestEchoFile(t *testing.T) { givenPath string givenFile string whenPath string - expectStartsWith string expectCode int + expectStartsWith string }{ { name: "ok", @@ -471,37 +315,36 @@ func TestEchoMiddleware(t *testing.T) { buf := new(bytes.Buffer) e.Pre(func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { - // before route match is found RouteInfo does not exist - assert.Equal(t, RouteInfo{}, c.RouteInfo()) + return func(c Context) error { + assert.Empty(t, c.Path()) buf.WriteString("-1") return next(c) } }) e.Use(func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { + return func(c Context) error { buf.WriteString("1") return next(c) } }) e.Use(func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { + return func(c Context) error { buf.WriteString("2") return next(c) } }) e.Use(func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { + return func(c Context) error { buf.WriteString("3") return next(c) } }) // Route - e.GET("/", func(c *Context) error { + e.GET("/", func(c Context) error { return c.String(http.StatusOK, "OK") }) @@ -514,11 +357,11 @@ func TestEchoMiddleware(t *testing.T) { func TestEchoMiddlewareError(t *testing.T) { e := New() e.Use(func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { + return func(c Context) error { return errors.New("error") } }) - e.GET("/", notFoundHandler) + e.GET("/", NotFoundHandler) c, _ := request(http.MethodGet, "/", e) assert.Equal(t, http.StatusInternalServerError, c) } @@ -527,7 +370,7 @@ func TestEchoHandler(t *testing.T) { e := New() // HandlerFunc - e.GET("/ok", func(c *Context) error { + e.GET("/ok", func(c Context) error { return c.String(http.StatusOK, "OK") }) @@ -538,256 +381,230 @@ func TestEchoHandler(t *testing.T) { func TestEchoWrapHandler(t *testing.T) { e := New() - - var actualID string - var actualPattern string - e.GET("/:id", WrapHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("test")) - actualID = r.PathValue("id") - actualPattern = r.Pattern - }))) - - req := httptest.NewRequest(http.MethodGet, "/123", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "test", rec.Body.String()) - assert.Equal(t, "123", actualID) - assert.Equal(t, "/:id", actualPattern) + c := e.NewContext(req, rec) + h := WrapHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("test")) + if err != nil { + assert.Fail(t, err.Error()) + } + })) + if assert.NoError(t, h(c)) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "test", rec.Body.String()) + } } func TestEchoWrapMiddleware(t *testing.T) { e := New() - - var actualID string - var actualPattern string - e.Use(WrapMiddleware(func(h http.Handler) http.Handler { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + buf := new(bytes.Buffer) + mw := WrapMiddleware(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - actualID = r.PathValue("id") - actualPattern = r.Pattern + buf.Write([]byte("mw")) h.ServeHTTP(w, r) }) - })) - - e.GET("/:id", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") }) - - req := httptest.NewRequest(http.MethodGet, "/123", nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusTeapot, rec.Code) - assert.Equal(t, "OK", rec.Body.String()) - assert.Equal(t, "123", actualID) - assert.Equal(t, "/:id", actualPattern) + h := mw(func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + if assert.NoError(t, h(c)) { + assert.Equal(t, "mw", buf.String()) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "OK", rec.Body.String()) + } } func TestEchoConnect(t *testing.T) { e := New() - - ri := e.CONNECT("/", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - - assert.Equal(t, http.MethodConnect, ri.Method) - assert.Equal(t, "/", ri.Path) - assert.Equal(t, http.MethodConnect+":/", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodConnect, "/", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, "OK", body) + testMethod(t, http.MethodConnect, "/", e) } func TestEchoDelete(t *testing.T) { e := New() - - ri := e.DELETE("/", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - - assert.Equal(t, http.MethodDelete, ri.Method) - assert.Equal(t, "/", ri.Path) - assert.Equal(t, http.MethodDelete+":/", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodDelete, "/", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, "OK", body) + testMethod(t, http.MethodDelete, "/", e) } func TestEchoGet(t *testing.T) { e := New() - - ri := e.GET("/", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - - assert.Equal(t, http.MethodGet, ri.Method) - assert.Equal(t, "/", ri.Path) - assert.Equal(t, http.MethodGet+":/", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodGet, "/", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, "OK", body) + testMethod(t, http.MethodGet, "/", e) } func TestEchoHead(t *testing.T) { e := New() - - ri := e.HEAD("/", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - - assert.Equal(t, http.MethodHead, ri.Method) - assert.Equal(t, "/", ri.Path) - assert.Equal(t, http.MethodHead+":/", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodHead, "/", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, "OK", body) + testMethod(t, http.MethodHead, "/", e) } func TestEchoOptions(t *testing.T) { e := New() - - ri := e.OPTIONS("/", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - - assert.Equal(t, http.MethodOptions, ri.Method) - assert.Equal(t, "/", ri.Path) - assert.Equal(t, http.MethodOptions+":/", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodOptions, "/", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, "OK", body) + testMethod(t, http.MethodOptions, "/", e) } func TestEchoPatch(t *testing.T) { e := New() - - ri := e.PATCH("/", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - - assert.Equal(t, http.MethodPatch, ri.Method) - assert.Equal(t, "/", ri.Path) - assert.Equal(t, http.MethodPatch+":/", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodPatch, "/", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, "OK", body) + testMethod(t, http.MethodPatch, "/", e) } func TestEchoPost(t *testing.T) { e := New() - - ri := e.POST("/", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - - assert.Equal(t, http.MethodPost, ri.Method) - assert.Equal(t, "/", ri.Path) - assert.Equal(t, http.MethodPost+":/", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodPost, "/", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, "OK", body) + testMethod(t, http.MethodPost, "/", e) } func TestEchoPut(t *testing.T) { e := New() - - ri := e.PUT("/", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - - assert.Equal(t, http.MethodPut, ri.Method) - assert.Equal(t, "/", ri.Path) - assert.Equal(t, http.MethodPut+":/", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodPut, "/", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, "OK", body) + testMethod(t, http.MethodPut, "/", e) } func TestEchoTrace(t *testing.T) { e := New() + testMethod(t, http.MethodTrace, "/", e) +} - ri := e.TRACE("/", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") +func TestEchoAny(t *testing.T) { // JFC + e := New() + e.Any("/", func(c Context) error { + return c.String(http.StatusOK, "Any") }) - - assert.Equal(t, http.MethodTrace, ri.Method) - assert.Equal(t, "/", ri.Path) - assert.Equal(t, http.MethodTrace+":/", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodTrace, "/", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, "OK", body) } -func TestEcho_Any(t *testing.T) { +func TestEchoMatch(t *testing.T) { // JFC e := New() - - ri := e.Any("/activate", func(c *Context) error { - return c.String(http.StatusTeapot, "OK from ANY") + e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c Context) error { + return c.String(http.StatusOK, "Match") }) +} - assert.Equal(t, RouteAny, ri.Method) - assert.Equal(t, "/activate", ri.Path) - assert.Equal(t, RouteAny+":/activate", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodTrace, "/activate", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, `OK from ANY`, body) +func TestEchoURL(t *testing.T) { + e := New() + static := func(Context) error { return nil } + getUser := func(Context) error { return nil } + getAny := func(Context) error { return nil } + getFile := func(Context) error { return nil } + + e.GET("/static/file", static) + e.GET("/users/:id", getUser) + e.GET("/documents/*", getAny) + g := e.Group("/group") + g.GET("/users/:uid/files/:fid", getFile) + + assert.Equal(t, "/static/file", e.URL(static)) + assert.Equal(t, "/users/:id", e.URL(getUser)) + assert.Equal(t, "/users/1", e.URL(getUser, "1")) + assert.Equal(t, "/users/1", e.URL(getUser, "1")) + assert.Equal(t, "/documents/foo.txt", e.URL(getAny, "foo.txt")) + assert.Equal(t, "/documents/*", e.URL(getAny)) + assert.Equal(t, "/group/users/1/files/:fid", e.URL(getFile, "1")) + assert.Equal(t, "/group/users/1/files/1", e.URL(getFile, "1", "1")) } -func TestEcho_Any_hasLowerPriority(t *testing.T) { +func TestEchoRoutes(t *testing.T) { e := New() + routes := []*Route{ + {http.MethodGet, "/users/:user/events", ""}, + {http.MethodGet, "/users/:user/events/public", ""}, + {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, + {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, + } + for _, r := range routes { + e.Add(r.Method, r.Path, func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + } - e.Any("/activate", func(c *Context) error { - return c.String(http.StatusTeapot, "ANY") - }) - e.GET("/activate", func(c *Context) error { - return c.String(http.StatusLocked, "GET") + if assert.Equal(t, len(routes), len(e.Routes())) { + for _, r := range e.Routes() { + found := false + for _, rr := range routes { + if r.Method == rr.Method && r.Path == rr.Path { + found = true + break + } + } + if !found { + t.Errorf("Route %s %s not found", r.Method, r.Path) + } + } + } +} + +func TestEchoRoutesHandleAdditionalHosts(t *testing.T) { + e := New() + domain2Router := e.Host("domain2.router.com") + routes := []*Route{ + {http.MethodGet, "/users/:user/events", ""}, + {http.MethodGet, "/users/:user/events/public", ""}, + {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, + {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, + } + for _, r := range routes { + domain2Router.Add(r.Method, r.Path, func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + } + e.Add(http.MethodGet, "/api", func(c Context) error { + return c.String(http.StatusOK, "OK") }) - status, body := request(http.MethodTrace, "/activate", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, `ANY`, body) + domain2Routes := e.Routers()["domain2.router.com"].Routes() - status, body = request(http.MethodGet, "/activate", e) - assert.Equal(t, http.StatusLocked, status) - assert.Equal(t, `GET`, body) + assert.Len(t, domain2Routes, len(routes)) + for _, r := range domain2Routes { + found := false + for _, rr := range routes { + if r.Method == rr.Method && r.Path == rr.Path { + found = true + break + } + } + if !found { + t.Errorf("Route %s %s not found", r.Method, r.Path) + } + } } -func TestEchoMatch(t *testing.T) { // JFC +func TestEchoRoutesHandleDefaultHost(t *testing.T) { e := New() - ris := e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c *Context) error { - return c.String(http.StatusOK, "Match") + routes := []*Route{ + {http.MethodGet, "/users/:user/events", ""}, + {http.MethodGet, "/users/:user/events/public", ""}, + {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, + {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, + } + for _, r := range routes { + e.Add(r.Method, r.Path, func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + } + e.Host("subdomain.mysite.site").Add(http.MethodGet, "/api", func(c Context) error { + return c.String(http.StatusOK, "OK") }) - assert.Len(t, ris, 2) + + defaultRouterRoutes := e.Routes() + assert.Len(t, defaultRouterRoutes, len(routes)) + for _, r := range defaultRouterRoutes { + found := false + for _, rr := range routes { + if r.Method == rr.Method && r.Path == rr.Path { + found = true + break + } + } + if !found { + t.Errorf("Route %s %s not found", r.Method, r.Path) + } + } } func TestEchoServeHTTPPathEncoding(t *testing.T) { e := New() - e.GET("/with/slash", func(c *Context) error { + e.GET("/with/slash", func(c Context) error { return c.String(http.StatusOK, "/with/slash") }) - e.GET("/:id", func(c *Context) error { + e.GET("/:id", func(c Context) error { return c.String(http.StatusOK, c.Param("id")) }) @@ -824,16 +641,117 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) { } } +func TestEchoHost(t *testing.T) { + okHandler := func(c Context) error { return c.String(http.StatusOK, http.StatusText(http.StatusOK)) } + teapotHandler := func(c Context) error { return c.String(http.StatusTeapot, http.StatusText(http.StatusTeapot)) } + acceptHandler := func(c Context) error { return c.String(http.StatusAccepted, http.StatusText(http.StatusAccepted)) } + teapotMiddleware := MiddlewareFunc(func(next HandlerFunc) HandlerFunc { return teapotHandler }) + + e := New() + e.GET("/", acceptHandler) + e.GET("/foo", acceptHandler) + + ok := e.Host("ok.com") + ok.GET("/", okHandler) + ok.GET("/foo", okHandler) + + teapot := e.Host("teapot.com") + teapot.GET("/", teapotHandler) + teapot.GET("/foo", teapotHandler) + + middle := e.Host("middleware.com", teapotMiddleware) + middle.GET("/", okHandler) + middle.GET("/foo", okHandler) + + var testCases = []struct { + name string + whenHost string + whenPath string + expectBody string + expectStatus int + }{ + { + name: "No Host Root", + whenHost: "", + whenPath: "/", + expectBody: http.StatusText(http.StatusAccepted), + expectStatus: http.StatusAccepted, + }, + { + name: "No Host Foo", + whenHost: "", + whenPath: "/foo", + expectBody: http.StatusText(http.StatusAccepted), + expectStatus: http.StatusAccepted, + }, + { + name: "OK Host Root", + whenHost: "ok.com", + whenPath: "/", + expectBody: http.StatusText(http.StatusOK), + expectStatus: http.StatusOK, + }, + { + name: "OK Host Foo", + whenHost: "ok.com", + whenPath: "/foo", + expectBody: http.StatusText(http.StatusOK), + expectStatus: http.StatusOK, + }, + { + name: "Teapot Host Root", + whenHost: "teapot.com", + whenPath: "/", + expectBody: http.StatusText(http.StatusTeapot), + expectStatus: http.StatusTeapot, + }, + { + name: "Teapot Host Foo", + whenHost: "teapot.com", + whenPath: "/foo", + expectBody: http.StatusText(http.StatusTeapot), + expectStatus: http.StatusTeapot, + }, + { + name: "Middleware Host", + whenHost: "middleware.com", + whenPath: "/", + expectBody: http.StatusText(http.StatusTeapot), + expectStatus: http.StatusTeapot, + }, + { + name: "Middleware Host Foo", + whenHost: "middleware.com", + whenPath: "/foo", + expectBody: http.StatusText(http.StatusTeapot), + expectStatus: http.StatusTeapot, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.whenPath, nil) + req.Host = tc.whenHost + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectStatus, rec.Code) + assert.Equal(t, tc.expectBody, rec.Body.String()) + }) + } +} + func TestEchoGroup(t *testing.T) { e := New() buf := new(bytes.Buffer) e.Use(MiddlewareFunc(func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { + return func(c Context) error { buf.WriteString("0") return next(c) } })) - h := func(c *Context) error { + h := func(c Context) error { return c.NoContent(http.StatusOK) } @@ -846,7 +764,7 @@ func TestEchoGroup(t *testing.T) { // Group g1 := e.Group("/group1") g1.Use(func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { + return func(c Context) error { buf.WriteString("1") return next(c) } @@ -856,14 +774,14 @@ func TestEchoGroup(t *testing.T) { // Nested groups with middleware g2 := e.Group("/group2") g2.Use(func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { + return func(c Context) error { buf.WriteString("2") return next(c) } }) g3 := g2.Group("/group3") g3.Use(func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { + return func(c Context) error { buf.WriteString("3") return next(c) } @@ -882,12 +800,20 @@ func TestEchoGroup(t *testing.T) { assert.Equal(t, "023", buf.String()) } -func TestEcho_RouteNotFound(t *testing.T) { - var testCases = []struct { - expectRoute any - name string - whenURL string - expectCode int +func TestEchoNotFound(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/files", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestEcho_RouteNotFound(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute any + expectCode int }{ { name: "404, route to static not found handler /a/c/xx", @@ -919,10 +845,10 @@ func TestEcho_RouteNotFound(t *testing.T) { t.Run(tc.name, func(t *testing.T) { e := New() - okHandler := func(c *Context) error { + okHandler := func(c Context) error { return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) } - notFoundHandler := func(c *Context) error { + notFoundHandler := func(c Context) error { return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) } @@ -946,18 +872,10 @@ func TestEcho_RouteNotFound(t *testing.T) { } } -func TestEchoNotFound(t *testing.T) { - e := New() - req := httptest.NewRequest(http.MethodGet, "/files", nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusNotFound, rec.Code) -} - func TestEchoMethodNotAllowed(t *testing.T) { e := New() - e.GET("/", func(c *Context) error { + e.GET("/", func(c Context) error { return c.String(http.StatusOK, "Echo!") }) req := httptest.NewRequest(http.MethodPost, "/", nil) @@ -968,133 +886,348 @@ func TestEchoMethodNotAllowed(t *testing.T) { assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow)) } -func TestEcho_OnAddRoute(t *testing.T) { - exampleRoute := Route{ - Method: http.MethodGet, - Path: "/api/files/:id", - Handler: notFoundHandler, - Middlewares: nil, - Name: "x", +func TestEchoContext(t *testing.T) { + e := New() + c := e.AcquireContext() + assert.IsType(t, new(context), c) + e.ReleaseContext(c) +} + +func waitForServerStart(e *Echo, errChan <-chan error, isTLS bool) error { + ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond) + defer cancel() + + ticker := time.NewTicker(5 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + var addr net.Addr + if isTLS { + addr = e.TLSListenerAddr() + } else { + addr = e.ListenerAddr() + } + if addr != nil && strings.Contains(addr.String(), ":") { + return nil // was started + } + case err := <-errChan: + if err == http.ErrServerClosed { + return nil + } + return err + } } +} + +func TestEchoStart(t *testing.T) { + e := New() + errChan := make(chan error) + + go func() { + err := e.Start(":0") + if err != nil { + errChan <- err + } + }() + + err := waitForServerStart(e, errChan, false) + assert.NoError(t, err) + + assert.NoError(t, e.Close()) +} +func TestEcho_StartTLS(t *testing.T) { var testCases = []struct { - whenRoute Route - whenError error name string + addr string + certFile string + keyFile string expectError string - expectAdded []string - expectLen int }{ { - name: "ok", - whenRoute: exampleRoute, - whenError: nil, - expectAdded: []string{"/static", "/api/files/:id"}, - expectError: "", - expectLen: 2, + name: "ok", + addr: ":0", }, { - name: "nok, error is returned", - whenRoute: exampleRoute, - whenError: errors.New("nope"), - expectAdded: []string{"/static"}, - expectError: "nope", - expectLen: 1, + name: "nok, invalid certFile", + addr: ":0", + certFile: "not existing", + expectError: "open not existing: no such file or directory", + }, + { + name: "nok, invalid keyFile", + addr: ":0", + keyFile: "not existing", + expectError: "open not existing: no such file or directory", + }, + { + name: "nok, failed to create cert out of certFile and keyFile", + addr: ":0", + keyFile: "_fixture/certs/cert.pem", // we are passing cert instead of key + expectError: "tls: found a certificate rather than a key in the PEM for the private key", + }, + { + name: "nok, invalid tls address", + addr: "nope", + expectError: "listen tcp: address nope: missing port in address", }, } + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - e := New() + errChan := make(chan error) - added := make([]string, 0) - cnt := 0 - e.OnAddRoute = func(route Route) error { - if cnt > 0 && tc.whenError != nil { // we want to GET /static to succeed for nok tests - return tc.whenError + go func() { + certFile := "_fixture/certs/cert.pem" + if tc.certFile != "" { + certFile = tc.certFile + } + keyFile := "_fixture/certs/key.pem" + if tc.keyFile != "" { + keyFile = tc.keyFile } - cnt++ - added = append(added, route.Path) - return nil - } - - e.GET("/static", notFoundHandler) - var err error - _, err = e.AddRoute(tc.whenRoute) + err := e.StartTLS(tc.addr, certFile, keyFile) + if err != nil { + errChan <- err + } + }() + err := waitForServerStart(e, errChan, true) if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) + if _, ok := err.(*os.PathError); ok { + assert.Error(t, err) // error messages for unix and windows are different. so test only error type here + } else { + assert.EqualError(t, err, tc.expectError) + } } else { assert.NoError(t, err) } - assert.Len(t, e.Router().Routes(), tc.expectLen) - assert.Equal(t, tc.expectAdded, added) + assert.NoError(t, e.Close()) }) } } -func TestEchoContext(t *testing.T) { +func TestEchoStartTLSAndStart(t *testing.T) { + // We test if Echo and listeners work correctly when Echo is simultaneously attached to HTTP and HTTPS server e := New() - c := e.AcquireContext() - assert.IsType(t, new(Context), c) - e.ReleaseContext(c) -} + e.GET("/", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) -func TestPreMiddlewares(t *testing.T) { - e := New() - assert.Equal(t, 0, len(e.PreMiddlewares())) + errTLSChan := make(chan error) + go func() { + certFile := "_fixture/certs/cert.pem" + keyFile := "_fixture/certs/key.pem" + err := e.StartTLS("localhost:", certFile, keyFile) + if err != nil { + errTLSChan <- err + } + }() - e.Pre(func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { - return next(c) + err := waitForServerStart(e, errTLSChan, true) + assert.NoError(t, err) + defer func() { + if err := e.Shutdown(stdContext.Background()); err != nil { + t.Error(err) } - }) + }() + + // check if HTTPS works (note: we are using self signed certs so InsecureSkipVerify=true) + client := &http.Client{Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }} + res, err := client.Get("https://" + e.TLSListenerAddr().String()) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) + + errChan := make(chan error) + go func() { + err := e.Start("localhost:") + if err != nil { + errChan <- err + } + }() + err = waitForServerStart(e, errChan, false) + assert.NoError(t, err) + + // now we are serving both HTTPS and HTTP listeners. see if HTTP works in addition to HTTPS + res, err = http.Get("http://" + e.ListenerAddr().String()) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) - assert.Equal(t, 1, len(e.PreMiddlewares())) + // see if HTTPS works after HTTP listener is also added + res, err = client.Get("https://" + e.TLSListenerAddr().String()) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) } -func TestMiddlewares(t *testing.T) { - e := New() - assert.Equal(t, 0, len(e.Middlewares())) +func TestEchoStartTLSByteString(t *testing.T) { + cert, err := os.ReadFile("_fixture/certs/cert.pem") + require.NoError(t, err) + key, err := os.ReadFile("_fixture/certs/key.pem") + require.NoError(t, err) - e.Use(func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { - return next(c) - } - }) + testCases := []struct { + cert any + key any + expectedErr error + name string + }{ + { + cert: "_fixture/certs/cert.pem", + key: "_fixture/certs/key.pem", + expectedErr: nil, + name: `ValidCertAndKeyFilePath`, + }, + { + cert: cert, + key: key, + expectedErr: nil, + name: `ValidCertAndKeyByteString`, + }, + { + cert: cert, + key: 1, + expectedErr: ErrInvalidCertOrKeyType, + name: `InvalidKeyType`, + }, + { + cert: 0, + key: key, + expectedErr: ErrInvalidCertOrKeyType, + name: `InvalidCertType`, + }, + { + cert: 0, + key: 1, + expectedErr: ErrInvalidCertOrKeyType, + name: `InvalidCertAndKeyTypes`, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.name, func(t *testing.T) { + e := New() + e.HideBanner = true - assert.Equal(t, 1, len(e.Middlewares())) + errChan := make(chan error) + + go func() { + errChan <- e.StartTLS(":0", test.cert, test.key) + }() + + err := waitForServerStart(e, errChan, true) + if test.expectedErr != nil { + assert.EqualError(t, err, test.expectedErr.Error()) + } else { + assert.NoError(t, err) + } + + assert.NoError(t, e.Close()) + }) + } } -func TestEcho_Start(t *testing.T) { - e := New() - e.GET("/", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - rndPort, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatal(err) +func TestEcho_StartAutoTLS(t *testing.T) { + var testCases = []struct { + name string + addr string + expectError string + }{ + { + name: "ok", + addr: ":0", + }, + { + name: "nok, invalid address", + addr: "nope", + expectError: "listen tcp: address nope: missing port in address", + }, } - defer rndPort.Close() - errChan := make(chan error, 1) - go func() { - errChan <- e.Start(rndPort.Addr().String()) - }() - select { - case <-time.After(250 * time.Millisecond): - t.Fatal("start did not error out") - case err := <-errChan: - expectContains := "bind: address already in use" - if runtime.GOOS == "windows" { - expectContains = "bind: Only one usage of each socket address" - } - assert.Contains(t, err.Error(), expectContains) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + errChan := make(chan error) + + go func() { + errChan <- e.StartAutoTLS(tc.addr) + }() + + err := waitForServerStart(e, errChan, true) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + assert.NoError(t, e.Close()) + }) } } +func TestEcho_StartH2CServer(t *testing.T) { + var testCases = []struct { + name string + addr string + expectError string + }{ + { + name: "ok", + addr: ":0", + }, + { + name: "nok, invalid address", + addr: "nope", + expectError: "listen tcp: address nope: missing port in address", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Debug = true + h2s := &http2.Server{} + + errChan := make(chan error) + go func() { + err := e.StartH2CServer(tc.addr, h2s) + if err != nil { + errChan <- err + } + }() + + err := waitForServerStart(e, errChan, false) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + assert.NoError(t, e.Close()) + }) + } +} + +func testMethod(t *testing.T, method, path string, e *Echo) { + p := reflect.ValueOf(path) + h := reflect.ValueOf(func(c Context) error { + return c.String(http.StatusOK, method) + }) + i := any(e) + reflect.ValueOf(i).MethodByName(method).Call([]reflect.Value{p, h}) + _, body := request(method, path, e) + assert.Equal(t, method, body) +} + func request(method, path string, e *Echo) (int, string) { req := httptest.NewRequest(method, path, nil) rec := httptest.NewRecorder() @@ -1102,143 +1235,589 @@ func request(method, path string, e *Echo) (int, string) { return rec.Code, rec.Body.String() } -type customError struct { - Code int - Message string +func TestHTTPError(t *testing.T) { + t.Run("non-internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]any{ + "code": 12, + }) + + assert.Equal(t, "code=400, message=map[code:12]", err.Error()) + }) + + t.Run("internal and SetInternal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]any{ + "code": 12, + }) + err.SetInternal(errors.New("internal error")) + assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error()) + }) + + t.Run("internal and WithInternal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]any{ + "code": 12, + }) + err = err.WithInternal(errors.New("internal error")) + assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error()) + }) +} + +func TestHTTPError_Unwrap(t *testing.T) { + t.Run("non-internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]any{ + "code": 12, + }) + + assert.Nil(t, errors.Unwrap(err)) + }) + + t.Run("unwrap internal and SetInternal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]any{ + "code": 12, + }) + err.SetInternal(errors.New("internal error")) + assert.Equal(t, "internal error", errors.Unwrap(err).Error()) + }) + + t.Run("unwrap internal and WithInternal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]any{ + "code": 12, + }) + err = err.WithInternal(errors.New("internal error")) + assert.Equal(t, "internal error", errors.Unwrap(err).Error()) + }) } -func (ce *customError) StatusCode() int { - return ce.Code +type customError struct { + s string } func (ce *customError) MarshalJSON() ([]byte, error) { - return fmt.Appendf(nil, `{"x":"%v"}`, ce.Message), nil + return []byte(fmt.Sprintf(`{"x":"%v"}`, ce.s)), nil } func (ce *customError) Error() string { - return ce.Message + return ce.s } func TestDefaultHTTPErrorHandler(t *testing.T) { var testCases = []struct { - whenError error - name string - whenMethod string - expectBody string - expectLogged string - expectStatus int - givenExposeError bool - givenLoggerFunc bool + name string + givenDebug bool + whenPath string + expectCode int + expectBody string }{ { - name: "ok, expose error = true, HTTPError, no wrapped err", - givenExposeError: true, - whenError: &HTTPError{Code: http.StatusTeapot, Message: "my_error"}, - expectStatus: http.StatusTeapot, - expectBody: `{"message":"my_error"}` + "\n", + name: "with Debug=true plain response contains error message", + givenDebug: true, + whenPath: "/plain", + expectCode: http.StatusInternalServerError, + expectBody: "{\n \"error\": \"an error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", }, { - name: "ok, expose error = true, HTTPError + wrapped error", - givenExposeError: true, - whenError: HTTPError{Code: http.StatusTeapot, Message: "my_error"}.Wrap(errors.New("internal_error")), - expectStatus: http.StatusTeapot, - expectBody: `{"error":"internal_error","message":"my_error"}` + "\n", + name: "with Debug=true special handling for HTTPError", + givenDebug: true, + whenPath: "/badrequest", + expectCode: http.StatusBadRequest, + expectBody: "{\n \"error\": \"code=400, message=Invalid request\",\n \"message\": \"Invalid request\"\n}\n", }, { - name: "ok, expose error = true, HTTPError + wrapped HTTPError", - givenExposeError: true, - whenError: HTTPError{Code: http.StatusTeapot, Message: "my_error"}.Wrap(&HTTPError{Code: http.StatusTeapot, Message: "early_error"}), - expectStatus: http.StatusTeapot, - expectBody: `{"error":"code=418, message=early_error","message":"my_error"}` + "\n", + name: "with Debug=true complex errors are serialized to pretty JSON", + givenDebug: true, + whenPath: "/servererror", + expectCode: http.StatusInternalServerError, + expectBody: "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", }, { - name: "ok, expose error = false, HTTPError", - whenError: &HTTPError{Code: http.StatusTeapot, Message: "my_error"}, - expectStatus: http.StatusTeapot, - expectBody: `{"message":"my_error"}` + "\n", + name: "with Debug=true if the body is already set HTTPErrorHandler should not add anything to response body", + givenDebug: true, + whenPath: "/early-return", + expectCode: http.StatusOK, + expectBody: "OK", }, { - name: "ok, expose error = false, HTTPError, no message", - whenError: &HTTPError{Code: http.StatusTeapot, Message: ""}, - expectStatus: http.StatusTeapot, - expectBody: `{"message":"I'm a teapot"}` + "\n", + name: "with Debug=true internal error should be reflected in the message", + givenDebug: true, + whenPath: "/internal-error", + expectCode: http.StatusBadRequest, + expectBody: "{\n \"error\": \"code=400, message=Bad Request, internal=internal error message body\",\n \"message\": \"Bad Request\"\n}\n", }, { - name: "ok, expose error = false, HTTPError + internal HTTPError", - whenError: HTTPError{Code: http.StatusTooEarly, Message: "my_error"}.Wrap(&HTTPError{Code: http.StatusTeapot, Message: "early_error"}), - expectStatus: http.StatusTooEarly, - expectBody: `{"message":"my_error"}` + "\n", + name: "with Debug=false the error response is shortened", + whenPath: "/plain", + expectCode: http.StatusInternalServerError, + expectBody: "{\"message\":\"Internal Server Error\"}\n", }, { - name: "ok, expose error = true, Error", - givenExposeError: true, - whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), - expectStatus: http.StatusInternalServerError, - expectBody: `{"error":"my errors wraps: internal_error","message":"Internal Server Error"}` + "\n", + name: "with Debug=false the error response is shortened", + whenPath: "/badrequest", + expectCode: http.StatusBadRequest, + expectBody: "{\"message\":\"Invalid request\"}\n", }, { - name: "ok, expose error = false, Error", - whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), - expectStatus: http.StatusInternalServerError, - expectBody: `{"message":"Internal Server Error"}` + "\n", + name: "with Debug=false No difference for error response with non plain string errors", + whenPath: "/servererror", + expectCode: http.StatusInternalServerError, + expectBody: "{\"code\":33,\"error\":\"stackinfo\",\"message\":\"Something bad happened\"}\n", }, { - name: "ok, http.HEAD, expose error = true, Error", - givenExposeError: true, - whenMethod: http.MethodHead, - whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), - expectStatus: http.StatusInternalServerError, - expectBody: ``, + name: "with Debug=false when httpError contains an error", + whenPath: "/error-in-httperror", + expectCode: http.StatusBadRequest, + expectBody: "{\"message\":\"error in httperror\"}\n", }, { - name: "ok, custom error implement MarshalJSON + HTTPStatusCoder", - whenMethod: http.MethodGet, - whenError: &customError{Code: http.StatusTeapot, Message: "custom error msg"}, - expectStatus: http.StatusTeapot, - expectBody: `{"x":"custom error msg"}` + "\n", + name: "with Debug=false when httpError contains an error", + whenPath: "/customerror-in-httperror", + expectCode: http.StatusBadRequest, + expectBody: "{\"x\":\"custom error msg\"}\n", }, } - for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - buf := new(bytes.Buffer) e := New() - e.Logger = slog.New(slog.DiscardHandler) - e.Any("/path", func(c *Context) error { - return tc.whenError + e.Debug = tc.givenDebug // With Debug=true plain response contains error message + + e.Any("/plain", func(c Context) error { + return errors.New("an error occurred") }) - e.HTTPErrorHandler = DefaultHTTPErrorHandler(tc.givenExposeError) + e.Any("/badrequest", func(c Context) error { // and special handling for HTTPError + return NewHTTPError(http.StatusBadRequest, "Invalid request") + }) - method := http.MethodGet - if tc.whenMethod != "" { - method = tc.whenMethod - } - c, b := request(method, "/path", e) + e.Any("/servererror", func(c Context) error { // complex errors are serialized to pretty JSON + return NewHTTPError(http.StatusInternalServerError, map[string]any{ + "code": 33, + "message": "Something bad happened", + "error": "stackinfo", + }) + }) + + // if the body is already set HTTPErrorHandler should not add anything to response body + e.Any("/early-return", func(c Context) error { + err := c.String(http.StatusOK, "OK") + if err != nil { + assert.Fail(t, err.Error()) + } + return errors.New("ERROR") + }) + + // internal error should be reflected in the message + e.GET("/internal-error", func(c Context) error { + err := errors.New("internal error message body") + return NewHTTPError(http.StatusBadRequest).SetInternal(err) + }) + + e.GET("/error-in-httperror", func(c Context) error { + return NewHTTPError(http.StatusBadRequest, errors.New("error in httperror")) + }) + + e.GET("/customerror-in-httperror", func(c Context) error { + return NewHTTPError(http.StatusBadRequest, &customError{s: "custom error msg"}) + }) - assert.Equal(t, tc.expectStatus, c) + c, b := request(http.MethodGet, tc.whenPath, e) + assert.Equal(t, tc.expectCode, c) assert.Equal(t, tc.expectBody, b) - assert.Equal(t, tc.expectLogged, buf.String()) }) } } -func TestDefaultHTTPErrorHandler_CommitedResponse(t *testing.T) { +func TestEchoClose(t *testing.T) { e := New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - resp := httptest.NewRecorder() - c := e.NewContext(req, resp) + errCh := make(chan error) + + go func() { + errCh <- e.Start(":0") + }() + + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) + + if err := e.Close(); err != nil { + t.Fatal(err) + } - c.orgResponse.Committed = true - errHandler := DefaultHTTPErrorHandler(false) + assert.NoError(t, e.Close()) - errHandler(c, errors.New("my_error")) - assert.Equal(t, http.StatusOK, resp.Code) + err = <-errCh + assert.Equal(t, err.Error(), "http: Server closed") } -func benchmarkEchoRoutes(b *testing.B, routes []testRoute) { +func TestEchoShutdown(t *testing.T) { e := New() - req := httptest.NewRequest(http.MethodGet, "/", nil) + errCh := make(chan error) + + go func() { + errCh <- e.Start(":0") + }() + + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) + + if err := e.Close(); err != nil { + t.Fatal(err) + } + + ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 10*time.Second) + defer cancel() + assert.NoError(t, e.Shutdown(ctx)) + + err = <-errCh + assert.Equal(t, err.Error(), "http: Server closed") +} + +var listenerNetworkTests = []struct { + test string + network string + address string +}{ + {"tcp ipv4 address", "tcp", "127.0.0.1:1323"}, + {"tcp ipv6 address", "tcp", "[::1]:1323"}, + {"tcp4 ipv4 address", "tcp4", "127.0.0.1:1323"}, + {"tcp6 ipv6 address", "tcp6", "[::1]:1323"}, +} + +func supportsIPv6() bool { + addrs, _ := net.InterfaceAddrs() + for _, addr := range addrs { + // Check if any interface has local IPv6 assigned + if strings.Contains(addr.String(), "::1") { + return true + } + } + return false +} + +func TestEchoListenerNetwork(t *testing.T) { + hasIPv6 := supportsIPv6() + for _, tt := range listenerNetworkTests { + if !hasIPv6 && strings.Contains(tt.address, "::") { + t.Skip("Skipping testing IPv6 for " + tt.address + ", not available") + continue + } + t.Run(tt.test, func(t *testing.T) { + e := New() + e.ListenerNetwork = tt.network + + // HandlerFunc + e.GET("/ok", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + + errCh := make(chan error) + + go func() { + errCh <- e.Start(tt.address) + }() + + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) + + if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil { + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + assert.Fail(t, err.Error()) + } + }(resp.Body) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + if body, err := io.ReadAll(resp.Body); err == nil { + assert.Equal(t, "OK", string(body)) + } else { + assert.Fail(t, err.Error()) + } + + } else { + assert.Fail(t, err.Error()) + } + + if err := e.Close(); err != nil { + t.Fatal(err) + } + }) + } +} + +func TestEchoListenerNetworkInvalid(t *testing.T) { + e := New() + e.ListenerNetwork = "unix" + + // HandlerFunc + e.GET("/ok", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + + assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323")) +} + +func TestEcho_OnAddRouteHandler(t *testing.T) { + type rr struct { + host string + route Route + handler HandlerFunc + middleware []MiddlewareFunc + } + dummyHandler := func(Context) error { return nil } + e := New() + + added := make([]rr, 0) + e.OnAddRouteHandler = func(host string, route Route, handler HandlerFunc, middleware []MiddlewareFunc) { + added = append(added, rr{ + host: host, + route: route, + handler: handler, + middleware: middleware, + }) + } + + e.GET("/static", dummyHandler) + e.Host("domain.site").GET("/static/*", dummyHandler, func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + return next(c) + } + }) + + assert.Len(t, added, 2) + + assert.Equal(t, "", added[0].host) + assert.Equal(t, Route{Method: http.MethodGet, Path: "/static", Name: "github.com/labstack/echo/v4.TestEcho_OnAddRouteHandler.func1"}, added[0].route) + assert.Len(t, added[0].middleware, 0) + + assert.Equal(t, "domain.site", added[1].host) + assert.Equal(t, Route{Method: http.MethodGet, Path: "/static/*", Name: "github.com/labstack/echo/v4.TestEcho_OnAddRouteHandler.func1"}, added[1].route) + assert.Len(t, added[1].middleware, 1) +} + +func TestEchoReverse(t *testing.T) { + var testCases = []struct { + name string + whenRouteName string + whenParams []any + expect string + }{ + { + name: "ok, not existing path returns empty url", + whenRouteName: "not-existing", + expect: "", + }, + { + name: "ok,static with no params", + whenRouteName: "/static", + expect: "/static", + }, + { + name: "ok,static with non existent param", + whenRouteName: "/static", + whenParams: []any{"missing param"}, + expect: "/static", + }, + { + name: "ok, wildcard with no params", + whenRouteName: "/static/*", + expect: "/static/*", + }, + { + name: "ok, wildcard with params", + whenRouteName: "/static/*", + whenParams: []any{"foo.txt"}, + expect: "/static/foo.txt", + }, + { + name: "ok, single param without param", + whenRouteName: "/params/:foo", + expect: "/params/:foo", + }, + { + name: "ok, single param with param", + whenRouteName: "/params/:foo", + whenParams: []any{"one"}, + expect: "/params/one", + }, + { + name: "ok, multi param without params", + whenRouteName: "/params/:foo/bar/:qux", + expect: "/params/:foo/bar/:qux", + }, + { + name: "ok, multi param with one param", + whenRouteName: "/params/:foo/bar/:qux", + whenParams: []any{"one"}, + expect: "/params/one/bar/:qux", + }, + { + name: "ok, multi param with all params", + whenRouteName: "/params/:foo/bar/:qux", + whenParams: []any{"one", "two"}, + expect: "/params/one/bar/two", + }, + { + name: "ok, multi param + wildcard with all params", + whenRouteName: "/params/:foo/bar/:qux/*", + whenParams: []any{"one", "two", "three"}, + expect: "/params/one/bar/two/three", + }, + { + name: "ok, backslash is not escaped", + whenRouteName: "/backslash", + whenParams: []any{"test"}, + expect: `/a\b/test`, + }, + { + name: "ok, escaped colon verbs", + whenRouteName: "/params:customVerb", + whenParams: []any{"PATCH"}, + expect: `/params:PATCH`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + dummyHandler := func(Context) error { return nil } + + e.GET("/static", dummyHandler).Name = "/static" + e.GET("/static/*", dummyHandler).Name = "/static/*" + e.GET("/params/:foo", dummyHandler).Name = "/params/:foo" + e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux" + e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*" + e.GET("/a\\b/:x", dummyHandler).Name = "/backslash" + e.GET("/params\\::customVerb", dummyHandler).Name = "/params:customVerb" + + assert.Equal(t, tc.expect, e.Reverse(tc.whenRouteName, tc.whenParams...)) + }) + } +} + +func TestEchoReverseHandleHostProperly(t *testing.T) { + dummyHandler := func(Context) error { return nil } + + e := New() + + // routes added to the default router are different form different hosts + e.GET("/static", dummyHandler).Name = "default-host /static" + e.GET("/static/*", dummyHandler).Name = "xxx" + + // different host + h := e.Host("the_host") + h.GET("/static", dummyHandler).Name = "host2 /static" + h.GET("/static/v2/*", dummyHandler).Name = "xxx" + + assert.Equal(t, "/static", e.Reverse("default-host /static")) + // when actual route does not have params and we provide some to Reverse we should get that route url back + assert.Equal(t, "/static", e.Reverse("default-host /static", "missing param")) + + host2Router := e.Routers()["the_host"] + assert.Equal(t, "/static", host2Router.Reverse("host2 /static")) + assert.Equal(t, "/static", host2Router.Reverse("host2 /static", "missing param")) + + assert.Equal(t, "/static/v2/*", host2Router.Reverse("xxx")) + assert.Equal(t, "/static/v2/foo.txt", host2Router.Reverse("xxx", "foo.txt")) + +} + +func TestEcho_ListenerAddr(t *testing.T) { + e := New() + + addr := e.ListenerAddr() + assert.Nil(t, addr) + + errCh := make(chan error) + go func() { + errCh <- e.Start(":0") + }() + + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) +} + +func TestEcho_TLSListenerAddr(t *testing.T) { + cert, err := os.ReadFile("_fixture/certs/cert.pem") + require.NoError(t, err) + key, err := os.ReadFile("_fixture/certs/key.pem") + require.NoError(t, err) + + e := New() + + addr := e.TLSListenerAddr() + assert.Nil(t, addr) + + errCh := make(chan error) + go func() { + errCh <- e.StartTLS(":0", cert, key) + }() + + err = waitForServerStart(e, errCh, true) + assert.NoError(t, err) +} + +func TestEcho_StartServer(t *testing.T) { + cert, err := os.ReadFile("_fixture/certs/cert.pem") + require.NoError(t, err) + key, err := os.ReadFile("_fixture/certs/key.pem") + require.NoError(t, err) + certs, err := tls.X509KeyPair(cert, key) + require.NoError(t, err) + + var testCases = []struct { + name string + addr string + TLSConfig *tls.Config + expectError string + }{ + { + name: "ok", + addr: ":0", + }, + { + name: "ok, start with TLS", + addr: ":0", + TLSConfig: &tls.Config{Certificates: []tls.Certificate{certs}}, + }, + { + name: "nok, invalid address", + addr: "nope", + expectError: "listen tcp: address nope: missing port in address", + }, + { + name: "nok, invalid tls address", + addr: "nope", + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + expectError: "listen tcp: address nope: missing port in address", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Debug = true + + server := new(http.Server) + server.Addr = tc.addr + if tc.TLSConfig != nil { + server.TLSConfig = tc.TLSConfig + } + + errCh := make(chan error) + go func() { + errCh <- e.StartServer(server) + }() + + err := waitForServerStart(e, errCh, tc.TLSConfig != nil) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + assert.NoError(t, e.Close()) + }) + } +} + +func benchmarkEchoRoutes(b *testing.B, routes []*Route) { + e := New() + req := httptest.NewRequest("GET", "/", nil) u := req.URL w := httptest.NewRecorder() @@ -1246,7 +1825,7 @@ func benchmarkEchoRoutes(b *testing.B, routes []testRoute) { // Add routes for _, route := range routes { - e.Add(route.Method, route.Path, func(c *Context) error { + e.Add(route.Method, route.Path, func(c Context) error { return nil }) } diff --git a/echotest/context.go b/echotest/context.go deleted file mode 100644 index ca3bd1056..000000000 --- a/echotest/context.go +++ /dev/null @@ -1,183 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors - -package echotest - -import ( - "bytes" - "io" - "mime/multipart" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" - - "github.com/labstack/echo/v5" -) - -// ContextConfig is configuration for creating echo.Context for testing purposes. -type ContextConfig struct { - // Request will be used instead of default `httptest.NewRequest(http.MethodGet, "/", nil)` - Request *http.Request - - // Response will be used instead of default `httptest.NewRecorder()` - Response *httptest.ResponseRecorder - - // QueryValues will be set as Request.URL.RawQuery value - QueryValues url.Values - - // Headers will be set as Request.Header value - Headers http.Header - - // PathValues initializes context.PathValues with given value. - PathValues echo.PathValues - - // RouteInfo initializes context.RouteInfo() with given value - RouteInfo *echo.RouteInfo - - // FormValues creates form-urlencoded form out of given values. If there is no - // `content-type` header it will be set to `application/x-www-form-urlencoded` - // In case Request was not set the Request.Method is set to `POST` - // - // FormValues, MultipartForm and JSONBody are mutually exclusive. - FormValues url.Values - - // MultipartForm creates multipart form out of given value. If there is no - // `content-type` header it will be set to `multipart/form-data` - // In case Request was not set the Request.Method is set to `POST` - // - // FormValues, MultipartForm and JSONBody are mutually exclusive. - MultipartForm *MultipartForm - - // JSONBody creates JSON body out of given bytes. If there is no - // `content-type` header it will be set to `application/json` - // In case Request was not set the Request.Method is set to `POST` - // - // FormValues, MultipartForm and JSONBody are mutually exclusive. - JSONBody []byte -} - -// MultipartForm is used to create multipart form out of given value -type MultipartForm struct { - Fields map[string]string - Files []MultipartFormFile -} - -// MultipartFormFile is used to create file in multipart form out of given value -type MultipartFormFile struct { - Fieldname string - Filename string - Content []byte -} - -// ToContext converts ContextConfig to echo.Context -func (conf ContextConfig) ToContext(t *testing.T) *echo.Context { - c, _ := conf.ToContextRecorder(t) - return c -} - -// ToContextRecorder converts ContextConfig to echo.Context and httptest.ResponseRecorder -func (conf ContextConfig) ToContextRecorder(t *testing.T) (*echo.Context, *httptest.ResponseRecorder) { - if conf.Response == nil { - conf.Response = httptest.NewRecorder() - } - isDefaultRequest := false - if conf.Request == nil { - isDefaultRequest = true - conf.Request = httptest.NewRequest(http.MethodGet, "/", nil) - } - - if len(conf.QueryValues) > 0 { - conf.Request.URL.RawQuery = conf.QueryValues.Encode() - } - if len(conf.Headers) > 0 { - conf.Request.Header = conf.Headers - } - if len(conf.FormValues) > 0 { - body := strings.NewReader(url.Values(conf.FormValues).Encode()) - conf.Request.Body = io.NopCloser(body) - conf.Request.ContentLength = int64(body.Len()) - - if conf.Request.Header.Get(echo.HeaderContentType) == "" { - conf.Request.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) - } - if isDefaultRequest { - conf.Request.Method = http.MethodPost - } - } else if conf.MultipartForm != nil { - var body bytes.Buffer - mw := multipart.NewWriter(&body) - for field, value := range conf.MultipartForm.Fields { - if err := mw.WriteField(field, value); err != nil { - t.Fatal(err) - } - } - for _, file := range conf.MultipartForm.Files { - fw, err := mw.CreateFormFile(file.Fieldname, file.Filename) - if err != nil { - t.Fatal(err) - } - if _, err = fw.Write(file.Content); err != nil { - t.Fatal(err) - } - } - if err := mw.Close(); err != nil { - t.Fatal(err) - } - - conf.Request.Body = io.NopCloser(&body) - conf.Request.ContentLength = int64(body.Len()) - if conf.Request.Header.Get(echo.HeaderContentType) == "" { - conf.Request.Header.Set(echo.HeaderContentType, mw.FormDataContentType()) - } - if isDefaultRequest { - conf.Request.Method = http.MethodPost - } - } else if conf.JSONBody != nil { - body := bytes.NewReader(conf.JSONBody) - conf.Request.Body = io.NopCloser(body) - conf.Request.ContentLength = int64(body.Len()) - - if conf.Request.Header.Get(echo.HeaderContentType) == "" { - conf.Request.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) - } - if isDefaultRequest { - conf.Request.Method = http.MethodPost - } - } - - ec := echo.NewContext(conf.Request, conf.Response, echo.New()) - if conf.RouteInfo == nil { - conf.RouteInfo = &echo.RouteInfo{ - Name: "", - Method: conf.Request.Method, - Path: "/test", - Parameters: []string{}, - } - for _, p := range conf.PathValues { - conf.RouteInfo.Parameters = append(conf.RouteInfo.Parameters, p.Name) - } - } - ec.InitializeRoute(conf.RouteInfo, &conf.PathValues) - return ec, conf.Response -} - -// ServeWithHandler serves ContextConfig with given handler and returns httptest.ResponseRecorder for response checking -func (conf ContextConfig) ServeWithHandler(t *testing.T, handler echo.HandlerFunc, opts ...any) *httptest.ResponseRecorder { - c, rec := conf.ToContextRecorder(t) - - errHandler := echo.DefaultHTTPErrorHandler(false) - for _, opt := range opts { - switch o := opt.(type) { - case echo.HTTPErrorHandler: - errHandler = o - } - } - - err := handler(c) - if err != nil { - errHandler(c, err) - } - return rec -} diff --git a/echotest/context_external_test.go b/echotest/context_external_test.go deleted file mode 100644 index d98257148..000000000 --- a/echotest/context_external_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package echotest_test - -import ( - "net/http" - "testing" - - "github.com/labstack/echo/v5" - "github.com/labstack/echo/v5/echotest" - "github.com/stretchr/testify/assert" -) - -func TestToContext_JSONBody(t *testing.T) { - c := echotest.ContextConfig{ - JSONBody: echotest.LoadBytes(t, "testdata/test.json"), - }.ToContext(t) - - payload := struct { - Field string `json:"field"` - }{} - if err := c.Bind(&payload); err != nil { - t.Fatal(err) - } - - assert.Equal(t, "value", payload.Field) - assert.Equal(t, http.MethodPost, c.Request().Method) - assert.Equal(t, echo.MIMEApplicationJSON, c.Request().Header.Get(echo.HeaderContentType)) -} diff --git a/echotest/context_test.go b/echotest/context_test.go deleted file mode 100644 index 66815e4b0..000000000 --- a/echotest/context_test.go +++ /dev/null @@ -1,157 +0,0 @@ -package echotest - -import ( - "net/http" - "net/url" - "strings" - "testing" - - "github.com/labstack/echo/v5" - "github.com/stretchr/testify/assert" -) - -func TestServeWithHandler(t *testing.T) { - handler := func(c *echo.Context) error { - return c.String(http.StatusOK, c.QueryParam("key")) - } - testConf := ContextConfig{ - QueryValues: url.Values{"key": []string{"value"}}, - } - - resp := testConf.ServeWithHandler(t, handler) - - assert.Equal(t, http.StatusOK, resp.Code) - assert.Equal(t, "value", resp.Body.String()) -} - -func TestServeWithHandler_error(t *testing.T) { - handler := func(c *echo.Context) error { - return echo.NewHTTPError(http.StatusBadRequest, "something went wrong") - } - testConf := ContextConfig{ - QueryValues: url.Values{"key": []string{"value"}}, - } - - customErrHandler := echo.DefaultHTTPErrorHandler(true) - - resp := testConf.ServeWithHandler(t, handler, customErrHandler) - - assert.Equal(t, http.StatusBadRequest, resp.Code) - assert.Equal(t, `{"message":"something went wrong"}`+"\n", resp.Body.String()) -} - -func TestToContext_QueryValues(t *testing.T) { - testConf := ContextConfig{ - QueryValues: url.Values{"t": []string{"2006-01-02"}}, - } - c := testConf.ToContext(t) - - v, err := echo.QueryParam[string](c, "t") - - assert.NoError(t, err) - assert.Equal(t, "2006-01-02", v) -} - -func TestToContext_Headers(t *testing.T) { - testConf := ContextConfig{ - Headers: http.Header{echo.HeaderXRequestID: []string{"ABC"}}, - } - c := testConf.ToContext(t) - - id := c.Request().Header.Get(echo.HeaderXRequestID) - - assert.Equal(t, "ABC", id) -} - -func TestToContext_PathValues(t *testing.T) { - testConf := ContextConfig{ - PathValues: echo.PathValues{{ - Name: "key", - Value: "value", - }}, - } - c := testConf.ToContext(t) - - key := c.Param("key") - - assert.Equal(t, "value", key) -} - -func TestToContext_RouteInfo(t *testing.T) { - testConf := ContextConfig{ - RouteInfo: &echo.RouteInfo{ - Name: "my_route", - Method: http.MethodGet, - Path: "/:id", - Parameters: []string{"id"}, - }, - } - c := testConf.ToContext(t) - - ri := c.RouteInfo() - - assert.Equal(t, echo.RouteInfo{ - Name: "my_route", - Method: http.MethodGet, - Path: "/:id", - Parameters: []string{"id"}, - }, ri) -} - -func TestToContext_FormValues(t *testing.T) { - testConf := ContextConfig{ - FormValues: url.Values{"key": []string{"value"}}, - } - c := testConf.ToContext(t) - - assert.Equal(t, "value", c.FormValue("key")) - assert.Equal(t, http.MethodPost, c.Request().Method) - assert.Equal(t, echo.MIMEApplicationForm, c.Request().Header.Get(echo.HeaderContentType)) -} - -func TestToContext_MultipartForm(t *testing.T) { - testConf := ContextConfig{ - MultipartForm: &MultipartForm{ - Fields: map[string]string{ - "key": "value", - }, - Files: []MultipartFormFile{ - { - Fieldname: "file", - Filename: "test.json", - Content: LoadBytes(t, "testdata/test.json"), - }, - }, - }, - } - c := testConf.ToContext(t) - - assert.Equal(t, "value", c.FormValue("key")) - assert.Equal(t, http.MethodPost, c.Request().Method) - assert.Equal(t, true, strings.HasPrefix(c.Request().Header.Get(echo.HeaderContentType), "multipart/form-data; boundary=")) - - fv, err := c.FormFile("file") - if err != nil { - t.Fatal(err) - } - assert.Equal(t, "test.json", fv.Filename) - assert.Equal(t, int64(23), fv.Size) -} - -func TestToContext_JSONBody(t *testing.T) { - testConf := ContextConfig{ - JSONBody: LoadBytes(t, "testdata/test.json"), - } - c := testConf.ToContext(t) - - payload := struct { - Field string `json:"field"` - }{} - if err := c.Bind(&payload); err != nil { - t.Fatal(err) - } - - assert.Equal(t, "value", payload.Field) - assert.Equal(t, http.MethodPost, c.Request().Method) - assert.Equal(t, echo.MIMEApplicationJSON, c.Request().Header.Get(echo.HeaderContentType)) -} diff --git a/echotest/reader.go b/echotest/reader.go deleted file mode 100644 index 0caceca02..000000000 --- a/echotest/reader.go +++ /dev/null @@ -1,46 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors - -package echotest - -import ( - "os" - "path/filepath" - "runtime" - "testing" -) - -type loadBytesOpts func([]byte) []byte - -// TrimNewlineEnd instructs LoadBytes to remove `\n` from the end of loaded file. -func TrimNewlineEnd(bytes []byte) []byte { - bLen := len(bytes) - if bLen > 1 && bytes[bLen-1] == '\n' { - bytes = bytes[:bLen-1] - } - return bytes -} - -// LoadBytes is helper to load file contents relative to current (where test file is) package -// directory. -func LoadBytes(t *testing.T, name string, opts ...loadBytesOpts) []byte { - bytes := loadBytes(t, name, 2) - - for _, f := range opts { - bytes = f(bytes) - } - - return bytes -} - -func loadBytes(t *testing.T, name string, callDepth int) []byte { - _, b, _, _ := runtime.Caller(callDepth) - basepath := filepath.Dir(b) - - path := filepath.Join(basepath, name) // relative path - bytes, err := os.ReadFile(path) - if err != nil { - t.Fatal(err) - } - return bytes[:] -} diff --git a/echotest/reader_external_test.go b/echotest/reader_external_test.go deleted file mode 100644 index 43fd57416..000000000 --- a/echotest/reader_external_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package echotest_test - -import ( - "strings" - "testing" - - "github.com/labstack/echo/v5/echotest" - "github.com/stretchr/testify/assert" -) - -const testJSONContent = `{ - "field": "value" -}` - -func TestLoadBytesOK(t *testing.T) { - data := echotest.LoadBytes(t, "testdata/test.json") - assert.Equal(t, []byte(testJSONContent+"\n"), data) -} - -func TestLoadBytes_custom(t *testing.T) { - data := echotest.LoadBytes(t, "testdata/test.json", func(bytes []byte) []byte { - return []byte(strings.ToUpper(string(bytes))) - }) - assert.Equal(t, []byte(strings.ToUpper(testJSONContent)+"\n"), data) -} diff --git a/echotest/reader_test.go b/echotest/reader_test.go deleted file mode 100644 index 23b3c2dd2..000000000 --- a/echotest/reader_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package echotest - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -const testJSONContent = `{ - "field": "value" -}` - -func TestLoadBytesOK(t *testing.T) { - data := LoadBytes(t, "testdata/test.json") - assert.Equal(t, []byte(testJSONContent+"\n"), data) -} - -func TestLoadBytesOK_TrimNewlineEnd(t *testing.T) { - data := LoadBytes(t, "testdata/test.json", TrimNewlineEnd) - assert.Equal(t, []byte(testJSONContent), data) -} diff --git a/echotest/testdata/test.json b/echotest/testdata/test.json deleted file mode 100644 index 94ae65f17..000000000 --- a/echotest/testdata/test.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "field": "value" -} diff --git a/go.mod b/go.mod index a2480a285..55990b60b 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,23 @@ -module github.com/labstack/echo/v5 +module github.com/labstack/echo/v4 go 1.25.0 require ( + github.com/labstack/gommon v0.5.0 github.com/stretchr/testify v1.11.1 - golang.org/x/net v0.49.0 - golang.org/x/time v0.14.0 + github.com/valyala/fasttemplate v1.2.2 + golang.org/x/crypto v0.50.0 + golang.org/x/net v0.53.0 + golang.org/x/time v0.15.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.22 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/text v0.33.0 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + golang.org/x/sys v0.43.0 // indirect + golang.org/x/text v0.36.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f1e80fc13..77f53a71d 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,29 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/labstack/gommon v0.5.0 h1:6VSQ2NOzsnEJ5W6+84E0RbcaDDmgB6NIAzWCczTEe6c= +github.com/labstack/gommon v0.5.0/go.mod h1:Rzlg7HHy1maLfzBYGg9NZcVuz1sA68HHhLjhcEllYE0= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4= +github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= -golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= -golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= -golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= -golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= +github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/group.go b/group.go index 8092bc904..cb37b123f 100644 --- a/group.go +++ b/group.go @@ -4,7 +4,6 @@ package echo import ( - "io/fs" "net/http" ) @@ -12,167 +11,119 @@ import ( // routes that share a common middleware or functionality that should be separate // from the parent echo instance while still inheriting from it. type Group struct { - echo *Echo + common + host string prefix string + echo *Echo middleware []MiddlewareFunc } // Use implements `Echo#Use()` for sub-routes within the Group. -// Group middlewares are not executed on request when there is no matching route found. func (g *Group) Use(middleware ...MiddlewareFunc) { g.middleware = append(g.middleware, middleware...) + if len(g.middleware) == 0 { + return + } + // group level middlewares are different from Echo `Pre` and `Use` middlewares (those are global). Group level middlewares + // are only executed if they are added to the Router with route. + // So we register catch all route (404 is a safe way to emulate route match) for this group and now during routing the + // Router would find route to match our request path and therefore guarantee the middleware(s) will get executed. + g.RouteNotFound("", NotFoundHandler) + g.RouteNotFound("/*", NotFoundHandler) } -// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error. -func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. +func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return g.Add(http.MethodConnect, path, h, m...) } -// DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error. -func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// DELETE implements `Echo#DELETE()` for sub-routes within the Group. +func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return g.Add(http.MethodDelete, path, h, m...) } -// GET implements `Echo#GET()` for sub-routes within the Group. Panics on error. -func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// GET implements `Echo#GET()` for sub-routes within the Group. +func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return g.Add(http.MethodGet, path, h, m...) } -// HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error. -func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// HEAD implements `Echo#HEAD()` for sub-routes within the Group. +func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return g.Add(http.MethodHead, path, h, m...) } -// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error. -func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. +func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return g.Add(http.MethodOptions, path, h, m...) } -// PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error. -func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// PATCH implements `Echo#PATCH()` for sub-routes within the Group. +func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return g.Add(http.MethodPatch, path, h, m...) } -// POST implements `Echo#POST()` for sub-routes within the Group. Panics on error. -func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// POST implements `Echo#POST()` for sub-routes within the Group. +func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return g.Add(http.MethodPost, path, h, m...) } -// PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error. -func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// PUT implements `Echo#PUT()` for sub-routes within the Group. +func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return g.Add(http.MethodPut, path, h, m...) } -// TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error. -func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// TRACE implements `Echo#TRACE()` for sub-routes within the Group. +func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return g.Add(http.MethodTrace, path, h, m...) } -// Any implements `Echo#Any()` for sub-routes within the Group. Panics on error. -func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { - return g.Add(RouteAny, path, handler, middleware...) -} - -// Match implements `Echo#Match()` for sub-routes within the Group. Panics on error. -func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { - errs := make([]error, 0) - ris := make(Routes, 0) - for _, m := range methods { - ri, err := g.AddRoute(Route{ - Method: m, - Path: path, - Handler: handler, - Middlewares: middleware, - }) - if err != nil { - errs = append(errs, err) - continue - } - ris = append(ris, ri) +// Any implements `Echo#Any()` for sub-routes within the Group. +func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { + routes := make([]*Route, len(methods)) + for i, m := range methods { + routes[i] = g.Add(m, path, handler, middleware...) } - if len(errs) > 0 { - panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + return routes +} + +// Match implements `Echo#Match()` for sub-routes within the Group. +func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { + routes := make([]*Route, len(methods)) + for i, m := range methods { + routes[i] = g.Add(m, path, handler, middleware...) } - return ris + return routes } // Group creates a new sub-group with prefix and optional sub-group-level middleware. -// Important! Group middlewares are only executed in case there was exact route match and not -// for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add -// a catch-all route `/*` for the group which handler returns always 404 func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) { m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) m = append(m, g.middleware...) m = append(m, middleware...) sg = g.echo.Group(g.prefix+prefix, m...) + sg.host = g.host return } -// Static implements `Echo#Static()` for sub-routes within the Group. -func (g *Group) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo { - subFs := MustSubFS(g.echo.Filesystem, fsRoot) - return g.StaticFS(pathPrefix, subFs, middleware...) -} - -// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo { - return g.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(filesystem, false), - middleware..., - ) -} - -// FileFS implements `Echo#FileFS()` for sub-routes within the Group. -// -// Avoid using the leading `/` slash as most of the Go standard library fs.FS implementations require relative paths for -// file operations. -func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo { - return g.GET(path, StaticFileHandler(file, filesystem), m...) -} - -// File implements `Echo#File()` for sub-routes within the Group. Panics on error. -// -// Avoid using the leading `/` slash as most of the Go standard library fs.FS implementations require relative paths for -// file operations. -func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo { - handler := func(c *Context) error { - return c.File(file) - } - return g.Add(http.MethodGet, path, handler, middleware...) +// File implements `Echo#File()` for sub-routes within the Group. +func (g *Group) File(path, file string) { + g.file(path, file, g.GET) } // RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group. // -// Example: `g.RouteNotFound("/*", func(c *echo.Context) error { return c.NoContent(http.StatusNotFound) })` -func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { +// Example: `g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })` +func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return g.Add(RouteNotFound, path, h, m...) } -// Add implements `Echo#Add()` for sub-routes within the Group. Panics on error. -func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { - ri, err := g.AddRoute(Route{ - Method: method, - Path: path, - Handler: handler, - Middlewares: middleware, - }) - if err != nil { - panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage - } - return ri -} - -// AddRoute registers a new Routable with Router -func (g *Group) AddRoute(route Route) (RouteInfo, error) { - // Combine middleware into a new slice to avoid accidentally passing the same slice for +// Add implements `Echo#Add()` for sub-routes within the Group. +func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { + // Combine into a new slice to avoid accidentally passing the same slice for // multiple routes, which would lead to later add() calls overwriting the // middleware from earlier calls. - groupRoute := route.WithPrefix(g.prefix, append([]MiddlewareFunc{}, g.middleware...)) - return g.echo.add(groupRoute) + m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) + m = append(m, g.middleware...) + m = append(m, middleware...) + return g.echo.add(g.host, method, g.prefix+path, handler, m...) } diff --git a/group_fs.go b/group_fs.go new file mode 100644 index 000000000..c1b7ec2d3 --- /dev/null +++ b/group_fs.go @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "io/fs" + "net/http" +) + +// Static implements `Echo#Static()` for sub-routes within the Group. +func (g *Group) Static(pathPrefix, fsRoot string) { + subFs := MustSubFS(g.echo.Filesystem, fsRoot) + g.StaticFS(pathPrefix, subFs) +} + +// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) { + g.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(filesystem, false), + ) +} + +// FileFS implements `Echo#FileFS()` for sub-routes within the Group. +func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { + return g.GET(path, StaticFileHandler(file, filesystem), m...) +} diff --git a/group_fs_test.go b/group_fs_test.go new file mode 100644 index 000000000..caa200940 --- /dev/null +++ b/group_fs_test.go @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "github.com/stretchr/testify/assert" + "io/fs" + "net/http" + "net/http/httptest" + "os" + "testing" +) + +func TestGroup_FileFS(t *testing.T) { + var testCases = []struct { + name string + whenPath string + whenFile string + whenFS fs.FS + givenURL string + expectCode int + expectStartsWith []byte + }{ + { + name: "ok", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/assets/walle", + expectCode: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, requesting invalid path", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/assets/walle.png", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + { + name: "nok, serving not existent file from filesystem", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "not-existent.png", + givenURL: "/assets/walle", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + g := e.Group("/assets") + g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) + + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestGroup_StaticPanic(t *testing.T) { + var testCases = []struct { + name string + givenRoot string + }{ + { + name: "panics for ../", + givenRoot: "../images", + }, + { + name: "panics for /", + givenRoot: "/images", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Filesystem = os.DirFS("./") + + g := e.Group("/assets") + + assert.Panics(t, func() { + g.Static("/images", tc.givenRoot) + }) + }) + } +} diff --git a/group_method_handling_test.go b/group_method_handling_test.go deleted file mode 100644 index 4a8cf0979..000000000 --- a/group_method_handling_test.go +++ /dev/null @@ -1,115 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors - -package echo - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" -) - -// These tests lock in v5's method-handling semantics for routes registered through -// a Group. v5 resolves method mismatches (405) and OPTIONS at the router level and -// does NOT register any implicit per-group catch-all route. -// -// They double as a regression gate. Registering a group-level catch-all β€” whether -// manually via g.RouteNotFound("/*", ...) or automatically (as proposed in #2996 to -// fix CORS-on-group preflight) β€” makes that catch-all match every method, which masks -// both 405 and v5's automatic OPTIONS response as 404 β€” demonstrated directly by -// TestGroupRoute_catchAllMasksMethodHandling below. If that masking becomes the -// default (e.g. #2996 lands), the first two tests below fail. - -// A method mismatch on an existing group route must return 405 with the allowed -// methods, not be masked to 404. -func TestGroupRoute_methodMismatchReturns405(t *testing.T) { - e := New() - g := e.Group("/api") - g.GET("/users", func(c *Context) error { return c.String(http.StatusOK, "users") }) - - req := httptest.NewRequest(http.MethodPost, "/api/users", nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusMethodNotAllowed, rec.Code, - "POST to a GET-only group route must be 405, not masked to 404") - assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow), - "405 response must advertise the allowed methods") -} - -// OPTIONS on an existing group route is answered automatically by Echo (204 + -// Allow). This is the behavior CORS preflight relies on, so it must not be masked. -func TestGroupRoute_automaticOPTIONS(t *testing.T) { - e := New() - g := e.Group("/api") - g.GET("/users", func(c *Context) error { return c.String(http.StatusOK, "users") }) - - req := httptest.NewRequest(http.MethodOptions, "/api/users", nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusNoContent, rec.Code, - "OPTIONS on a registered group route must be auto-answered (204), not masked to 404") - assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow), - "automatic OPTIONS response must advertise the allowed methods") -} - -// A matched concrete route resolves to its own handler; only a genuinely unmatched -// path under the prefix is a 404. -func TestGroupRoute_concreteRoutesResolve(t *testing.T) { - e := New() - g := e.Group("/api") - g.GET("/users", func(c *Context) error { return c.String(http.StatusOK, "users") }) - - status, body := request(http.MethodGet, "/api/users", e) - assert.Equal(t, http.StatusOK, status) - assert.Equal(t, "users", body) - - status, _ = request(http.MethodGet, "/api/nope", e) - assert.Equal(t, http.StatusNotFound, status) -} - -// A group prefix must not affect routing of routes registered outside the group. -func TestGroup_doesNotAffectRootRoutes(t *testing.T) { - e := New() - e.GET("/health", func(c *Context) error { return c.String(http.StatusOK, "root") }) - g := e.Group("/api") - g.GET("/users", func(c *Context) error { return c.String(http.StatusOK, "users") }) - - status, body := request(http.MethodGet, "/health", e) - assert.Equal(t, http.StatusOK, status) - assert.Equal(t, "root", body) -} - -// Characterization of the regression the 405/OPTIONS tests above guard against: -// registering a group-wide catch-all (the manual equivalent of #2996's auto- -// registration) makes it match every method, so method mismatches and the automatic -// OPTIONS response are masked as 404 even though the concrete route still resolves. -// If a future change teaches the catch-all to preserve method semantics, update this. -func TestGroupRoute_catchAllMasksMethodHandling(t *testing.T) { - e := New() - g := e.Group("/api") - g.GET("/users", func(c *Context) error { return c.String(http.StatusOK, "users") }) - g.RouteNotFound("/*", func(c *Context) error { return c.NoContent(http.StatusNotFound) }) - - // The concrete route still resolves. - status, body := request(http.MethodGet, "/api/users", e) - assert.Equal(t, http.StatusOK, status) - assert.Equal(t, "users", body) - - // But the catch-all masks the method mismatch (would be 405) ... - post := httptest.NewRequest(http.MethodPost, "/api/users", nil) - postRec := httptest.NewRecorder() - e.ServeHTTP(postRec, post) - assert.Equal(t, http.StatusNotFound, postRec.Code, - "a group-wide catch-all masks the 405 method-mismatch as 404") - - // ... and the automatic OPTIONS response (would be 204). - opts := httptest.NewRequest(http.MethodOptions, "/api/users", nil) - optsRec := httptest.NewRecorder() - e.ServeHTTP(optsRec, opts) - assert.Equal(t, http.StatusNotFound, optsRec.Code, - "a group-wide catch-all masks the automatic OPTIONS (204) response as 404") -} diff --git a/group_test.go b/group_test.go index 7078b6497..78c2ed485 100644 --- a/group_test.go +++ b/group_test.go @@ -4,70 +4,31 @@ package echo import ( - "io/fs" "net/http" "net/http/httptest" "os" - "strings" "testing" "github.com/stretchr/testify/assert" ) -func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) { - e := New() - - called := false - mw := func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { - called = true - return c.NoContent(http.StatusTeapot) - } - } - // even though group has middleware it will not be executed when there are no routes under that group - _ = e.Group("/group", mw) - - status, body := request(http.MethodGet, "/group/nope", e) - assert.Equal(t, http.StatusNotFound, status) - assert.Equal(t, `{"message":"Not Found"}`+"\n", body) - - assert.False(t, called) -} - -func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) { - e := New() - - called := false - mw := func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { - called = true - return c.NoContent(http.StatusTeapot) - } - } - // even though group has middleware and routes when we have no match on some route the middlewares for that - // group will not be executed - g := e.Group("/group", mw) - g.GET("/yes", handlerFunc) - - status, body := request(http.MethodGet, "/group/nope", e) - assert.Equal(t, http.StatusNotFound, status) - assert.Equal(t, `{"message":"Not Found"}`+"\n", body) - - assert.False(t, called) -} - -func TestGroup_multiLevelGroup(t *testing.T) { - e := New() - - api := e.Group("/api") - users := api.Group("/users") - users.GET("/activate", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - - status, body := request(http.MethodGet, "/api/users/activate", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, `OK`, body) +// TODO: Fix me +func TestGroup(t *testing.T) { + g := New().Group("/group") + h := func(Context) error { return nil } + g.CONNECT("/", h) + g.DELETE("/", h) + g.GET("/", h) + g.HEAD("/", h) + g.OPTIONS("/", h) + g.PATCH("/", h) + g.POST("/", h) + g.PUT("/", h) + g.TRACE("/", h) + g.Any("/", h) + g.Match([]string{http.MethodGet, http.MethodPost}, "/", h) + g.Static("/static", "/tmp") + g.File("/walle", "_fixture/images//walle.png") } func TestGroupFile(t *testing.T) { @@ -87,29 +48,29 @@ func TestGroupRouteMiddleware(t *testing.T) { // Ensure middleware slices are not re-used e := New() g := e.Group("/group") - h := func(*Context) error { return nil } + h := func(Context) error { return nil } m1 := func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { + return func(c Context) error { return next(c) } } m2 := func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { + return func(c Context) error { return next(c) } } m3 := func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { + return func(c Context) error { return next(c) } } m4 := func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { + return func(c Context) error { return c.NoContent(404) } } m5 := func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { + return func(c Context) error { return c.NoContent(405) } } @@ -128,17 +89,17 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { e := New() g := e.Group("/group") m1 := func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { + return func(c Context) error { return next(c) } } m2 := func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { - return c.String(http.StatusOK, c.RouteInfo().Path) + return func(c Context) error { + return c.String(http.StatusOK, c.Path()) } } - h := func(c *Context) error { - return c.String(http.StatusOK, c.RouteInfo().Path) + h := func(c Context) error { + return c.String(http.StatusOK, c.Path()) } g.Use(m1) g.GET("/help", h, m2) @@ -162,155 +123,11 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { } -func TestGroup_CONNECT(t *testing.T) { - e := New() - - users := e.Group("/users") - ri := users.CONNECT("/activate", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - - assert.Equal(t, http.MethodConnect, ri.Method) - assert.Equal(t, "/users/activate", ri.Path) - assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodConnect, "/users/activate", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, `OK`, body) -} - -func TestGroup_DELETE(t *testing.T) { - e := New() - - users := e.Group("/users") - ri := users.DELETE("/activate", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - - assert.Equal(t, http.MethodDelete, ri.Method) - assert.Equal(t, "/users/activate", ri.Path) - assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodDelete, "/users/activate", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, `OK`, body) -} - -func TestGroup_HEAD(t *testing.T) { - e := New() - - users := e.Group("/users") - ri := users.HEAD("/activate", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - - assert.Equal(t, http.MethodHead, ri.Method) - assert.Equal(t, "/users/activate", ri.Path) - assert.Equal(t, http.MethodHead+":/users/activate", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodHead, "/users/activate", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, `OK`, body) -} - -func TestGroup_OPTIONS(t *testing.T) { - e := New() - - users := e.Group("/users") - ri := users.OPTIONS("/activate", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - - assert.Equal(t, http.MethodOptions, ri.Method) - assert.Equal(t, "/users/activate", ri.Path) - assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodOptions, "/users/activate", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, `OK`, body) -} - -func TestGroup_PATCH(t *testing.T) { - e := New() - - users := e.Group("/users") - ri := users.PATCH("/activate", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - - assert.Equal(t, http.MethodPatch, ri.Method) - assert.Equal(t, "/users/activate", ri.Path) - assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodPatch, "/users/activate", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, `OK`, body) -} - -func TestGroup_POST(t *testing.T) { - e := New() - - users := e.Group("/users") - ri := users.POST("/activate", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - - assert.Equal(t, http.MethodPost, ri.Method) - assert.Equal(t, "/users/activate", ri.Path) - assert.Equal(t, http.MethodPost+":/users/activate", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodPost, "/users/activate", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, `OK`, body) -} - -func TestGroup_PUT(t *testing.T) { - e := New() - - users := e.Group("/users") - ri := users.PUT("/activate", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - - assert.Equal(t, http.MethodPut, ri.Method) - assert.Equal(t, "/users/activate", ri.Path) - assert.Equal(t, http.MethodPut+":/users/activate", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodPut, "/users/activate", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, `OK`, body) -} - -func TestGroup_TRACE(t *testing.T) { - e := New() - - users := e.Group("/users") - ri := users.TRACE("/activate", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - - assert.Equal(t, http.MethodTrace, ri.Method) - assert.Equal(t, "/users/activate", ri.Path) - assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodTrace, "/users/activate", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, `OK`, body) -} - func TestGroup_RouteNotFound(t *testing.T) { var testCases = []struct { - expectRoute any name string whenURL string + expectRoute any expectCode int }{ { @@ -344,10 +161,10 @@ func TestGroup_RouteNotFound(t *testing.T) { e := New() g := e.Group("/group") - okHandler := func(c *Context) error { + okHandler := func(c Context) error { return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) } - notFoundHandler := func(c *Context) error { + notFoundHandler := func(c Context) error { return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) } @@ -371,416 +188,44 @@ func TestGroup_RouteNotFound(t *testing.T) { } } -func TestGroup_Any(t *testing.T) { - e := New() - - users := e.Group("/users") - ri := users.Any("/activate", func(c *Context) error { - return c.String(http.StatusTeapot, "OK from ANY") - }) - - assert.Equal(t, RouteAny, ri.Method) - assert.Equal(t, "/users/activate", ri.Path) - assert.Equal(t, RouteAny+":/users/activate", ri.Name) - assert.Nil(t, ri.Parameters) - - status, body := request(http.MethodTrace, "/users/activate", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, `OK from ANY`, body) -} - -func TestGroup_Match(t *testing.T) { - e := New() - - myMethods := []string{http.MethodGet, http.MethodPost} - users := e.Group("/users") - ris := users.Match(myMethods, "/activate", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - assert.Len(t, ris, 2) - - for _, m := range myMethods { - status, body := request(m, "/users/activate", e) - assert.Equal(t, http.StatusTeapot, status) - assert.Equal(t, `OK`, body) - } -} - -func TestGroup_MatchWithErrors(t *testing.T) { - e := New() - - users := e.Group("/users") - users.GET("/activate", func(c *Context) error { - return c.String(http.StatusOK, "OK") - }) - myMethods := []string{http.MethodGet, http.MethodPost} - - errs := func() (errs []error) { - defer func() { - if r := recover(); r != nil { - if tmpErr, ok := r.([]error); ok { - errs = tmpErr - return - } - panic(r) - } - }() - - users.Match(myMethods, "/activate", func(c *Context) error { - return c.String(http.StatusTeapot, "OK") - }) - return nil - }() - assert.Len(t, errs, 1) - assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed") - - for _, m := range myMethods { - status, body := request(m, "/users/activate", e) - - expect := http.StatusTeapot - if m == http.MethodGet { - expect = http.StatusOK - } - assert.Equal(t, expect, status) - assert.Equal(t, `OK`, body) - } -} - -func TestGroup_Static(t *testing.T) { - e := New() - - g := e.Group("/books") - ri := g.Static("/download", "_fixture") - assert.Equal(t, http.MethodGet, ri.Method) - assert.Equal(t, "/books/download*", ri.Path) - assert.Equal(t, "GET:/books/download*", ri.Name) - assert.Equal(t, []string{"*"}, ri.Parameters) - - req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusOK, rec.Code) - body := rec.Body.String() - assert.True(t, strings.HasPrefix(body, "")) -} - -func TestGroup_StaticMultiTest(t *testing.T) { - var testCases = []struct { - name string - givenPrefix string - givenRoot string - whenURL string - expectHeaderLocation string - expectBodyStartsWith string - expectBodyNotContains string - expectStatus int - }{ - { - name: "ok", - givenPrefix: "/images", - givenRoot: "_fixture/images", - whenURL: "/test/images/walle.png", - expectStatus: http.StatusOK, - expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), - }, - { - name: "ok, without prefix", - givenPrefix: "", - givenRoot: "_fixture/images", - whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png` - expectStatus: http.StatusOK, - expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), - }, - { - name: "nok, without prefix does not serve dir index", - givenPrefix: "", - givenRoot: "_fixture/images", - whenURL: "/test/", // `/test` + `*` creates route `/test*` - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "No file", - givenPrefix: "/images", - givenRoot: "_fixture/scripts", - whenURL: "/test/images/bolt.png", - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "Directory", - givenPrefix: "/images", - givenRoot: "_fixture/images", - whenURL: "/test/images/", - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "Directory Redirect", - givenPrefix: "/", - givenRoot: "_fixture", - whenURL: "/test/folder", - expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/test/folder/", - expectBodyStartsWith: "", - }, - { - name: "Directory Redirect with non-root path", - givenPrefix: "/static", - givenRoot: "_fixture", - whenURL: "/test/static", - expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/test/static/", - expectBodyStartsWith: "", - }, - { - name: "Prefixed directory 404 (request URL without slash)", - givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" - givenRoot: "_fixture", - whenURL: "/test/folder", // no trailing slash - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "Prefixed directory redirect (without slash redirect to slash)", - givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* - givenRoot: "_fixture", - whenURL: "/test/folder", // no trailing slash - expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/test/folder/", - expectBodyStartsWith: "", - }, - { - name: "Directory with index.html", - givenPrefix: "/", - givenRoot: "_fixture", - whenURL: "/test/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "Prefixed directory with index.html (prefix ending with slash)", - givenPrefix: "/assets/", - givenRoot: "_fixture", - whenURL: "/test/assets/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "Prefixed directory with index.html (prefix ending without slash)", - givenPrefix: "/assets", - givenRoot: "_fixture", - whenURL: "/test/assets/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "Sub-directory with index.html", - givenPrefix: "/", - givenRoot: "_fixture", - whenURL: "/test/folder/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "nok, URL encoded path traversal (single encoding, slash - unix separator)", - givenRoot: "_fixture/dist/public", - whenURL: "/%2e%2e%2fprivate.txt", - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - expectBodyNotContains: `private file`, - }, - { - name: "nok, URL encoded path traversal (single encoding, backslash - windows separator)", - givenRoot: "_fixture/dist/public", - whenURL: "/%2e%2e%5cprivate.txt", - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - expectBodyNotContains: `private file`, - }, - { - name: "do not allow directory traversal (backslash - windows separator)", - givenPrefix: "/", - givenRoot: "_fixture/", - whenURL: `/test/..\\middleware/basic_auth.go`, - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "do not allow directory traversal (slash - unix separator)", - givenPrefix: "/", - givenRoot: "_fixture/", - whenURL: `/test/../middleware/basic_auth.go`, - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - - g := e.Group("/test") - g.Static(tc.givenPrefix, tc.givenRoot) - - req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - - assert.Equal(t, tc.expectStatus, rec.Code) - body := rec.Body.String() - if tc.expectBodyStartsWith != "" { - assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) - } else { - assert.Equal(t, "", body) - } - if tc.expectBodyNotContains != "" { - assert.NotContains(t, body, tc.expectBodyNotContains) - } - - if tc.expectHeaderLocation != "" { - assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) - } else { - _, ok := rec.Result().Header["Location"] - assert.False(t, ok) - } - }) - } -} - -func TestGroup_FileFS(t *testing.T) { - var testCases = []struct { - whenFS fs.FS - name string - whenPath string - whenFile string - givenURL string - expectStartsWith []byte - expectCode int - }{ - { - name: "ok", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/assets/walle", - expectCode: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, requesting invalid path", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/assets/walle.png", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - { - name: "nok, serving not existent file from filesystem", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "not-existent.png", - givenURL: "/assets/walle", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - g := e.Group("/assets") - g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) - - req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.Equal(t, tc.expectCode, rec.Code) - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} - -func TestGroup_StaticPanic(t *testing.T) { - var testCases = []struct { - name string - givenRoot string - }{ - { - name: "panics for ../", - givenRoot: "../images", - }, - { - name: "panics for /", - givenRoot: "/images", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.Filesystem = os.DirFS("./") - - g := e.Group("/assets") - - assert.Panics(t, func() { - g.Static("/images", tc.givenRoot) - }) - }) - } -} - func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) { var testCases = []struct { - expectBody any - name string - whenURL string - expectCode int - givenCustom404 bool - expectMiddlewareCalled bool + name string + givenCustom404 bool + whenURL string + expectBody any + expectCode int }{ { - name: "ok, custom 404 handler is called with middleware", - givenCustom404: true, - whenURL: "/group/test3", - expectBody: "404 GET /group/*", - expectCode: http.StatusNotFound, - expectMiddlewareCalled: true, // because RouteNotFound is added after middleware is added + name: "ok, custom 404 handler is called with middleware", + givenCustom404: true, + whenURL: "/group/test3", + expectBody: "GET /group/*", + expectCode: http.StatusNotFound, }, { - name: "ok, default group 404 handler is not called with middleware", - givenCustom404: false, - whenURL: "/group/test3", - expectBody: "404 GET /*", - expectCode: http.StatusNotFound, - expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added + name: "ok, default group 404 handler is called with middleware", + givenCustom404: false, + whenURL: "/group/test3", + expectBody: "{\"message\":\"Not Found\"}\n", + expectCode: http.StatusNotFound, }, { - name: "ok, (no slash) default group 404 handler is called with middleware", - givenCustom404: false, - whenURL: "/group", - expectBody: "404 GET /*", - expectCode: http.StatusNotFound, - expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added + name: "ok, (no slash) default group 404 handler is called with middleware", + givenCustom404: false, + whenURL: "/group", + expectBody: "{\"message\":\"Not Found\"}\n", + expectCode: http.StatusNotFound, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - okHandler := func(c *Context) error { + okHandler := func(c Context) error { return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) } - notFoundHandler := func(c *Context) error { - return c.String(http.StatusNotFound, "404 "+c.Request().Method+" "+c.Path()) + notFoundHandler := func(c Context) error { + return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) } e := New() @@ -792,7 +237,7 @@ func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) { middlewareCalled := false g.Use(func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { + return func(c Context) error { middlewareCalled = true return next(c) } @@ -806,7 +251,7 @@ func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) { e.ServeHTTP(rec, req) - assert.Equal(t, tc.expectMiddlewareCalled, middlewareCalled) + assert.True(t, middlewareCalled) assert.Equal(t, tc.expectCode, rec.Code) assert.Equal(t, tc.expectBody, rec.Body.String()) }) diff --git a/httperror.go b/httperror.go deleted file mode 100644 index 8cb10c8ef..000000000 --- a/httperror.go +++ /dev/null @@ -1,162 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors - -package echo - -import ( - "errors" - "fmt" - "net/http" -) - -// The following errors can produce HTTP status code by implementing HTTPStatusCoder interface -var ( - ErrBadRequest = &httpError{http.StatusBadRequest} // 400 - ErrUnauthorized = &httpError{http.StatusUnauthorized} // 401 - ErrForbidden = &httpError{http.StatusForbidden} // 403 - ErrNotFound = &httpError{http.StatusNotFound} // 404 - ErrMethodNotAllowed = &httpError{http.StatusMethodNotAllowed} // 405 - ErrRequestTimeout = &httpError{http.StatusRequestTimeout} // 408 - ErrStatusRequestEntityTooLarge = &httpError{http.StatusRequestEntityTooLarge} // 413 - ErrUnsupportedMediaType = &httpError{http.StatusUnsupportedMediaType} // 415 - ErrTooManyRequests = &httpError{http.StatusTooManyRequests} // 429 - ErrInternalServerError = &httpError{http.StatusInternalServerError} // 500 - ErrBadGateway = &httpError{http.StatusBadGateway} // 502 - ErrServiceUnavailable = &httpError{http.StatusServiceUnavailable} // 503 -) - -// The following errors fall into 500 (InternalServerError) category -var ( - ErrValidatorNotRegistered = errors.New("validator not registered") - ErrRendererNotRegistered = errors.New("renderer not registered") - ErrInvalidRedirectCode = errors.New("invalid redirect status code") - ErrCookieNotFound = errors.New("cookie not found") - ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") - ErrInvalidListenerNetwork = errors.New("invalid listener network") -) - -// HTTPStatusCoder is an interface that errors can implement to produce status code for HTTP response -type HTTPStatusCoder interface { - StatusCode() int -} - -// StatusCode returns status code from err if it implements HTTPStatusCoder interface. -// If err does not implement the interface, it returns 0. -func StatusCode(err error) int { - var sc HTTPStatusCoder - if errors.As(err, &sc) { - return sc.StatusCode() - } - return 0 -} - -// ResolveResponseStatus returns the Response and HTTP status code that should be (or has been) sent for rw, -// given an optional error. -// -// This function is useful for middleware and handlers that need to figure out the HTTP status -// code to return based on the error that occurred or what was set in the response. -// -// Precedence rules: -// 1. If the response has already been committed, the committed status wins (err is ignored). -// 2. Otherwise, start with 200 OK (net/http default if WriteHeader is never called). -// 3. If the response has a non-zero suggested status, use it. -// 4. If err != nil, it overrides the suggested status: -// - StatusCode(err) if non-zero -// - otherwise 500 Internal Server Error. -func ResolveResponseStatus(rw http.ResponseWriter, err error) (resp *Response, status int) { - resp, _ = UnwrapResponse(rw) - - // once committed (sent to the client), the wire status is fixed; err cannot change it. - if resp != nil && resp.Committed { - if resp.Status == 0 { - // unlikely path, but fall back to net/http implicit default if handler never calls WriteHeader - return resp, http.StatusOK - } - return resp, resp.Status - } - - // net/http implicit default if handler never calls WriteHeader. - status = http.StatusOK - - // suggested status written from middleware/handlers, if present. - if resp != nil && resp.Status != 0 { - status = resp.Status - } - - // error overrides suggested status (matches typical Echo error-handler semantics). - if err != nil { - if s := StatusCode(err); s != 0 { - status = s - } else { - status = http.StatusInternalServerError - } - } - - return resp, status -} - -// NewHTTPError creates a new instance of HTTPError -func NewHTTPError(code int, message string) *HTTPError { - return &HTTPError{ - Code: code, - Message: message, - } -} - -// HTTPError represents an error that occurred while handling a request. -type HTTPError struct { - // Code is status code for HTTP response - Code int `json:"-"` - Message string `json:"message"` - err error -} - -// StatusCode returns status code for HTTP response -func (he *HTTPError) StatusCode() int { - return he.Code -} - -// Error makes it compatible with the ` error ` interface. -func (he *HTTPError) Error() string { - msg := he.Message - if msg == "" { - msg = http.StatusText(he.Code) - } - if he.err == nil { - return fmt.Sprintf("code=%d, message=%v", he.Code, msg) - } - return fmt.Sprintf("code=%d, message=%v, err=%v", he.Code, msg, he.err.Error()) -} - -// Wrap returns a new HTTPError with given errors wrapped inside -func (he HTTPError) Wrap(err error) error { - return &HTTPError{ - Code: he.Code, - Message: he.Message, - err: err, - } -} - -func (he *HTTPError) Unwrap() error { - return he.err -} - -type httpError struct { - code int -} - -func (he httpError) StatusCode() int { - return he.code -} - -func (he httpError) Error() string { - return http.StatusText(he.code) // does not include status code -} - -func (he httpError) Wrap(err error) error { - return &HTTPError{ - Code: he.code, - Message: http.StatusText(he.code), - err: err, - } -} diff --git a/httperror_external_test.go b/httperror_external_test.go deleted file mode 100644 index 91acdca25..000000000 --- a/httperror_external_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors - -// run tests as external package to get real feel for API -package echo_test - -import ( - "encoding/json" - "fmt" - "github.com/labstack/echo/v5" - "net/http" - "net/http/httptest" -) - -func ExampleDefaultHTTPErrorHandler() { - e := echo.New() - e.GET("/api/endpoint", func(c *echo.Context) error { - return &apiError{ - Code: http.StatusBadRequest, - Body: map[string]any{"message": "custom error"}, - } - }) - - req := httptest.NewRequest(http.MethodGet, "/api/endpoint?err=1", nil) - resp := httptest.NewRecorder() - - e.ServeHTTP(resp, req) - - fmt.Printf("%d %s", resp.Code, resp.Body.String()) - - // Output: 400 {"error":{"message":"custom error"}} -} - -type apiError struct { - Code int - Body any -} - -func (e *apiError) StatusCode() int { - return e.Code -} - -func (e *apiError) MarshalJSON() ([]byte, error) { - type body struct { - Error any `json:"error"` - } - return json.Marshal(body{Error: e.Body}) -} - -func (e *apiError) Error() string { - return http.StatusText(e.Code) -} diff --git a/httperror_test.go b/httperror_test.go deleted file mode 100644 index 778a186ce..000000000 --- a/httperror_test.go +++ /dev/null @@ -1,186 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors - -package echo - -import ( - "errors" - "fmt" - "net/http" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestHTTPError_StatusCode(t *testing.T) { - var err error = &HTTPError{Code: http.StatusBadRequest, Message: "my error message"} - - code := 0 - var sc HTTPStatusCoder - if errors.As(err, &sc) { - code = sc.StatusCode() - } - assert.Equal(t, http.StatusBadRequest, code) -} - -func TestHTTPError_Error(t *testing.T) { - var testCases = []struct { - name string - error error - expect string - }{ - { - name: "ok, without message", - error: &HTTPError{Code: http.StatusBadRequest}, - expect: "code=400, message=Bad Request", - }, - { - name: "ok, with message", - error: &HTTPError{Code: http.StatusBadRequest, Message: "my error message"}, - expect: "code=400, message=my error message", - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.expect, tc.error.Error()) - }) - } -} - -func TestHTTPError_WrapUnwrap(t *testing.T) { - err := &HTTPError{Code: http.StatusBadRequest, Message: "bad"} - wrapped := err.Wrap(errors.New("my_error")).(*HTTPError) - - err.Code = http.StatusOK - err.Message = "changed" - - assert.Equal(t, http.StatusBadRequest, wrapped.Code) - assert.Equal(t, "bad", wrapped.Message) - - assert.Equal(t, errors.New("my_error"), wrapped.Unwrap()) - assert.Equal(t, "code=400, message=bad, err=my_error", wrapped.Error()) -} - -func TestNewHTTPError(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, "bad") - err2 := &HTTPError{Code: http.StatusBadRequest, Message: "bad"} - - assert.Equal(t, err2, err) -} - -func TestStatusCode(t *testing.T) { - var testCases = []struct { - name string - err error - expect int - }{ - { - name: "ok, HTTPError", - err: &HTTPError{Code: http.StatusNotFound}, - expect: http.StatusNotFound, - }, - { - name: "ok, sentinel error", - err: ErrNotFound, - expect: http.StatusNotFound, - }, - { - name: "ok, wrapped HTTPError", - err: fmt.Errorf("wrapped: %w", &HTTPError{Code: http.StatusTeapot}), - expect: http.StatusTeapot, - }, - { - name: "nok, normal error", - err: errors.New("error"), - expect: 0, - }, - { - name: "nok, nil", - err: nil, - expect: 0, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.expect, StatusCode(tc.err)) - }) - } -} - -func TestResolveResponseStatus(t *testing.T) { - someErr := errors.New("some error") - - var testCases = []struct { - name string - whenResp http.ResponseWriter - whenErr error - expectStatus int - expectResp bool - }{ - { - name: "nil resp, nil err -> 200", - whenResp: nil, - whenErr: nil, - expectStatus: http.StatusOK, - expectResp: false, - }, - { - name: "resp suggested status used when no error", - whenResp: &Response{Status: http.StatusCreated}, - whenErr: nil, - expectStatus: http.StatusCreated, - expectResp: true, - }, - { - name: "error overrides suggested status with StatusCode(err)", - whenResp: &Response{Status: http.StatusAccepted}, - whenErr: ErrBadRequest, - expectStatus: http.StatusBadRequest, - expectResp: true, - }, - { - name: "error overrides suggested status with 500 when StatusCode(err)==0", - whenResp: &Response{Status: http.StatusAccepted}, - whenErr: ErrInternalServerError, - expectStatus: http.StatusInternalServerError, - expectResp: true, - }, - { - name: "nil resp, error -> 500 when StatusCode(err)==0", - whenResp: nil, - whenErr: someErr, - expectStatus: http.StatusInternalServerError, - expectResp: false, - }, - { - name: "committed response wins over error", - whenResp: &Response{Committed: true, Status: http.StatusNoContent}, - whenErr: someErr, - expectStatus: http.StatusNoContent, - expectResp: true, - }, - { - name: "committed response with status 0 falls back to 200 (defensive)", - whenResp: &Response{Committed: true, Status: 0}, - whenErr: someErr, - expectStatus: http.StatusOK, - expectResp: true, - }, - { - name: "resp with status 0 and no error -> 200", - whenResp: &Response{Status: 0}, - whenErr: nil, - expectStatus: http.StatusOK, - expectResp: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - resp, status := ResolveResponseStatus(tc.whenResp, tc.whenErr) - - assert.Equal(t, tc.expectResp, resp != nil) - assert.Equal(t, tc.expectStatus, status) - }) - } -} diff --git a/internal/pathutil/pathutil.go b/internal/pathutil/pathutil.go index e9171752e..9934fa31c 100644 --- a/internal/pathutil/pathutil.go +++ b/internal/pathutil/pathutil.go @@ -11,9 +11,9 @@ package pathutil // Backslash is included as defense-in-depth against Windows-style separators even // though fs.FS itself only uses forward slashes. // -// Such sequences let an attacker smuggle a separator past the router, which by -// default matches on the raw encoded path, so they must be rejected before -// unescaping when resolving static files. +// Such sequences let an attacker smuggle a separator past the router, which +// matches on the raw encoded path, so they must be rejected before unescaping +// when resolving static files. func HasEncodedPathSeparator(s string) bool { for i := 0; i+2 < len(s); i++ { if s[i] != '%' { diff --git a/ip.go b/ip.go index c864e0689..dce51f55d 100644 --- a/ip.go +++ b/ip.go @@ -202,8 +202,8 @@ func (c *ipChecker) trust(ip net.IP) bool { // See https://echo.labstack.com/guide/ip-address for more details. type IPExtractor func(*http.Request) string -// ExtractIPDirect extracts an IP address using an actual IP address. -// Use this if your server faces to internet directly (i.e.: uses no proxy). +// ExtractIPDirect extracts IP address using actual IP address. +// Use this if your server faces to internet directory (i.e.: uses no proxy). func ExtractIPDirect() IPExtractor { return extractIP } @@ -219,24 +219,30 @@ func extractIP(req *http.Request) string { return host } -// ExtractIPFromRealIPHeader extracts IP address using `x-real-ip` header. +// ExtractIPFromRealIPHeader extracts IP address using x-real-ip header. // Use this if you put proxy which uses this header. func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { checker := newIPChecker(options) return func(req *http.Request) string { + directIP := extractIP(req) realIP := req.Header.Get(HeaderXRealIP) - if realIP != "" { + if realIP == "" { + return directIP + } + + if checker.trust(net.ParseIP(directIP)) { realIP = strings.TrimPrefix(realIP, "[") realIP = strings.TrimSuffix(realIP, "]") - if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) { + if rIP := net.ParseIP(realIP); rIP != nil { return realIP } } - return extractIP(req) + + return directIP } } -// ExtractIPFromXFFHeader extracts IP address using `x-forwarded-for` header. +// ExtractIPFromXFFHeader extracts IP address using x-forwarded-for header. // Use this if you put proxy which uses this header. // This returns nearest untrustable IP. If all IPs are trustable, returns furthest one (i.e.: XFF[0]). func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor { @@ -265,45 +271,3 @@ func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor { return strings.TrimSpace(ips[0]) } } - -// LegacyIPExtractor returns an IPExtractor that derives the client IP address -// from common proxy headers, falling back to the request's remote address. -// -// Resolution order: -// 1. X-Forwarded-For: returns the first IP in the comma-separated list. -// If multiple values are present, only the left-most (original client) -// is used. Surrounding brackets (for IPv6) are stripped. -// 2. X-Real-IP: used if X-Forwarded-For is absent. Surrounding brackets -// (for IPv6) are stripped. -// 3. req.RemoteAddr: used as a fallback; the host portion is extracted -// via net.SplitHostPort. -// -// Notes: -// - No validation is performed on header values. -// - This function trusts headers as-is and is therefore not safe against -// spoofing unless the application is behind a trusted proxy that is -// configured to strip/replace/modify headers correctly. -// -// Use ExtractIPFromXFFHeader or ExtractIPFromRealIPHeader instead of LegacyIPExtractor. -func LegacyIPExtractor() IPExtractor { - return legacyIPExtractor -} - -func legacyIPExtractor(req *http.Request) string { - if ip := req.Header.Get(HeaderXForwardedFor); ip != "" { - i := strings.IndexAny(ip, ",") - if i > 0 { - ip = strings.TrimSpace(ip[:i]) - } - ip = strings.TrimPrefix(ip, "[") - ip = strings.TrimSuffix(ip, "]") - return ip - } - if ip := req.Header.Get(HeaderXRealIP); ip != "" { - ip = strings.TrimPrefix(ip, "[") - ip = strings.TrimSuffix(ip, "]") - return ip - } - ra, _, _ := net.SplitHostPort(req.RemoteAddr) - return ra -} diff --git a/ip_test.go b/ip_test.go index b20368616..e850b78cb 100644 --- a/ip_test.go +++ b/ip_test.go @@ -22,8 +22,8 @@ func mustParseCIDR(s string) *net.IPNet { func TestIPChecker_TrustOption(t *testing.T) { var testCases = []struct { name string - whenIP string givenOptions []TrustOption + whenIP string expect bool }{ { @@ -490,14 +490,14 @@ func TestExtractIPDirect(t *testing.T) { } func TestExtractIPFromRealIPHeader(t *testing.T) { - _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") + _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.0/24") _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64") var testCases = []struct { - whenRequest http.Request name string - expectIP string givenTrustOptions []TrustOption + whenRequest http.Request + expectIP string }{ { name: "request has no headers, extracts IP from request remote addr", @@ -518,36 +518,42 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, { name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"203.0.113.199"}, // <-- this is untrusted + HeaderXRealIP: []string{"203.0.113.199"}, }, - RemoteAddr: "203.0.113.1:8080", + RemoteAddr: "8.8.8.8:8080", // <-- this is untrusted }, - expectIP: "203.0.113.1", + expectIP: "8.8.8.8", }, { name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"[2001:db8::113:199]"}, // <-- this is untrusted + HeaderXRealIP: []string{"[bc01:1010::9090:1888]"}, }, - RemoteAddr: "[2001:db8::113:1]:8080", + RemoteAddr: "[fe64:aa10::1]:8080", // <-- this is untrusted }, - expectIP: "2001:db8::113:1", + expectIP: "fe64:aa10::1", }, { name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" - TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.0/24" }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"203.0.113.199"}, + HeaderXRealIP: []string{"8.8.8.8"}, }, RemoteAddr: "203.0.113.1:8080", }, - expectIP: "203.0.113.199", + expectIP: "8.8.8.8", }, { name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", @@ -556,11 +562,11 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"[2001:db8::113:199]"}, + HeaderXRealIP: []string{"[fe64:db8::113:199]"}, }, RemoteAddr: "[2001:db8::113:1]:8080", }, - expectIP: "2001:db8::113:199", + expectIP: "fe64:db8::113:199", }, { name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", @@ -569,12 +575,12 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"203.0.113.199"}, - HeaderXForwardedFor: []string{"203.0.113.198, 203.0.113.197"}, // <-- should not affect anything + HeaderXRealIP: []string{"8.8.8.8"}, + HeaderXForwardedFor: []string{"1.1.1.1 ,8.8.8.8"}, // <-- should not affect anything }, RemoteAddr: "203.0.113.1:8080", }, - expectIP: "203.0.113.199", + expectIP: "8.8.8.8", }, { name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", @@ -583,12 +589,12 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"[2001:db8::113:199]"}, - HeaderXForwardedFor: []string{"[2001:db8::113:198], [2001:db8::113:197]"}, // <-- should not affect anything + HeaderXRealIP: []string{"[fe64:db8::113:199]"}, + HeaderXForwardedFor: []string{"[feab:cde9::113:198], [fe64:db8::113:199]"}, // <-- should not affect anything }, RemoteAddr: "[2001:db8::113:1]:8080", }, - expectIP: "2001:db8::113:199", + expectIP: "fe64:db8::113:199", }, } @@ -605,10 +611,10 @@ func TestExtractIPFromXFFHeader(t *testing.T) { _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64") var testCases = []struct { - whenRequest http.Request name string - expectIP string givenTrustOptions []TrustOption + whenRequest http.Request + expectIP string }{ { name: "request has no headers, extracts IP from request remote addr", @@ -714,75 +720,3 @@ func TestExtractIPFromXFFHeader(t *testing.T) { }) } } - -func TestLegacyIPExtractor(t *testing.T) { - var testCases = []struct { - name string - whenReq *http.Request - expect string - expectedError string - }{ - { - name: "extract first ip from X-Forwarded-For", - whenReq: &http.Request{Header: http.Header{"X-Forwarded-For": []string{"203.0.113.10, 198.51.100.7"}}}, - expect: "203.0.113.10", - }, - { - name: "extract single ip from X-Forwarded-For", - whenReq: &http.Request{Header: http.Header{"X-Forwarded-For": []string{"203.0.113.10"}}}, - expect: "203.0.113.10", - }, - { - name: "trim brackets from ipv6 in X-Forwarded-For when multiple values", - whenReq: &http.Request{Header: http.Header{"X-Forwarded-For": []string{"[2001:db8::1], 198.51.100.7"}}}, - expect: "2001:db8::1", - }, - { - name: "prefer X-Forwarded-For over X-Real-Ip", - whenReq: &http.Request{ - Header: http.Header{ - "X-Forwarded-For": []string{"203.0.113.10"}, - "X-Real-Ip": []string{"198.51.100.7"}, - }, - }, - expect: "203.0.113.10", - }, - { - name: "extract from X-Real-Ip", - whenReq: &http.Request{Header: http.Header{"X-Real-Ip": []string{"[2001:db8::1]"}}}, - expect: "2001:db8::1", - }, - { - name: "extract plain ipv4 from X-Real-Ip", - whenReq: &http.Request{Header: http.Header{"X-Real-Ip": []string{"203.0.113.10"}}}, - expect: "203.0.113.10", - }, - { - name: "fallback to RemoteAddr host", - whenReq: &http.Request{RemoteAddr: "203.0.113.10:12345"}, - expect: "203.0.113.10", - }, - { - name: "fallback to RemoteAddr ipv6 host", - whenReq: &http.Request{RemoteAddr: "[2001:db8::1]:12345"}, - expect: "2001:db8::1", - }, - { - name: "returns empty string when RemoteAddr is invalid and no headers exist", - whenReq: &http.Request{RemoteAddr: "not-a-host-port"}, - expect: "", - }, - { - name: "trim brackets from single ipv6 in X-Forwarded-For", - whenReq: &http.Request{Header: http.Header{"X-Forwarded-For": []string{"[2001:db8::1]"}}}, - expect: "2001:db8::1", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ip := LegacyIPExtractor()(tc.whenReq) - assert.Equal(t, tc.expect, ip) - }) - } -} diff --git a/json.go b/json.go index a969ccb8c..589cda55f 100644 --- a/json.go +++ b/json.go @@ -5,6 +5,8 @@ package echo import ( "encoding/json" + "fmt" + "net/http" ) // DefaultJSONSerializer implements JSON encoding using encoding/json. @@ -12,18 +14,21 @@ type DefaultJSONSerializer struct{} // Serialize converts an interface into a json and writes it to the response. // You can optionally use the indent parameter to produce pretty JSONs. -func (d DefaultJSONSerializer) Serialize(c *Context, target any, indent string) error { +func (d DefaultJSONSerializer) Serialize(c Context, i any, indent string) error { enc := json.NewEncoder(c.Response()) if indent != "" { enc.SetIndent("", indent) } - return enc.Encode(target) + return enc.Encode(i) } // Deserialize reads a JSON from a request body and converts it into an interface. -func (d DefaultJSONSerializer) Deserialize(c *Context, target any) error { - if err := json.NewDecoder(c.Request().Body).Decode(target); err != nil { - return ErrBadRequest.Wrap(err) +func (d DefaultJSONSerializer) Deserialize(c Context, i any) error { + err := json.NewDecoder(c.Request().Body).Decode(i) + if ute, ok := err.(*json.UnmarshalTypeError); ok { + return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err) + } else if se, ok := err.(*json.SyntaxError); ok { + return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err) } - return nil + return err } diff --git a/json_test.go b/json_test.go index 1804b3e82..0b15ed1a1 100644 --- a/json_test.go +++ b/json_test.go @@ -17,7 +17,7 @@ func TestDefaultJSONCodec_Encode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) // Echo assert.Equal(t, e, c.Echo()) @@ -34,15 +34,15 @@ func TestDefaultJSONCodec_Encode(t *testing.T) { enc := new(DefaultJSONSerializer) - err := enc.Serialize(c, user{ID: 1, Name: "Jon Snow"}, "") + err := enc.Serialize(c, user{1, "Jon Snow"}, "") if assert.NoError(t, err) { assert.Equal(t, userJSON+"\n", rec.Body.String()) } req = httptest.NewRequest(http.MethodPost, "/", nil) rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - err = enc.Serialize(c, user{ID: 1, Name: "Jon Snow"}, " ") + c = e.NewContext(req, rec).(*context) + err = enc.Serialize(c, user{1, "Jon Snow"}, " ") if assert.NoError(t, err) { assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } @@ -54,7 +54,7 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(*context) // Echo assert.Equal(t, e, c.Echo()) @@ -80,10 +80,10 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { var userUnmarshalSyntaxError = user{} req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(invalidContent)) rec = httptest.NewRecorder() - c = e.NewContext(req, rec) + c = e.NewContext(req, rec).(*context) err = enc.Deserialize(c, &userUnmarshalSyntaxError) assert.IsType(t, &HTTPError{}, err) - assert.EqualError(t, err, "code=400, message=Bad Request, err=invalid character 'i' looking for beginning of value") + assert.EqualError(t, err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value") var userUnmarshalTypeError = struct { ID string `json:"id"` @@ -92,9 +92,9 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec = httptest.NewRecorder() - c = e.NewContext(req, rec) + c = e.NewContext(req, rec).(*context) err = enc.Deserialize(c, &userUnmarshalTypeError) assert.IsType(t, &HTTPError{}, err) - assert.EqualError(t, err, "code=400, message=Bad Request, err=json: cannot unmarshal number into Go struct field .id of type string") + assert.EqualError(t, err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string") } diff --git a/log.go b/log.go new file mode 100644 index 000000000..42164c325 --- /dev/null +++ b/log.go @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "github.com/labstack/gommon/log" + "io" +) + +// Logger defines the logging interface. +type Logger interface { + Output() io.Writer + SetOutput(w io.Writer) + Prefix() string + SetPrefix(p string) + Level() log.Lvl + SetLevel(v log.Lvl) + SetHeader(h string) + Print(i ...any) + Printf(format string, args ...any) + Printj(j log.JSON) + Debug(i ...any) + Debugf(format string, args ...any) + Debugj(j log.JSON) + Info(i ...any) + Infof(format string, args ...any) + Infoj(j log.JSON) + Warn(i ...any) + Warnf(format string, args ...any) + Warnj(j log.JSON) + Error(i ...any) + Errorf(format string, args ...any) + Errorj(j log.JSON) + Fatal(i ...any) + Fatalj(j log.JSON) + Fatalf(format string, args ...any) + Panic(i ...any) + Panicj(j log.JSON) + Panicf(format string, args ...any) +} diff --git a/middleware/DEVELOPMENT.md b/middleware/DEVELOPMENT.md deleted file mode 100644 index 77cb226dd..000000000 --- a/middleware/DEVELOPMENT.md +++ /dev/null @@ -1,11 +0,0 @@ -# Development Guidelines for middlewares - -## Best practices: - -* Do not use `panic` in middleware creator functions in case of invalid configuration. -* In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead - because previous middlewares up in call chain could have logic for dealing with returned errors. -* Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they - want to create middleware with panics or with returning errors on configuration errors. -* When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same. - diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 8a9500a93..4a46098e3 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -4,153 +4,105 @@ package middleware import ( - "bytes" - "cmp" "encoding/base64" - "errors" + "net/http" "strconv" "strings" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" ) -// BasicAuthConfig defines the config for BasicAuthWithConfig middleware. -// -// SECURITY: The Validator function is responsible for securely comparing credentials. -// See BasicAuthValidator documentation for guidance on preventing timing attacks. +// BasicAuthConfig defines the config for BasicAuth middleware. type BasicAuthConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper - // Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers - // this function would be called once for each header until first valid result is returned + // Validator is a function to validate BasicAuth credentials. // Required. Validator BasicAuthValidator - // Realm is a string to define realm attribute of BasicAuthWithConfig. + // Realm is a string to define realm attribute of BasicAuth. // Default value "Restricted". Realm string - - // AllowedCheckLimit set how many headers are allowed to be checked. This is useful - // environments like corporate test environments with application proxies restricting - // access to environment with their own auth scheme. - // Defaults to 1. - AllowedCheckLimit uint } -// BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials. -// -// SECURITY WARNING: To prevent timing attacks that could allow attackers to enumerate -// valid usernames or passwords, validator implementations MUST use constant-time -// comparison for credential checking. Use crypto/subtle.ConstantTimeCompare instead -// of standard string equality (==) or switch statements. -// -// Example of SECURE implementation: -// -// import "crypto/subtle" -// -// validator := func(c *echo.Context, username, password string) (bool, error) { -// // Fetch expected credentials from database/config -// expectedUser := "admin" -// expectedPass := "secretpassword" -// -// // Use constant-time comparison to prevent timing attacks -// userMatch := subtle.ConstantTimeCompare([]byte(username), []byte(expectedUser)) == 1 -// passMatch := subtle.ConstantTimeCompare([]byte(password), []byte(expectedPass)) == 1 -// -// if userMatch && passMatch { -// return true, nil -// } -// return false, nil -// } -// -// Example of INSECURE implementation (DO NOT USE): -// -// // VULNERABLE TO TIMING ATTACKS - DO NOT USE -// validator := func(c *echo.Context, username, password string) (bool, error) { -// if username == "admin" && password == "secret" { // Timing leak! -// return true, nil -// } -// return false, nil -// } -type BasicAuthValidator func(c *echo.Context, user string, password string) (bool, error) +// BasicAuthValidator defines a function to validate BasicAuth credentials. +// The function should return a boolean indicating whether the credentials are valid, +// and an error if any error occurs during the validation process. +type BasicAuthValidator func(string, string, echo.Context) (bool, error) const ( basic = "basic" defaultRealm = "Restricted" ) +// DefaultBasicAuthConfig is the default BasicAuth middleware config. +var DefaultBasicAuthConfig = BasicAuthConfig{ + Skipper: DefaultSkipper, + Realm: defaultRealm, +} + // BasicAuth returns an BasicAuth middleware. // // For valid credentials it calls the next handler. // For missing or invalid credentials, it sends "401 - Unauthorized" response. func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { - return BasicAuthWithConfig(BasicAuthConfig{Validator: fn}) + c := DefaultBasicAuthConfig + c.Validator = fn + return BasicAuthWithConfig(c) } -// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config. +// BasicAuthWithConfig returns an BasicAuth middleware with config. +// See `BasicAuth()`. func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { - return toMiddlewareOrPanic(config) -} - -// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration -func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { + // Defaults if config.Validator == nil { - return nil, errors.New("echo basic-auth middleware requires a validator function") + panic("echo: basic-auth middleware requires a validator function") } if config.Skipper == nil { - config.Skipper = DefaultSkipper + config.Skipper = DefaultBasicAuthConfig.Skipper } - realm := defaultRealm - if config.Realm != "" { - realm = config.Realm + if config.Realm == "" { + config.Realm = defaultRealm } - realm = strconv.Quote(realm) - limit := cmp.Or(config.AllowedCheckLimit, 1) + + // Pre-compute the quoted realm for WWW-Authenticate header (RFC 7617) + quotedRealm := strconv.Quote(config.Realm) return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) error { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } - var lastError error + auth := c.Request().Header.Get(echo.HeaderAuthorization) l := len(basic) - i := uint(0) - for _, auth := range c.Request().Header[echo.HeaderAuthorization] { - if i >= limit { - break - } - if len(auth) <= l+1 || !strings.EqualFold(auth[:l], basic) { - continue - } - i++ + if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { // Invalid base64 shouldn't be treated as error // instead should be treated as invalid client input - b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:]) - if errDecode != nil { - lastError = echo.ErrBadRequest.Wrap(errDecode) - continue + b, err := base64.StdEncoding.DecodeString(auth[l+1:]) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest).SetInternal(err) } - before, after, ok := bytes.Cut(b, []byte{':'}) + + cred := string(b) + user, pass, ok := strings.Cut(cred, ":") if ok { - valid, errValidate := config.Validator(c, string(before), string(after)) - if errValidate != nil { - lastError = errValidate + // Verify credentials + valid, err := config.Validator(user, pass, c) + if err != nil { + return err } else if valid { return next(c) } } } - if lastError != nil { - return lastError - } - // Need to return `401` for browsers to pop-up login box. - c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm) + // Realm is case-insensitive, so we can use "basic" directly. See RFC 7617. + c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+quotedRealm) return echo.ErrUnauthorized } - }, nil + } } diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 42386354f..2d3192615 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -4,7 +4,6 @@ package middleware import ( - "crypto/subtle" "encoding/base64" "errors" "net/http" @@ -12,177 +11,116 @@ import ( "strings" "testing" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) func TestBasicAuth(t *testing.T) { - validatorFunc := func(c *echo.Context, u, p string) (bool, error) { - // Use constant-time comparison to prevent timing attacks - userMatch := subtle.ConstantTimeCompare([]byte(u), []byte("joe")) == 1 - passMatch := subtle.ConstantTimeCompare([]byte(p), []byte("secret")) == 1 + e := echo.New() - if userMatch && passMatch { + mockValidator := func(u, p string, c echo.Context) (bool, error) { + if u == "joe" && p == "secret" { return true, nil } - - // Special case for testing error handling - if u == "error" { - return false, errors.New(p) - } - return false, nil } - defaultConfig := BasicAuthConfig{Validator: validatorFunc} - var testCases = []struct { - name string - givenConfig BasicAuthConfig - whenAuth []string - expectHeader string - expectErr string + tests := []struct { + name string + authHeader string + expectedCode int + expectedAuth string + skipperResult bool + expectedErr bool + expectedErrMsg string }{ { - name: "ok", - givenConfig: defaultConfig, - whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, - }, - { - name: "ok, multiple", - givenConfig: BasicAuthConfig{Validator: validatorFunc, AllowedCheckLimit: 2}, - whenAuth: []string{ - "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), - basic + " NOT_BASE64", - basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), - }, + name: "Valid credentials", + authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + expectedCode: http.StatusOK, }, { - name: "nok, multiple, valid out of limit", - givenConfig: BasicAuthConfig{Validator: validatorFunc, AllowedCheckLimit: 1}, - whenAuth: []string{ - "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), - basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid_password")), - // limit only check first and should not check auth below - basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), - }, - expectHeader: basic + ` realm="Restricted"`, - expectErr: "Unauthorized", + name: "Case-insensitive header scheme", + authHeader: strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + expectedCode: http.StatusOK, }, { - name: "nok, invalid Authorization header", - givenConfig: defaultConfig, - whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, - expectHeader: basic + ` realm="Restricted"`, - expectErr: "Unauthorized", + name: "Invalid credentials", + authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")), + expectedCode: http.StatusUnauthorized, + expectedAuth: basic + ` realm="someRealm"`, + expectedErr: true, + expectedErrMsg: "Unauthorized", }, { - name: "nok, not base64 Authorization header", - givenConfig: defaultConfig, - whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"}, - expectErr: "code=400, message=Bad Request, err=illegal base64 data at input byte 3", + name: "Invalid base64 string", + authHeader: basic + " invalidString", + expectedCode: http.StatusBadRequest, + expectedErr: true, + expectedErrMsg: "Bad Request", }, { - name: "nok, missing Authorization header", - givenConfig: defaultConfig, - expectHeader: basic + ` realm="Restricted"`, - expectErr: "Unauthorized", + name: "Missing Authorization header", + expectedCode: http.StatusUnauthorized, + expectedErr: true, + expectedErrMsg: "Unauthorized", }, { - name: "ok, realm", - givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, - whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + name: "Invalid Authorization header", + authHeader: base64.StdEncoding.EncodeToString([]byte("invalid")), + expectedCode: http.StatusUnauthorized, + expectedErr: true, + expectedErrMsg: "Unauthorized", }, { - name: "ok, realm, case-insensitive header scheme", - givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, - whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, - }, - { - name: "nok, realm, invalid Authorization header", - givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, - whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, - expectHeader: basic + ` realm="someRealm"`, - expectErr: "Unauthorized", - }, - { - name: "nok, validator func returns an error", - givenConfig: defaultConfig, - whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))}, - expectErr: "my_error", - }, - { - name: "ok, skipped", - givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c *echo.Context) bool { - return true - }}, - whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + name: "Skipped Request", + authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")), + expectedCode: http.StatusOK, + skipperResult: true, }, } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) res := httptest.NewRecorder() c := e.NewContext(req, res) - config := tc.givenConfig - - mw, err := config.ToMiddleware() - assert.NoError(t, err) + if tt.authHeader != "" { + req.Header.Set(echo.HeaderAuthorization, tt.authHeader) + } - h := mw(func(c *echo.Context) error { - return c.String(http.StatusTeapot, "test") + h := BasicAuthWithConfig(BasicAuthConfig{ + Validator: mockValidator, + Realm: "someRealm", + Skipper: func(c echo.Context) bool { + return tt.skipperResult + }, + })(func(c echo.Context) error { + return c.String(http.StatusOK, "test") }) - if len(tc.whenAuth) != 0 { - for _, a := range tc.whenAuth { - req.Header.Add(echo.HeaderAuthorization, a) - } - } - err = h(c) + err := h(c) - if tc.expectErr != "" { - assert.Equal(t, http.StatusOK, res.Code) - assert.EqualError(t, err, tc.expectErr) + if tt.expectedErr { + var he *echo.HTTPError + errors.As(err, &he) + assert.Equal(t, tt.expectedCode, he.Code) + if tt.expectedAuth != "" { + assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate)) + } } else { - assert.Equal(t, http.StatusTeapot, res.Code) assert.NoError(t, err) - } - if tc.expectHeader != "" { - assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate)) + assert.Equal(t, tt.expectedCode, res.Code) } }) } } -func TestBasicAuth_panic(t *testing.T) { - assert.Panics(t, func() { - mw := BasicAuth(nil) - assert.NotNil(t, mw) - }) - - mw := BasicAuth(func(c *echo.Context, user string, password string) (bool, error) { - return true, nil - }) - assert.NotNil(t, mw) -} - -func TestBasicAuthWithConfig_panic(t *testing.T) { - assert.Panics(t, func() { - mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil}) - assert.NotNil(t, mw) - }) - - mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c *echo.Context, user string, password string) (bool, error) { - return true, nil - }}) - assert.NotNil(t, mw) -} - func TestBasicAuthRealm(t *testing.T) { e := echo.New() - mockValidator := func(c *echo.Context, u, p string) (bool, error) { + mockValidator := func(u, p string, c echo.Context) (bool, error) { return false, nil // Always fail to trigger WWW-Authenticate header } @@ -227,13 +165,15 @@ func TestBasicAuthRealm(t *testing.T) { h := BasicAuthWithConfig(BasicAuthConfig{ Validator: mockValidator, Realm: tt.realm, - })(func(c *echo.Context) error { + })(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) err := h(c) - assert.Equal(t, echo.ErrUnauthorized, err) + var he *echo.HTTPError + errors.As(err, &he) + assert.Equal(t, http.StatusUnauthorized, he.Code) assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate)) }) } diff --git a/middleware/body_dump.go b/middleware/body_dump.go index 0443a67ab..add778d67 100644 --- a/middleware/body_dump.go +++ b/middleware/body_dump.go @@ -10,9 +10,8 @@ import ( "io" "net" "net/http" - "sync" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" ) // BodyDumpConfig defines the config for BodyDump middleware. @@ -20,127 +19,78 @@ type BodyDumpConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper - // Handler receives request, response payloads and handler error if there are any. + // Handler receives request and response payload. // Required. Handler BodyDumpHandler - - // MaxRequestBytes limits how much of the request body to dump. - // If the request body exceeds this limit, only the first MaxRequestBytes - // are dumped. The handler callback receives truncated data. - // Default: 5 * MB (5,242,880 bytes) - // Set to -1 to disable limits (not recommended in production). - MaxRequestBytes int64 - - // MaxResponseBytes limits how much of the response body to dump. - // If the response body exceeds this limit, only the first MaxResponseBytes - // are dumped. The handler callback receives truncated data. - // Default: 5 * MB (5,242,880 bytes) - // Set to -1 to disable limits (not recommended in production). - MaxResponseBytes int64 } // BodyDumpHandler receives the request and response payload. -type BodyDumpHandler func(c *echo.Context, reqBody []byte, resBody []byte, err error) +type BodyDumpHandler func(echo.Context, []byte, []byte) type bodyDumpResponseWriter struct { io.Writer http.ResponseWriter } +// DefaultBodyDumpConfig is the default BodyDump middleware config. +var DefaultBodyDumpConfig = BodyDumpConfig{ + Skipper: DefaultSkipper, +} + // BodyDump returns a BodyDump middleware. // // BodyDump middleware captures the request and response payload and calls the // registered handler. -// -// SECURITY: By default, this limits dumped bodies to 5MB to prevent memory exhaustion -// attacks. To customize limits, use BodyDumpWithConfig. To disable limits (not recommended -// in production), explicitly set MaxRequestBytes and MaxResponseBytes to -1. func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc { - return BodyDumpWithConfig(BodyDumpConfig{Handler: handler}) + c := DefaultBodyDumpConfig + c.Handler = handler + return BodyDumpWithConfig(c) } // BodyDumpWithConfig returns a BodyDump middleware with config. // See: `BodyDump()`. -// -// SECURITY: If MaxRequestBytes and MaxResponseBytes are not set (zero values), they default -// to 5MB each to prevent DoS attacks via large payloads. Set them explicitly to -1 to disable -// limits if needed for your use case. func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc { - return toMiddlewareOrPanic(config) -} - -// ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration -func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) { + // Defaults if config.Handler == nil { - return nil, errors.New("echo body-dump middleware requires a handler function") + panic("echo: body-dump middleware requires a handler function") } if config.Skipper == nil { - config.Skipper = DefaultSkipper - } - if config.MaxRequestBytes == 0 { - config.MaxRequestBytes = 5 * MB - } - if config.MaxResponseBytes == 0 { - config.MaxResponseBytes = 5 * MB + config.Skipper = DefaultBodyDumpConfig.Skipper } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) error { + return func(c echo.Context) (err error) { if config.Skipper(c) { return next(c) } - reqBuf := bodyDumpBufferPool.Get().(*bytes.Buffer) - reqBuf.Reset() - defer bodyDumpBufferPool.Put(reqBuf) - - var bodyReader io.Reader = c.Request().Body - if config.MaxRequestBytes > 0 { - bodyReader = io.LimitReader(c.Request().Body, config.MaxRequestBytes) - } - _, readErr := io.Copy(reqBuf, bodyReader) - if readErr != nil && readErr != io.EOF { - return readErr - } - if config.MaxRequestBytes > 0 { - // Drain any remaining body data to prevent connection issues - _, _ = io.Copy(io.Discard, c.Request().Body) - _ = c.Request().Body.Close() - } - - reqBody := make([]byte, reqBuf.Len()) - copy(reqBody, reqBuf.Bytes()) - c.Request().Body = io.NopCloser(bytes.NewReader(reqBody)) - - // response part - resBuf := bodyDumpBufferPool.Get().(*bytes.Buffer) - resBuf.Reset() - defer bodyDumpBufferPool.Put(resBuf) - - var respWriter io.Writer - if config.MaxResponseBytes > 0 { - respWriter = &limitedWriter{ - response: c.Response(), - dumpBuf: resBuf, - limit: config.MaxResponseBytes, + // Request + reqBody := []byte{} + if c.Request().Body != nil { + var readErr error + reqBody, readErr = io.ReadAll(c.Request().Body) + if readErr != nil { + return readErr } - } else { - respWriter = io.MultiWriter(c.Response(), resBuf) } - writer := &bodyDumpResponseWriter{ - Writer: respWriter, - ResponseWriter: c.Response(), - } - c.SetResponse(writer) + c.Request().Body = io.NopCloser(bytes.NewBuffer(reqBody)) // Reset + + // Response + resBody := new(bytes.Buffer) + mw := io.MultiWriter(c.Response().Writer, resBody) + writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer} + c.Response().Writer = writer - err := next(c) + if err = next(c); err != nil { + c.Error(err) + } // Callback - config.Handler(c, reqBody, resBuf.Bytes(), err) + config.Handler(c, reqBody, resBody.Bytes()) - return err + return } - }, nil + } } func (w *bodyDumpResponseWriter) WriteHeader(code int) { @@ -165,34 +115,3 @@ func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter { return w.ResponseWriter } - -var bodyDumpBufferPool = sync.Pool{ - New: func() any { - return new(bytes.Buffer) - }, -} - -type limitedWriter struct { - response http.ResponseWriter - dumpBuf *bytes.Buffer - dumped int64 - limit int64 -} - -func (w *limitedWriter) Write(b []byte) (n int, err error) { - // Always write full data to actual response (don't truncate client response) - n, err = w.response.Write(b) - if err != nil { - return n, err - } - - // Write to dump buffer only up to limit - if w.dumped < w.limit { - remaining := w.limit - w.dumped - toDump := min(int64(n), remaining) - w.dumpBuf.Write(b[:toDump]) - w.dumped += toDump - } - - return n, nil -} diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go index e5f64541a..7a7dee3d9 100644 --- a/middleware/body_dump_test.go +++ b/middleware/body_dump_test.go @@ -11,7 +11,7 @@ import ( "strings" "testing" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) @@ -21,7 +21,7 @@ func TestBodyDump(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := func(c *echo.Context) error { + h := func(c echo.Context) error { body, err := io.ReadAll(c.Request().Body) if err != nil { return err @@ -31,11 +31,10 @@ func TestBodyDump(t *testing.T) { requestBody := "" responseBody := "" - mw, err := BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) { requestBody = string(reqBody) responseBody = string(resBody) - }}.ToMiddleware() - assert.NoError(t, err) + }) if assert.NoError(t, mw(h)(c)) { assert.Equal(t, requestBody, hw) @@ -44,76 +43,51 @@ func TestBodyDump(t *testing.T) { assert.Equal(t, hw, rec.Body.String()) } -} - -func TestBodyDump_skipper(t *testing.T) { - e := echo.New() - - isCalled := false - mw, err := BodyDumpConfig{ - Skipper: func(c *echo.Context) bool { - return true + // Must set default skipper + BodyDumpWithConfig(BodyDumpConfig{ + Skipper: nil, + Handler: func(c echo.Context, reqBody, resBody []byte) { + requestBody = string(reqBody) + responseBody = string(resBody) }, - Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { - isCalled = true - }, - }.ToMiddleware() - assert.NoError(t, err) - - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}")) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - h := func(c *echo.Context) error { - return errors.New("some error") - } - - err = mw(h)(c) - assert.EqualError(t, err, "some error") - assert.False(t, isCalled) + }) } -func TestBodyDump_fails(t *testing.T) { +func TestBodyDumpFails(t *testing.T) { e := echo.New() hw := "Hello, World!" req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := func(c *echo.Context) error { + h := func(c echo.Context) error { return errors.New("some error") } - mw, err := BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}}.ToMiddleware() - assert.NoError(t, err) + mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {}) - err = mw(h)(c) - assert.EqualError(t, err, "some error") - assert.Equal(t, http.StatusOK, rec.Code) - -} + if !assert.Error(t, mw(h)(c)) { + t.FailNow() + } -func TestBodyDumpWithConfig_panic(t *testing.T) { assert.Panics(t, func() { - mw := BodyDumpWithConfig(BodyDumpConfig{ + mw = BodyDumpWithConfig(BodyDumpConfig{ Skipper: nil, Handler: nil, }) - assert.NotNil(t, mw) }) assert.NotPanics(t, func() { - mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}}) - assert.NotNil(t, mw) - }) -} - -func TestBodyDump_panic(t *testing.T) { - assert.Panics(t, func() { - mw := BodyDump(nil) - assert.NotNil(t, mw) - }) + mw = BodyDumpWithConfig(BodyDumpConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + Handler: func(c echo.Context, reqBody, resBody []byte) { + }, + }) - assert.NotPanics(t, func() { - BodyDump(func(c *echo.Context, reqBody, resBody []byte, err error) {}) + if !assert.Error(t, mw(h)(c)) { + t.FailNow() + } }) } @@ -121,6 +95,7 @@ func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) { bdrw := bodyDumpResponseWriter{ ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush } + assert.PanicsWithError(t, "response writer flushing is not supported", func() { bdrw.Flush() }) @@ -131,6 +106,7 @@ func TestBodyDumpResponseWriter_CanFlush(t *testing.T) { bdrw := bodyDumpResponseWriter{ ResponseWriter: &trwu, } + bdrw.Flush() assert.Equal(t, 1, trwu.unwrapCalled) } @@ -140,6 +116,7 @@ func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) { bdrw := bodyDumpResponseWriter{ ResponseWriter: trwu, } + result := bdrw.Unwrap() assert.Equal(t, trwu, result) } @@ -149,6 +126,7 @@ func TestBodyDumpResponseWriter_CanHijack(t *testing.T) { bdrw := bodyDumpResponseWriter{ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping } + _, _, err := bdrw.Hijack() assert.EqualError(t, err, "can hijack") } @@ -158,6 +136,7 @@ func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) { bdrw := bodyDumpResponseWriter{ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping } + _, _, err := bdrw.Hijack() assert.EqualError(t, err, "feature not supported") } @@ -176,14 +155,14 @@ func TestBodyDump_ReadError(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := func(c *echo.Context) error { + h := func(c echo.Context) error { // This handler should not be reached if body read fails body, _ := io.ReadAll(c.Request().Body) return c.String(http.StatusOK, string(body)) } requestBodyReceived := "" - mw := BodyDump(func(c *echo.Context, reqBody, resBody []byte, err error) { + mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) { requestBodyReceived = string(reqBody) }) @@ -223,359 +202,3 @@ func (f *failingReadCloser) Read(p []byte) (n int, err error) { func (f *failingReadCloser) Close() error { return nil } - -func TestBodyDump_RequestWithinLimit(t *testing.T) { - e := echo.New() - requestData := "Hello, World!" - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData)) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h := func(c *echo.Context) error { - body, _ := io.ReadAll(c.Request().Body) - return c.String(http.StatusOK, string(body)) - } - - requestBodyDumped := "" - mw, err := BodyDumpConfig{ - Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { - requestBodyDumped = string(reqBody) - }, - MaxRequestBytes: 1 * MB, // 1MB limit - MaxResponseBytes: 1 * MB, - }.ToMiddleware() - assert.NoError(t, err) - - err = mw(h)(c) - assert.NoError(t, err) - assert.Equal(t, requestData, requestBodyDumped, "Small request should be fully dumped") - assert.Equal(t, requestData, rec.Body.String(), "Handler should receive full request") -} - -func TestBodyDump_RequestExceedsLimit(t *testing.T) { - e := echo.New() - // Create 2KB of data but limit to 1KB - largeData := strings.Repeat("A", 2*1024) - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeData)) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h := func(c *echo.Context) error { - body, _ := io.ReadAll(c.Request().Body) - return c.String(http.StatusOK, string(body)) - } - - requestBodyDumped := "" - limit := int64(1024) // 1KB limit - mw, err := BodyDumpConfig{ - Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { - requestBodyDumped = string(reqBody) - }, - MaxRequestBytes: limit, - MaxResponseBytes: 1 * MB, - }.ToMiddleware() - assert.NoError(t, err) - - err = mw(h)(c) - assert.NoError(t, err) - assert.Equal(t, int(limit), len(requestBodyDumped), "Dumped request should be truncated to limit") - assert.Equal(t, strings.Repeat("A", 1024), requestBodyDumped, "Dumped data should match first N bytes") - // Handler should receive truncated data (what was dumped) - assert.Equal(t, strings.Repeat("A", 1024), rec.Body.String()) -} - -func TestBodyDump_RequestAtExactLimit(t *testing.T) { - e := echo.New() - exactData := strings.Repeat("B", 1024) - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(exactData)) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h := func(c *echo.Context) error { - body, _ := io.ReadAll(c.Request().Body) - return c.String(http.StatusOK, string(body)) - } - - requestBodyDumped := "" - limit := int64(1024) - mw, err := BodyDumpConfig{ - Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { - requestBodyDumped = string(reqBody) - }, - MaxRequestBytes: limit, - MaxResponseBytes: 1 * MB, - }.ToMiddleware() - assert.NoError(t, err) - - err = mw(h)(c) - assert.NoError(t, err) - assert.Equal(t, int(limit), len(requestBodyDumped), "Exact limit should dump full data") - assert.Equal(t, exactData, requestBodyDumped) -} - -func TestBodyDump_ResponseWithinLimit(t *testing.T) { - e := echo.New() - responseData := "Response data" - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h := func(c *echo.Context) error { - return c.String(http.StatusOK, responseData) - } - - responseBodyDumped := "" - mw, err := BodyDumpConfig{ - Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { - responseBodyDumped = string(resBody) - }, - MaxRequestBytes: 1 * MB, - MaxResponseBytes: 1 * MB, - }.ToMiddleware() - assert.NoError(t, err) - - err = mw(h)(c) - assert.NoError(t, err) - assert.Equal(t, responseData, responseBodyDumped, "Small response should be fully dumped") - assert.Equal(t, responseData, rec.Body.String(), "Client should receive full response") -} - -func TestBodyDump_ResponseExceedsLimit(t *testing.T) { - e := echo.New() - largeResponse := strings.Repeat("X", 2*1024) // 2KB - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h := func(c *echo.Context) error { - return c.String(http.StatusOK, largeResponse) - } - - responseBodyDumped := "" - limit := int64(1024) // 1KB limit - mw, err := BodyDumpConfig{ - Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { - responseBodyDumped = string(resBody) - }, - MaxRequestBytes: 1 * MB, - MaxResponseBytes: limit, - }.ToMiddleware() - assert.NoError(t, err) - - err = mw(h)(c) - assert.NoError(t, err) - // Dump should be truncated - assert.Equal(t, int(limit), len(responseBodyDumped), "Dumped response should be truncated to limit") - assert.Equal(t, strings.Repeat("X", 1024), responseBodyDumped) - // Client should still receive full response! - assert.Equal(t, largeResponse, rec.Body.String(), "Client must receive full response despite dump truncation") -} - -func TestBodyDump_ClientGetsFullResponse(t *testing.T) { - e := echo.New() - // This is critical - even when dump is limited, client gets everything - largeResponse := strings.Repeat("DATA", 500) // 2KB - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h := func(c *echo.Context) error { - // Write response in chunks to test incremental writes - for range 4 { - c.Response().Write([]byte(strings.Repeat("DATA", 125))) - } - return nil - } - - responseBodyDumped := "" - mw, err := BodyDumpConfig{ - Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { - responseBodyDumped = string(resBody) - }, - MaxRequestBytes: 1 * MB, - MaxResponseBytes: 512, // Very small limit - }.ToMiddleware() - assert.NoError(t, err) - - err = mw(h)(c) - assert.NoError(t, err) - assert.Equal(t, 512, len(responseBodyDumped), "Dump should be limited") - assert.Equal(t, largeResponse, rec.Body.String(), "Client must get full response") -} - -func TestBodyDump_BothLimitsSimultaneous(t *testing.T) { - e := echo.New() - largeRequest := strings.Repeat("Q", 2*1024) - largeResponse := strings.Repeat("R", 2*1024) - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeRequest)) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h := func(c *echo.Context) error { - io.ReadAll(c.Request().Body) // Consume request - return c.String(http.StatusOK, largeResponse) - } - - requestBodyDumped := "" - responseBodyDumped := "" - limit := int64(1024) - mw, err := BodyDumpConfig{ - Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { - requestBodyDumped = string(reqBody) - responseBodyDumped = string(resBody) - }, - MaxRequestBytes: limit, - MaxResponseBytes: limit, - }.ToMiddleware() - assert.NoError(t, err) - - err = mw(h)(c) - assert.NoError(t, err) - assert.Equal(t, int(limit), len(requestBodyDumped), "Request dump should be limited") - assert.Equal(t, int(limit), len(responseBodyDumped), "Response dump should be limited") -} - -func TestBodyDump_DefaultConfig(t *testing.T) { - e := echo.New() - smallData := "test" - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(smallData)) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h := func(c *echo.Context) error { - body, _ := io.ReadAll(c.Request().Body) - return c.String(http.StatusOK, string(body)) - } - - requestBodyDumped := "" - // Use default config which should have 1MB limits - config := BodyDumpConfig{} - config.Handler = func(c *echo.Context, reqBody, resBody []byte, err error) { - requestBodyDumped = string(reqBody) - } - mw, err := config.ToMiddleware() - assert.NoError(t, err) - - err = mw(h)(c) - assert.NoError(t, err) - assert.Equal(t, smallData, requestBodyDumped) -} - -func TestBodyDump_LargeRequestDosPrevention(t *testing.T) { - e := echo.New() - // Simulate a very large request (10MB) that could cause OOM - largeSize := 10 * 1024 * 1024 // 10MB - largeData := strings.Repeat("M", largeSize) - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeData)) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h := func(c *echo.Context) error { - body, _ := io.ReadAll(c.Request().Body) - return c.String(http.StatusOK, string(body)) - } - - requestBodyDumped := "" - limit := int64(1 * MB) // Only dump 1MB max - mw, err := BodyDumpConfig{ - Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { - requestBodyDumped = string(reqBody) - }, - MaxRequestBytes: limit, - MaxResponseBytes: 1 * MB, - }.ToMiddleware() - assert.NoError(t, err) - - err = mw(h)(c) - assert.NoError(t, err) - // Verify only 1MB was dumped, not 10MB - assert.Equal(t, int(limit), len(requestBodyDumped), "Should only dump up to limit, preventing DoS") - assert.Less(t, len(requestBodyDumped), largeSize, "Dump should be much smaller than full request") -} - -func TestBodyDump_LargeResponseDosPrevention(t *testing.T) { - e := echo.New() - // Simulate a very large response (10MB) - largeSize := 10 * 1024 * 1024 // 10MB - largeResponse := strings.Repeat("R", largeSize) - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h := func(c *echo.Context) error { - return c.String(http.StatusOK, largeResponse) - } - - responseBodyDumped := "" - limit := int64(1 * MB) // Only dump 1MB max - mw, err := BodyDumpConfig{ - Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { - responseBodyDumped = string(resBody) - }, - MaxRequestBytes: 1 * MB, - MaxResponseBytes: limit, - }.ToMiddleware() - assert.NoError(t, err) - - err = mw(h)(c) - assert.NoError(t, err) - // Verify only 1MB was dumped, not 10MB - assert.Equal(t, int(limit), len(responseBodyDumped), "Should only dump up to limit, preventing DoS") - assert.Less(t, len(responseBodyDumped), largeSize, "Dump should be much smaller than full response") - // Client still gets full response - assert.Equal(t, largeSize, rec.Body.Len(), "Client must receive full response") -} - -func BenchmarkBodyDump_WithLimit(b *testing.B) { - e := echo.New() - requestData := strings.Repeat("data", 256) // 1KB - responseData := strings.Repeat("resp", 256) // 1KB - - h := func(c *echo.Context) error { - io.ReadAll(c.Request().Body) - return c.String(http.StatusOK, responseData) - } - - mw, _ := BodyDumpConfig{ - Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { - // Simulate logging - _ = len(reqBody) + len(resBody) - }, - MaxRequestBytes: 1 * MB, - MaxResponseBytes: 1 * MB, - }.ToMiddleware() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData)) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - mw(h)(c) - } -} - -func BenchmarkBodyDump_BufferPooling(b *testing.B) { - e := echo.New() - requestData := strings.Repeat("x", 1024) - responseData := "response" - - h := func(c *echo.Context) error { - io.ReadAll(c.Request().Body) - return c.String(http.StatusOK, responseData) - } - - mw, _ := BodyDumpConfig{ - Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}, - MaxRequestBytes: 1 * MB, - MaxResponseBytes: 1 * MB, - }.ToMiddleware() - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData)) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - mw(h)(c) - } -} diff --git a/middleware/body_limit.go b/middleware/body_limit.go index 4f1963e18..1450e9293 100644 --- a/middleware/body_limit.go +++ b/middleware/body_limit.go @@ -4,20 +4,24 @@ package middleware import ( + "fmt" "io" "net/http" "sync" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" + "github.com/labstack/gommon/bytes" ) -// BodyLimitConfig defines the config for BodyLimitWithConfig middleware. +// BodyLimitConfig defines the config for BodyLimit middleware. type BodyLimitConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper - // LimitBytes is maximum allowed size in bytes for a request body - LimitBytes int64 + // Maximum allowed size for a request body, it can be specified + // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P. + Limit string `yaml:"limit"` + limit int64 } type limitedReader struct { @@ -26,43 +30,50 @@ type limitedReader struct { read int64 } +// DefaultBodyLimitConfig is the default BodyLimit middleware config. +var DefaultBodyLimitConfig = BodyLimitConfig{ + Skipper: DefaultSkipper, +} + // BodyLimit returns a BodyLimit middleware. // -// BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it -// sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request +// BodyLimit middleware sets the maximum allowed size for a request body, if the +// size exceeds the configured limit, it sends "413 - Request Entity Too Large" +// response. The BodyLimit is determined based on both `Content-Length` request // header and actual content read, which makes it super secure. -func BodyLimit(limitBytes int64) echo.MiddlewareFunc { - return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes}) +// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M, +// G, T or P. +func BodyLimit(limit string) echo.MiddlewareFunc { + c := DefaultBodyLimitConfig + c.Limit = limit + return BodyLimitWithConfig(c) } -// BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for -// a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response. -// The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which -// makes it super secure. +// BodyLimitWithConfig returns a BodyLimit middleware with config. +// See: `BodyLimit()`. func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { - return toMiddlewareOrPanic(config) -} - -// ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration -func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) { + // Defaults if config.Skipper == nil { - config.Skipper = DefaultSkipper + config.Skipper = DefaultBodyLimitConfig.Skipper } - pool := sync.Pool{ - New: func() any { - return &limitedReader{BodyLimitConfig: config} - }, + + limit, err := bytes.Parse(config.Limit) + if err != nil { + panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit)) } + config.limit = limit + pool := limitedReaderPool(config) return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) error { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } + req := c.Request() // Based on content length - if req.ContentLength > config.LimitBytes { + if req.ContentLength > config.limit { return echo.ErrStatusRequestEntityTooLarge } @@ -77,13 +88,13 @@ func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) { return next(c) } - }, nil + } } func (r *limitedReader) Read(b []byte) (n int, err error) { n, err = r.reader.Read(b) r.read += int64(n) - if r.read > r.LimitBytes { + if r.read > r.limit { return n, echo.ErrStatusRequestEntityTooLarge } return @@ -97,3 +108,11 @@ func (r *limitedReader) Reset(reader io.ReadCloser) { r.reader = reader r.read = 0 } + +func limitedReaderPool(c BodyLimitConfig) sync.Pool { + return sync.Pool{ + New: func() any { + return &limitedReader{BodyLimitConfig: c} + }, + } +} diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index 68d904da8..d14c2b649 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -10,17 +10,17 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) -func TestBodyLimitConfig_ToMiddleware(t *testing.T) { +func TestBodyLimit(t *testing.T) { e := echo.New() hw := []byte("Hello, World!") req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := func(c *echo.Context) error { + h := func(c echo.Context) error { body, err := io.ReadAll(c.Request().Body) if err != nil { return err @@ -29,76 +29,41 @@ func TestBodyLimitConfig_ToMiddleware(t *testing.T) { } // Based on content length (within limit) - mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware() - assert.NoError(t, err) - - err = mw(h)(c) - if assert.NoError(t, err) { + if assert.NoError(t, BodyLimit("2M")(h)(c)) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, hw, rec.Body.Bytes()) } - // Based on content read (overlimit) - mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() - assert.NoError(t, err) - he := mw(h)(c).(echo.HTTPStatusCoder) - assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) + // Based on content length (overlimit) + he := BodyLimit("2B")(h)(c).(*echo.HTTPError) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) // Based on content read (within limit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) - - mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware() - assert.NoError(t, err) - err = mw(h)(c) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "Hello, World!", rec.Body.String()) + if assert.NoError(t, BodyLimit("2M")(h)(c)) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "Hello, World!", rec.Body.String()) + } // Based on content read (overlimit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) - mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() - assert.NoError(t, err) - he = mw(h)(c).(echo.HTTPStatusCoder) - assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) -} - -func TestBodyLimitAfterDecompressUsesDecodedSize(t *testing.T) { - e := echo.New() - body := "ok" - gz, err := gzipString(body) - assert.NoError(t, err) - assert.Greater(t, len(gz), len(body)) - - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) - req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err = Decompress()(BodyLimit(int64(len(body)))(func(c *echo.Context) error { - body, readErr := io.ReadAll(c.Request().Body) - if readErr != nil { - return readErr - } - return c.String(http.StatusOK, string(body)) - }))(c) - - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, body, rec.Body.String()) + he = BodyLimit("2B")(h)(c).(*echo.HTTPError) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) } func TestBodyLimitReader(t *testing.T) { hw := []byte("Hello, World!") config := BodyLimitConfig{ - Skipper: DefaultSkipper, - LimitBytes: 2, + Skipper: DefaultSkipper, + Limit: "2B", + limit: 2, } reader := &limitedReader{ BodyLimitConfig: config, @@ -107,8 +72,8 @@ func TestBodyLimitReader(t *testing.T) { // read all should return ErrStatusRequestEntityTooLarge _, err := io.ReadAll(reader) - he := err.(echo.HTTPStatusCoder) - assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) + he := err.(*echo.HTTPError) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) // reset reader and read two bytes must succeed bt := make([]byte, 2) @@ -118,74 +83,91 @@ func TestBodyLimitReader(t *testing.T) { assert.Equal(t, nil, err) } -func TestBodyLimit_skipper(t *testing.T) { +func TestBodyLimitWithConfig_Skipper(t *testing.T) { e := echo.New() - h := func(c *echo.Context) error { + h := func(c echo.Context) error { body, err := io.ReadAll(c.Request().Body) if err != nil { return err } return c.String(http.StatusOK, string(body)) } - mw, err := BodyLimitConfig{ - Skipper: func(c *echo.Context) bool { + mw := BodyLimitWithConfig(BodyLimitConfig{ + Skipper: func(c echo.Context) bool { return true }, - LimitBytes: 2, - }.ToMiddleware() - assert.NoError(t, err) + Limit: "2B", // if not skipped this limit would make request to fail limit check + }) hw := []byte("Hello, World!") req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - err = mw(h)(c) + err := mw(h)(c) assert.NoError(t, err) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, hw, rec.Body.Bytes()) } func TestBodyLimitWithConfig(t *testing.T) { - e := echo.New() - hw := []byte("Hello, World!") - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - h := func(c *echo.Context) error { - body, err := io.ReadAll(c.Request().Body) - if err != nil { - return err - } - return c.String(http.StatusOK, string(body)) + var testCases = []struct { + name string + givenLimit string + whenBody []byte + expectBody []byte + expectError string + }{ + { + name: "ok, body is less than limit", + givenLimit: "10B", + whenBody: []byte("123456789"), + expectBody: []byte("123456789"), + expectError: "", + }, + { + name: "nok, body is more than limit", + givenLimit: "9B", + whenBody: []byte("1234567890"), + expectBody: []byte(nil), + expectError: "code=413, message=Request Entity Too Large", + }, } - mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB}) - - err := mw(h)(c) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, hw, rec.Body.Bytes()) -} - -func TestBodyLimit(t *testing.T) { - e := echo.New() - hw := []byte("Hello, World!") - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - h := func(c *echo.Context) error { - body, err := io.ReadAll(c.Request().Body) - if err != nil { - return err - } - return c.String(http.StatusOK, string(body)) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + h := func(c echo.Context) error { + body, err := io.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + mw := BodyLimitWithConfig(BodyLimitConfig{ + Limit: tc.givenLimit, + }) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(tc.whenBody)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := mw(h)(c) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + // not testing status as middlewares return error instead of committing it and OK cases are anyway 200 + assert.Equal(t, tc.expectBody, rec.Body.Bytes()) + }) } +} - mw := BodyLimit(2 * MB) - - err := mw(h)(c) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, hw, rec.Body.Bytes()) +func TestBodyLimit_panicOnInvalidLimit(t *testing.T) { + assert.PanicsWithError( + t, + "echo: invalid body-limit=", + func() { BodyLimit("") }, + ) } diff --git a/middleware/compress.go b/middleware/compress.go index 7754d5db8..4a2497b49 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -7,18 +7,13 @@ import ( "bufio" "bytes" "compress/gzip" - "errors" "io" "net" "net/http" "strings" "sync" - "github.com/labstack/echo/v5" -) - -const ( - gzipScheme = "gzip" + "github.com/labstack/echo/v4" ) // GzipConfig defines the config for Gzip middleware. @@ -28,7 +23,7 @@ type GzipConfig struct { // Gzip compression level. // Optional. Default value -1. - Level int + Level int `yaml:"level"` // Length threshold before gzip compression is applied. // Optional. Default value 0. @@ -55,36 +50,42 @@ type gzipResponseWriter struct { code int } -// Gzip returns a middleware which compresses HTTP response using gzip compression scheme. -func Gzip() echo.MiddlewareFunc { - return GzipWithConfig(GzipConfig{}) +const ( + gzipScheme = "gzip" +) + +// DefaultGzipConfig is the default Gzip middleware config. +var DefaultGzipConfig = GzipConfig{ + Skipper: DefaultSkipper, + Level: -1, + MinLength: 0, } -// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme. -func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { - return toMiddlewareOrPanic(config) +// Gzip returns a middleware which compresses HTTP response using gzip compression +// scheme. +func Gzip() echo.MiddlewareFunc { + return GzipWithConfig(DefaultGzipConfig) } -// ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration -func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) { +// GzipWithConfig return Gzip middleware with config. +// See: `Gzip()`. +func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { + // Defaults if config.Skipper == nil { - config.Skipper = DefaultSkipper - } - if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression - return nil, errors.New("invalid gzip level") + config.Skipper = DefaultGzipConfig.Skipper } if config.Level == 0 { - config.Level = -1 + config.Level = DefaultGzipConfig.Level } if config.MinLength < 0 { - config.MinLength = 0 + config.MinLength = DefaultGzipConfig.MinLength } pool := gzipCompressPool(config) bpool := bufferPool() return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) error { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } @@ -97,18 +98,13 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if !ok { return echo.NewHTTPError(http.StatusInternalServerError, "invalid pool object") } - rw := res + rw := res.Writer w.Reset(rw) + buf := bpool.Get().(*bytes.Buffer) buf.Reset() - grw := &gzipResponseWriter{ - Writer: w, - ResponseWriter: rw, - minLength: config.MinLength, - buffer: buf, - } - c.SetResponse(grw) + grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf} defer func() { // There are different reasons for cases when we have not yet written response to the client and now need to do so. // a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now. @@ -123,25 +119,26 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // We have to reset response to it's pristine state when // nothing is written to body or error is returned. // See issue #424, #407. - c.SetResponse(rw) + res.Writer = rw w.Reset(io.Discard) } else if !grw.minLengthExceeded { // Write uncompressed response - c.SetResponse(rw) + res.Writer = rw if grw.wroteHeader { grw.ResponseWriter.WriteHeader(grw.code) } - _, _ = grw.buffer.WriteTo(rw) + grw.buffer.WriteTo(rw) w.Reset(io.Discard) } - _ = w.Close() + w.Close() bpool.Put(buf) pool.Put(w) }() + res.Writer = grw } return next(c) } - }, nil + } } func (w *gzipResponseWriter) WriteHeader(code int) { @@ -189,7 +186,7 @@ func (w *gzipResponseWriter) Flush() { w.ResponseWriter.WriteHeader(w.code) } - _, _ = w.Writer.Write(w.buffer.Bytes()) + w.Writer.Write(w.buffer.Bytes()) } if gw, ok := w.Writer.(*gzip.Writer); ok { @@ -198,14 +195,14 @@ func (w *gzipResponseWriter) Flush() { _ = http.NewResponseController(w.ResponseWriter).Flush() } -func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return http.NewResponseController(w.ResponseWriter).Hijack() -} - func (w *gzipResponseWriter) Unwrap() http.ResponseWriter { return w.ResponseWriter } +func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return http.NewResponseController(w.ResponseWriter).Hijack() +} + func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { if p, ok := w.ResponseWriter.(http.Pusher); ok { return p.Push(target, opts) diff --git a/middleware/compress_test.go b/middleware/compress_test.go index 084ffc9c7..c9083ee28 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -11,216 +11,91 @@ import ( "net/http/httptest" "os" "testing" - "time" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) -func TestGzip_NoAcceptEncodingHeader(t *testing.T) { - // Skip if no Accept-Encoding header - h := Gzip()(func(c *echo.Context) error { - c.Response().Write([]byte("test")) // For Content-Type sniffing - return nil - }) - +func TestGzip(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - err := h(c) - assert.NoError(t, err) - - assert.Equal(t, "test", rec.Body.String()) -} - -func TestMustGzipWithConfig_panics(t *testing.T) { - assert.Panics(t, func() { - GzipWithConfig(GzipConfig{Level: 999}) - }) -} - -func TestGzip_AcceptEncodingHeader(t *testing.T) { - h := Gzip()(func(c *echo.Context) error { + // Skip if no Accept-Encoding header + h := Gzip()(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) + h(c) - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := h(c) - assert.NoError(t, err) + assert.Equal(t, "test", rec.Body.String()) + // Gzip + req = httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h(c) assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) - r, err := gzip.NewReader(rec.Body) - assert.NoError(t, err) - buf := new(bytes.Buffer) - defer r.Close() - buf.ReadFrom(r) - assert.Equal(t, "test", buf.String()) -} + if assert.NoError(t, err) { + buf := new(bytes.Buffer) + defer r.Close() + buf.ReadFrom(r) + assert.Equal(t, "test", buf.String()) + } -func TestGzip_chunked(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) + chunkBuf := make([]byte, 5) + + // Gzip chunked + req = httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + rec = httptest.NewRecorder() - chunkChan := make(chan struct{}) - waitChan := make(chan struct{}) - h := Gzip()(func(c *echo.Context) error { - rc := http.NewResponseController(c.Response()) + c = e.NewContext(req, rec) + Gzip()(func(c echo.Context) error { c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Transfer-Encoding", "chunked") // Write and flush the first part of the data - c.Response().Write([]byte("first\n")) - rc.Flush() + c.Response().Write([]byte("test\n")) + c.Response().Flush() + + // Read the first part of the data + assert.True(t, rec.Flushed) + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + r.Reset(rec.Body) - chunkChan <- struct{}{} - <-waitChan + _, err = io.ReadFull(r, chunkBuf) + assert.NoError(t, err) + assert.Equal(t, "test\n", string(chunkBuf)) // Write and flush the second part of the data - c.Response().Write([]byte("second\n")) - rc.Flush() + c.Response().Write([]byte("test\n")) + c.Response().Flush() - chunkChan <- struct{}{} - <-waitChan + _, err = io.ReadFull(r, chunkBuf) + assert.NoError(t, err) + assert.Equal(t, "test\n", string(chunkBuf)) // Write the final part of the data and return - c.Response().Write([]byte("third")) - - chunkChan <- struct{}{} + c.Response().Write([]byte("test")) return nil - }) - - go func() { - err := h(c) - chunkChan <- struct{}{} - assert.NoError(t, err) - }() + })(c) - <-chunkChan // wait for first write - waitChan <- struct{}{} - - <-chunkChan // wait for second write - waitChan <- struct{}{} - - <-chunkChan // wait for final write in handler - <-chunkChan // wait for return from handler - time.Sleep(5 * time.Millisecond) // to have time for flushing - - assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) - - r, err := gzip.NewReader(rec.Body) - assert.NoError(t, err) buf := new(bytes.Buffer) + defer r.Close() buf.ReadFrom(r) - assert.Equal(t, "first\nsecond\nthird", buf.String()) -} - -func TestGzip_NoContent(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - h := Gzip()(func(c *echo.Context) error { - return c.NoContent(http.StatusNoContent) - }) - if assert.NoError(t, h(c)) { - assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) - assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) - assert.Equal(t, 0, len(rec.Body.Bytes())) - } -} - -func TestGzip_Empty(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - h := Gzip()(func(c *echo.Context) error { - return c.String(http.StatusOK, "") - }) - if assert.NoError(t, h(c)) { - assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) - assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get(echo.HeaderContentType)) - r, err := gzip.NewReader(rec.Body) - if assert.NoError(t, err) { - var buf bytes.Buffer - buf.ReadFrom(r) - assert.Equal(t, "", buf.String()) - } - } -} - -func TestGzip_ErrorReturned(t *testing.T) { - e := echo.New() - e.Use(Gzip()) - e.GET("/", func(c *echo.Context) error { - return echo.ErrNotFound - }) - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusNotFound, rec.Code) - assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) -} - -func TestGzipWithConfig_invalidLevel(t *testing.T) { - mw, err := GzipConfig{Level: 12}.ToMiddleware() - assert.EqualError(t, err, "invalid gzip level") - assert.Nil(t, mw) -} - -// Issue #806 -func TestGzipWithStatic(t *testing.T) { - e := echo.New() - e.Filesystem = os.DirFS("../") - - e.Use(Gzip()) - e.Static("/test", "_fixture/images") - req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.Equal(t, http.StatusOK, rec.Code) - // Data is written out in chunks when Content-Length == "", so only - // validate the content length if it's not set. - if cl := rec.Header().Get("Content-Length"); cl != "" { - assert.Equal(t, cl, rec.Body.Len()) - } - r, err := gzip.NewReader(rec.Body) - if assert.NoError(t, err) { - defer r.Close() - want, err := os.ReadFile("../_fixture/images/walle.png") - if assert.NoError(t, err) { - buf := new(bytes.Buffer) - buf.ReadFrom(r) - assert.Equal(t, want, buf.Bytes()) - } - } + assert.Equal(t, "test", buf.String()) } func TestGzipWithMinLength(t *testing.T) { e := echo.New() // Minimal response length e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) - e.GET("/", func(c *echo.Context) error { + e.GET("/", func(c echo.Context) error { c.Response().Write([]byte("foobarfoobar")) return nil }) @@ -243,7 +118,7 @@ func TestGzipWithMinLengthTooShort(t *testing.T) { e := echo.New() // Minimal response length e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) - e.GET("/", func(c *echo.Context) error { + e.GET("/", func(c echo.Context) error { c.Response().Write([]byte("test")) return nil }) @@ -259,7 +134,7 @@ func TestGzipWithResponseWithoutBody(t *testing.T) { e := echo.New() e.Use(Gzip()) - e.GET("/", func(c *echo.Context) error { + e.GET("/", func(c echo.Context) error { return c.Redirect(http.StatusMovedPermanently, "http://localhost") }) @@ -286,14 +161,13 @@ func TestGzipWithMinLengthChunked(t *testing.T) { var r *gzip.Reader = nil c := e.NewContext(req, rec) - next := func(c *echo.Context) error { - rc := http.NewResponseController(c.Response()) + GzipWithConfig(GzipConfig{MinLength: 10})(func(c echo.Context) error { c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Transfer-Encoding", "chunked") // Write and flush the first part of the data c.Response().Write([]byte("test\n")) - rc.Flush() + c.Response().Flush() // Read the first part of the data assert.True(t, rec.Flushed) @@ -309,7 +183,7 @@ func TestGzipWithMinLengthChunked(t *testing.T) { // Write and flush the second part of the data c.Response().Write([]byte("test\n")) - rc.Flush() + c.Response().Flush() _, err = io.ReadFull(r, chunkBuf) assert.NoError(t, err) @@ -318,10 +192,8 @@ func TestGzipWithMinLengthChunked(t *testing.T) { // Write the final part of the data and return c.Response().Write([]byte("test")) return nil - } - err := GzipWithConfig(GzipConfig{MinLength: 10})(next)(c) + })(c) - assert.NoError(t, err) assert.NotNil(t, r) buf := new(bytes.Buffer) @@ -338,7 +210,7 @@ func TestGzipWithMinLengthNoContent(t *testing.T) { req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c *echo.Context) error { + h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c echo.Context) error { return c.NoContent(http.StatusNoContent) }) if assert.NoError(t, h(c)) { @@ -348,11 +220,106 @@ func TestGzipWithMinLengthNoContent(t *testing.T) { } } +func TestGzipNoContent(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := Gzip()(func(c echo.Context) error { + return c.NoContent(http.StatusNoContent) + }) + if assert.NoError(t, h(c)) { + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) + assert.Equal(t, 0, len(rec.Body.Bytes())) + } +} + +func TestGzipEmpty(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := Gzip()(func(c echo.Context) error { + return c.String(http.StatusOK, "") + }) + if assert.NoError(t, h(c)) { + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get(echo.HeaderContentType)) + r, err := gzip.NewReader(rec.Body) + if assert.NoError(t, err) { + var buf bytes.Buffer + buf.ReadFrom(r) + assert.Equal(t, "", buf.String()) + } + } +} + +func TestGzipErrorReturned(t *testing.T) { + e := echo.New() + e.Use(Gzip()) + e.GET("/", func(c echo.Context) error { + return echo.ErrNotFound + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotFound, rec.Code) + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) +} + +func TestGzipErrorReturnedInvalidConfig(t *testing.T) { + e := echo.New() + // Invalid level + e.Use(GzipWithConfig(GzipConfig{Level: 12})) + e.GET("/", func(c echo.Context) error { + c.Response().Write([]byte("test")) + return nil + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Contains(t, rec.Body.String(), `{"message":"invalid pool object"}`) +} + +// Issue #806 +func TestGzipWithStatic(t *testing.T) { + e := echo.New() + e.Use(Gzip()) + e.Static("/test", "../_fixture/images") + req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + // Data is written out in chunks when Content-Length == "", so only + // validate the content length if it's not set. + if cl := rec.Header().Get("Content-Length"); cl != "" { + assert.Equal(t, cl, rec.Body.Len()) + } + r, err := gzip.NewReader(rec.Body) + if assert.NoError(t, err) { + defer r.Close() + want, err := os.ReadFile("../_fixture/images/walle.png") + if assert.NoError(t, err) { + buf := new(bytes.Buffer) + buf.ReadFrom(r) + assert.Equal(t, want, buf.Bytes()) + } + } +} + func TestGzipResponseWriter_CanUnwrap(t *testing.T) { trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()} bdrw := gzipResponseWriter{ ResponseWriter: trwu, } + result := bdrw.Unwrap() assert.Equal(t, trwu, result) } @@ -362,6 +329,7 @@ func TestGzipResponseWriter_CanHijack(t *testing.T) { bdrw := gzipResponseWriter{ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping } + _, _, err := bdrw.Hijack() assert.EqualError(t, err, "can hijack") } @@ -371,6 +339,7 @@ func TestGzipResponseWriter_CanNotHijack(t *testing.T) { bdrw := gzipResponseWriter{ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping } + _, _, err := bdrw.Hijack() assert.EqualError(t, err, "feature not supported") } @@ -381,7 +350,7 @@ func BenchmarkGzip(b *testing.B) { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - h := Gzip()(func(c *echo.Context) error { + h := Gzip()(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) diff --git a/middleware/context_timeout.go b/middleware/context_timeout.go index 68465199a..5d9ae9755 100644 --- a/middleware/context_timeout.go +++ b/middleware/context_timeout.go @@ -8,18 +8,51 @@ import ( "errors" "time" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" ) +// ContextTimeout Middleware +// +// ContextTimeout provides request timeout functionality using Go's context mechanism. +// It is the recommended replacement for the deprecated Timeout middleware. +// +// +// Basic Usage: +// +// e.Use(middleware.ContextTimeout(30 * time.Second)) +// +// With Configuration: +// +// e.Use(middleware.ContextTimeoutWithConfig(middleware.ContextTimeoutConfig{ +// Timeout: 30 * time.Second, +// Skipper: middleware.DefaultSkipper, +// })) +// +// Handler Example: +// +// e.GET("/task", func(c echo.Context) error { +// ctx := c.Request().Context() +// +// result, err := performTaskWithContext(ctx) +// if err != nil { +// if errors.Is(err, context.DeadlineExceeded) { +// return echo.NewHTTPError(http.StatusServiceUnavailable, "timeout") +// } +// return err +// } +// +// return c.JSON(http.StatusOK, result) +// }) + // ContextTimeoutConfig defines the config for ContextTimeout middleware. type ContextTimeoutConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper - // ErrorHandler is a function when error arises in middeware execution. - ErrorHandler func(c *echo.Context, err error) error + // ErrorHandler is a function when error arises in middleware execution. + ErrorHandler func(err error, c echo.Context) error - // Timeout configures a timeout for the middleware + // Timeout configures a timeout for the middleware, defaults to 0 for no timeout Timeout time.Duration } @@ -31,7 +64,11 @@ func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc { // ContextTimeoutWithConfig returns a Timeout middleware with config. func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc { - return toMiddlewareOrPanic(config) + mw, err := config.ToMiddleware() + if err != nil { + panic(err) + } + return mw } // ToMiddleware converts Config to middleware. @@ -43,16 +80,16 @@ func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) { config.Skipper = DefaultSkipper } if config.ErrorHandler == nil { - config.ErrorHandler = func(c *echo.Context, err error) error { + config.ErrorHandler = func(err error, c echo.Context) error { if err != nil && errors.Is(err, context.DeadlineExceeded) { - return echo.ErrServiceUnavailable.Wrap(err) + return echo.ErrServiceUnavailable.WithInternal(err) } return err } } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) error { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } @@ -63,7 +100,7 @@ func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) { c.SetRequest(c.Request().WithContext(timeoutContext)) if err := next(c); err != nil { - return config.ErrorHandler(c, err) + return config.ErrorHandler(err, c) } return nil } diff --git a/middleware/context_timeout_test.go b/middleware/context_timeout_test.go index c7ba76beb..e69bcd268 100644 --- a/middleware/context_timeout_test.go +++ b/middleware/context_timeout_test.go @@ -6,7 +6,6 @@ package middleware import ( "context" "errors" - "github.com/labstack/echo/v5" "net/http" "net/http/httptest" "net/url" @@ -14,13 +13,14 @@ import ( "testing" "time" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) func TestContextTimeoutSkipper(t *testing.T) { t.Parallel() m := ContextTimeoutWithConfig(ContextTimeoutConfig{ - Skipper: func(context *echo.Context) bool { + Skipper: func(context echo.Context) bool { return true }, Timeout: 10 * time.Millisecond, @@ -32,7 +32,7 @@ func TestContextTimeoutSkipper(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) - err := m(func(c *echo.Context) error { + err := m(func(c echo.Context) error { if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil { return err } @@ -65,7 +65,7 @@ func TestContextTimeoutErrorOutInHandler(t *testing.T) { c := e.NewContext(req, rec) rec.Code = 1 // we want to be sure that even 200 will not be sent - err := m(func(c *echo.Context) error { + err := m(func(c echo.Context) error { // this error must not be written to the client response. Middlewares upstream of timeout middleware must be able // to handle returned error and this can be done only then handler has not yet committed (written status code) // the response. @@ -91,7 +91,7 @@ func TestContextTimeoutSuccessfulRequest(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) - err := m(func(c *echo.Context) error { + err := m(func(c echo.Context) error { return c.JSON(http.StatusCreated, map[string]string{"data": "ok"}) })(c) @@ -115,7 +115,7 @@ func TestContextTimeoutTestRequestClone(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) - err := m(func(c *echo.Context) error { + err := m(func(c echo.Context) error { // Cookie test cookie, err := c.Request().Cookie("cookie") if assert.NoError(t, err) { @@ -150,24 +150,23 @@ func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) - err := m(func(c *echo.Context) error { + err := m(func(c echo.Context) error { if err := sleepWithContext(c.Request().Context(), time.Duration(80*time.Millisecond)); err != nil { return err } return c.String(http.StatusOK, "Hello, World!") })(c) + assert.IsType(t, &echo.HTTPError{}, err) assert.Error(t, err) - if assert.IsType(t, &echo.HTTPError{}, err) { - assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code) - assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message) - } + assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code) + assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message) } func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) { t.Parallel() - timeoutErrorHandler := func(c *echo.Context, err error) error { + timeoutErrorHandler := func(err error, c echo.Context) error { if err != nil { if errors.Is(err, context.DeadlineExceeded) { return &echo.HTTPError{ @@ -192,7 +191,7 @@ func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) - err := m(func(c *echo.Context) error { + err := m(func(c echo.Context) error { // NOTE: Very short periods are not reliable for tests due to Go routine scheduling and the unpredictable order // for 1) request and 2) time goroutine. For most OS this works as expected, but MacOS seems most flaky. diff --git a/middleware/cors.go b/middleware/cors.go index 96ed16985..a1f445321 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -4,13 +4,12 @@ package middleware import ( - "errors" - "fmt" "net/http" + "regexp" "strconv" "strings" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" ) // CORSConfig defines the config for CORS middleware. @@ -20,41 +19,29 @@ type CORSConfig struct { // AllowOrigins determines the value of the Access-Control-Allow-Origin // response header. This header defines a list of origins that may access the - // resource. - // - // Origin consist of following parts: `scheme + "://" + host + optional ":" + port` - // Wildcard can be used, but has to be set explicitly []string{"*"} - // Example: `https://example.com`, `http://example.com:8080`, `*` + // resource. The wildcard characters '*' and '?' are supported and are + // converted to regex fragments '.*' and '.' accordingly. // // Security: use extreme caution when handling the origin, and carefully // validate any logic. Remember that attackers may register hostile domain names. // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin // - // Mandatory. - AllowOrigins []string - - // UnsafeAllowOriginFunc is an optional custom function to validate the origin. It takes the - // origin as an argument and returns - // - string, allowed origin - // - bool, true if allowed or false otherwise. - // - error, if an error is returned, it is returned immediately by the handler. - // If this option is set, AllowOrigins is ignored. + // Optional. Default value []string{"*"}. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin + AllowOrigins []string `yaml:"allow_origins"` + + // AllowOriginFunc is a custom function to validate the origin. It takes the + // origin as an argument and returns true if allowed or false otherwise. If + // an error is returned, it is returned by the handler. If this option is + // set, AllowOrigins is ignored. // // Security: use extreme caution when handling the origin, and carefully - // validate any logic. Remember that attackers may register hostile (sub)domain names. + // validate any logic. Remember that attackers may register hostile domain names. // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // - // Sub-domain checks example: - // UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) { - // if strings.HasSuffix(origin, ".example.com") { - // return origin, true, nil - // } - // return "", false, nil - // }, - // // Optional. - UnsafeAllowOriginFunc func(c *echo.Context, origin string) (allowedOrigin string, allowed bool, err error) + AllowOriginFunc func(origin string) (bool, error) `yaml:"-"` // AllowMethods determines the value of the Access-Control-Allow-Methods // response header. This header specified the list of methods allowed when @@ -66,16 +53,16 @@ type CORSConfig struct { // from `Allow` header that echo.Router set into context. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods - AllowMethods []string + AllowMethods []string `yaml:"allow_methods"` // AllowHeaders determines the value of the Access-Control-Allow-Headers // response header. This header is used in response to a preflight request to // indicate which HTTP headers can be used when making the actual request. // - // Optional. Defaults to empty list. No domains allowed for CORS. + // Optional. Default value []string{}. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers - AllowHeaders []string + AllowHeaders []string `yaml:"allow_headers"` // AllowCredentials determines the value of the // Access-Control-Allow-Credentials response header. This header indicates @@ -92,7 +79,16 @@ type CORSConfig struct { // https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials - AllowCredentials bool + AllowCredentials bool `yaml:"allow_credentials"` + + // UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials + // flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header. + // + // This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties) + // attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject. + // + // Optional. Default value is false. + UnsafeWildcardOriginWithAllowCredentials bool `yaml:"unsafe_wildcard_origin_with_allow_credentials"` // ExposeHeaders determines the value of Access-Control-Expose-Headers, which // defines a list of headers that clients are allowed to access. @@ -100,7 +96,7 @@ type CORSConfig struct { // Optional. Default value []string{}, in which case the header is not set. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header - ExposeHeaders []string + ExposeHeaders []string `yaml:"expose_headers"` // MaxAge determines the value of the Access-Control-Max-Age response header. // This header indicates how long (in seconds) the results of a preflight @@ -110,16 +106,19 @@ type CORSConfig struct { // Optional. Default value 0 - meaning header is not sent. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age - MaxAge int + MaxAge int `yaml:"max_age"` +} + +// DefaultCORSConfig is the default CORS middleware config. +var DefaultCORSConfig = CORSConfig{ + Skipper: DefaultSkipper, + AllowOrigins: []string{"*"}, + AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, } // CORS returns a Cross-Origin Resource Sharing (CORS) middleware. // See also [MDN: Cross-Origin Resource Sharing (CORS)]. // -// Origin consist of following parts: `scheme + "://" + host + optional ":" + port` -// Wildcard `*` can be used, but has to be set explicitly. -// Example: `https://example.com`, `http://example.com:8080`, `*` -// // Security: Poorly configured CORS can compromise security because it allows // relaxation of the browser's Same-Origin policy. See [Exploiting CORS // misconfigurations for Bitcoins and bounties] and [Portswigger: Cross-origin @@ -128,29 +127,45 @@ type CORSConfig struct { // [MDN: Cross-Origin Resource Sharing (CORS)]: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS // [Exploiting CORS misconfigurations for Bitcoins and bounties]: https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // [Portswigger: Cross-origin resource sharing (CORS)]: https://portswigger.net/web-security/cors -func CORS(allowOrigins ...string) echo.MiddlewareFunc { - c := CORSConfig{ - AllowOrigins: allowOrigins, - } - return CORSWithConfig(c) +func CORS() echo.MiddlewareFunc { + return CORSWithConfig(DefaultCORSConfig) } -// CORSWithConfig returns a CORS middleware with config or panics on invalid configuration. +// CORSWithConfig returns a CORS middleware with config. // See: [CORS]. func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { - return toMiddlewareOrPanic(config) -} - -// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration -func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { - config.Skipper = DefaultSkipper + config.Skipper = DefaultCORSConfig.Skipper + } + if len(config.AllowOrigins) == 0 { + config.AllowOrigins = DefaultCORSConfig.AllowOrigins } hasCustomAllowMethods := true if len(config.AllowMethods) == 0 { hasCustomAllowMethods = false - config.AllowMethods = []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete} + config.AllowMethods = DefaultCORSConfig.AllowMethods + } + + allowOriginPatterns := make([]*regexp.Regexp, 0, len(config.AllowOrigins)) + for _, origin := range config.AllowOrigins { + if origin == "*" { + continue // "*" is handled differently and does not need regexp + } + pattern := regexp.QuoteMeta(origin) + pattern = strings.ReplaceAll(pattern, "\\*", ".*") + pattern = strings.ReplaceAll(pattern, "\\?", ".") + pattern = "^" + pattern + "$" + + re, err := regexp.Compile(pattern) + if err != nil { + // this is to preserve previous behaviour - invalid patterns were just ignored. + // If we would turn this to panic, users with invalid patterns + // would have applications crashing in production due unrecovered panic. + // TODO: this should be turned to error/panic in `v5` + continue + } + allowOriginPatterns = append(allowOriginPatterns, re) } allowMethods := strings.Join(config.AllowMethods, ",") @@ -162,29 +177,8 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) { maxAge = strconv.Itoa(config.MaxAge) } - allowOriginFunc := config.UnsafeAllowOriginFunc - if config.UnsafeAllowOriginFunc == nil { - if len(config.AllowOrigins) == 0 { - return nil, errors.New("at least one AllowOrigins is required or UnsafeAllowOriginFunc must be provided") - } - allowOriginFunc = config.defaultAllowOriginFunc - for _, origin := range config.AllowOrigins { - if origin == "*" { - if config.AllowCredentials { - return nil, fmt.Errorf("* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc") - } - allowOriginFunc = config.starAllowOriginFunc - break - } - if err := validateOrigin(origin, "allow origin"); err != nil { - return nil, err - } - } - config.AllowOrigins = append([]string(nil), config.AllowOrigins...) - } - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) error { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } @@ -192,6 +186,7 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) { req := c.Request() res := c.Response() origin := req.Header.Get(echo.HeaderOrigin) + allowOrigin := "" res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) @@ -216,51 +211,76 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain if origin == "" { - if preflight { // req.Method=OPTIONS - return c.NoContent(http.StatusNoContent) + if !preflight { + return next(c) } - return next(c) // let non-browser calls through + return c.NoContent(http.StatusNoContent) } - allowedOrigin, allowed, err := allowOriginFunc(c, origin) - if err != nil { - return err - } - if !allowed { - // Origin existed and was NOT allowed - if preflight { - // From: https://github.com/labstack/echo/issues/2767 - // If the request's origin isn't allowed by the CORS configuration, - // the middleware should simply omit the relevant CORS headers from the response - // and let the browser fail the CORS check (if any). - return c.NoContent(http.StatusNoContent) + if config.AllowOriginFunc != nil { + allowed, err := config.AllowOriginFunc(origin) + if err != nil { + return err + } + if allowed { + allowOrigin = origin + } + } else { + // Check allowed origins + for _, o := range config.AllowOrigins { + if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials { + allowOrigin = origin + break + } + if o == "*" || o == origin { + allowOrigin = o + break + } + if matchSubdomain(origin, o) { + allowOrigin = origin + break + } + } + + checkPatterns := false + if allowOrigin == "" { + // to avoid regex cost by invalid (long) domains (253 is domain name max limit) + if len(origin) <= (253+3+5) && strings.Contains(origin, "://") { + checkPatterns = true + } + } + if checkPatterns { + for _, re := range allowOriginPatterns { + if match := re.MatchString(origin); match { + allowOrigin = origin + break + } + } } - // From: https://github.com/labstack/echo/issues/2767 - // no CORS middleware should block non-preflight requests; - // such requests should be let through. One reason is that not all requests that - // carry an Origin header participate in the CORS protocol. - return next(c) } - // Origin existed and was allowed + // Origin not allowed + if allowOrigin == "" { + if !preflight { + return next(c) + } + return c.NoContent(http.StatusNoContent) + } - res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin) + res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) if config.AllowCredentials { res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") } - // Simple request will be let though + // Simple request if !preflight { if exposeHeaders != "" { res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders) } return next(c) } - // Below code is for Preflight (OPTIONS) request - // - // Preflight will end with c.NoContent(http.StatusNoContent) as we do not know if - // at the end of handler chain is actual OPTIONS route or 404/405 route which - // response code will confuse browsers + + // Preflight request res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) @@ -283,18 +303,5 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) { } return c.NoContent(http.StatusNoContent) } - }, nil -} - -func (config CORSConfig) starAllowOriginFunc(c *echo.Context, origin string) (string, bool, error) { - return "*", true, nil -} - -func (config CORSConfig) defaultAllowOriginFunc(c *echo.Context, origin string) (string, bool, error) { - for _, allowedOrigin := range config.AllowOrigins { - if strings.EqualFold(allowedOrigin, origin) { - return allowedOrigin, true, nil - } } - return "", false, nil } diff --git a/middleware/cors_test.go b/middleware/cors_test.go index 5de4ca063..5461e9362 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -4,87 +4,72 @@ package middleware import ( - "cmp" "errors" "net/http" "net/http/httptest" - "strings" "testing" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) func TestCORS(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodOptions, "/", nil) // Preflight request - req.Header.Set(echo.HeaderOrigin, "http://example.com") - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - mw := CORS("*") - handler := mw(func(c *echo.Context) error { - return nil - }) - - err := handler(c) - assert.NoError(t, err) - assert.Equal(t, http.StatusNoContent, rec.Code) - assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) -} - -func TestCORSConfig(t *testing.T) { var testCases = []struct { name string - givenConfig *CORSConfig + givenMW echo.MiddlewareFunc whenMethod string whenHeaders map[string]string expectHeaders map[string]string notExpectHeaders map[string]string - expectErr string }{ { - name: "ok, wildcard origin", - givenConfig: &CORSConfig{ - AllowOrigins: []string{"*"}, - }, + name: "ok, wildcard origin", whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"}, expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "*"}, }, { - name: "ok, wildcard AllowedOrigin with no Origin header in request", - givenConfig: &CORSConfig{ - AllowOrigins: []string{"*"}, - }, + name: "ok, wildcard AllowedOrigin with no Origin header in request", notExpectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: ""}, }, + { + name: "ok, invalid pattern is ignored", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{ + "\xff", // Invalid UTF-8 makes regexp.Compile to error + "*.example.com", + }, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{echo.HeaderOrigin: "http://aaa.example.com"}, + expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://aaa.example.com"}, + }, { name: "ok, specific AllowOrigins and AllowCredentials", - givenConfig: &CORSConfig{ - AllowOrigins: []string{"http://localhost", "http://localhost:8080"}, + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"localhost"}, AllowCredentials: true, MaxAge: 3600, - }, - whenHeaders: map[string]string{echo.HeaderOrigin: "http://localhost"}, + }), + whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"}, expectHeaders: map[string]string{ - echo.HeaderAccessControlAllowOrigin: "http://localhost", + echo.HeaderAccessControlAllowOrigin: "localhost", echo.HeaderAccessControlAllowCredentials: "true", }, }, { name: "ok, preflight request with matching origin for `AllowOrigins`", - givenConfig: &CORSConfig{ - AllowOrigins: []string{"http://localhost"}, + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"localhost"}, AllowCredentials: true, MaxAge: 3600, - }, + }), whenMethod: http.MethodOptions, whenHeaders: map[string]string{ - echo.HeaderOrigin: "http://localhost", + echo.HeaderOrigin: "localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, expectHeaders: map[string]string{ - echo.HeaderAccessControlAllowOrigin: "http://localhost", + echo.HeaderAccessControlAllowOrigin: "localhost", echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", echo.HeaderAccessControlAllowCredentials: "true", echo.HeaderAccessControlMaxAge: "3600", @@ -92,14 +77,14 @@ func TestCORSConfig(t *testing.T) { }, { name: "ok, preflight request when `Access-Control-Max-Age` is set", - givenConfig: &CORSConfig{ - AllowOrigins: []string{"http://localhost"}, + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"localhost"}, AllowCredentials: true, MaxAge: 1, - }, + }), whenMethod: http.MethodOptions, whenHeaders: map[string]string{ - echo.HeaderOrigin: "http://localhost", + echo.HeaderOrigin: "localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, expectHeaders: map[string]string{ @@ -108,14 +93,14 @@ func TestCORSConfig(t *testing.T) { }, { name: "ok, preflight request when `Access-Control-Max-Age` is set to 0 - not to cache response", - givenConfig: &CORSConfig{ - AllowOrigins: []string{"http://localhost"}, + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"localhost"}, AllowCredentials: true, MaxAge: -1, // forces `Access-Control-Max-Age: 0` - }, + }), whenMethod: http.MethodOptions, whenHeaders: map[string]string{ - echo.HeaderOrigin: "http://localhost", + echo.HeaderOrigin: "localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, expectHeaders: map[string]string{ @@ -124,16 +109,16 @@ func TestCORSConfig(t *testing.T) { }, { name: "ok, CORS check are skipped", - givenConfig: &CORSConfig{ - AllowOrigins: []string{"http://localhost"}, + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"localhost"}, AllowCredentials: true, - Skipper: func(c *echo.Context) bool { + Skipper: func(c echo.Context) bool { return true }, - }, + }), whenMethod: http.MethodOptions, whenHeaders: map[string]string{ - echo.HeaderOrigin: "http://localhost", + echo.HeaderOrigin: "localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, notExpectHeaders: map[string]string{ @@ -144,33 +129,31 @@ func TestCORSConfig(t *testing.T) { }, }, { - name: "nok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` true", - givenConfig: &CORSConfig{ + name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` true", + givenMW: CORSWithConfig(CORSConfig{ AllowOrigins: []string{"*"}, AllowCredentials: true, MaxAge: 3600, - }, + }), whenMethod: http.MethodOptions, whenHeaders: map[string]string{ echo.HeaderOrigin: "localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, - expectErr: `* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc`, - }, - { - name: "nok, preflight request with invalid `AllowOrigins` value", - givenConfig: &CORSConfig{ - AllowOrigins: []string{"http://server", "missing-scheme"}, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "*", // Note: browsers will ignore and complain about responses having `*` + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + echo.HeaderAccessControlAllowCredentials: "true", + echo.HeaderAccessControlMaxAge: "3600", }, - expectErr: `allow origin is missing scheme or host: missing-scheme`, }, { name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` false", - givenConfig: &CORSConfig{ + givenMW: CORSWithConfig(CORSConfig{ AllowOrigins: []string{"*"}, AllowCredentials: false, // important for this testcase MaxAge: 3600, - }, + }), whenMethod: http.MethodOptions, whenHeaders: map[string]string{ echo.HeaderOrigin: "localhost", @@ -187,23 +170,29 @@ func TestCORSConfig(t *testing.T) { }, { name: "ok, INSECURE preflight request with wildcard `AllowOrigins` and `AllowCredentials` true", - givenConfig: &CORSConfig{ - AllowOrigins: []string{"*"}, - AllowCredentials: true, - MaxAge: 3600, - }, + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"*"}, + AllowCredentials: true, + UnsafeWildcardOriginWithAllowCredentials: true, // important for this testcase + MaxAge: 3600, + }), whenMethod: http.MethodOptions, whenHeaders: map[string]string{ echo.HeaderOrigin: "localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, - expectErr: `* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc`, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "localhost", // This could end up as cross-origin attack + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + echo.HeaderAccessControlAllowCredentials: "true", + echo.HeaderAccessControlMaxAge: "3600", + }, }, { name: "ok, preflight request with Access-Control-Request-Headers", - givenConfig: &CORSConfig{ + givenMW: CORSWithConfig(CORSConfig{ AllowOrigins: []string{"*"}, - }, + }), whenMethod: http.MethodOptions, whenHeaders: map[string]string{ echo.HeaderOrigin: "localhost", @@ -218,28 +207,18 @@ func TestCORSConfig(t *testing.T) { }, { name: "ok, preflight request with `AllowOrigins` which allow all subdomains aaa with *", - givenConfig: &CORSConfig{ - UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (allowedOrigin string, allowed bool, err error) { - if strings.HasSuffix(origin, ".example.com") { - allowed = true - } - return origin, allowed, nil - }, - }, + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"http://*.example.com"}, + }), whenMethod: http.MethodOptions, whenHeaders: map[string]string{echo.HeaderOrigin: "http://aaa.example.com"}, expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://aaa.example.com"}, }, { name: "ok, preflight request with `AllowOrigins` which allow all subdomains bbb with *", - givenConfig: &CORSConfig{ - UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) { - if strings.HasSuffix(origin, ".example.com") { - return origin, true, nil - } - return "", false, nil - }, - }, + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"http://*.example.com"}, + }), whenMethod: http.MethodOptions, whenHeaders: map[string]string{echo.HeaderOrigin: "http://bbb.example.com"}, expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://bbb.example.com"}, @@ -249,26 +228,18 @@ func TestCORSConfig(t *testing.T) { t.Run(tc.name, func(t *testing.T) { e := echo.New() - var mw echo.MiddlewareFunc - var err error - if tc.givenConfig != nil { - mw, err = tc.givenConfig.ToMiddleware() - } else { - mw, err = CORSConfig{}.ToMiddleware() - } - if err != nil { - if tc.expectErr != "" { - assert.EqualError(t, err, tc.expectErr) - return - } - t.Fatal(err) + mw := CORS() + if tc.givenMW != nil { + mw = tc.givenMW } - - h := mw(func(c *echo.Context) error { + h := mw(func(c echo.Context) error { return nil }) - method := cmp.Or(tc.whenMethod, http.MethodGet) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } req := httptest.NewRequest(method, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) @@ -276,7 +247,7 @@ func TestCORSConfig(t *testing.T) { req.Header.Set(k, v) } - err = h(c) + err := h(c) assert.NoError(t, err) header := rec.Header() @@ -330,7 +301,98 @@ func Test_allowOriginScheme(t *testing.T) { cors := CORSWithConfig(CORSConfig{ AllowOrigins: []string{tt.pattern}, }) - h := cors(func(c *echo.Context) error { return echo.ErrNotFound }) + h := cors(echo.NotFoundHandler) + h(c) + + if tt.expected { + assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } else { + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + } + } +} + +func Test_allowOriginSubdomain(t *testing.T) { + tests := []struct { + domain, pattern string + expected bool + }{ + { + domain: "http://aaa.example.com", + pattern: "http://*.example.com", + expected: true, + }, + { + domain: "http://bbb.aaa.example.com", + pattern: "http://*.example.com", + expected: true, + }, + { + domain: "http://bbb.aaa.example.com", + pattern: "http://*.aaa.example.com", + expected: true, + }, + { + domain: "http://aaa.example.com:8080", + pattern: "http://*.example.com:8080", + expected: true, + }, + + { + domain: "http://fuga.hoge.com", + pattern: "http://*.example.com", + expected: false, + }, + { + domain: "http://ccc.bbb.example.com", + pattern: "http://*.aaa.example.com", + expected: false, + }, + { + domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ + .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ + .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ + .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`, + pattern: "http://*.example.com", + expected: false, + }, + { + domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`, + pattern: "http://*.example.com", + expected: false, + }, + { + domain: "http://ccc.bbb.example.com", + pattern: "http://example.com", + expected: false, + }, + { + domain: "https://prod-preview--aaa.bbb.com", + pattern: "https://*--aaa.bbb.com", + expected: true, + }, + { + domain: "http://ccc.bbb.example.com", + pattern: "http://*.example.com", + expected: true, + }, + { + domain: "http://ccc.bbb.example.com", + pattern: "http://foo.[a-z]*.example.com", + expected: false, + }, + } + + e := echo.New() + for _, tt := range tests { + req := httptest.NewRequest(http.MethodOptions, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + req.Header.Set(echo.HeaderOrigin, tt.domain) + cors := CORSWithConfig(CORSConfig{ + AllowOrigins: []string{tt.pattern}, + }) + h := cors(echo.NotFoundHandler) h(c) if tt.expected { @@ -343,53 +405,50 @@ func Test_allowOriginScheme(t *testing.T) { func TestCORSWithConfig_AllowMethods(t *testing.T) { var testCases = []struct { - name string - givenAllowOrigins []string - givenAllowMethods []string - whenAllowContextKey string - whenOrigin string + name string + allowOrigins []string + allowContextKey string + + whenOrigin string + whenAllowMethods []string + expectAllow string expectAccessControlAllowMethods string }{ { - name: "custom AllowMethods, preflight, no origin, sets only allow header from context key", - givenAllowOrigins: []string{"*"}, - givenAllowMethods: []string{http.MethodGet, http.MethodHead}, - whenAllowContextKey: "OPTIONS, GET", - whenOrigin: "", - expectAllow: "OPTIONS, GET", + name: "custom AllowMethods, preflight, no origin, sets only allow header from context key", + allowContextKey: "OPTIONS, GET", + whenAllowMethods: []string{http.MethodGet, http.MethodHead}, + whenOrigin: "", + expectAllow: "OPTIONS, GET", }, { - name: "default AllowMethods, preflight, no origin, no allow header in context key and in response", - givenAllowOrigins: []string{"*"}, - givenAllowMethods: nil, - whenAllowContextKey: "", - whenOrigin: "", - expectAllow: "", + name: "default AllowMethods, preflight, no origin, no allow header in context key and in response", + allowContextKey: "", + whenAllowMethods: nil, + whenOrigin: "", + expectAllow: "", }, { name: "custom AllowMethods, preflight, existing origin, sets both headers different values", - givenAllowOrigins: []string{"*"}, - givenAllowMethods: []string{http.MethodGet, http.MethodHead}, - whenAllowContextKey: "OPTIONS, GET", + allowContextKey: "OPTIONS, GET", + whenAllowMethods: []string{http.MethodGet, http.MethodHead}, whenOrigin: "http://google.com", expectAllow: "OPTIONS, GET", expectAccessControlAllowMethods: "GET,HEAD", }, { name: "default AllowMethods, preflight, existing origin, sets both headers", - givenAllowOrigins: []string{"*"}, - givenAllowMethods: nil, - whenAllowContextKey: "OPTIONS, GET", + allowContextKey: "OPTIONS, GET", + whenAllowMethods: nil, whenOrigin: "http://google.com", expectAllow: "OPTIONS, GET", expectAccessControlAllowMethods: "OPTIONS, GET", }, { name: "default AllowMethods, preflight, existing origin, no allows, sets only CORS allow methods", - givenAllowOrigins: []string{"*"}, - givenAllowMethods: nil, - whenAllowContextKey: "", + allowContextKey: "", + whenAllowMethods: nil, whenOrigin: "http://google.com", expectAllow: "", expectAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", @@ -399,13 +458,13 @@ func TestCORSWithConfig_AllowMethods(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() - e.GET("/test", func(c *echo.Context) error { + e.GET("/test", func(c echo.Context) error { return c.String(http.StatusOK, "OK") }) cors := CORSWithConfig(CORSConfig{ - AllowOrigins: tc.givenAllowOrigins, - AllowMethods: tc.givenAllowMethods, + AllowOrigins: tc.allowOrigins, + AllowMethods: tc.whenAllowMethods, }) req := httptest.NewRequest(http.MethodOptions, "/test", nil) @@ -413,13 +472,11 @@ func TestCORSWithConfig_AllowMethods(t *testing.T) { c := e.NewContext(req, rec) req.Header.Set(echo.HeaderOrigin, tc.whenOrigin) - if tc.whenAllowContextKey != "" { - c.Set(echo.ContextKeyHeaderAllow, tc.whenAllowContextKey) + if tc.allowContextKey != "" { + c.Set(echo.ContextKeyHeaderAllow, tc.allowContextKey) } - h := cors(func(c *echo.Context) error { - return c.String(http.StatusOK, "OK") - }) + h := cors(echo.NotFoundHandler) h(c) assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow)) @@ -535,10 +592,10 @@ func TestCorsHeaders(t *testing.T) { //MaxAge: 3600, })) - e.GET("/", func(c *echo.Context) error { + e.GET("/", func(c echo.Context) error { return c.String(http.StatusOK, "OK") }) - e.POST("/", func(c *echo.Context) error { + e.POST("/", func(c echo.Context) error { return c.String(http.StatusCreated, "OK") }) @@ -582,17 +639,17 @@ func TestCorsHeaders(t *testing.T) { } func Test_allowOriginFunc(t *testing.T) { - returnTrue := func(c *echo.Context, origin string) (string, bool, error) { - return origin, true, nil + returnTrue := func(origin string) (bool, error) { + return true, nil } - returnFalse := func(c *echo.Context, origin string) (string, bool, error) { - return origin, false, nil + returnFalse := func(origin string) (bool, error) { + return false, nil } - returnError := func(c *echo.Context, origin string) (string, bool, error) { - return origin, true, errors.New("this is a test error") + returnError := func(origin string) (bool, error) { + return true, errors.New("this is a test error") } - allowOriginFuncs := []func(c *echo.Context, origin string) (string, bool, error){ + allowOriginFuncs := []func(origin string) (bool, error){ returnTrue, returnFalse, returnError, @@ -606,21 +663,21 @@ func Test_allowOriginFunc(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(echo.HeaderOrigin, origin) - cors, err := CORSConfig{UnsafeAllowOriginFunc: allowOriginFunc}.ToMiddleware() - assert.NoError(t, err) - - h := cors(func(c *echo.Context) error { return echo.ErrNotFound }) - err = h(c) + cors := CORSWithConfig(CORSConfig{ + AllowOriginFunc: allowOriginFunc, + }) + h := cors(echo.NotFoundHandler) + err := h(c) - allowedOrigin, allowed, expectedErr := allowOriginFunc(c, origin) + expected, expectedErr := allowOriginFunc(origin) if expectedErr != nil { assert.Equal(t, expectedErr, err) assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) continue } - if allowed { - assert.Equal(t, allowedOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + if expected { + assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) } else { assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) } diff --git a/middleware/csrf.go b/middleware/csrf.go index e3616516f..1a35da63c 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -10,7 +10,7 @@ import ( "strings" "time" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" ) // CSRFUsingSecFetchSite is a context key for CSRF middleware what is set when the client browser is using Sec-Fetch-Site @@ -26,24 +26,25 @@ const CSRFUsingSecFetchSite = "_echo_csrf_using_sec_fetch_site_" type CSRFConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper - // TrustedOrigins permits any request with `Sec-Fetch-Site` header whose `Origin` header - // exactly matches a configured origin. - // Values should be formatted as Origin header "scheme://host[:port]". + + // TrustedOrigin permits any request with `Sec-Fetch-Site` header whose `Origin` header + // exactly matches the specified value. + // Values should be formated as Origin header "scheme://host[:port]". // // See [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin // See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers TrustedOrigins []string - // AllowSecFetchSiteFunc allows custom behaviour for `Sec-Fetch-Site` requests that are about to - // fail with CSRF error, to be allowed or replaced with custom error. + // AllowSecFetchSameSite allows custom behaviour for `Sec-Fetch-Site` requests that are about to + // fail with CRSF error, to be allowed or replaced with custom error. // This function applies to `Sec-Fetch-Site` values: // - `same-site` same registrable domain (subdomain and/or different port) // - `cross-site` request originates from different site // See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers - AllowSecFetchSiteFunc func(c *echo.Context) (bool, error) + AllowSecFetchSiteFunc func(c echo.Context) (bool, error) // TokenLength is the length of the generated token. - TokenLength uint8 + TokenLength uint8 `yaml:"token_length"` // Optional. Default value 32. // TokenLookup is a string in the form of ":" or ":,:" that is used @@ -57,48 +58,49 @@ type CSRFConfig struct { // - "header:X-CSRF-Token,query:csrf" TokenLookup string `yaml:"token_lookup"` - // Generator defines a function to generate token. - // Optional. Defaults tp randomString(TokenLength). - Generator func() string - // Context key to store generated CSRF token into context. // Optional. Default value "csrf". - ContextKey string + ContextKey string `yaml:"context_key"` // Name of the CSRF cookie. This cookie will store CSRF token. // Optional. Default value "csrf". - CookieName string + CookieName string `yaml:"cookie_name"` // Domain of the CSRF cookie. // Optional. Default value none. - CookieDomain string + CookieDomain string `yaml:"cookie_domain"` // Path of the CSRF cookie. // Optional. Default value none. - CookiePath string + CookiePath string `yaml:"cookie_path"` // Max age (in seconds) of the CSRF cookie. // Optional. Default value 86400 (24hr). - CookieMaxAge int + CookieMaxAge int `yaml:"cookie_max_age"` // Indicates if CSRF cookie is secure. // Optional. Default value false. - CookieSecure bool + CookieSecure bool `yaml:"cookie_secure"` // Indicates if CSRF cookie is HTTP only. // Optional. Default value false. - CookieHTTPOnly bool + CookieHTTPOnly bool `yaml:"cookie_http_only"` // Indicates SameSite mode of the CSRF cookie. // Optional. Default value SameSiteDefaultMode. - CookieSameSite http.SameSite + CookieSameSite http.SameSite `yaml:"cookie_same_site"` // ErrorHandler defines a function which is executed for returning custom errors. - ErrorHandler func(c *echo.Context, err error) error + ErrorHandler CSRFErrorHandler + + generator func(length uint8) string } +// CSRFErrorHandler is a function which is executed for creating custom errors. +type CSRFErrorHandler func(err error, c echo.Context) error + // ErrCSRFInvalid is returned when CSRF check fails -var ErrCSRFInvalid = &echo.HTTPError{Code: http.StatusForbidden, Message: "invalid csrf token"} +var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") // DefaultCSRFConfig is the default CSRF middleware config. var DefaultCSRFConfig = CSRFConfig{ @@ -114,26 +116,25 @@ var DefaultCSRFConfig = CSRFConfig{ // CSRF returns a Cross-Site Request Forgery (CSRF) middleware. // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery func CSRF() echo.MiddlewareFunc { - return CSRFWithConfig(DefaultCSRFConfig) + c := DefaultCSRFConfig + return CSRFWithConfig(c) } -// CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration. +// CSRFWithConfig returns a CSRF middleware with config. +// See `CSRF()`. func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { - // Defaults if config.Skipper == nil { config.Skipper = DefaultCSRFConfig.Skipper } if config.TokenLength == 0 { config.TokenLength = DefaultCSRFConfig.TokenLength } - if config.Generator == nil { - config.Generator = createRandomStringGenerator(config.TokenLength) - } + if config.TokenLookup == "" { config.TokenLookup = DefaultCSRFConfig.TokenLookup } @@ -150,19 +151,23 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { config.CookieSecure = true } if len(config.TrustedOrigins) > 0 { - if err := validateOrigins(config.TrustedOrigins, "trusted origin"); err != nil { - return nil, err + if vErr := validateOrigins(config.TrustedOrigins, "trusted origin"); vErr != nil { + return nil, vErr } config.TrustedOrigins = append([]string(nil), config.TrustedOrigins...) } + tokenGenerator := randomString + if config.generator != nil { + tokenGenerator = config.generator + } - extractors, cErr := createExtractors(config.TokenLookup, 1) + extractors, cErr := CreateExtractors(config.TokenLookup) if cErr != nil { return nil, cErr } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) error { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } @@ -180,7 +185,7 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { token := "" if k, err := c.Cookie(config.CookieName); err != nil { - token = config.Generator() // Generate token + token = tokenGenerator(config.TokenLength) } else { token = k.Value // Reuse token } @@ -193,7 +198,7 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { var lastTokenErr error outer: for _, extractor := range extractors { - clientTokens, _, err := extractor(c) + clientTokens, err := extractor(c) if err != nil { lastExtractorErr = err continue @@ -212,11 +217,22 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if lastTokenErr != nil { finalErr = lastTokenErr } else if lastExtractorErr != nil { - finalErr = echo.ErrBadRequest.Wrap(lastExtractorErr) + // ugly part to preserve backwards compatible errors. someone could rely on them + if lastExtractorErr == errQueryExtractorValueMissing { + lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string") + } else if lastExtractorErr == errFormExtractorValueMissing { + lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter") + } else if lastExtractorErr == errHeaderExtractorValueMissing { + lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header") + } else { + lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error()) + } + finalErr = lastExtractorErr } + if finalErr != nil { if config.ErrorHandler != nil { - return config.ErrorHandler(c, finalErr) + return config.ErrorHandler(finalErr, c) } return finalErr } @@ -257,7 +273,7 @@ func validateCSRFToken(token, clientToken string) bool { var safeMethods = []string{http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace} -func (config CSRFConfig) checkSecFetchSiteRequest(c *echo.Context) (bool, error) { +func (config CSRFConfig) checkSecFetchSiteRequest(c echo.Context) (bool, error) { // https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers // Sec-Fetch-Site values are: // - `same-origin` exact origin match - allow always @@ -295,13 +311,13 @@ func (config CSRFConfig) checkSecFetchSiteRequest(c *echo.Context) (bool, error) } // we are here when request is state-changing and `cross-site` or `same-site` - // Note: if you want to allow `same-site` use config.TrustedOrigins or `config.AllowSecFetchSiteFunc` + // Note: if you want to block `same-site` use config.TrustedOrigins or `config.AllowSecFetchSiteFunc` if config.AllowSecFetchSiteFunc != nil { return config.AllowSecFetchSiteFunc(c) } if secFetchSite == "same-site" { - return false, echo.NewHTTPError(http.StatusForbidden, "same-site request blocked by CSRF") + return false, nil // fall back to legacy token } return false, echo.NewHTTPError(http.StatusForbidden, "cross-site request blocked by CSRF") } diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index a13fdc82c..0b3210f07 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -11,7 +11,7 @@ import ( "strings" "testing" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) @@ -57,7 +57,6 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenFormTokens: map[string][]string{ "csrf": {"invalid", "token"}, }, - expectError: "code=403, message=invalid csrf token", }, { name: "nok, invalid token from POST form", @@ -75,7 +74,7 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenCSRFCookie: "token", givenMethod: http.MethodPost, givenFormTokens: map[string][]string{}, - expectError: "code=400, message=Bad Request, err=missing value in the form", + expectError: "code=400, message=missing csrf token in the form parameter", }, { name: "ok, token from POST header", @@ -94,7 +93,6 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenHeaderTokens: map[string][]string{ echo.HeaderXCSRFToken: {"invalid", "token"}, }, - expectError: "code=403, message=invalid csrf token", }, { name: "nok, invalid token from POST header", @@ -112,7 +110,7 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenCSRFCookie: "token", givenMethod: http.MethodPost, givenHeaderTokens: map[string][]string{}, - expectError: "code=400, message=Bad Request, err=missing value in request header", + expectError: "code=400, message=missing csrf token in request header", }, { name: "ok, token from PUT query param", @@ -131,7 +129,6 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenQueryTokens: map[string][]string{ "csrf": {"invalid", "token"}, }, - expectError: "code=403, message=invalid csrf token", }, { name: "nok, invalid token from PUT query form", @@ -149,7 +146,7 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenCSRFCookie: "token", givenMethod: http.MethodPut, givenQueryTokens: map[string][]string{}, - expectError: "code=400, message=Bad Request, err=missing value in the query string", + expectError: "code=400, message=missing csrf token in the query string", }, { name: "nok, invalid TokenLookup", @@ -213,7 +210,7 @@ func TestCSRF_tokenExtractors(t *testing.T) { assert.NoError(t, err) } - h := csrf(func(c *echo.Context) error { + h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -261,7 +258,7 @@ func TestCSRFWithConfig(t *testing.T) { name: "nok, POST without token", whenMethod: http.MethodPost, expectEmptyBody: true, - expectErr: `code=400, message=Bad Request, err=missing value in request header`, + expectErr: `code=400, message=missing csrf token in request header`, }, { name: "nok, POST empty token", @@ -328,12 +325,11 @@ func TestCSRFWithConfig(t *testing.T) { if tc.givenConfig != nil { config = *tc.givenConfig } - if config.Generator == nil { - config.Generator = func() string { + if config.generator == nil { + config.generator = func(_ uint8) string { return "TESTTOKEN" } } - mw, err := config.ToMiddleware() if tc.expectMWError != "" { assert.EqualError(t, err, tc.expectMWError) @@ -341,7 +337,7 @@ func TestCSRFWithConfig(t *testing.T) { } assert.NoError(t, err) - h := mw(func(c *echo.Context) error { + h := mw(func(c echo.Context) error { cToken := c.Get(cmp.Or(config.ContextKey, DefaultCSRFConfig.ContextKey)) assert.Equal(t, tc.expectTokenInContext, cToken) return c.String(http.StatusOK, "test") @@ -373,7 +369,7 @@ func TestCSRF(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) csrf := CSRF() - h := csrf(func(c *echo.Context) error { + h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -393,7 +389,7 @@ func TestCSRFSetSameSiteMode(t *testing.T) { CookieSameSite: http.SameSiteStrictMode, }) - h := csrf(func(c *echo.Context) error { + h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -410,7 +406,7 @@ func TestCSRFWithoutSameSiteMode(t *testing.T) { csrf := CSRFWithConfig(CSRFConfig{}) - h := csrf(func(c *echo.Context) error { + h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -429,7 +425,7 @@ func TestCSRFWithSameSiteDefaultMode(t *testing.T) { CookieSameSite: http.SameSiteDefaultMode, }) - h := csrf(func(c *echo.Context) error { + h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -449,7 +445,7 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) { }.ToMiddleware() assert.NoError(t, err) - h := csrf(func(c *echo.Context) error { + h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -485,12 +481,12 @@ func TestCSRFConfig_skipper(t *testing.T) { c := e.NewContext(req, rec) csrf := CSRFWithConfig(CSRFConfig{ - Skipper: func(c *echo.Context) bool { + Skipper: func(c echo.Context) bool { return tc.whenSkip }, }) - h := csrf(func(c *echo.Context) error { + h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -504,13 +500,13 @@ func TestCSRFConfig_skipper(t *testing.T) { func TestCSRFErrorHandling(t *testing.T) { cfg := CSRFConfig{ - ErrorHandler: func(c *echo.Context, err error) error { + ErrorHandler: func(err error, c echo.Context) error { return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed") }, } e := echo.New() - e.POST("/", func(c *echo.Context) error { + e.POST("/", func(c echo.Context) error { return c.String(http.StatusNotImplemented, "should not end up here") }) @@ -583,7 +579,6 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { whenMethod: http.MethodPost, whenSecFetchSite: "same-site", expectAllow: false, - expectErr: `code=403, message=same-site request blocked by CSRF`, }, { name: "ok, unsafe POST + same-origin passes", @@ -641,7 +636,6 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { whenMethod: http.MethodPut, whenSecFetchSite: "same-site", expectAllow: false, - expectErr: `code=403, message=same-site request blocked by CSRF`, }, { name: "nok, unsafe DELETE + cross-site is blocked", @@ -657,7 +651,6 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { whenMethod: http.MethodDelete, whenSecFetchSite: "same-site", expectAllow: false, - expectErr: `code=403, message=same-site request blocked by CSRF`, }, { name: "nok, unsafe PATCH + cross-site is blocked", @@ -770,7 +763,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { { name: "ok, unsafe POST + same-site + custom func allows", givenConfig: CSRFConfig{ - AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { + AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { return true, nil }, }, @@ -781,7 +774,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { { name: "ok, unsafe POST + cross-site + custom func allows", givenConfig: CSRFConfig{ - AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { + AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { return true, nil }, }, @@ -792,7 +785,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { { name: "nok, unsafe POST + same-site + custom func returns custom error", givenConfig: CSRFConfig{ - AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { + AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { return false, echo.NewHTTPError(http.StatusTeapot, "custom error from func") }, }, @@ -804,7 +797,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { { name: "nok, unsafe POST + cross-site + custom func returns false with nil error", givenConfig: CSRFConfig{ - AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { + AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { return false, nil }, }, @@ -825,7 +818,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { name: "ok, unsafe POST + cross-site + trusted origin takes precedence over custom func", givenConfig: CSRFConfig{ TrustedOrigins: []string{"https://trusted.example.com"}, - AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { + AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { return false, echo.NewHTTPError(http.StatusTeapot, "should not be called") }, }, @@ -838,7 +831,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { name: "nok, unsafe POST + cross-site + trusted origin not matched, custom func blocks", givenConfig: CSRFConfig{ TrustedOrigins: []string{"https://trusted.example.com"}, - AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { + AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { return false, echo.NewHTTPError(http.StatusTeapot, "custom block") }, }, @@ -860,7 +853,8 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { } res := httptest.NewRecorder() - c := echo.NewContext(req, res) + e := echo.New() + c := e.NewContext(req, res) allow, err := tc.givenConfig.checkSecFetchSiteRequest(c) diff --git a/middleware/decompress.go b/middleware/decompress.go index 501ee6c5b..8c418efd7 100644 --- a/middleware/decompress.go +++ b/middleware/decompress.go @@ -9,7 +9,7 @@ import ( "net/http" "sync" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" ) // DecompressConfig defines the config for Decompress middleware. @@ -19,13 +19,6 @@ type DecompressConfig struct { // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers GzipDecompressPool Decompressor - - // MaxDecompressedSize limits the maximum size of decompressed request body in bytes. - // If the decompressed body exceeds this limit, the middleware returns HTTP 413 error. - // This prevents zip bomb attacks where small compressed payloads decompress to huge sizes. - // Default: 100 * MB (104,857,600 bytes) - // Set to -1 to disable limits (not recommended in production). - MaxDecompressedSize int64 } // GZIPEncoding content-encoding header if set to "gzip", decompress body contents. @@ -36,6 +29,12 @@ type Decompressor interface { gzipDecompressPool() sync.Pool } +// DefaultDecompressConfig defines the config for decompress middleware +var DefaultDecompressConfig = DecompressConfig{ + Skipper: DefaultSkipper, + GzipDecompressPool: &DefaultGzipDecompressPool{}, +} + // DefaultGzipDecompressPool is the default implementation of Decompressor interface type DefaultGzipDecompressPool struct { } @@ -45,39 +44,24 @@ func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool { } // Decompress decompresses request body based if content encoding type is set to "gzip" with default config -// -// SECURITY: By default, this limits decompressed data to 100MB to prevent zip bomb attacks. -// To customize the limit, use DecompressWithConfig. To disable limits (not recommended in production), -// set MaxDecompressedSize to -1. func Decompress() echo.MiddlewareFunc { - return DecompressWithConfig(DecompressConfig{}) + return DecompressWithConfig(DefaultDecompressConfig) } -// DecompressWithConfig returns a decompress middleware with config or panics on invalid configuration. -// -// SECURITY: If MaxDecompressedSize is not set (zero value), it defaults to 100MB to prevent -// DoS attacks via zip bombs. Set to -1 to explicitly disable limits if needed for your use case. +// DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { - return toMiddlewareOrPanic(config) -} - -// ToMiddleware converts DecompressConfig to middleware or returns an error for invalid configuration -func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) { + // Defaults if config.Skipper == nil { - config.Skipper = DefaultSkipper + config.Skipper = DefaultGzipConfig.Skipper } if config.GzipDecompressPool == nil { - config.GzipDecompressPool = &DefaultGzipDecompressPool{} - } - // Apply secure default for decompression limit - if config.MaxDecompressedSize == 0 { - config.MaxDecompressedSize = 100 * MB + config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool } return func(next echo.HandlerFunc) echo.HandlerFunc { pool := config.GzipDecompressPool.gzipDecompressPool() - return func(c *echo.Context) error { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } @@ -89,10 +73,7 @@ func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) { i := pool.Get() gr, ok := i.(*gzip.Reader) if !ok || gr == nil { - if err, isErr := i.(error); isErr { - return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) - } - return echo.NewHTTPError(http.StatusInternalServerError, "unexpected type from gzip decompression pool") + return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) } defer pool.Put(gr) @@ -109,48 +90,9 @@ func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // only Close gzip reader if it was set to a proper gzip source otherwise it will panic on close. defer gr.Close() - // Apply decompression size limit to prevent zip bombs - if config.MaxDecompressedSize > 0 { - c.Request().Body = &limitedGzipReader{ - Reader: gr, - remaining: config.MaxDecompressedSize, - limit: config.MaxDecompressedSize, - } - } else { - // -1 means explicitly unlimited (not recommended) - c.Request().Body = gr - } - c.Request().ContentLength = -1 + c.Request().Body = gr return next(c) } - }, nil -} - -// limitedGzipReader wraps a gzip reader with size limiting to prevent zip bombs -type limitedGzipReader struct { - *gzip.Reader - remaining int64 - limit int64 -} - -func (r *limitedGzipReader) Read(p []byte) (n int, err error) { - if r.remaining <= 0 { - // Limit exceeded - return 413 error - return 0, echo.ErrStatusRequestEntityTooLarge - } - - // Limit the read to remaining bytes - if int64(len(p)) > r.remaining { - p = p[:r.remaining] } - - n, err = r.Reader.Read(p) - r.remaining -= int64(n) - - return n, err -} - -func (r *limitedGzipReader) Close() error { - return r.Reader.Close() } diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index 8dc3057ba..52506ce8e 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -14,91 +14,61 @@ import ( "sync" "testing" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) func TestDecompress(t *testing.T) { e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - h := Decompress()(func(c *echo.Context) error { + // Skip if no Content-Encoding header + h := Decompress()(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) + h(c) + + assert.Equal(t, "test", rec.Body.String()) - // Decompress request body + // Decompress body := `{"name": "echo"}` gz, _ := gzipString(body) - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := h(c) - assert.NoError(t, err) - + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h(c) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := io.ReadAll(req.Body) assert.NoError(t, err) assert.Equal(t, body, string(b)) } -func TestDecompress_skippedIfNoHeader(t *testing.T) { +func TestDecompressDefaultConfig(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - // Skip if no Content-Encoding header - h := Decompress()(func(c *echo.Context) error { + h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) + h(c) - err := h(c) - assert.NoError(t, err) assert.Equal(t, "test", rec.Body.String()) -} - -func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h, err := DecompressConfig{}.ToMiddleware() - assert.NoError(t, err) - - err = h(func(c *echo.Context) error { - c.Response().Write([]byte("test")) // For Content-Type sniffing - return nil - })(c) - assert.NoError(t, err) - - assert.Equal(t, "test", rec.Body.String()) - -} - -func TestDecompressWithConfig_DefaultConfig(t *testing.T) { - e := echo.New() - - h := Decompress()(func(c *echo.Context) error { - c.Response().Write([]byte("test")) // For Content-Type sniffing - return nil - }) - // Decompress body := `{"name": "echo"}` gz, _ := gzipString(body) - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := h(c) - assert.NoError(t, err) - + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h(c) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := io.ReadAll(req.Body) assert.NoError(t, err) @@ -113,9 +83,7 @@ func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() e.NewContext(req, rec) - e.ServeHTTP(rec, req) - assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := io.ReadAll(req.Body) assert.NoError(t, err) @@ -129,13 +97,10 @@ func TestDecompressNoContent(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := Decompress()(func(c *echo.Context) error { + h := Decompress()(func(c echo.Context) error { return c.NoContent(http.StatusNoContent) }) - - err := h(c) - - if assert.NoError(t, err) { + if assert.NoError(t, h(c)) { assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) assert.Equal(t, 0, len(rec.Body.Bytes())) @@ -145,15 +110,13 @@ func TestDecompressNoContent(t *testing.T) { func TestDecompressErrorReturned(t *testing.T) { e := echo.New() e.Use(Decompress()) - e.GET("/", func(c *echo.Context) error { + e.GET("/", func(c echo.Context) error { return echo.ErrNotFound }) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusNotFound, rec.Code) assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) } @@ -161,7 +124,7 @@ func TestDecompressErrorReturned(t *testing.T) { func TestDecompressSkipper(t *testing.T) { e := echo.New() e.Use(DecompressWithConfig(DecompressConfig{ - Skipper: func(c *echo.Context) bool { + Skipper: func(c echo.Context) bool { return c.Request().URL.Path == "/skip" }, })) @@ -170,9 +133,7 @@ func TestDecompressSkipper(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - e.ServeHTTP(rec, req) - assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSON) reqBody, err := io.ReadAll(c.Request().Body) assert.NoError(t, err) @@ -201,9 +162,7 @@ func TestDecompressPoolError(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - e.ServeHTTP(rec, req) - assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) reqBody, err := io.ReadAll(c.Request().Body) assert.NoError(t, err) @@ -215,8 +174,10 @@ func BenchmarkDecompress(b *testing.B) { e := echo.New() body := `{"name": "echo"}` gz, _ := gzipString(body) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - h := Decompress()(func(c *echo.Context) error { + h := Decompress()(func(c echo.Context) error { c.Response().Write([]byte(body)) // For Content-Type sniffing return nil }) @@ -226,8 +187,6 @@ func BenchmarkDecompress(b *testing.B) { for i := 0; i < b.N; i++ { // Decompress - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) - req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) h(c) @@ -249,260 +208,3 @@ func gzipString(body string) ([]byte, error) { return buf.Bytes(), nil } - -func TestDecompress_WithinLimit(t *testing.T) { - e := echo.New() - body := strings.Repeat("test data ", 100) // Small payload ~1KB - gz, _ := gzipString(body) - - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) - req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h, err := DecompressConfig{MaxDecompressedSize: 100 * MB}.ToMiddleware() - assert.NoError(t, err) - - err = h(func(c *echo.Context) error { - b, _ := io.ReadAll(c.Request().Body) - return c.String(http.StatusOK, string(b)) - })(c) - - assert.NoError(t, err) - assert.Equal(t, body, rec.Body.String()) -} - -func TestDecompress_ExceedsLimit(t *testing.T) { - e := echo.New() - // Create 2KB of data but limit to 1KB - largeBody := strings.Repeat("A", 2*1024) - gz, _ := gzipString(largeBody) - - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) - req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit - assert.NoError(t, err) - - err = h(func(c *echo.Context) error { - _, readErr := io.ReadAll(c.Request().Body) - return readErr - })(c) - - // Should return 413 error - assert.Error(t, err) - he, ok := err.(echo.HTTPStatusCoder) - assert.True(t, ok) - assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) -} - -func TestDecompress_AtExactLimit(t *testing.T) { - e := echo.New() - exactBody := strings.Repeat("B", 1024) // Exactly 1KB - gz, _ := gzipString(exactBody) - - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) - req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() - assert.NoError(t, err) - - err = h(func(c *echo.Context) error { - b, _ := io.ReadAll(c.Request().Body) - return c.String(http.StatusOK, string(b)) - })(c) - - assert.NoError(t, err) - assert.Equal(t, exactBody, rec.Body.String()) -} - -func TestDecompress_ZipBomb(t *testing.T) { - e := echo.New() - // Create highly compressed data that expands to 2MB - // but limit is 1MB - largeBody := bytes.Repeat([]byte("A"), 2*1024*1024) // 2MB - var buf bytes.Buffer - gzWriter := gzip.NewWriter(&buf) - gzWriter.Write(largeBody) - gzWriter.Close() - - req := httptest.NewRequest(http.MethodPost, "/", &buf) - req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h, err := DecompressConfig{MaxDecompressedSize: 1 * MB}.ToMiddleware() - assert.NoError(t, err) - - err = h(func(c *echo.Context) error { - _, readErr := io.ReadAll(c.Request().Body) - return readErr - })(c) - - // Should return 413 error - assert.Error(t, err) - he, ok := err.(echo.HTTPStatusCoder) - assert.True(t, ok) - assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) -} - -func TestDecompress_UnlimitedExplicit(t *testing.T) { - e := echo.New() - largeBody := strings.Repeat("X", 10*1024) // 10KB - gz, _ := gzipString(largeBody) - - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) - req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h, err := DecompressConfig{MaxDecompressedSize: -1}.ToMiddleware() // Unlimited - assert.NoError(t, err) - - err = h(func(c *echo.Context) error { - b, _ := io.ReadAll(c.Request().Body) - return c.String(http.StatusOK, string(b)) - })(c) - - assert.NoError(t, err) - assert.Equal(t, largeBody, rec.Body.String()) -} - -func TestDecompress_DefaultLimit(t *testing.T) { - e := echo.New() - smallBody := "test" - gz, _ := gzipString(smallBody) - - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) - req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - // Use zero value which should apply 100MB default - h, err := DecompressConfig{}.ToMiddleware() - assert.NoError(t, err) - - err = h(func(c *echo.Context) error { - b, _ := io.ReadAll(c.Request().Body) - return c.String(http.StatusOK, string(b)) - })(c) - - assert.NoError(t, err) - assert.Equal(t, smallBody, rec.Body.String()) -} - -func TestDecompress_SmallCustomLimit(t *testing.T) { - e := echo.New() - body := strings.Repeat("D", 512) // 512 bytes - gz, _ := gzipString(body) - - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) - req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit - assert.NoError(t, err) - - err = h(func(c *echo.Context) error { - b, _ := io.ReadAll(c.Request().Body) - return c.String(http.StatusOK, string(b)) - })(c) - - assert.NoError(t, err) - assert.Equal(t, body, rec.Body.String()) -} - -func TestDecompress_MultipleReads(t *testing.T) { - e := echo.New() - // Test that limit is enforced across multiple Read() calls - largeBody := strings.Repeat("M", 2*1024) // 2KB - gz, _ := gzipString(largeBody) - - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) - req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit - assert.NoError(t, err) - - err = h(func(c *echo.Context) error { - // Read in small chunks - buf := make([]byte, 256) - total := 0 - for { - n, readErr := c.Request().Body.Read(buf) - total += n - if readErr != nil { - if readErr == io.EOF { - return nil - } - return readErr - } - } - })(c) - - // Should return 413 error from cumulative reads - assert.Error(t, err) - he, ok := err.(echo.HTTPStatusCoder) - assert.True(t, ok) - assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) -} - -func TestDecompress_LargePayloadDosPrevention(t *testing.T) { - e := echo.New() - // Simulate a DoS attack with highly compressed large payload - largeSize := 10 * 1024 * 1024 // 10MB decompressed - largeBody := bytes.Repeat([]byte("Z"), largeSize) - var buf bytes.Buffer - gzWriter := gzip.NewWriter(&buf) - gzWriter.Write(largeBody) - gzWriter.Close() - - req := httptest.NewRequest(http.MethodPost, "/", &buf) - req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h, err := DecompressConfig{MaxDecompressedSize: 1 * MB}.ToMiddleware() - assert.NoError(t, err) - - err = h(func(c *echo.Context) error { - _, readErr := io.ReadAll(c.Request().Body) - return readErr - })(c) - - // Should prevent DoS by returning 413 - assert.Error(t, err) - he, ok := err.(echo.HTTPStatusCoder) - assert.True(t, ok) - assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) -} - -func BenchmarkDecompress_WithLimit(b *testing.B) { - e := echo.New() - body := strings.Repeat("benchmark data ", 1000) // ~15KB - gz, _ := gzipString(body) - - h, _ := DecompressConfig{MaxDecompressedSize: 100 * MB}.ToMiddleware() - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) - req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - h(func(c *echo.Context) error { - io.ReadAll(c.Request().Body) - return nil - })(c) - } -} diff --git a/middleware/extractor.go b/middleware/extractor.go index f800a49e9..3f2741407 100644 --- a/middleware/extractor.go +++ b/middleware/extractor.go @@ -4,11 +4,11 @@ package middleware import ( + "errors" "fmt" + "github.com/labstack/echo/v4" "net/textproto" "strings" - - "github.com/labstack/echo/v5" ) const ( @@ -17,44 +17,18 @@ const ( extractorLimit = 20 ) -// ExtractorSource is type to indicate source for extracted value -type ExtractorSource string - -const ( - // ExtractorSourceHeader means value was extracted from request header - ExtractorSourceHeader ExtractorSource = "header" - // ExtractorSourceQuery means value was extracted from request query parameters - ExtractorSourceQuery ExtractorSource = "query" - // ExtractorSourcePathParam means value was extracted from route path parameters - ExtractorSourcePathParam ExtractorSource = "param" - // ExtractorSourceCookie means value was extracted from request cookies - ExtractorSourceCookie ExtractorSource = "cookie" - // ExtractorSourceForm means value was extracted from request form values - ExtractorSourceForm ExtractorSource = "form" -) - -// ValueExtractorError is error type when middleware extractor is unable to extract value from lookups -type ValueExtractorError struct { - message string -} - -// Error returns errors text -func (e *ValueExtractorError) Error() string { - return e.message -} - -var errHeaderExtractorValueMissing = &ValueExtractorError{message: "missing value in request header"} -var errHeaderExtractorValueInvalid = &ValueExtractorError{message: "invalid value in request header"} -var errQueryExtractorValueMissing = &ValueExtractorError{message: "missing value in the query string"} -var errParamExtractorValueMissing = &ValueExtractorError{message: "missing value in path params"} -var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing value in cookies"} -var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"} +var errHeaderExtractorValueMissing = errors.New("missing value in request header") +var errHeaderExtractorValueInvalid = errors.New("invalid value in request header") +var errQueryExtractorValueMissing = errors.New("missing value in the query string") +var errParamExtractorValueMissing = errors.New("missing value in path params") +var errCookieExtractorValueMissing = errors.New("missing value in cookies") +var errFormExtractorValueMissing = errors.New("missing value in the form") // ValuesExtractor defines a function for extracting values (keys/tokens) from the given context. -type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error) +type ValuesExtractor func(c echo.Context) ([]string, error) // CreateExtractors creates ValuesExtractors from given lookups. -// lookups is a string in the form of ":" or ":,:" that is used +// Lookups is a string in the form of ":" or ":,:" that is used // to extract key from the request. // Possible values: // - "header:" or "header::" @@ -69,25 +43,17 @@ type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error) // // Multiple sources example: // - "header:Authorization,header:X-Api-Key" -// -// limit sets the maximum amount how many lookups can be returned. -func CreateExtractors(lookups string, limit uint) ([]ValuesExtractor, error) { - return createExtractors(lookups, limit) +func CreateExtractors(lookups string) ([]ValuesExtractor, error) { + return createExtractors(lookups, "") } -func createExtractors(lookups string, limit uint) ([]ValuesExtractor, error) { +func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, error) { if lookups == "" { return nil, nil } - if limit == 0 { - limit = 1 - } else if limit > extractorLimit { - limit = extractorLimit - } - - sources := strings.SplitSeq(lookups, ",") + sources := strings.Split(lookups, ",") var extractors = make([]ValuesExtractor, 0) - for source := range sources { + for _, source := range sources { parts := strings.Split(source, ":") if len(parts) < 2 { return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source) @@ -95,19 +61,28 @@ func createExtractors(lookups string, limit uint) ([]ValuesExtractor, error) { switch parts[0] { case "query": - extractors = append(extractors, valuesFromQuery(parts[1], limit)) + extractors = append(extractors, valuesFromQuery(parts[1])) case "param": - extractors = append(extractors, valuesFromParam(parts[1], limit)) + extractors = append(extractors, valuesFromParam(parts[1])) case "cookie": - extractors = append(extractors, valuesFromCookie(parts[1], limit)) + extractors = append(extractors, valuesFromCookie(parts[1])) case "form": - extractors = append(extractors, valuesFromForm(parts[1], limit)) + extractors = append(extractors, valuesFromForm(parts[1])) case "header": prefix := "" if len(parts) > 2 { prefix = parts[2] + } else if authScheme != "" && parts[1] == echo.HeaderAuthorization { + // backwards compatibility for JWT and KeyAuth: + // * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer " etc + // * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that + // behaviour for default values and Authorization header. + prefix = authScheme + if !strings.HasSuffix(prefix, " ") { + prefix += " " + } } - extractors = append(extractors, valuesFromHeader(parts[1], prefix, limit)) + extractors = append(extractors, valuesFromHeader(parts[1], prefix)) } } return extractors, nil @@ -119,32 +94,28 @@ func createExtractors(lookups string, limit uint) ([]ValuesExtractor, error) { // note the space at the end. In case of basic authentication `Authorization: Basic ` prefix we want to remove // is `Basic `. In case of JWT tokens `Authorization: Bearer ` prefix is `Bearer `. // If prefix is left empty the whole value is returned. -func valuesFromHeader(header string, valuePrefix string, limit uint) ValuesExtractor { +func valuesFromHeader(header string, valuePrefix string) ValuesExtractor { prefixLen := len(valuePrefix) // standard library parses http.Request header keys in canonical form but we may provide something else so fix this header = textproto.CanonicalMIMEHeaderKey(header) - if limit == 0 { - limit = 1 - } - return func(c *echo.Context) ([]string, ExtractorSource, error) { + return func(c echo.Context) ([]string, error) { values := c.Request().Header.Values(header) if len(values) == 0 { - return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing + return nil, errHeaderExtractorValueMissing } - i := uint(0) result := make([]string, 0) - for _, value := range values { + for i, value := range values { if prefixLen == 0 { result = append(result, value) - i++ - if i >= limit { + if i >= extractorLimit-1 { break } - } else if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) { + continue + } + if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) { result = append(result, value[prefixLen:]) - i++ - if i >= limit { + if i >= extractorLimit-1 { break } } @@ -152,102 +123,85 @@ func valuesFromHeader(header string, valuePrefix string, limit uint) ValuesExtra if len(result) == 0 { if prefixLen > 0 { - return nil, ExtractorSourceHeader, errHeaderExtractorValueInvalid + return nil, errHeaderExtractorValueInvalid } - return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing + return nil, errHeaderExtractorValueMissing } - return result, ExtractorSourceHeader, nil + return result, nil } } // valuesFromQuery returns a function that extracts values from the query string. -func valuesFromQuery(param string, limit uint) ValuesExtractor { - if limit == 0 { - limit = 1 - } - return func(c *echo.Context) ([]string, ExtractorSource, error) { +func valuesFromQuery(param string) ValuesExtractor { + return func(c echo.Context) ([]string, error) { result := c.QueryParams()[param] if len(result) == 0 { - return nil, ExtractorSourceQuery, errQueryExtractorValueMissing - } else if len(result) > int(limit)-1 { - result = result[:limit] + return nil, errQueryExtractorValueMissing + } else if len(result) > extractorLimit-1 { + result = result[:extractorLimit] } - return result, ExtractorSourceQuery, nil + return result, nil } } // valuesFromParam returns a function that extracts values from the url param string. -func valuesFromParam(param string, limit uint) ValuesExtractor { - if limit == 0 { - limit = 1 - } - return func(c *echo.Context) ([]string, ExtractorSource, error) { +func valuesFromParam(param string) ValuesExtractor { + return func(c echo.Context) ([]string, error) { result := make([]string, 0) - i := uint(0) - for _, p := range c.PathValues() { - if param != p.Name { - continue - } - result = append(result, p.Value) - i++ - if i >= limit { - break + paramVales := c.ParamValues() + for i, p := range c.ParamNames() { + if param == p { + result = append(result, paramVales[i]) + if i >= extractorLimit-1 { + break + } } } if len(result) == 0 { - return nil, ExtractorSourcePathParam, errParamExtractorValueMissing + return nil, errParamExtractorValueMissing } - return result, ExtractorSourcePathParam, nil + return result, nil } } // valuesFromCookie returns a function that extracts values from the named cookie. -func valuesFromCookie(name string, limit uint) ValuesExtractor { - if limit == 0 { - limit = 1 - } - return func(c *echo.Context) ([]string, ExtractorSource, error) { +func valuesFromCookie(name string) ValuesExtractor { + return func(c echo.Context) ([]string, error) { cookies := c.Cookies() if len(cookies) == 0 { - return nil, ExtractorSourceCookie, errCookieExtractorValueMissing + return nil, errCookieExtractorValueMissing } - i := uint(0) result := make([]string, 0) - for _, cookie := range cookies { - if name != cookie.Name { - continue - } - result = append(result, cookie.Value) - i++ - if i >= limit { - break + for i, cookie := range cookies { + if name == cookie.Name { + result = append(result, cookie.Value) + if i >= extractorLimit-1 { + break + } } } if len(result) == 0 { - return nil, ExtractorSourceCookie, errCookieExtractorValueMissing + return nil, errCookieExtractorValueMissing } - return result, ExtractorSourceCookie, nil + return result, nil } } // valuesFromForm returns a function that extracts values from the form field. -func valuesFromForm(name string, limit uint) ValuesExtractor { - if limit == 0 { - limit = 1 - } - return func(c *echo.Context) ([]string, ExtractorSource, error) { +func valuesFromForm(name string) ValuesExtractor { + return func(c echo.Context) ([]string, error) { if c.Request().Form == nil { - _, _ = c.MultipartForm() // we want to trigger c.request.ParseMultipartForm(c.formParseMaxMemory) + _ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does } values := c.Request().Form[name] if len(values) == 0 { - return nil, ExtractorSourceForm, errFormExtractorValueMissing + return nil, errFormExtractorValueMissing } - if len(values) > int(limit)-1 { - values = values[:limit] + if len(values) > extractorLimit-1 { + values = values[:extractorLimit] } result := append([]string{}, values...) - return result, ExtractorSourceForm, nil + return result, nil } } diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go index 04cc7b829..42cbcfeab 100644 --- a/middleware/extractor_test.go +++ b/middleware/extractor_test.go @@ -6,26 +6,39 @@ package middleware import ( "bytes" "fmt" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" "mime/multipart" "net/http" "net/http/httptest" "net/url" "strings" "testing" - - "github.com/labstack/echo/v5" - "github.com/stretchr/testify/assert" ) +type pathParam struct { + name string + value string +} + +func setPathParams(c echo.Context, params []pathParam) { + names := make([]string, 0, len(params)) + values := make([]string, 0, len(params)) + for _, pp := range params { + names = append(names, pp.name) + values = append(values, pp.value) + } + c.SetParamNames(names...) + c.SetParamValues(values...) +} + func TestCreateExtractors(t *testing.T) { var testCases = []struct { name string givenRequest func() *http.Request - givenPathValues echo.PathValues - whenLookups string - whenLimit uint + givenPathParams []pathParam + whenLoopups string expectValues []string - expectSource ExtractorSource expectCreateError string expectError string }{ @@ -36,9 +49,8 @@ func TestCreateExtractors(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer token") return req }, - whenLookups: "header:Authorization:Bearer ", + whenLoopups: "header:Authorization:Bearer ", expectValues: []string{"token"}, - expectSource: ExtractorSourceHeader, }, { name: "ok, form", @@ -50,9 +62,8 @@ func TestCreateExtractors(t *testing.T) { req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) return req }, - whenLookups: "form:name", + whenLoopups: "form:name", expectValues: []string{"Jon Snow"}, - expectSource: ExtractorSourceForm, }, { name: "ok, cookie", @@ -61,18 +72,16 @@ func TestCreateExtractors(t *testing.T) { req.Header.Set(echo.HeaderCookie, "_csrf=token") return req }, - whenLookups: "cookie:_csrf", + whenLoopups: "cookie:_csrf", expectValues: []string{"token"}, - expectSource: ExtractorSourceCookie, }, { name: "ok, param", - givenPathValues: echo.PathValues{ - {Name: "id", Value: "123"}, + givenPathParams: []pathParam{ + {name: "id", value: "123"}, }, - whenLookups: "param:id", + whenLoopups: "param:id", expectValues: []string{"123"}, - expectSource: ExtractorSourcePathParam, }, { name: "ok, query", @@ -80,13 +89,12 @@ func TestCreateExtractors(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/?id=999", nil) return req }, - whenLookups: "query:id", + whenLoopups: "query:id", expectValues: []string{"999"}, - expectSource: ExtractorSourceQuery, }, { name: "nok, invalid lookup", - whenLookups: "query", + whenLoopups: "query", expectCreateError: "extractor source for lookup could not be split into needed parts: query", }, } @@ -101,11 +109,11 @@ func TestCreateExtractors(t *testing.T) { } rec := httptest.NewRecorder() c := e.NewContext(req, rec) - if tc.givenPathValues != nil { - c.SetPathValues(tc.givenPathValues) + if tc.givenPathParams != nil { + setPathParams(c, tc.givenPathParams) } - extractors, err := CreateExtractors(tc.whenLookups, tc.whenLimit) + extractors, err := CreateExtractors(tc.whenLoopups) if tc.expectCreateError != "" { assert.EqualError(t, err, tc.expectCreateError) return @@ -113,9 +121,8 @@ func TestCreateExtractors(t *testing.T) { assert.NoError(t, err) for _, e := range extractors { - values, source, eErr := e(c) + values, eErr := e(c) assert.Equal(t, tc.expectValues, values) - assert.Equal(t, tc.expectSource, source) if tc.expectError != "" { assert.EqualError(t, eErr, tc.expectError) return @@ -136,7 +143,6 @@ func TestValuesFromHeader(t *testing.T) { givenRequest func(req *http.Request) whenName string whenValuePrefix string - whenLimit uint expectValues []string expectError string }{ @@ -162,7 +168,6 @@ func TestValuesFromHeader(t *testing.T) { }, whenName: echo.HeaderAuthorization, whenValuePrefix: "basic ", - whenLimit: 2, expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"}, }, { @@ -208,7 +213,6 @@ func TestValuesFromHeader(t *testing.T) { }, whenName: echo.HeaderAuthorization, whenValuePrefix: "basic ", - whenLimit: extractorLimit, expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", @@ -223,7 +227,6 @@ func TestValuesFromHeader(t *testing.T) { }, whenName: echo.HeaderAuthorization, whenValuePrefix: "", - whenLimit: extractorLimit, expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", @@ -242,11 +245,10 @@ func TestValuesFromHeader(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix, tc.whenLimit) + extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix) - values, source, err := extractor(c) + values, err := extractor(c) assert.Equal(t, tc.expectValues, values) - assert.Equal(t, ExtractorSourceHeader, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -261,7 +263,6 @@ func TestValuesFromQuery(t *testing.T) { name string givenQueryPart string whenName string - whenLimit uint expectValues []string expectError string }{ @@ -275,7 +276,6 @@ func TestValuesFromQuery(t *testing.T) { name: "ok, multiple value", givenQueryPart: "?id=123&id=456&name=test", whenName: "id", - whenLimit: 2, expectValues: []string{"123", "456"}, }, { @@ -290,8 +290,7 @@ func TestValuesFromQuery(t *testing.T) { "&id=1&id=2&id=3&id=4&id=5&id=6&id=7&id=8&id=9&id=10" + "&id=11&id=12&id=13&id=14&id=15&id=16&id=17&id=18&id=19&id=20" + "&id=21&id=22&id=23&id=24&id=25", - whenName: "id", - whenLimit: extractorLimit, + whenName: "id", expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", @@ -307,11 +306,10 @@ func TestValuesFromQuery(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - extractor := valuesFromQuery(tc.whenName, tc.whenLimit) + extractor := valuesFromQuery(tc.whenName) - values, source, err := extractor(c) + values, err := extractor(c) assert.Equal(t, tc.expectValues, values) - assert.Equal(t, ExtractorSourceQuery, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -322,56 +320,53 @@ func TestValuesFromQuery(t *testing.T) { } func TestValuesFromParam(t *testing.T) { - examplePathValues := echo.PathValues{ - {Name: "id", Value: "123"}, - {Name: "gid", Value: "456"}, - {Name: "gid", Value: "789"}, + examplePathParams := []pathParam{ + {name: "id", value: "123"}, + {name: "gid", value: "456"}, + {name: "gid", value: "789"}, } - examplePathValues20 := make(echo.PathValues, 0) + examplePathParams20 := make([]pathParam, 0) for i := 1; i < 25; i++ { - examplePathValues20 = append(examplePathValues20, echo.PathValue{Name: "id", Value: fmt.Sprintf("%v", i)}) + examplePathParams20 = append(examplePathParams20, pathParam{name: "id", value: fmt.Sprintf("%v", i)}) } var testCases = []struct { name string - givenPathValues echo.PathValues + givenPathParams []pathParam whenName string - whenLimit uint expectValues []string expectError string }{ { name: "ok, single value", - givenPathValues: examplePathValues, + givenPathParams: examplePathParams, whenName: "id", expectValues: []string{"123"}, }, { name: "ok, multiple value", - givenPathValues: examplePathValues, + givenPathParams: examplePathParams, whenName: "gid", - whenLimit: 2, expectValues: []string{"456", "789"}, }, { name: "nok, no values", - givenPathValues: nil, + givenPathParams: nil, whenName: "nope", expectValues: nil, expectError: errParamExtractorValueMissing.Error(), }, { name: "nok, no matching value", - givenPathValues: examplePathValues, + givenPathParams: examplePathParams, whenName: "nope", expectValues: nil, expectError: errParamExtractorValueMissing.Error(), }, { name: "ok, cut values over extractorLimit", - givenPathValues: examplePathValues20, + givenPathParams: examplePathParams20, whenName: "id", - whenLimit: extractorLimit, expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", @@ -386,15 +381,14 @@ func TestValuesFromParam(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - if tc.givenPathValues != nil { - c.SetPathValues(tc.givenPathValues) + if tc.givenPathParams != nil { + setPathParams(c, tc.givenPathParams) } - extractor := valuesFromParam(tc.whenName, tc.whenLimit) + extractor := valuesFromParam(tc.whenName) - values, source, err := extractor(c) + values, err := extractor(c) assert.Equal(t, tc.expectValues, values) - assert.Equal(t, ExtractorSourcePathParam, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -413,7 +407,6 @@ func TestValuesFromCookie(t *testing.T) { name string givenRequest func(req *http.Request) whenName string - whenLimit uint expectValues []string expectError string }{ @@ -430,7 +423,6 @@ func TestValuesFromCookie(t *testing.T) { req.Header.Add(echo.HeaderCookie, "_csrf=token2") }, whenName: "_csrf", - whenLimit: 2, expectValues: []string{"token", "token2"}, }, { @@ -454,8 +446,7 @@ func TestValuesFromCookie(t *testing.T) { req.Header.Add(echo.HeaderCookie, fmt.Sprintf("_csrf=%v", i)) } }, - whenName: "_csrf", - whenLimit: extractorLimit, + whenName: "_csrf", expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", @@ -474,11 +465,10 @@ func TestValuesFromCookie(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - extractor := valuesFromCookie(tc.whenName, tc.whenLimit) + extractor := valuesFromCookie(tc.whenName) - values, source, err := extractor(c) + values, err := extractor(c) assert.Equal(t, tc.expectValues, values) - assert.Equal(t, ExtractorSourceCookie, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -537,7 +527,6 @@ func TestValuesFromForm(t *testing.T) { name string givenRequest *http.Request whenName string - whenLimit uint expectValues []string expectError string }{ @@ -553,7 +542,6 @@ func TestValuesFromForm(t *testing.T) { v.Add("emails[]", "snow@labstack.com") }), whenName: "emails[]", - whenLimit: 2, expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, }, { @@ -562,7 +550,6 @@ func TestValuesFromForm(t *testing.T) { w.WriteField("emails[]", "snow@labstack.com") }), whenName: "emails[]", - whenLimit: 2, expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, }, { @@ -577,7 +564,6 @@ func TestValuesFromForm(t *testing.T) { v.Add("emails[]", "snow@labstack.com") }), whenName: "emails[]", - whenLimit: 2, expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, }, { @@ -593,8 +579,7 @@ func TestValuesFromForm(t *testing.T) { v.Add("id[]", fmt.Sprintf("%v", i)) } }), - whenName: "id[]", - whenLimit: extractorLimit, + whenName: "id[]", expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", @@ -610,11 +595,10 @@ func TestValuesFromForm(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - extractor := valuesFromForm(tc.whenName, tc.whenLimit) + extractor := valuesFromForm(tc.whenName) - values, source, err := extractor(c) + values, err := extractor(c) assert.Equal(t, tc.expectValues, values) - assert.Equal(t, ExtractorSourceForm, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { diff --git a/middleware/key_auth.go b/middleware/key_auth.go index 2dcb98039..79bee207c 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -4,25 +4,19 @@ package middleware import ( - "cmp" "errors" - "fmt" + "github.com/labstack/echo/v4" "net/http" - - "github.com/labstack/echo/v5" ) // KeyAuthConfig defines the config for KeyAuth middleware. -// -// SECURITY: The Validator function is responsible for securely comparing API keys. -// See KeyAuthValidator documentation for guidance on preventing timing attacks. type KeyAuthConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // KeyLookup is a string in the form of ":" or ":,:" that is used // to extract key from the request. - // Optional. Default value "header:Authorization:Bearer ". + // Optional. Default value "header:Authorization". // Possible values: // - "header:" or "header::" // `` is argument value to cut/trim prefix of the extracted value. This is useful if header @@ -36,22 +30,16 @@ type KeyAuthConfig struct { // - "header:Authorization,header:X-Api-Key" KeyLookup string - // AllowedCheckLimit set how many KeyLookup values are allowed to be checked. This is - // useful environments like corporate test environments with application proxies restricting - // access to environment with their own auth scheme. - AllowedCheckLimit uint + // AuthScheme to be used in the Authorization header. + // Optional. Default value "Bearer". + AuthScheme string // Validator is a function to validate key. // Required. Validator KeyAuthValidator - // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator - // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key. + // ErrorHandler defines a function which is executed for an invalid key. // It may be used to define a custom error. - // - // Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler. - // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users - // In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain. ErrorHandler KeyAuthErrorHandler // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to @@ -63,55 +51,31 @@ type KeyAuthConfig struct { } // KeyAuthValidator defines a function to validate KeyAuth credentials. -// -// SECURITY WARNING: To prevent timing attacks that could allow attackers to enumerate -// valid API keys, validator implementations MUST use constant-time comparison. -// Use crypto/subtle.ConstantTimeCompare instead of standard string equality (==) -// or switch statements. -// -// Example of SECURE implementation: -// -// import "crypto/subtle" -// -// validator := func(c *echo.Context, key string, source ExtractorSource) (bool, error) { -// // Fetch valid keys from database/config -// validKeys := []string{"key1", "key2", "key3"} -// -// for _, validKey := range validKeys { -// // Use constant-time comparison to prevent timing attacks -// if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 { -// return true, nil -// } -// } -// return false, nil -// } -// -// Example of INSECURE implementation (DO NOT USE): -// -// // VULNERABLE TO TIMING ATTACKS - DO NOT USE -// validator := func(c *echo.Context, key string, source ExtractorSource) (bool, error) { -// switch key { // Timing leak! -// case "valid-key": -// return true, nil -// default: -// return false, nil -// } -// } -type KeyAuthValidator func(c *echo.Context, key string, source ExtractorSource) (bool, error) +type KeyAuthValidator func(auth string, c echo.Context) (bool, error) // KeyAuthErrorHandler defines a function which is executed for an invalid key. -type KeyAuthErrorHandler func(c *echo.Context, err error) error - -// ErrKeyMissing denotes an error raised when key value could not be extracted from request -var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key") +type KeyAuthErrorHandler func(err error, c echo.Context) error -// ErrInvalidKey denotes an error raised when key value is invalid by validator -var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key") +// ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups +type ErrKeyAuthMissing struct { + Err error +} // DefaultKeyAuthConfig is the default KeyAuth middleware config. var DefaultKeyAuthConfig = KeyAuthConfig{ - Skipper: DefaultSkipper, - KeyLookup: "header:" + echo.HeaderAuthorization + ":Bearer ", + Skipper: DefaultSkipper, + KeyLookup: "header:" + echo.HeaderAuthorization, + AuthScheme: "Bearer", +} + +// Error returns errors text +func (e *ErrKeyAuthMissing) Error() string { + return e.Err.Error() +} + +// Unwrap unwraps error +func (e *ErrKeyAuthMissing) Unwrap() error { + return e.Err } // KeyAuth returns an KeyAuth middleware. @@ -125,39 +89,31 @@ func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc { return KeyAuthWithConfig(c) } -// KeyAuthWithConfig returns an KeyAuth middleware or panics if configuration is invalid. -// -// For first valid key it calls the next handler. -// For invalid key, it sends "401 - Unauthorized" response. -// For missing key, it sends "400 - Bad Request" response. +// KeyAuthWithConfig returns an KeyAuth middleware with config. +// See `KeyAuth()`. func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { - return toMiddlewareOrPanic(config) -} - -// ToMiddleware converts KeyAuthConfig to middleware or returns an error for invalid configuration -func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { + // Defaults if config.Skipper == nil { config.Skipper = DefaultKeyAuthConfig.Skipper } + // Defaults + if config.AuthScheme == "" { + config.AuthScheme = DefaultKeyAuthConfig.AuthScheme + } if config.KeyLookup == "" { config.KeyLookup = DefaultKeyAuthConfig.KeyLookup } if config.Validator == nil { - return nil, errors.New("echo key-auth middleware requires a validator function") + panic("echo: key-auth middleware requires a validator function") } - limit := cmp.Or(config.AllowedCheckLimit, 1) - - extractors, cErr := createExtractors(config.KeyLookup, limit) + extractors, cErr := createExtractors(config.KeyLookup, config.AuthScheme) if cErr != nil { - return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", cErr) - } - if len(extractors) == 0 { - return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string") + panic(cErr) } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) error { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } @@ -165,41 +121,59 @@ func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { var lastExtractorErr error var lastValidatorErr error for _, extractor := range extractors { - keys, source, extrErr := extractor(c) - if extrErr != nil { - lastExtractorErr = extrErr + keys, err := extractor(c) + if err != nil { + lastExtractorErr = err continue } for _, key := range keys { - valid, err := config.Validator(c, key, source) + valid, err := config.Validator(key, c) if err != nil { lastValidatorErr = err continue } - if !valid { - lastValidatorErr = ErrInvalidKey - continue + if valid { + return next(c) } - return next(c) + lastValidatorErr = errors.New("invalid key") } } - // prioritize validator errors over extracting errors + // we are here only when we did not successfully extract and validate any of keys err := lastValidatorErr - if err == nil { - err = lastExtractorErr + if err == nil { // prioritize validator errors over extracting errors + // ugly part to preserve backwards compatible errors. someone could rely on them + if lastExtractorErr == errQueryExtractorValueMissing { + err = errors.New("missing key in the query string") + } else if lastExtractorErr == errCookieExtractorValueMissing { + err = errors.New("missing key in cookies") + } else if lastExtractorErr == errFormExtractorValueMissing { + err = errors.New("missing key in the form") + } else if lastExtractorErr == errHeaderExtractorValueMissing { + err = errors.New("missing key in request header") + } else if lastExtractorErr == errHeaderExtractorValueInvalid { + err = errors.New("invalid key in the request header") + } else { + err = lastExtractorErr + } + err = &ErrKeyAuthMissing{Err: err} } + if config.ErrorHandler != nil { - tmpErr := config.ErrorHandler(c, err) + tmpErr := config.ErrorHandler(err, c) if config.ContinueOnIgnoredError && tmpErr == nil { return next(c) } return tmpErr } - if lastValidatorErr == nil { - return ErrKeyMissing.Wrap(err) + if lastValidatorErr != nil { // prioritize validator errors over extracting errors + return &echo.HTTPError{ + Code: http.StatusUnauthorized, + Message: "Unauthorized", + Internal: lastValidatorErr, + } } - return echo.ErrUnauthorized.Wrap(err) + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - }, nil + } } diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go index 49a917ed3..447f0bee8 100644 --- a/middleware/key_auth_test.go +++ b/middleware/key_auth_test.go @@ -4,34 +4,30 @@ package middleware import ( - "crypto/subtle" "errors" "net/http" "net/http/httptest" "strings" "testing" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) -func testKeyValidator(c *echo.Context, key string, source ExtractorSource) (bool, error) { - // Use constant-time comparison to prevent timing attacks - if subtle.ConstantTimeCompare([]byte(key), []byte("valid-key")) == 1 { +func testKeyValidator(key string, c echo.Context) (bool, error) { + switch key { + case "valid-key": return true, nil - } - - // Special case for testing error handling - if key == "error-key" { // Error path doesn't need constant-time + case "error-key": return false, errors.New("some user defined error") + default: + return false, nil } - - return false, nil } func TestKeyAuth(t *testing.T) { handlerCalled := false - handler := func(c *echo.Context) error { + handler := func(c echo.Context) error { handlerCalled = true return c.String(http.StatusOK, "test") } @@ -71,7 +67,7 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") }, whenConfig: func(conf *KeyAuthConfig) { - conf.Skipper = func(context *echo.Context) bool { + conf.Skipper = func(context echo.Context) bool { return true } }, @@ -83,7 +79,7 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key") }, expectHandlerCalled: false, - expectError: "code=401, message=Unauthorized, err=code=401, message=invalid key", + expectError: "code=401, message=Unauthorized, internal=invalid key", }, { name: "nok, defaults, invalid scheme in header", @@ -91,13 +87,24 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bear valid-key") }, expectHandlerCalled: false, - expectError: "code=401, message=missing key, err=invalid value in request header", + expectError: "code=400, message=invalid key in the request header", }, { name: "nok, defaults, missing header", givenRequest: func(req *http.Request) {}, expectHandlerCalled: false, - expectError: "code=401, message=missing key, err=missing value in request header", + expectError: "code=400, message=missing key in request header", + }, + { + name: "ok, custom key lookup from multiple places, query and header", + givenRequest: func(req *http.Request) { + req.URL.RawQuery = "key=invalid-key" + req.Header.Set("API-Key", "valid-key") + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "query:key,header:API-Key" + }, + expectHandlerCalled: true, }, { name: "ok, custom key lookup, header", @@ -117,7 +124,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "header:API-Key" }, expectHandlerCalled: false, - expectError: "code=401, message=missing key, err=missing value in request header", + expectError: "code=400, message=missing key in request header", }, { name: "ok, custom key lookup, query", @@ -137,7 +144,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "query:key" }, expectHandlerCalled: false, - expectError: "code=401, message=missing key, err=missing value in the query string", + expectError: "code=400, message=missing key in the query string", }, { name: "ok, custom key lookup, form", @@ -162,7 +169,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "form:key" }, expectHandlerCalled: false, - expectError: "code=401, message=missing key, err=missing value in the form", + expectError: "code=400, message=missing key in the form", }, { name: "ok, custom key lookup, cookie", @@ -186,18 +193,20 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "cookie:key" }, expectHandlerCalled: false, - expectError: "code=401, message=missing key, err=missing value in cookies", + expectError: "code=400, message=missing key in cookies", }, { name: "nok, custom errorHandler, error from extractor", whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "header:token" - conf.ErrorHandler = func(c *echo.Context, err error) error { - return echo.NewHTTPError(http.StatusTeapot, "custom").Wrap(err) + conf.ErrorHandler = func(err error, context echo.Context) error { + httpError := echo.NewHTTPError(http.StatusTeapot, "custom") + httpError.Internal = err + return httpError } }, expectHandlerCalled: false, - expectError: "code=418, message=custom, err=missing value in request header", + expectError: "code=418, message=custom, internal=missing key in request header", }, { name: "nok, custom errorHandler, error from validator", @@ -205,12 +214,14 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") }, whenConfig: func(conf *KeyAuthConfig) { - conf.ErrorHandler = func(c *echo.Context, err error) error { - return echo.NewHTTPError(http.StatusTeapot, "custom").Wrap(err) + conf.ErrorHandler = func(err error, context echo.Context) error { + httpError := echo.NewHTTPError(http.StatusTeapot, "custom") + httpError.Internal = err + return httpError } }, expectHandlerCalled: false, - expectError: "code=418, message=custom, err=some user defined error", + expectError: "code=418, message=custom, internal=some user defined error", }, { name: "nok, defaults, error from validator", @@ -219,33 +230,14 @@ func TestKeyAuthWithConfig(t *testing.T) { }, whenConfig: func(conf *KeyAuthConfig) {}, expectHandlerCalled: false, - expectError: "code=401, message=Unauthorized, err=some user defined error", - }, - { - name: "ok, custom validator checks source", - givenRequest: func(req *http.Request) { - q := req.URL.Query() - q.Add("key", "valid-key") - req.URL.RawQuery = q.Encode() - }, - whenConfig: func(conf *KeyAuthConfig) { - conf.KeyLookup = "query:key" - conf.Validator = func(c *echo.Context, key string, source ExtractorSource) (bool, error) { - if source == ExtractorSourceQuery { - return true, nil - } - return false, errors.New("invalid source") - } - - }, - expectHandlerCalled: true, + expectError: "code=401, message=Unauthorized, internal=some user defined error", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { handlerCalled := false - handler := func(c *echo.Context) error { + handler := func(c echo.Context) error { handlerCalled = true return c.String(http.StatusOK, "test") } @@ -280,96 +272,108 @@ func TestKeyAuthWithConfig(t *testing.T) { } } -func TestKeyAuthWithConfig_errors(t *testing.T) { +func TestKeyAuthWithConfig_panicsOnInvalidLookup(t *testing.T) { + assert.PanicsWithError( + t, + "extractor source for lookup could not be split into needed parts: a", + func() { + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + KeyAuthWithConfig(KeyAuthConfig{ + Validator: testKeyValidator, + KeyLookup: "a", + })(handler) + }, + ) +} + +func TestKeyAuthWithConfig_panicsOnEmptyValidator(t *testing.T) { + assert.PanicsWithValue( + t, + "echo: key-auth middleware requires a validator function", + func() { + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + KeyAuthWithConfig(KeyAuthConfig{ + Validator: nil, + })(handler) + }, + ) +} + +func TestKeyAuthWithConfig_ContinueOnIgnoredError(t *testing.T) { var testCases = []struct { - name string - whenConfig KeyAuthConfig - expectError string + name string + whenContinueOnIgnoredError bool + givenKey string + expectStatus int + expectBody string }{ { - name: "ok, no error", - whenConfig: KeyAuthConfig{ - Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) { - return false, nil - }, - }, + name: "no error handler is called", + whenContinueOnIgnoredError: true, + givenKey: "valid-key", + expectStatus: http.StatusTeapot, + expectBody: "", }, { - name: "ok, missing validator func", - whenConfig: KeyAuthConfig{ - Validator: nil, - }, - expectError: "echo key-auth middleware requires a validator function", + name: "ContinueOnIgnoredError is false and error handler is called for missing token", + whenContinueOnIgnoredError: false, + givenKey: "", + // empty response with 200. This emulates previous behaviour when error handler swallowed the error + expectStatus: http.StatusOK, + expectBody: "", }, { - name: "ok, extractor source can not be split", - whenConfig: KeyAuthConfig{ - KeyLookup: "nope", - Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) { - return false, nil - }, - }, - expectError: "echo key-auth middleware could not create key extractor: extractor source for lookup could not be split into needed parts: nope", + name: "error handler is called for missing token", + whenContinueOnIgnoredError: true, + givenKey: "", + expectStatus: http.StatusTeapot, + expectBody: "public-auth", }, { - name: "ok, no extractors", - whenConfig: KeyAuthConfig{ - KeyLookup: "nope:nope", - Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) { - return false, nil - }, - }, - expectError: "echo key-auth middleware could not create extractors from KeyLookup string", + name: "error handler is called for invalid token", + whenContinueOnIgnoredError: true, + givenKey: "x.x.x", + expectStatus: http.StatusUnauthorized, + expectBody: "{\"message\":\"Unauthorized\"}\n", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - mw, err := tc.whenConfig.ToMiddleware() - if tc.expectError != "" { - assert.Nil(t, mw) - assert.EqualError(t, err, tc.expectError) - } else { - assert.NotNil(t, mw) - assert.NoError(t, err) - } - }) - } -} + e := echo.New() -func TestMustKeyAuthWithConfig_panic(t *testing.T) { - assert.Panics(t, func() { - KeyAuthWithConfig(KeyAuthConfig{}) - }) -} + e.GET("/", func(c echo.Context) error { + testValue, _ := c.Get("test").(string) + return c.String(http.StatusTeapot, testValue) + }) -func TestKeyAuth_errorHandlerSwallowsError(t *testing.T) { - handlerCalled := false - var authValue string - handler := func(c *echo.Context) error { - handlerCalled = true - authValue = c.Get("auth").(string) - return c.String(http.StatusOK, "test") - } - middlewareChain := KeyAuthWithConfig(KeyAuthConfig{ - Validator: testKeyValidator, - ErrorHandler: func(c *echo.Context, err error) error { - // could check error to decide if we can swallow the error - c.Set("auth", "public") - return nil - }, - ContinueOnIgnoredError: true, - })(handler) + e.Use(KeyAuthWithConfig(KeyAuthConfig{ + Validator: testKeyValidator, + ErrorHandler: func(err error, c echo.Context) error { + if _, ok := err.(*ErrKeyAuthMissing); ok { + c.Set("test", "public-auth") + return nil + } + return echo.ErrUnauthorized + }, + KeyLookup: "header:X-API-Key", + ContinueOnIgnoredError: tc.whenContinueOnIgnoredError, + })) - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - // no auth header this time - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenKey != "" { + req.Header.Set("X-API-Key", tc.givenKey) + } + res := httptest.NewRecorder() - err := middlewareChain(c) + e.ServeHTTP(res, req) - assert.NoError(t, err) - assert.True(t, handlerCalled) - assert.Equal(t, "public", authValue) + assert.Equal(t, tc.expectStatus, res.Code) + assert.Equal(t, tc.expectBody, res.Body.String()) + }) + } } diff --git a/middleware/logger.go b/middleware/logger.go new file mode 100644 index 000000000..0766dbf55 --- /dev/null +++ b/middleware/logger.go @@ -0,0 +1,420 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors + +package middleware + +import ( + "bytes" + "io" + "strconv" + "strings" + "sync" + "time" + + "github.com/labstack/echo/v4" + "github.com/labstack/gommon/color" + "github.com/valyala/fasttemplate" +) + +// LoggerConfig defines the config for Logger middleware. +// +// # Configuration Examples +// +// ## Basic Usage with Default Settings +// +// e.Use(middleware.Logger()) +// +// This uses the default JSON format that logs all common request/response details. +// +// ## Custom Simple Format +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Format: "${time_rfc3339_nano} ${status} ${method} ${uri} ${latency_human}\n", +// })) +// +// ## JSON Format with Custom Fields +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Format: `{"timestamp":"${time_rfc3339_nano}","level":"info","remote_ip":"${remote_ip}",` + +// `"method":"${method}","uri":"${uri}","status":${status},"latency":"${latency_human}",` + +// `"user_agent":"${user_agent}","error":"${error}"}` + "\n", +// })) +// +// ## Custom Time Format +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Format: "${time_custom} ${method} ${uri} ${status}\n", +// CustomTimeFormat: "2006-01-02 15:04:05", +// })) +// +// ## Logging Headers and Parameters +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Format: `{"time":"${time_rfc3339_nano}","method":"${method}","uri":"${uri}",` + +// `"status":${status},"auth":"${header:Authorization}","user":"${query:user}",` + +// `"form_data":"${form:action}","session":"${cookie:session_id}"}` + "\n", +// })) +// +// ## Custom Output (File Logging) +// +// file, err := os.OpenFile("app.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) +// if err != nil { +// log.Fatal(err) +// } +// defer file.Close() +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Output: file, +// })) +// +// ## Custom Tag Function +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Format: `{"time":"${time_rfc3339_nano}","user_id":"${custom}","method":"${method}"}` + "\n", +// CustomTagFunc: func(c echo.Context, buf *bytes.Buffer) (int, error) { +// userID := getUserIDFromContext(c) // Your custom logic +// return buf.WriteString(strconv.Itoa(userID)) +// }, +// })) +// +// ## Conditional Logging (Skip Certain Requests) +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Skipper: func(c echo.Context) bool { +// // Skip logging for health check endpoints +// return c.Request().URL.Path == "/health" || c.Request().URL.Path == "/metrics" +// }, +// })) +// +// ## Integration with External Logging Service +// +// logBuffer := &SyncBuffer{} // Thread-safe buffer for external service +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Format: `{"timestamp":"${time_rfc3339_nano}","service":"my-api","level":"info",` + +// `"method":"${method}","uri":"${uri}","status":${status},"latency_ms":${latency},` + +// `"remote_ip":"${remote_ip}","user_agent":"${user_agent}","error":"${error}"}` + "\n", +// Output: logBuffer, +// })) +// +// # Available Tags +// +// ## Time Tags +// - time_unix: Unix timestamp (seconds) +// - time_unix_milli: Unix timestamp (milliseconds) +// - time_unix_micro: Unix timestamp (microseconds) +// - time_unix_nano: Unix timestamp (nanoseconds) +// - time_rfc3339: RFC3339 format (2006-01-02T15:04:05Z07:00) +// - time_rfc3339_nano: RFC3339 with nanoseconds +// - time_custom: Uses CustomTimeFormat field +// +// ## Request Information +// - id: Request ID from X-Request-ID header +// - remote_ip: Client IP address (respects proxy headers) +// - uri: Full request URI with query parameters +// - host: Host header value +// - method: HTTP method (GET, POST, etc.) +// - path: URL path without query parameters +// - route: Echo route pattern (e.g., /users/:id) +// - protocol: HTTP protocol version +// - referer: Referer header value +// - user_agent: User-Agent header value +// +// ## Response Information +// - status: HTTP status code +// - error: Error message if request failed +// - latency: Request processing time in nanoseconds +// - latency_human: Human-readable processing time +// - bytes_in: Request body size in bytes +// - bytes_out: Response body size in bytes +// +// ## Dynamic Tags +// - header:: Value of specific header (e.g., header:Authorization) +// - query:: Value of specific query parameter (e.g., query:user_id) +// - form:: Value of specific form field (e.g., form:username) +// - cookie:: Value of specific cookie (e.g., cookie:session_id) +// - custom: Output from CustomTagFunc +// +// # Troubleshooting +// +// ## Common Issues +// +// 1. **Missing logs**: Check if Skipper function is filtering out requests +// 2. **Invalid JSON**: Ensure CustomTagFunc outputs valid JSON content +// 3. **Performance issues**: Consider using a buffered writer for high-traffic applications +// 4. **File permission errors**: Ensure write permissions when logging to files +// +// ## Performance Tips +// +// - Use time_unix formats for better performance than time_rfc3339 +// - Minimize the number of dynamic tags (header:, query:, form:, cookie:) +// - Use Skipper to exclude high-frequency, low-value requests (health checks, etc.) +// - Consider async logging for very high-traffic applications +type LoggerConfig struct { + // Skipper defines a function to skip middleware. + // Use this to exclude certain requests from logging (e.g., health checks). + // + // Example: + // Skipper: func(c echo.Context) bool { + // return c.Request().URL.Path == "/health" + // }, + Skipper Skipper + + // Format defines the logging format using template tags. + // Tags are enclosed in ${} and replaced with actual values. + // See the detailed tag documentation above for all available options. + // + // Default: JSON format with common fields + // Example: "${time_rfc3339_nano} ${status} ${method} ${uri} ${latency_human}\n" + Format string `yaml:"format"` + + // CustomTimeFormat specifies the time format used by ${time_custom} tag. + // Uses Go's reference time: Mon Jan 2 15:04:05 MST 2006 + // + // Default: "2006-01-02 15:04:05.00000" + // Example: "2006-01-02 15:04:05" or "15:04:05.000" + CustomTimeFormat string `yaml:"custom_time_format"` + + // CustomTagFunc is called when ${custom} tag is encountered. + // Use this to add application-specific information to logs. + // The function should write valid content for your log format. + // + // Example: + // CustomTagFunc: func(c echo.Context, buf *bytes.Buffer) (int, error) { + // userID := getUserFromContext(c) + // return buf.WriteString(`"user_id":"` + userID + `"`) + // }, + CustomTagFunc func(c echo.Context, buf *bytes.Buffer) (int, error) + + // Output specifies where logs are written. + // Can be any io.Writer: files, buffers, network connections, etc. + // + // Default: os.Stdout + // Example: Custom file, syslog, or external logging service + Output io.Writer + + template *fasttemplate.Template + colorer *color.Color + pool *sync.Pool + timeNow func() time.Time +} + +// DefaultLoggerConfig is the default Logger middleware config. +var DefaultLoggerConfig = LoggerConfig{ + Skipper: DefaultSkipper, + Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` + + `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + + `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + + `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", + CustomTimeFormat: "2006-01-02 15:04:05.00000", + colorer: color.New(), + timeNow: time.Now, +} + +// Logger returns a middleware that logs HTTP requests using the default configuration. +// +// The default format logs requests as JSON with the following fields: +// - time: RFC3339 nano timestamp +// - id: Request ID from X-Request-ID header +// - remote_ip: Client IP address +// - host: Host header +// - method: HTTP method +// - uri: Request URI +// - user_agent: User-Agent header +// - status: HTTP status code +// - error: Error message (if any) +// - latency: Processing time in nanoseconds +// - latency_human: Human-readable processing time +// - bytes_in: Request body size +// - bytes_out: Response body size +// +// Example output: +// +// {"time":"2023-01-15T10:30:45.123456789Z","id":"","remote_ip":"127.0.0.1", +// "host":"localhost:8080","method":"GET","uri":"/users/123","user_agent":"curl/7.81.0", +// "status":200,"error":"","latency":1234567,"latency_human":"1.234567ms", +// "bytes_in":0,"bytes_out":42} +// +// For custom configurations, use LoggerWithConfig instead. +// +// Deprecated: please use middleware.RequestLogger or middleware.RequestLoggerWithConfig instead. +func Logger() echo.MiddlewareFunc { + return LoggerWithConfig(DefaultLoggerConfig) +} + +// LoggerWithConfig returns a Logger middleware with custom configuration. +// +// This function allows you to customize all aspects of request logging including: +// - Log format and fields +// - Output destination +// - Time formatting +// - Custom tags and logic +// - Request filtering +// +// See LoggerConfig documentation for detailed configuration examples and options. +// +// Example: +// +// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ +// Format: "${time_rfc3339} ${status} ${method} ${uri} ${latency_human}\n", +// Output: customLogWriter, +// Skipper: func(c echo.Context) bool { +// return c.Request().URL.Path == "/health" +// }, +// })) +// +// Deprecated: please use middleware.RequestLoggerWithConfig instead. +func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultLoggerConfig.Skipper + } + if config.Format == "" { + config.Format = DefaultLoggerConfig.Format + } + writeString := func(buf *bytes.Buffer, in string) (int, error) { return buf.WriteString(in) } + if config.Format[0] == '{' { // format looks like JSON, so we need to escape invalid characters + writeString = writeJSONSafeString + } + + if config.Output == nil { + config.Output = DefaultLoggerConfig.Output + } + timeNow := DefaultLoggerConfig.timeNow + if config.timeNow != nil { + timeNow = config.timeNow + } + + config.template = fasttemplate.New(config.Format, "${", "}") + config.colorer = color.New() + config.colorer.SetOutput(config.Output) + config.pool = &sync.Pool{ + New: func() any { + return bytes.NewBuffer(make([]byte, 256)) + }, + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) (err error) { + if config.Skipper(c) { + return next(c) + } + + req := c.Request() + res := c.Response() + start := time.Now() + if err = next(c); err != nil { + c.Error(err) + } + stop := time.Now() + buf := config.pool.Get().(*bytes.Buffer) + buf.Reset() + defer config.pool.Put(buf) + + if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) { + switch tag { + case "custom": + if config.CustomTagFunc == nil { + return 0, nil + } + return config.CustomTagFunc(c, buf) + case "time_unix": + return buf.WriteString(strconv.FormatInt(timeNow().Unix(), 10)) + case "time_unix_milli": + return buf.WriteString(strconv.FormatInt(timeNow().UnixMilli(), 10)) + case "time_unix_micro": + return buf.WriteString(strconv.FormatInt(timeNow().UnixMicro(), 10)) + case "time_unix_nano": + return buf.WriteString(strconv.FormatInt(timeNow().UnixNano(), 10)) + case "time_rfc3339": + return buf.WriteString(timeNow().Format(time.RFC3339)) + case "time_rfc3339_nano": + return buf.WriteString(timeNow().Format(time.RFC3339Nano)) + case "time_custom": + return buf.WriteString(timeNow().Format(config.CustomTimeFormat)) + case "id": + id := req.Header.Get(echo.HeaderXRequestID) + if id == "" { + id = res.Header().Get(echo.HeaderXRequestID) + } + return writeString(buf, id) + case "remote_ip": + return writeString(buf, c.RealIP()) + case "host": + return writeString(buf, req.Host) + case "uri": + return writeString(buf, req.RequestURI) + case "method": + return writeString(buf, req.Method) + case "path": + p := req.URL.Path + if p == "" { + p = "/" + } + return writeString(buf, p) + case "route": + return writeString(buf, c.Path()) + case "protocol": + return writeString(buf, req.Proto) + case "referer": + return writeString(buf, req.Referer()) + case "user_agent": + return writeString(buf, req.UserAgent()) + case "status": + n := res.Status + s := config.colorer.Green(n) + switch { + case n >= 500: + s = config.colorer.Red(n) + case n >= 400: + s = config.colorer.Yellow(n) + case n >= 300: + s = config.colorer.Cyan(n) + } + return buf.WriteString(s) + case "error": + if err != nil { + return writeJSONSafeString(buf, err.Error()) + } + case "latency": + l := stop.Sub(start) + return buf.WriteString(strconv.FormatInt(int64(l), 10)) + case "latency_human": + return buf.WriteString(stop.Sub(start).String()) + case "bytes_in": + cl := req.Header.Get(echo.HeaderContentLength) + if cl == "" { + cl = "0" + } + return writeString(buf, cl) + case "bytes_out": + return buf.WriteString(strconv.FormatInt(res.Size, 10)) + default: + switch { + case strings.HasPrefix(tag, "header:"): + return writeString(buf, c.Request().Header.Get(tag[7:])) + case strings.HasPrefix(tag, "query:"): + return writeString(buf, c.QueryParam(tag[6:])) + case strings.HasPrefix(tag, "form:"): + return writeString(buf, c.FormValue(tag[5:])) + case strings.HasPrefix(tag, "cookie:"): + cookie, err := c.Cookie(tag[7:]) + if err == nil { + return buf.Write([]byte(cookie.Value)) + } + } + } + return 0, nil + }); err != nil { + return + } + + if config.Output == nil { + _, err = c.Logger().Output().Write(buf.Bytes()) + return + } + _, err = config.Output.Write(buf.Bytes()) + return + } + } +} diff --git a/middleware/logger_strings.go b/middleware/logger_strings.go new file mode 100644 index 000000000..8476cb046 --- /dev/null +++ b/middleware/logger_strings.go @@ -0,0 +1,242 @@ +// SPDX-License-Identifier: BSD-3-Clause +// SPDX-FileCopyrightText: Copyright 2010 The Go Authors +// +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// +// Go LICENSE https://raw.githubusercontent.com/golang/go/36bca3166e18db52687a4d91ead3f98ffe6d00b8/LICENSE +/** +Copyright 2009 The Go Authors. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google LLC nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +package middleware + +import ( + "bytes" + "unicode/utf8" +) + +// This function is modified copy from Go standard library encoding/json/encode.go `appendString` function +// Source: https://github.com/golang/go/blob/36bca3166e18db52687a4d91ead3f98ffe6d00b8/src/encoding/json/encode.go#L999 +func writeJSONSafeString(buf *bytes.Buffer, src string) (int, error) { + const hex = "0123456789abcdef" + + written := 0 + start := 0 + for i := 0; i < len(src); { + if b := src[i]; b < utf8.RuneSelf { + if safeSet[b] { + i++ + continue + } + + n, err := buf.Write([]byte(src[start:i])) + written += n + if err != nil { + return written, err + } + switch b { + case '\\', '"': + n, err := buf.Write([]byte{'\\', b}) + written += n + if err != nil { + return written, err + } + case '\b': + n, err := buf.Write([]byte{'\\', 'b'}) + written += n + if err != nil { + return n, err + } + case '\f': + n, err := buf.Write([]byte{'\\', 'f'}) + written += n + if err != nil { + return written, err + } + case '\n': + n, err := buf.Write([]byte{'\\', 'n'}) + written += n + if err != nil { + return written, err + } + case '\r': + n, err := buf.Write([]byte{'\\', 'r'}) + written += n + if err != nil { + return written, err + } + case '\t': + n, err := buf.Write([]byte{'\\', 't'}) + written += n + if err != nil { + return written, err + } + default: + // This encodes bytes < 0x20 except for \b, \f, \n, \r and \t. + n, err := buf.Write([]byte{'\\', 'u', '0', '0', hex[b>>4], hex[b&0xF]}) + written += n + if err != nil { + return written, err + } + } + i++ + start = i + continue + } + srcN := min(len(src)-i, utf8.UTFMax) + c, size := utf8.DecodeRuneInString(src[i : i+srcN]) + if c == utf8.RuneError && size == 1 { + n, err := buf.Write([]byte(src[start:i])) + written += n + if err != nil { + return written, err + } + n, err = buf.Write([]byte(`\ufffd`)) + written += n + if err != nil { + return written, err + } + i += size + start = i + continue + } + i += size + } + n, err := buf.Write([]byte(src[start:])) + written += n + return written, err +} + +// safeSet holds the value true if the ASCII character with the given array +// position can be represented inside a JSON string without any further +// escaping. +// +// All values are true except for the ASCII control characters (0-31), the +// double quote ("), and the backslash character ("\"). +var safeSet = [utf8.RuneSelf]bool{ + ' ': true, + '!': true, + '"': false, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '(': true, + ')': true, + '*': true, + '+': true, + ',': true, + '-': true, + '.': true, + '/': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + ':': true, + ';': true, + '<': true, + '=': true, + '>': true, + '?': true, + '@': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'V': true, + 'W': true, + 'X': true, + 'Y': true, + 'Z': true, + '[': true, + '\\': false, + ']': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '{': true, + '|': true, + '}': true, + '~': true, + '\u007f': true, +} diff --git a/middleware/logger_strings_test.go b/middleware/logger_strings_test.go new file mode 100644 index 000000000..3d66404c5 --- /dev/null +++ b/middleware/logger_strings_test.go @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors + +package middleware + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWriteJSONSafeString(t *testing.T) { + testCases := []struct { + name string + whenInput string + expect string + expectN int + }{ + // Basic cases + { + name: "empty string", + whenInput: "", + expect: "", + expectN: 0, + }, + { + name: "simple ASCII without special chars", + whenInput: "hello", + expect: "hello", + expectN: 5, + }, + { + name: "single character", + whenInput: "a", + expect: "a", + expectN: 1, + }, + { + name: "alphanumeric", + whenInput: "Hello123World", + expect: "Hello123World", + expectN: 13, + }, + + // Special character escaping + { + name: "backslash", + whenInput: `path\to\file`, + expect: `path\\to\\file`, + expectN: 14, + }, + { + name: "double quote", + whenInput: `say "hello"`, + expect: `say \"hello\"`, + expectN: 13, + }, + { + name: "backslash and quote combined", + whenInput: `a\b"c`, + expect: `a\\b\"c`, + expectN: 7, + }, + { + name: "single backslash", + whenInput: `\`, + expect: `\\`, + expectN: 2, + }, + { + name: "single quote", + whenInput: `"`, + expect: `\"`, + expectN: 2, + }, + + // Control character escaping + { + name: "backspace", + whenInput: "hello\bworld", + expect: `hello\bworld`, + expectN: 12, + }, + { + name: "form feed", + whenInput: "hello\fworld", + expect: `hello\fworld`, + expectN: 12, + }, + { + name: "newline", + whenInput: "hello\nworld", + expect: `hello\nworld`, + expectN: 12, + }, + { + name: "carriage return", + whenInput: "hello\rworld", + expect: `hello\rworld`, + expectN: 12, + }, + { + name: "tab", + whenInput: "hello\tworld", + expect: `hello\tworld`, + expectN: 12, + }, + { + name: "multiple newlines", + whenInput: "line1\nline2\nline3", + expect: `line1\nline2\nline3`, + expectN: 19, + }, + + // Low control characters (< 0x20) + { + name: "null byte", + whenInput: "hello\x00world", + expect: `hello\u0000world`, + expectN: 16, + }, + { + name: "control character 0x01", + whenInput: "test\x01value", + expect: `test\u0001value`, + expectN: 15, + }, + { + name: "control character 0x0e", + whenInput: "test\x0evalue", + expect: `test\u000evalue`, + expectN: 15, + }, + { + name: "control character 0x1f", + whenInput: "test\x1fvalue", + expect: `test\u001fvalue`, + expectN: 15, + }, + { + name: "multiple control characters", + whenInput: "\x00\x01\x02", + expect: `\u0000\u0001\u0002`, + expectN: 18, + }, + + // UTF-8 handling + { + name: "valid UTF-8 Chinese", + whenInput: "hello δΈ–η•Œ", + expect: "hello δΈ–η•Œ", + expectN: 12, + }, + { + name: "valid UTF-8 emoji", + whenInput: "party πŸŽ‰ time", + expect: "party πŸŽ‰ time", + expectN: 15, + }, + { + name: "mixed ASCII and UTF-8", + whenInput: "HelloδΈ–η•Œ123", + expect: "HelloδΈ–η•Œ123", + expectN: 14, + }, + { + name: "UTF-8 with special chars", + whenInput: "δΈ–η•Œ\n\"test\"", + expect: `δΈ–η•Œ\n\"test\"`, + expectN: 16, + }, + + // Invalid UTF-8 + { + name: "invalid UTF-8 sequence", + whenInput: "hello\xff\xfeworld", + expect: `hello\ufffd\ufffdworld`, + expectN: 22, + }, + { + name: "incomplete UTF-8 sequence", + whenInput: "test\xc3value", + expect: `test\ufffdvalue`, + expectN: 15, + }, + + // Complex mixed cases + { + name: "all common escapes", + whenInput: "tab\there\nquote\"backslash\\", + expect: `tab\there\nquote\"backslash\\`, + expectN: 29, + }, + { + name: "mixed controls and UTF-8", + whenInput: "hello\tδΈ–η•Œ\ntest\"", + expect: `hello\tδΈ–η•Œ\ntest\"`, + expectN: 21, + }, + { + name: "all control characters", + whenInput: "\b\f\n\r\t", + expect: `\b\f\n\r\t`, + expectN: 10, + }, + { + name: "control and low ASCII", + whenInput: "a\nb\x00c", + expect: `a\nb\u0000c`, + expectN: 11, + }, + + // Edge cases + { + name: "starts with special char", + whenInput: "\\start", + expect: `\\start`, + expectN: 7, + }, + { + name: "ends with special char", + whenInput: "end\"", + expect: `end\"`, + expectN: 5, + }, + { + name: "consecutive special chars", + whenInput: "\\\\\"\"", + expect: `\\\\\"\"`, + expectN: 8, + }, + { + name: "only special characters", + whenInput: "\"\\\n\t", + expect: `\"\\\n\t`, + expectN: 8, + }, + { + name: "spaces and punctuation", + whenInput: "Hello, World! How are you?", + expect: "Hello, World! How are you?", + expectN: 26, + }, + { + name: "JSON-like string", + whenInput: "{\"key\":\"value\"}", + expect: `{\"key\":\"value\"}`, + expectN: 19, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + buf := &bytes.Buffer{} + n, err := writeJSONSafeString(buf, tt.whenInput) + + assert.NoError(t, err) + assert.Equal(t, tt.expect, buf.String()) + assert.Equal(t, tt.expectN, n) + }) + } +} + +func BenchmarkWriteJSONSafeString(b *testing.B) { + testCases := []struct { + name string + input string + }{ + {"simple", "hello world"}, + {"with escapes", "tab\there\nquote\"backslash\\"}, + {"utf8", "hello δΈ–η•Œ πŸŽ‰"}, + {"mixed", "Hello\tδΈ–η•Œ\ntest\"value\\path"}, + {"long simple", "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789"}, + {"long complex", "line1\nline2\tline3\"quote\\slash\x00nullδΈ–η•ŒπŸŽ‰"}, + } + + for _, tc := range testCases { + b.Run(tc.name, func(b *testing.B) { + buf := &bytes.Buffer{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.Reset() + writeJSONSafeString(buf, tc.input) + } + }) + } +} diff --git a/middleware/logger_test.go b/middleware/logger_test.go new file mode 100644 index 000000000..e4b783db5 --- /dev/null +++ b/middleware/logger_test.go @@ -0,0 +1,540 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors + +package middleware + +import ( + "bytes" + "cmp" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "net/url" + "regexp" + "strings" + "testing" + "time" + "unsafe" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestLoggerDefaultMW(t *testing.T) { + var testCases = []struct { + name string + whenHeader map[string]string + whenStatusCode int + whenResponse string + whenError error + expect string + }{ + { + name: "ok, status 200", + whenStatusCode: http.StatusOK, + whenResponse: "test", + expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":200,"error":"","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":4}` + "\n", + }, + { + name: "ok, status 300", + whenStatusCode: http.StatusTemporaryRedirect, + whenResponse: "test", + expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":307,"error":"","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":4}` + "\n", + }, + { + name: "ok, handler error = status 500", + whenError: errors.New("error"), + expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":500,"error":"error","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":36}` + "\n", + }, + { + name: "error with invalid UTF-8 sequences", + whenError: errors.New("invalid data: \xFF\xFE"), + expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":500,"error":"invalid data: \ufffd\ufffd","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":36}` + "\n", + }, + { + name: "error with JSON special characters (quotes and backslashes)", + whenError: errors.New(`error with "quotes" and \backslash`), + expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":500,"error":"error with \"quotes\" and \\backslash","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":36}` + "\n", + }, + { + name: "error with control characters (newlines and tabs)", + whenError: errors.New("error\nwith\nnewlines\tand\ttabs"), + expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":500,"error":"error\nwith\nnewlines\tand\ttabs","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":36}` + "\n", + }, + { + name: "ok, remote_ip from X-Real-Ip header", + whenHeader: map[string]string{echo.HeaderXRealIP: "127.0.0.1"}, + whenStatusCode: http.StatusOK, + whenResponse: "test", + expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"127.0.0.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":200,"error":"","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":4}` + "\n", + }, + { + name: "ok, remote_ip from X-Forwarded-For header", + whenHeader: map[string]string{echo.HeaderXForwardedFor: "127.0.0.1"}, + whenStatusCode: http.StatusOK, + whenResponse: "test", + expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"127.0.0.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":200,"error":"","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":4}` + "\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + if len(tc.whenHeader) > 0 { + for k, v := range tc.whenHeader { + req.Header.Add(k, v) + } + } + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + DefaultLoggerConfig.timeNow = func() time.Time { return time.Unix(1588037200, 0).UTC() } + h := Logger()(func(c echo.Context) error { + if tc.whenError != nil { + return tc.whenError + } + return c.String(tc.whenStatusCode, tc.whenResponse) + }) + buf := new(bytes.Buffer) + e.Logger.SetOutput(buf) + + err := h(c) + assert.NoError(t, err) + + result := buf.String() + // handle everchanging latency numbers + result = regexp.MustCompile(`"latency":\d+,`).ReplaceAllString(result, `"latency":1,`) + result = regexp.MustCompile(`"latency_human":"[^"]+"`).ReplaceAllString(result, `"latency_human":"1Β΅s"`) + + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestLoggerWithLoggerConfig(t *testing.T) { + // to handle everchanging latency numbers + jsonLatency := map[string]*regexp.Regexp{ + `"latency":1,`: regexp.MustCompile(`"latency":\d+,`), + `"latency_human":"1Β΅s"`: regexp.MustCompile(`"latency_human":"[^"]+"`), + } + + form := make(url.Values) + form.Set("csrf", "token") + form.Add("multiple", "1") + form.Add("multiple", "2") + + var testCases = []struct { + name string + givenConfig LoggerConfig + whenURI string + whenMethod string + whenHost string + whenPath string + whenRoute string + whenProto string + whenRequestURI string + whenHeader map[string]string + whenFormValues url.Values + whenStatusCode int + whenResponse string + whenError error + whenReplacers map[string]*regexp.Regexp + expect string + }{ + { + name: "ok, skipper", + givenConfig: LoggerConfig{ + Skipper: func(c echo.Context) bool { return true }, + }, + expect: ``, + }, + { // this is an example how format that does not seem to be JSON is not currently escaped + name: "ok, NON json string is not escaped: method", + givenConfig: LoggerConfig{Format: `method:"${method}"`}, + whenMethod: `","method":":D"`, + expect: `method:"","method":":D""`, + }, + { + name: "ok, json string escape: method", + givenConfig: LoggerConfig{Format: `{"method":"${method}"}`}, + whenMethod: `","method":":D"`, + expect: `{"method":"\",\"method\":\":D\""}`, + }, + { + name: "ok, json string escape: id", + givenConfig: LoggerConfig{Format: `{"id":"${id}"}`}, + whenHeader: map[string]string{echo.HeaderXRequestID: `\"127.0.0.1\"`}, + expect: `{"id":"\\\"127.0.0.1\\\""}`, + }, + { + name: "ok, json string escape: remote_ip", + givenConfig: LoggerConfig{Format: `{"remote_ip":"${remote_ip}"}`}, + whenHeader: map[string]string{echo.HeaderXForwardedFor: `\"127.0.0.1\"`}, + expect: `{"remote_ip":"\\\"127.0.0.1\\\""}`, + }, + { + name: "ok, json string escape: host", + givenConfig: LoggerConfig{Format: `{"host":"${host}"}`}, + whenHost: `\"127.0.0.1\"`, + expect: `{"host":"\\\"127.0.0.1\\\""}`, + }, + { + name: "ok, json string escape: path", + givenConfig: LoggerConfig{Format: `{"path":"${path}"}`}, + whenPath: `\","` + "\n", + expect: `{"path":"\\\",\"\n"}`, + }, + { + name: "ok, json string escape: route", + givenConfig: LoggerConfig{Format: `{"route":"${route}"}`}, + whenRoute: `\","` + "\n", + expect: `{"route":"\\\",\"\n"}`, + }, + { + name: "ok, json string escape: proto", + givenConfig: LoggerConfig{Format: `{"protocol":"${protocol}"}`}, + whenProto: `\","` + "\n", + expect: `{"protocol":"\\\",\"\n"}`, + }, + { + name: "ok, json string escape: referer", + givenConfig: LoggerConfig{Format: `{"referer":"${referer}"}`}, + whenHeader: map[string]string{"Referer": `\","` + "\n"}, + expect: `{"referer":"\\\",\"\n"}`, + }, + { + name: "ok, json string escape: user_agent", + givenConfig: LoggerConfig{Format: `{"user_agent":"${user_agent}"}`}, + whenHeader: map[string]string{"User-Agent": `\","` + "\n"}, + expect: `{"user_agent":"\\\",\"\n"}`, + }, + { + name: "ok, json string escape: bytes_in", + givenConfig: LoggerConfig{Format: `{"bytes_in":"${bytes_in}"}`}, + whenHeader: map[string]string{echo.HeaderContentLength: `\","` + "\n"}, + expect: `{"bytes_in":"\\\",\"\n"}`, + }, + { + name: "ok, json string escape: query param", + givenConfig: LoggerConfig{Format: `{"query":"${query:test}"}`}, + whenURI: `/?test=1","`, + expect: `{"query":"1\",\""}`, + }, + { + name: "ok, json string escape: header", + givenConfig: LoggerConfig{Format: `{"header":"${header:referer}"}`}, + whenHeader: map[string]string{"referer": `\","` + "\n"}, + expect: `{"header":"\\\",\"\n"}`, + }, + { + name: "ok, json string escape: form", + givenConfig: LoggerConfig{Format: `{"csrf":"${form:csrf}"}`}, + whenMethod: http.MethodPost, + whenFormValues: url.Values{"csrf": {`token","`}}, + expect: `{"csrf":"token\",\""}`, + }, + { + name: "nok, json string escape: cookie - will not accept invalid chars", + // net/cookie.go: validCookieValueByte function allows these byte in cookie value + // only `0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\'` + givenConfig: LoggerConfig{Format: `{"cookie":"${cookie:session}"}`}, + whenHeader: map[string]string{"Cookie": `_ga=GA1.2.000000000.0000000000; session=test\n`}, + expect: `{"cookie":""}`, + }, + { + name: "ok, format time_unix", + givenConfig: LoggerConfig{Format: `${time_unix}`}, + whenStatusCode: http.StatusOK, + whenResponse: "test", + expect: `1588037200`, + }, + { + name: "ok, format time_unix_milli", + givenConfig: LoggerConfig{Format: `${time_unix_milli}`}, + whenStatusCode: http.StatusOK, + whenResponse: "test", + expect: `1588037200000`, + }, + { + name: "ok, format time_unix_micro", + givenConfig: LoggerConfig{Format: `${time_unix_micro}`}, + whenStatusCode: http.StatusOK, + whenResponse: "test", + expect: `1588037200000000`, + }, + { + name: "ok, format time_unix_nano", + givenConfig: LoggerConfig{Format: `${time_unix_nano}`}, + whenStatusCode: http.StatusOK, + whenResponse: "test", + expect: `1588037200000000000`, + }, + { + name: "ok, format time_rfc3339", + givenConfig: LoggerConfig{Format: `${time_rfc3339}`}, + whenStatusCode: http.StatusOK, + whenResponse: "test", + expect: `2020-04-28T01:26:40Z`, + }, + { + name: "ok, status 200", + whenStatusCode: http.StatusOK, + whenResponse: "test", + whenReplacers: jsonLatency, + expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":200,"error":"","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":4}` + "\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, cmp.Or(tc.whenURI, "/"), nil) + if tc.whenFormValues != nil { + req = httptest.NewRequest(http.MethodGet, cmp.Or(tc.whenURI, "/"), strings.NewReader(tc.whenFormValues.Encode())) + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + } + + for k, v := range tc.whenHeader { + req.Header.Add(k, v) + } + if tc.whenHost != "" { + req.Host = tc.whenHost + } + if tc.whenMethod != "" { + req.Method = tc.whenMethod + } + if tc.whenProto != "" { + req.Proto = tc.whenProto + } + if tc.whenRequestURI != "" { + req.RequestURI = tc.whenRequestURI + } + if tc.whenPath != "" { + req.URL.Path = tc.whenPath + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if tc.whenFormValues != nil { + c.FormValue("to trigger form parsing") + } + if tc.whenRoute != "" { + c.SetPath(tc.whenRoute) + } + + config := tc.givenConfig + if config.timeNow == nil { + config.timeNow = func() time.Time { return time.Unix(1588037200, 0).UTC() } + } + buf := new(bytes.Buffer) + if config.Output == nil { + e.Logger.SetOutput(buf) + } + + h := LoggerWithConfig(config)(func(c echo.Context) error { + if tc.whenError != nil { + return tc.whenError + } + return c.String(cmp.Or(tc.whenStatusCode, http.StatusOK), cmp.Or(tc.whenResponse, "test")) + }) + + err := h(c) + assert.NoError(t, err) + + result := buf.String() + + for replaceTo, replacer := range tc.whenReplacers { + result = replacer.ReplaceAllString(result, replaceTo) + } + + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestLoggerTemplate(t *testing.T) { + buf := new(bytes.Buffer) + + e := echo.New() + e.Use(LoggerWithConfig(LoggerConfig{ + Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` + + `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + + `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "route":"${route}", "referer":"${referer}",` + + `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}", "protocol":"${protocol}"` + + `"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n", + Output: buf, + })) + + e.GET("/users/:id", func(c echo.Context) error { + return c.String(http.StatusOK, "Header Logged") + }) + + req := httptest.NewRequest(http.MethodGet, "/users/1?username=apagano-param&password=secret", nil) + req.RequestURI = "/" + req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") + req.Header.Add("Referer", "google.com") + req.Header.Add("User-Agent", "echo-tests-agent") + req.Header.Add("X-Custom-Header", "AAA-CUSTOM-VALUE") + req.Header.Add("X-Request-ID", "6ba7b810-9dad-11d1-80b4-00c04fd430c8") + req.Header.Add("Cookie", "_ga=GA1.2.000000000.0000000000; session=ac08034cd216a647fc2eb62f2bcf7b810") + req.Form = url.Values{ + "username": []string{"apagano-form"}, + "password": []string{"secret-form"}, + } + + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + cases := map[string]bool{ + "apagano-param": true, + "apagano-form": true, + "AAA-CUSTOM-VALUE": true, + "BBB-CUSTOM-VALUE": false, + "secret-form": false, + "hexvalue": false, + "GET": true, + "127.0.0.1": true, + "\"path\":\"/users/1\"": true, + "\"route\":\"/users/:id\"": true, + "\"uri\":\"/\"": true, + "\"status\":200": true, + "\"bytes_in\":0": true, + "google.com": true, + "echo-tests-agent": true, + "6ba7b810-9dad-11d1-80b4-00c04fd430c8": true, + "ac08034cd216a647fc2eb62f2bcf7b810": true, + } + + for token, present := range cases { + assert.True(t, strings.Contains(buf.String(), token) == present, "Case: "+token) + } +} + +func TestLoggerCustomTimestamp(t *testing.T) { + buf := new(bytes.Buffer) + customTimeFormat := "2006-01-02 15:04:05.00000" + e := echo.New() + e.Use(LoggerWithConfig(LoggerConfig{ + Format: `{"time":"${time_custom}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` + + `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + + `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` + + `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}",` + + `"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n", + CustomTimeFormat: customTimeFormat, + Output: buf, + })) + + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusOK, "custom time stamp test") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + var objs map[string]*json.RawMessage + if err := json.Unmarshal(buf.Bytes(), &objs); err != nil { + panic(err) + } + loggedTime := *(*string)(unsafe.Pointer(objs["time"])) + _, err := time.Parse(customTimeFormat, loggedTime) + assert.Error(t, err) +} + +func TestLoggerCustomTagFunc(t *testing.T) { + e := echo.New() + buf := new(bytes.Buffer) + e.Use(LoggerWithConfig(LoggerConfig{ + Format: `{"method":"${method}",${custom}}` + "\n", + CustomTagFunc: func(c echo.Context, buf *bytes.Buffer) (int, error) { + return buf.WriteString(`"tag":"my-value"`) + }, + Output: buf, + })) + + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusOK, "custom time stamp test") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, `{"method":"GET","tag":"my-value"}`+"\n", buf.String()) +} + +func BenchmarkLoggerWithConfig_withoutMapFields(b *testing.B) { + e := echo.New() + + buf := new(bytes.Buffer) + mw := LoggerWithConfig(LoggerConfig{ + Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` + + `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + + `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` + + `"bytes_out":${bytes_out}, "protocol":"${protocol}"}` + "\n", + Output: buf, + })(func(c echo.Context) error { + c.Request().Header.Set(echo.HeaderXRequestID, "123") + c.FormValue("to force parse form") + return c.String(http.StatusTeapot, "OK") + }) + + f := make(url.Values) + f.Set("csrf", "token") + f.Add("multiple", "1") + f.Add("multiple", "2") + req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", strings.NewReader(f.Encode())) + req.Header.Set("Referer", "https://echo.labstack.com/") + req.Header.Set("User-Agent", "curl/7.68.0") + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + mw(c) + buf.Reset() + } +} + +func BenchmarkLoggerWithConfig_withMapFields(b *testing.B) { + e := echo.New() + + buf := new(bytes.Buffer) + mw := LoggerWithConfig(LoggerConfig{ + Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` + + `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + + `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` + + `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}", "protocol":"${protocol}"` + + `"us":"${query:username}", "cf":"${form:csrf}", "Referer2":"${header:Referer}"}` + "\n", + Output: buf, + })(func(c echo.Context) error { + c.Request().Header.Set(echo.HeaderXRequestID, "123") + c.FormValue("to force parse form") + return c.String(http.StatusTeapot, "OK") + }) + + f := make(url.Values) + f.Set("csrf", "token") + f.Add("multiple", "1") + f.Add("multiple", "2") + req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", strings.NewReader(f.Encode())) + req.Header.Set("Referer", "https://echo.labstack.com/") + req.Header.Set("User-Agent", "curl/7.68.0") + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + mw(c) + buf.Reset() + } +} diff --git a/middleware/method_override.go b/middleware/method_override.go index 25ec1f935..3991e1029 100644 --- a/middleware/method_override.go +++ b/middleware/method_override.go @@ -6,7 +6,7 @@ package middleware import ( "net/http" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" ) // MethodOverrideConfig defines the config for MethodOverride middleware. @@ -20,7 +20,7 @@ type MethodOverrideConfig struct { } // MethodOverrideGetter is a function that gets overridden method from the request -type MethodOverrideGetter func(c *echo.Context) string +type MethodOverrideGetter func(echo.Context) string // DefaultMethodOverrideConfig is the default MethodOverride middleware config. var DefaultMethodOverrideConfig = MethodOverrideConfig{ @@ -37,13 +37,9 @@ func MethodOverride() echo.MiddlewareFunc { return MethodOverrideWithConfig(DefaultMethodOverrideConfig) } -// MethodOverrideWithConfig returns a Method Override middleware with config or panics on invalid configuration. +// MethodOverrideWithConfig returns a MethodOverride middleware with config. +// See: `MethodOverride()`. func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { - return toMiddlewareOrPanic(config) -} - -// ToMiddleware converts MethodOverrideConfig to middleware or returns an error for invalid configuration -func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultMethodOverrideConfig.Skipper @@ -53,7 +49,7 @@ func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) { } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) error { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } @@ -67,13 +63,13 @@ func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) { } return next(c) } - }, nil + } } // MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from // the request header. func MethodFromHeader(header string) MethodOverrideGetter { - return func(c *echo.Context) string { + return func(c echo.Context) string { return c.Request().Header.Get(header) } } @@ -81,7 +77,7 @@ func MethodFromHeader(header string) MethodOverrideGetter { // MethodFromForm is a `MethodOverrideGetter` that gets overridden method from the // form parameter. func MethodFromForm(param string) MethodOverrideGetter { - return func(c *echo.Context) string { + return func(c echo.Context) string { return c.FormValue(param) } } @@ -89,7 +85,7 @@ func MethodFromForm(param string) MethodOverrideGetter { // MethodFromQuery is a `MethodOverrideGetter` that gets overridden method from // the query parameter. func MethodFromQuery(param string) MethodOverrideGetter { - return func(c *echo.Context) string { + return func(c echo.Context) string { return c.QueryParam(param) } } diff --git a/middleware/method_override_test.go b/middleware/method_override_test.go index 525ad10ba..0000d1d80 100644 --- a/middleware/method_override_test.go +++ b/middleware/method_override_test.go @@ -9,14 +9,14 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) func TestMethodOverride(t *testing.T) { e := echo.New() m := MethodOverride() - h := func(c *echo.Context) error { + h := func(c echo.Context) error { return c.String(http.StatusOK, "test") } @@ -25,68 +25,28 @@ func TestMethodOverride(t *testing.T) { rec := httptest.NewRecorder() req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) c := e.NewContext(req, rec) - - err := m(h)(c) - assert.NoError(t, err) - + m(h)(c) assert.Equal(t, http.MethodDelete, req.Method) -} - -func TestMethodOverride_formParam(t *testing.T) { - e := echo.New() - h := func(c *echo.Context) error { - return c.String(http.StatusOK, "test") - } - // Override with form parameter - m, err := MethodOverrideConfig{Getter: MethodFromForm("_method")}.ToMiddleware() - assert.NoError(t, err) - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) - rec := httptest.NewRecorder() + m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")}) + req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) + rec = httptest.NewRecorder() req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) - c := e.NewContext(req, rec) - - err = m(h)(c) - assert.NoError(t, err) - + c = e.NewContext(req, rec) + m(h)(c) assert.Equal(t, http.MethodDelete, req.Method) -} - -func TestMethodOverride_queryParam(t *testing.T) { - e := echo.New() - h := func(c *echo.Context) error { - return c.String(http.StatusOK, "test") - } // Override with query parameter - m, err := MethodOverrideConfig{Getter: MethodFromQuery("_method")}.ToMiddleware() - assert.NoError(t, err) - req := httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err = m(h)(c) - assert.NoError(t, err) - + m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")}) + req = httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + m(h)(c) assert.Equal(t, http.MethodDelete, req.Method) -} - -func TestMethodOverride_ignoreGet(t *testing.T) { - e := echo.New() - m := MethodOverride() - h := func(c *echo.Context) error { - return c.String(http.StatusOK, "test") - } // Ignore `GET` - req := httptest.NewRequest(http.MethodGet, "/", nil) + req = httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := m(h)(c) - assert.NoError(t, err) - assert.Equal(t, http.MethodGet, req.Method) } diff --git a/middleware/middleware.go b/middleware/middleware.go index 4562d03b5..164e52b4c 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -9,14 +9,15 @@ import ( "strconv" "strings" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" ) -// Skipper defines a function to skip middleware. Returning true skips processing the middleware. -type Skipper func(c *echo.Context) bool +// Skipper defines a function to skip middleware. Returning true skips processing +// the middleware. +type Skipper func(c echo.Context) bool // BeforeFunc defines a function which is executed just before the middleware. -type BeforeFunc func(c *echo.Context) +type BeforeFunc func(c echo.Context) func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { groups := pattern.FindAllStringSubmatch(input, -1) @@ -53,7 +54,7 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error return nil } - // Depending how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path. + // Depending on how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path. // We only want to use path part for rewriting and therefore trim prefix if it exists rawURI := req.RequestURI if rawURI != "" && rawURI[0] != '/' { @@ -84,11 +85,13 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error } // DefaultSkipper returns false which processes the middleware. -func DefaultSkipper(c *echo.Context) bool { +func DefaultSkipper(echo.Context) bool { return false } -func toMiddlewareOrPanic(config echo.MiddlewareConfigurator) echo.MiddlewareFunc { +func toMiddlewareOrPanic(config interface { + ToMiddleware() (echo.MiddlewareFunc, error) +}) echo.MiddlewareFunc { mw, err := config.ToMiddleware() if err != nil { panic(err) diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 28407ed5c..7f3dc3866 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -102,9 +102,11 @@ type testResponseWriterNoFlushHijack struct { func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) { } + func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) { return 0, nil } + func (w *testResponseWriterNoFlushHijack) Header() http.Header { return nil } @@ -116,12 +118,15 @@ type testResponseWriterUnwrapper struct { func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) { } + func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) { return 0, nil } + func (w *testResponseWriterUnwrapper) Header() http.Header { return nil } + func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter { w.unwrapCalled++ return w.rw diff --git a/middleware/proxy.go b/middleware/proxy.go index 497aefea4..f26870077 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -6,10 +6,8 @@ package middleware import ( "context" "crypto/tls" - "errors" "fmt" "io" - "maps" "math/rand" "net" "net/http" @@ -20,7 +18,7 @@ import ( "sync" "time" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" ) // TODO: Handle TLS proxy @@ -43,14 +41,14 @@ type ProxyConfig struct { // of previous retries is less than RetryCount. If the function returns true, the // request will be retried. The provided error indicates the reason for the request // failure. When the ProxyTarget is unavailable, the error will be an instance of - // echo.HTTPError with a code of http.StatusBadGateway. In all other cases, the error + // echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error // will indicate an internal error in the Proxy middleware. When a RetryFilter is not // specified, all requests that fail with http.StatusBadGateway will be retried. A custom // RetryFilter can be provided to only retry specific requests. Note that RetryFilter is // only called when the request to the target fails, or an internal error in the Proxy // middleware has occurred. Successful requests that return a non-200 response code cannot // be retried. - RetryFilter func(c *echo.Context, e error) bool + RetryFilter func(c echo.Context, e error) bool // ErrorHandler defines a function which can be used to return custom errors from // the Proxy middleware. ErrorHandler is only invoked when there has been @@ -59,7 +57,7 @@ type ProxyConfig struct { // when a ProxyTarget returns a non-200 response. In these cases, the response // is already written so errors cannot be modified. ErrorHandler is only // invoked after all retry attempts have been exhausted. - ErrorHandler func(c *echo.Context, err error) error + ErrorHandler func(c echo.Context, err error) error // Rewrite defines URL path rewrite rules. The values captured in asterisk can be // retrieved by index e.g. $1, $2 and so on. @@ -93,14 +91,20 @@ type ProxyConfig struct { type ProxyTarget struct { Name string URL *url.URL - Meta map[string]any + Meta echo.Map } // ProxyBalancer defines an interface to implement a load balancing technique. type ProxyBalancer interface { - AddTarget(target *ProxyTarget) bool - RemoveTarget(targetName string) bool - Next(c *echo.Context) (*ProxyTarget, error) + AddTarget(*ProxyTarget) bool + RemoveTarget(string) bool + Next(echo.Context) *ProxyTarget +} + +// TargetProvider defines an interface that gives the opportunity for balancer +// to return custom errors when selecting target. +type TargetProvider interface { + NextTarget(echo.Context) (*ProxyTarget, error) } type commonBalancer struct { @@ -127,7 +131,7 @@ var DefaultProxyConfig = ProxyConfig{ ContextKey: "target", } -func proxyRaw(c *echo.Context, t *ProxyTarget, config ProxyConfig) http.Handler { +func proxyRaw(t *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { var dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) if transport, ok := config.Transport.(*http.Transport); ok { if transport.TLSClientConfig != nil { @@ -143,13 +147,12 @@ func proxyRaw(c *echo.Context, t *ProxyTarget, config ProxyConfig) http.Handler } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - in, _, err := http.NewResponseController(w).Hijack() + in, _, err := c.Response().Hijack() if err != nil { c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL)) return } defer in.Close() - out, err := dialFunc(c.Request().Context(), "tcp", t.URL.Host) if err != nil { c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL))) @@ -189,9 +192,7 @@ func proxyRaw(c *echo.Context, t *ProxyTarget, config ProxyConfig) http.Handler func NewRandomBalancer(targets []*ProxyTarget) ProxyBalancer { b := randomBalancer{} b.targets = targets - // G404 (CWE-338): Use of weak random number generator (math/rand or math/rand/v2 instead of crypto/rand) - // this random is used to select next target. I can not think of reason this must be cryptographically safe. If you can - please open PR. - b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) // #nosec G404 + b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) return &b } @@ -235,15 +236,15 @@ func (b *commonBalancer) RemoveTarget(name string) bool { // Next randomly returns an upstream target. // // Note: `nil` is returned in case upstream target list is empty. -func (b *randomBalancer) Next(c *echo.Context) (*ProxyTarget, error) { +func (b *randomBalancer) Next(c echo.Context) *ProxyTarget { b.mutex.Lock() defer b.mutex.Unlock() if len(b.targets) == 0 { - return nil, nil + return nil } else if len(b.targets) == 1 { - return b.targets[0], nil + return b.targets[0] } - return b.targets[b.random.Intn(len(b.targets))], nil + return b.targets[b.random.Intn(len(b.targets))] } // Next returns an upstream target using round-robin technique. In the case @@ -254,13 +255,13 @@ func (b *randomBalancer) Next(c *echo.Context) (*ProxyTarget, error) { // return the original failed target. // // Note: `nil` is returned in case upstream target list is empty. -func (b *roundRobinBalancer) Next(c *echo.Context) (*ProxyTarget, error) { +func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget { b.mutex.Lock() defer b.mutex.Unlock() if len(b.targets) == 0 { - return nil, nil + return nil } else if len(b.targets) == 1 { - return b.targets[0], nil + return b.targets[0] } var i int @@ -282,8 +283,9 @@ func (b *roundRobinBalancer) Next(c *echo.Context) (*ProxyTarget, error) { i = b.i b.i++ } + c.Set(lastIdxKey, i) - return b.targets[i], nil + return b.targets[i] } // Proxy returns a Proxy middleware. @@ -295,26 +297,18 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc { return ProxyWithConfig(c) } -// ProxyWithConfig returns a Proxy middleware or panics if configuration is invalid. -// -// Proxy middleware forwards the request to upstream server using a configured load balancing technique. +// ProxyWithConfig returns a Proxy middleware with config. +// See: `Proxy()` func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { - return toMiddlewareOrPanic(config) -} - -// ToMiddleware converts ProxyConfig to middleware or returns an error for invalid configuration -func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { + if config.Balancer == nil { + panic("echo: proxy middleware requires balancer") + } + // Defaults if config.Skipper == nil { config.Skipper = DefaultProxyConfig.Skipper } - if config.ContextKey == "" { - config.ContextKey = DefaultProxyConfig.ContextKey - } - if config.Balancer == nil { - return nil, errors.New("echo proxy middleware requires balancer") - } if config.RetryFilter == nil { - config.RetryFilter = func(c *echo.Context, e error) bool { + config.RetryFilter = func(c echo.Context, e error) bool { if httpErr, ok := e.(*echo.HTTPError); ok { return httpErr.Code == http.StatusBadGateway } @@ -322,20 +316,23 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { } } if config.ErrorHandler == nil { - config.ErrorHandler = func(c *echo.Context, err error) error { + config.ErrorHandler = func(c echo.Context, err error) error { return err } } - if config.Rewrite != nil { if config.RegexRewrite == nil { config.RegexRewrite = make(map[*regexp.Regexp]string) } - maps.Copy(config.RegexRewrite, rewriteRulesRegex(config.Rewrite)) + for k, v := range rewriteRulesRegex(config.Rewrite) { + config.RegexRewrite[k] = v + } } + provider, isTargetProvider := config.Balancer.(TargetProvider) + return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) (err error) { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } @@ -355,24 +352,21 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if req.Header.Get(echo.HeaderXForwardedProto) == "" { req.Header.Set(echo.HeaderXForwardedProto, c.Scheme()) } - if c.IsWebSocket() { // For HTTP, this is set by Go HTTP reverse proxy. - // Append, not set, to preserve the incoming chain from upstream proxies. - prior := req.Header[echo.HeaderXForwardedFor] - if len(prior) > 0 { - req.Header.Set(echo.HeaderXForwardedFor, strings.Join(prior, ", ")+", "+c.RealIP()) - } else { - req.Header.Set(echo.HeaderXForwardedFor, c.RealIP()) - } + if c.IsWebSocket() && req.Header.Get(echo.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy. + req.Header.Set(echo.HeaderXForwardedFor, c.RealIP()) } retries := config.RetryCount for { - tgt, err := config.Balancer.Next(c) - if err != nil { - return config.ErrorHandler(c, err) - } - if tgt == nil || tgt.URL == nil { - return config.ErrorHandler(c, echo.NewHTTPError(http.StatusBadGateway, "no proxy target available")) + var tgt *ProxyTarget + var err error + if isTargetProvider { + tgt, err = provider.NextTarget(c) + if err != nil { + return config.ErrorHandler(c, err) + } + } else { + tgt = config.Balancer.Next(c) } c.Set(config.ContextKey, tgt) @@ -391,9 +385,9 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Proxy switch { case c.IsWebSocket(): - proxyRaw(c, tgt, config).ServeHTTP(res, req) + proxyRaw(tgt, c, config).ServeHTTP(res, req) default: // even SSE requests - proxyHTTP(c, tgt, config).ServeHTTP(res, req) + proxyHTTP(tgt, c, config).ServeHTTP(res, req) } err, hasError := c.Get("_error").(error) @@ -409,7 +403,7 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { retries-- } } - }, nil + } } // StatusCodeContextCanceled is a custom HTTP status code for situations @@ -419,7 +413,7 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // 499 too instead of the more problematic 5xx, which does not allow to detect this situation const StatusCodeContextCanceled = 499 -func proxyHTTP(c *echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handler { +func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { proxy := httputil.NewSingleHostReverseProxy(tgt.URL) proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { desc := tgt.URL.String() @@ -429,17 +423,15 @@ func proxyHTTP(c *echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handl // If the client canceled the request (usually by closing the connection), we can report a // client error (4xx) instead of a server error (5xx) to correctly identify the situation. // The Go standard library (at of late 2020) wraps the exported, standard - // context. Canceled error with unexported garbage value requiring a substring check, see + // context.Canceled error with unexported garbage value requiring a substring check, see // https://github.com/golang/go/blob/6965b01ea248cabb70c3749fd218b36089a21efb/src/net/net.go#L416-L430 - // From Caddy https://github.com/caddyserver/caddy/blob/afa778ae05503f563af0d1015cdf7e5e78b1eeec/modules/caddyhttp/reverseproxy/reverseproxy.go#L1352 - if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "operation was canceled") { - httpError := echo.NewHTTPError(StatusCodeContextCanceled, "client closed connection").Wrap(err) + if err == context.Canceled || strings.Contains(err.Error(), "operation was canceled") { + httpError := echo.NewHTTPError(StatusCodeContextCanceled, fmt.Sprintf("client closed connection: %v", err)) + httpError.Internal = err c.Set("_error", httpError) } else { - httpError := echo.NewHTTPError( - http.StatusBadGateway, - "remote server unreachable, could not proxy request", - ).Wrap(fmt.Errorf("server: %s, err: %w", desc, err)) + httpError := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err)) + httpError.Internal = err c.Set("_error", httpError) } } diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 5053f7945..dbf07648b 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -15,12 +15,11 @@ import ( "net/http/httptest" "net/url" "regexp" - "strings" "sync" "testing" "time" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" "golang.org/x/net/websocket" ) @@ -38,7 +37,6 @@ func TestProxy(t *testing.T) { })) defer t2.Close() url2, _ := url.Parse(t2.URL) - targets := []*ProxyTarget{ { Name: "target 1", @@ -62,7 +60,7 @@ func TestProxy(t *testing.T) { // Random e := echo.New() - e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) + e.Use(Proxy(rb)) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -84,7 +82,7 @@ func TestProxy(t *testing.T) { // Round-robin rrb := NewRoundRobinBalancer(targets) e = echo.New() - e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb})) + e.Use(Proxy(rrb)) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -114,24 +112,68 @@ func TestProxy(t *testing.T) { // ProxyTarget is set in context contextObserver := func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) (err error) { + return func(c echo.Context) (err error) { next(c) assert.Contains(t, targets, c.Get("target"), "target is not set in context") return nil } } + rrb1 := NewRoundRobinBalancer(targets) e = echo.New() e.Use(contextObserver) - e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRoundRobinBalancer(targets)})) + e.Use(Proxy(rrb1)) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) } -func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) { - assert.Panics(t, func() { - ProxyWithConfig(ProxyConfig{Balancer: nil}) - }) +type testProvider struct { + commonBalancer + target *ProxyTarget + err error +} + +func (p *testProvider) Next(c echo.Context) *ProxyTarget { + return &ProxyTarget{} +} + +func (p *testProvider) NextTarget(c echo.Context) (*ProxyTarget, error) { + return p.target, p.err +} + +func TestTargetProvider(t *testing.T) { + t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "target 1") + })) + defer t1.Close() + url1, _ := url.Parse(t1.URL) + + e := echo.New() + tp := &testProvider{} + tp.target = &ProxyTarget{Name: "target 1", URL: url1} + e.Use(Proxy(tp)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + e.ServeHTTP(rec, req) + body := rec.Body.String() + assert.Equal(t, "target 1", body) +} + +func TestFailNextTarget(t *testing.T) { + url1, err := url.Parse("http://dummy:8080") + assert.Nil(t, err) + + e := echo.New() + tp := &testProvider{} + tp.target = &ProxyTarget{Name: "target 1", URL: url1} + tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target") + + e.Use(Proxy(tp)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + e.ServeHTTP(rec, req) + body := rec.Body.String() + assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body) } func TestProxyRealIPHeader(t *testing.T) { @@ -141,7 +183,7 @@ func TestProxyRealIPHeader(t *testing.T) { url, _ := url.Parse(upstream.URL) rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) e := echo.New() - e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb})) + e.Use(Proxy(rrb)) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() @@ -346,7 +388,7 @@ func TestProxyError(t *testing.T) { // Random e := echo.New() - e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) + e.Use(Proxy(rb)) req := httptest.NewRequest(http.MethodGet, "/", nil) // Remote unreachable @@ -357,202 +399,8 @@ func TestProxyError(t *testing.T) { assert.Equal(t, http.StatusBadGateway, rec.Code) } -func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { - var timeoutStop sync.WaitGroup - timeoutStop.Add(1) - HTTPTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - timeoutStop.Wait() // wait until we have canceled the request - w.WriteHeader(http.StatusOK) - })) - defer HTTPTarget.Close() - targetURL, _ := url.Parse(HTTPTarget.URL) - target := &ProxyTarget{ - Name: "target", - URL: targetURL, - } - rb := NewRandomBalancer(nil) - assert.True(t, rb.AddTarget(target)) - e := echo.New() - e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - ctx, cancel := context.WithCancel(req.Context()) - req = req.WithContext(ctx) - go func() { - time.Sleep(10 * time.Millisecond) - cancel() - }() - e.ServeHTTP(rec, req) - timeoutStop.Done() - assert.Equal(t, 499, rec.Code) -} - -type testProvider struct { - commonBalancer - target *ProxyTarget - err error -} - -func (p *testProvider) Next(c *echo.Context) (*ProxyTarget, error) { - return p.target, p.err -} - -func TestTargetProvider(t *testing.T) { - t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "target 1") - })) - defer t1.Close() - url1, _ := url.Parse(t1.URL) - - e := echo.New() - tp := &testProvider{} - tp.target = &ProxyTarget{Name: "target 1", URL: url1} - e.Use(Proxy(tp)) - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - e.ServeHTTP(rec, req) - body := rec.Body.String() - assert.Equal(t, "target 1", body) -} - -func TestFailNextTarget(t *testing.T) { - url1, err := url.Parse("http://dummy:8080") - assert.Nil(t, err) - - e := echo.New() - tp := &testProvider{} - tp.target = &ProxyTarget{Name: "target 1", URL: url1} - tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target") - - e.Use(Proxy(tp)) - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - e.ServeHTTP(rec, req) - body := rec.Body.String() - assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body) -} - -func TestRandomBalancerWithNoTargets(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - // Assert balancer with empty targets does return `nil` on `Next()` - rb := NewRandomBalancer(nil) - target, err := rb.Next(c) - assert.Nil(t, target) - assert.NoError(t, err) -} - -func TestRoundRobinBalancerWithNoTargets(t *testing.T) { - // Assert balancer with empty targets does return `nil` on `Next()` - rrb := NewRoundRobinBalancer([]*ProxyTarget{}) - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - target, err := rrb.Next(c) - assert.Nil(t, target) - assert.NoError(t, err) -} - -func TestProxyWithNoTargetReturnsBadGateway(t *testing.T) { - targetURL, _ := url.Parse("http://127.0.0.1:8080") - target := &ProxyTarget{Name: "target", URL: targetURL} - emptyAfterRemove := NewRoundRobinBalancer([]*ProxyTarget{target}) - assert.True(t, emptyAfterRemove.RemoveTarget("target")) - - testCases := []struct { - name string - balancer ProxyBalancer - }{ - { - name: "random balancer with nil targets", - balancer: NewRandomBalancer(nil), - }, - { - name: "round-robin balancer with nil targets", - balancer: NewRoundRobinBalancer(nil), - }, - { - name: "round-robin balancer after removing last target", - balancer: emptyAfterRemove, - }, - { - name: "custom balancer with nil target", - balancer: &customBalancer{}, - }, - { - name: "custom balancer with nil target URL", - balancer: &customBalancer{target: &ProxyTarget{Name: "target"}}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - errorHandlerCalled := false - e.Use(ProxyWithConfig(ProxyConfig{ - Balancer: tc.balancer, - ErrorHandler: func(c *echo.Context, err error) error { - errorHandlerCalled = true - httpErr, ok := err.(*echo.HTTPError) - assert.True(t, ok, "expected http error to be passed to handler") - assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler") - return err - }, - })) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - - assert.NotPanics(t, func() { - e.ServeHTTP(rec, req) - }) - assert.True(t, errorHandlerCalled) - assert.Equal(t, http.StatusBadGateway, rec.Code) - }) - } -} - -func TestProxyWithNoTargetDoesNotRetry(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - targetURL, _ := url.Parse(server.URL) - - balancer := &sequenceBalancer{ - targets: []*ProxyTarget{ - nil, - {Name: "target", URL: targetURL}, - }, - } - - retryFilterCalled := false - e := echo.New() - e.Use(ProxyWithConfig(ProxyConfig{ - Balancer: balancer, - RetryCount: 1, - RetryFilter: func(c *echo.Context, err error) bool { - retryFilterCalled = true - return true - }, - })) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.False(t, retryFilterCalled) - assert.Equal(t, 1, balancer.calls) - assert.Equal(t, http.StatusBadGateway, rec.Code) -} - func TestProxyRetries(t *testing.T) { + newServer := func(res int) (*url.URL, *httptest.Server) { server := httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -583,13 +431,13 @@ func TestProxyRetries(t *testing.T) { URL: targetURL, } - alwaysRetryFilter := func(c *echo.Context, e error) bool { return true } - neverRetryFilter := func(c *echo.Context, e error) bool { return false } + alwaysRetryFilter := func(c echo.Context, e error) bool { return true } + neverRetryFilter := func(c echo.Context, e error) bool { return false } testCases := []struct { name string retryCount int - retryFilters []func(c *echo.Context, e error) bool + retryFilters []func(c echo.Context, e error) bool targets []*ProxyTarget expectedResponse int }{ @@ -612,7 +460,7 @@ func TestProxyRetries(t *testing.T) { { name: "retry count 1 does retry on handler return true", retryCount: 1, - retryFilters: []func(c *echo.Context, e error) bool{ + retryFilters: []func(c echo.Context, e error) bool{ alwaysRetryFilter, }, targets: []*ProxyTarget{ @@ -624,7 +472,7 @@ func TestProxyRetries(t *testing.T) { { name: "retry count 1 does not retry on handler return false", retryCount: 1, - retryFilters: []func(c *echo.Context, e error) bool{ + retryFilters: []func(c echo.Context, e error) bool{ neverRetryFilter, }, targets: []*ProxyTarget{ @@ -636,7 +484,7 @@ func TestProxyRetries(t *testing.T) { { name: "retry count 2 returns error when no more retries left", retryCount: 2, - retryFilters: []func(c *echo.Context, e error) bool{ + retryFilters: []func(c echo.Context, e error) bool{ alwaysRetryFilter, alwaysRetryFilter, }, @@ -651,7 +499,7 @@ func TestProxyRetries(t *testing.T) { { name: "retry count 2 returns error when retries left but handler returns false", retryCount: 3, - retryFilters: []func(c *echo.Context, e error) bool{ + retryFilters: []func(c echo.Context, e error) bool{ alwaysRetryFilter, alwaysRetryFilter, neverRetryFilter, @@ -667,7 +515,7 @@ func TestProxyRetries(t *testing.T) { { name: "retry count 3 succeeds", retryCount: 3, - retryFilters: []func(c *echo.Context, e error) bool{ + retryFilters: []func(c echo.Context, e error) bool{ alwaysRetryFilter, alwaysRetryFilter, alwaysRetryFilter, @@ -695,7 +543,7 @@ func TestProxyRetries(t *testing.T) { t.Run(tc.name, func(t *testing.T) { retryFilterCall := 0 - retryFilter := func(c *echo.Context, e error) bool { + retryFilter := func(c echo.Context, e error) bool { if len(tc.retryFilters) == 0 { assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall)) } @@ -771,13 +619,15 @@ func TestProxyRetryWithBackendTimeout(t *testing.T) { )) var wg sync.WaitGroup - for range 20 { - wg.Go(func() { + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, 200, rec.Code) - }) + }() } wg.Wait() @@ -808,13 +658,13 @@ func TestProxyErrorHandler(t *testing.T) { testCases := []struct { name string target *ProxyTarget - errorHandler func(c *echo.Context, e error) error + errorHandler func(c echo.Context, e error) error expectFinalError func(t *testing.T, err error) }{ { name: "Error handler not invoked when request success", target: goodTarget, - errorHandler: func(c *echo.Context, e error) error { + errorHandler: func(c echo.Context, e error) error { assert.FailNow(t, "error handler should not be invoked") return e }, @@ -822,7 +672,7 @@ func TestProxyErrorHandler(t *testing.T) { { name: "Error handler invoked when request fails", target: badTarget, - errorHandler: func(c *echo.Context, e error) error { + errorHandler: func(c echo.Context, e error) error { httpErr, ok := e.(*echo.HTTPError) assert.True(t, ok, "expected http error to be passed to handler") assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler") @@ -845,11 +695,10 @@ func TestProxyErrorHandler(t *testing.T) { )) errorHandlerCalled := false - dheh := echo.DefaultHTTPErrorHandler(false) - e.HTTPErrorHandler = func(c *echo.Context, err error) { + e.HTTPErrorHandler = func(err error, c echo.Context) { errorHandlerCalled = true tc.expectFinalError(t, err) - dheh(c, err) + e.DefaultHTTPErrorHandler(err, c) } req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -865,7 +714,47 @@ func TestProxyErrorHandler(t *testing.T) { } } +func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { + var timeoutStop sync.WaitGroup + timeoutStop.Add(1) + HTTPTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + timeoutStop.Wait() // wait until we have canceled the request + w.WriteHeader(http.StatusOK) + })) + defer HTTPTarget.Close() + targetURL, _ := url.Parse(HTTPTarget.URL) + target := &ProxyTarget{ + Name: "target", + URL: targetURL, + } + rb := NewRandomBalancer(nil) + assert.True(t, rb.AddTarget(target)) + e := echo.New() + e.Use(Proxy(rb)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + ctx, cancel := context.WithCancel(req.Context()) + req = req.WithContext(ctx) + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + e.ServeHTTP(rec, req) + timeoutStop.Done() + assert.Equal(t, 499, rec.Code) +} + +// Assert balancer with empty targets does return `nil` on `Next()` +func TestProxyBalancerWithNoTargets(t *testing.T) { + rb := NewRandomBalancer(nil) + assert.Nil(t, rb.Next(nil)) + + rrb := NewRoundRobinBalancer([]*ProxyTarget{}) + assert.Nil(t, rrb.Next(nil)) +} + type testContextKey string + type customBalancer struct { target *ProxyTarget } @@ -873,33 +762,15 @@ type customBalancer struct { func (b *customBalancer) AddTarget(target *ProxyTarget) bool { return false } + func (b *customBalancer) RemoveTarget(name string) bool { return false } -func (b *customBalancer) Next(c *echo.Context) (*ProxyTarget, error) { +func (b *customBalancer) Next(c echo.Context) *ProxyTarget { ctx := context.WithValue(c.Request().Context(), testContextKey("FROM_BALANCER"), "CUSTOM_BALANCER") c.SetRequest(c.Request().WithContext(ctx)) - return b.target, nil -} - -type sequenceBalancer struct { - targets []*ProxyTarget - calls int -} - -func (b *sequenceBalancer) AddTarget(target *ProxyTarget) bool { - return false -} - -func (b *sequenceBalancer) RemoveTarget(name string) bool { - return false -} - -func (b *sequenceBalancer) Next(c *echo.Context) (*ProxyTarget, error) { - target := b.targets[b.calls] - b.calls++ - return target, nil + return b.target } func TestModifyResponseUseContext(t *testing.T) { @@ -910,6 +781,7 @@ func TestModifyResponseUseContext(t *testing.T) { }), ) defer server.Close() + targetURL, _ := url.Parse(server.URL) e := echo.New() e.Use(ProxyWithConfig( @@ -930,9 +802,12 @@ func TestModifyResponseUseContext(t *testing.T) { }, }, )) + req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "OK", rec.Body.String()) assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER")) @@ -1165,108 +1040,3 @@ func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) { assert.NoError(t, err) assert.Equal(t, sendMsg, recvMsg) } - -// TestProxyWebSocketXForwardedFor verifies that for WebSocket Upgrade requests, -// the proxy middleware appends c.RealIP() to any existing X-Forwarded-For chain, -// mirroring net/http/httputil.(*ProxyRequest).SetXForwarded used by the HTTP path. -// -// Regression guard for the previous "set only if empty" behavior, which dropped -// the proxy's own peer IP from the chain whenever upstream proxies had already -// added entries. -func TestProxyWebSocketXForwardedFor(t *testing.T) { - tests := []struct { - name string - incomingXFF []string // nil = no incoming X-Forwarded-For header at all - wantPrefix string // expected join of entries preceding the appended proxy RealIP - }{ - { - name: "no incoming XFF, only proxy RealIP is set", - incomingXFF: nil, - wantPrefix: "", - }, - { - name: "single-line single-entry XFF is preserved with proxy RealIP appended", - incomingXFF: []string{"203.0.113.1"}, - wantPrefix: "203.0.113.1", - }, - { - name: "single-line comma-separated XFF is preserved with proxy RealIP appended", - incomingXFF: []string{"203.0.113.1, 10.0.0.5"}, - wantPrefix: "203.0.113.1, 10.0.0.5", - }, - { - name: "multi-line XFF (multiple header occurrences) is joined with proxy RealIP appended", - incomingXFF: []string{"203.0.113.1", "10.0.0.5"}, - wantPrefix: "203.0.113.1, 10.0.0.5", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Buffered so the upstream handler never blocks before the client reads. - headerCh := make(chan http.Header, 1) - - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - wsHandler := func(conn *websocket.Conn) { - headerCh <- conn.Request().Header.Clone() - defer conn.Close() - var msg string - if err := websocket.Message.Receive(conn, &msg); err == nil { - _ = websocket.Message.Send(conn, msg) - } - } - websocket.Server{Handler: wsHandler}.ServeHTTP(w, r) - })) - defer upstream.Close() - - tgtURL, _ := url.Parse(upstream.URL) - e := echo.New() - e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})})) - proxySrv := httptest.NewServer(e) - defer proxySrv.Close() - - proxyWSURL, _ := url.Parse(proxySrv.URL) - proxyWSURL.Scheme = "ws" - - origin, _ := url.Parse(proxySrv.URL) - cfg := &websocket.Config{ - Location: proxyWSURL, - Origin: origin, - Version: websocket.ProtocolVersionHybi13, - Header: http.Header{}, - } - for _, v := range tt.incomingXFF { - cfg.Header.Add(echo.HeaderXForwardedFor, v) - } - - wsConn, err := websocket.DialConfig(cfg) - assert.NoError(t, err) - defer wsConn.Close() - - assert.NoError(t, websocket.Message.Send(wsConn, "ping")) - var got string - assert.NoError(t, websocket.Message.Receive(wsConn, &got)) - - // The handler sends to headerCh before echoing, so it arrives before Receive returns. - captured := <-headerCh - xff := captured.Get(echo.HeaderXForwardedFor) - - // The middleware uses Header.Set, so the upstream sees exactly one - // X-Forwarded-For header line. Split it back into entries. - entries := strings.Split(xff, ", ") - assert.NotEmpty(t, entries, "X-Forwarded-For must be set by the proxy middleware") - - // The tail entry is the proxy's c.RealIP(). When the test client dials - // via httptest.NewServer the proxy sees 127.0.0.1. - tail := entries[len(entries)-1] - assert.Equal(t, "127.0.0.1", tail, - "proxy RealIP must be appended at the tail of X-Forwarded-For") - - // The remaining entries must equal the prior chain, preserving order - // and joining multi-line headers with ", ". - gotPrefix := strings.Join(entries[:len(entries)-1], ", ") - assert.Equal(t, tt.wantPrefix, gotPrefix, - "prior X-Forwarded-For entries must be preserved before the appended RealIP") - }) - } -} diff --git a/middleware/randomstring_bench_test.go b/middleware/randomstring_bench_test.go deleted file mode 100644 index d37d3dfd4..000000000 --- a/middleware/randomstring_bench_test.go +++ /dev/null @@ -1,50 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors - -package middleware - -import ( - "bufio" - "io" - "testing" -) - -// randomStringUnpooled is the previous (pre-pooling) implementation, kept here only to A/B benchmark -// against the current pooled randomString. -func randomStringUnpooled(length uint8) string { - reader := randomReaderPool.Get().(*bufio.Reader) - defer randomReaderPool.Put(reader) - - b := make([]byte, length) - r := make([]byte, length+(length/4)) - var i uint8 = 0 - for { - if _, err := io.ReadFull(reader, r); err != nil { - panic("unexpected error reading from crypto/rand") - } - for _, rb := range r { - if rb > randomStringMaxByte { - continue - } - b[i] = randomStringCharset[rb%randomStringCharsetLen] - i++ - if i == length { - return string(b) - } - } - } -} - -func BenchmarkRandomString_Unpooled(b *testing.B) { - b.ReportAllocs() - for i := 0; i < b.N; i++ { - _ = randomStringUnpooled(32) - } -} - -func BenchmarkRandomString_Pooled(b *testing.B) { - b.ReportAllocs() - for i := 0; i < b.N; i++ { - _ = randomString(32) - } -} diff --git a/middleware/randomstring_concurrent_test.go b/middleware/randomstring_concurrent_test.go deleted file mode 100644 index c1efc31a5..000000000 --- a/middleware/randomstring_concurrent_test.go +++ /dev/null @@ -1,37 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors - -package middleware - -import ( - "strings" - "sync" - "testing" -) - -// TestRandomStringConcurrent guards the pooled scratch buffers in randomString: concurrent callers -// must not share/alias a buffer and corrupt each other's output. Run with -race. -func TestRandomStringConcurrent(t *testing.T) { - const goroutines, iterations = 100, 300 - var wg sync.WaitGroup - wg.Add(goroutines) - for g := 0; g < goroutines; g++ { - go func() { - defer wg.Done() - for i := 0; i < iterations; i++ { - s := randomString(32) - if len(s) != 32 { - t.Errorf("expected length 32, got %d (%q)", len(s), s) - return - } - for _, r := range s { - if !strings.ContainsRune(randomStringCharset, r) { - t.Errorf("char %q not in charset (%q)", r, s) - return - } - } - } - }() - } - wg.Wait() -} diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index 9756daf50..2746a3de1 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -4,52 +4,37 @@ package middleware import ( - "errors" "math" "net/http" - "strconv" "sync" "time" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" "golang.org/x/time/rate" ) -// Rate limit response headers set by stores that implement RateLimiterStoreContext. -const ( - HeaderXRateLimitLimit = "X-RateLimit-Limit" - HeaderXRateLimitRemaining = "X-RateLimit-Remaining" -) - // RateLimiterStore is the interface to be implemented by custom stores. type RateLimiterStore interface { + // Stores for the rate limiter have to implement the Allow method Allow(identifier string) (bool, error) } -// RateLimiterStoreContext is an optional interface a RateLimiterStore may implement. -// When the configured store implements it, the rate limiter calls AllowContext -// (with the request context) instead of Allow, allowing the store to set response -// headers such as Retry-After or X-RateLimit-* on the allow/deny decision. -type RateLimiterStoreContext interface { - AllowContext(c *echo.Context, identifier string) (bool, error) -} - // RateLimiterConfig defines the configuration for the rate limiter type RateLimiterConfig struct { Skipper Skipper BeforeFunc BeforeFunc - // IdentifierExtractor uses *echo.Context to extract the identifier for a visitor + // IdentifierExtractor uses echo.Context to extract the identifier for a visitor IdentifierExtractor Extractor // Store defines a store for the rate limiter Store RateLimiterStore // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error - ErrorHandler func(c *echo.Context, err error) error + ErrorHandler func(context echo.Context, err error) error // DenyHandler provides a handler to be called when RateLimiter denies access - DenyHandler func(c *echo.Context, identifier string, err error) error + DenyHandler func(context echo.Context, identifier string, err error) error } -// Extractor is used to extract data from *echo.Context -type Extractor func(c *echo.Context) (string, error) +// Extractor is used to extract data from echo.Context +type Extractor func(context echo.Context) (string, error) // ErrRateLimitExceeded denotes an error raised when rate limit is exceeded var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") @@ -60,15 +45,23 @@ var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while ext // DefaultRateLimiterConfig defines default values for RateLimiterConfig var DefaultRateLimiterConfig = RateLimiterConfig{ Skipper: DefaultSkipper, - IdentifierExtractor: func(ctx *echo.Context) (string, error) { + IdentifierExtractor: func(ctx echo.Context) (string, error) { id := ctx.RealIP() return id, nil }, - ErrorHandler: func(c *echo.Context, err error) error { - return ErrExtractorError.Wrap(err) + ErrorHandler: func(context echo.Context, err error) error { + return &echo.HTTPError{ + Code: ErrExtractorError.Code, + Message: ErrExtractorError.Message, + Internal: err, + } }, - DenyHandler: func(c *echo.Context, identifier string, err error) error { - return ErrRateLimitExceeded.Wrap(err) + DenyHandler: func(context echo.Context, identifier string, err error) error { + return &echo.HTTPError{ + Code: ErrRateLimitExceeded.Code, + Message: ErrRateLimitExceeded.Message, + Internal: err, + } }, } @@ -79,7 +72,7 @@ RateLimiter returns a rate limiting middleware limiterStore := middleware.NewRateLimiterMemoryStore(20) - e.GET("/rate-limited", func(c *echo.Context) error { + e.GET("/rate-limited", func(c echo.Context) error { return c.String(http.StatusOK, "test") }, RateLimiter(limiterStore)) */ @@ -100,28 +93,23 @@ RateLimiterWithConfig returns a rate limiting middleware Store: middleware.NewRateLimiterMemoryStore( middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute} ) - IdentifierExtractor: func(ctx *echo.Context) (string, error) { + IdentifierExtractor: func(ctx echo.Context) (string, error) { id := ctx.RealIP() return id, nil }, - ErrorHandler: func(ctx *echo.Context, err error) error { + ErrorHandler: func(context echo.Context, err error) error { return context.JSON(http.StatusTooManyRequests, nil) }, - DenyHandler: func(ctx *echo.Context, identifier string, err error) error { + DenyHandler: func(context echo.Context, identifier string) error { return context.JSON(http.StatusForbidden, nil) }, } - e.GET("/rate-limited", func(c *echo.Context) error { + e.GET("/rate-limited", func(c echo.Context) error { return c.String(http.StatusOK, "test") }, middleware.RateLimiterWithConfig(config)) */ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { - return toMiddlewareOrPanic(config) -} - -// ToMiddleware converts RateLimiterConfig to middleware or returns an error for invalid configuration -func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultRateLimiterConfig.Skipper } @@ -135,10 +123,10 @@ func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) { config.DenyHandler = DefaultRateLimiterConfig.DenyHandler } if config.Store == nil { - return nil, errors.New("echo rate limiter store configuration must be provided") + panic("Store configuration must be provided") } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) error { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } @@ -148,29 +136,25 @@ func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) { identifier, err := config.IdentifierExtractor(c) if err != nil { - return config.ErrorHandler(c, err) + c.Error(config.ErrorHandler(c, err)) + return nil } - var allow bool - var allowErr error - if sc, ok := config.Store.(RateLimiterStoreContext); ok { - allow, allowErr = sc.AllowContext(c, identifier) - } else { - allow, allowErr = config.Store.Allow(identifier) - } - if !allow { - return config.DenyHandler(c, identifier, allowErr) + if allow, err := config.Store.Allow(identifier); !allow { + c.Error(config.DenyHandler(c, identifier, err)) + return nil } return next(c) } - }, nil + } } // RateLimiterMemoryStore is the built-in store implementation for RateLimiter type RateLimiterMemoryStore struct { - visitors map[string]*Visitor - mutex sync.Mutex - rate float64 // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit + visitors map[string]*Visitor + mutex sync.Mutex + rate rate.Limit // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. + burst int expiresIn time.Duration lastCleanup time.Time @@ -191,21 +175,21 @@ for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate Burst and ExpiresIn will be set to default values. -Note that if the provided rate is a float number and Burst is zero, Burst will be treated as the rounded up value of the rate. +Note that if the provided rate is a float number and Burst is zero, Burst will be treated as the rounded down value of the rate. Example (with 20 requests/sec): limiterStore := middleware.NewRateLimiterMemoryStore(20) */ -func NewRateLimiterMemoryStore(rateLimit float64) (store *RateLimiterMemoryStore) { +func NewRateLimiterMemoryStore(rate rate.Limit) (store *RateLimiterMemoryStore) { return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ - Rate: rateLimit, + Rate: rate, }) } /* NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore -with the provided configuration. Rate must be provided. Burst will be set to the rounded up value of +with the provided configuration. Rate must be provided. Burst will be set to the rounded down value of the configured rate if not provided or set to 0. The built-in memory store is usually capable for modest loads. For higher loads other @@ -242,7 +226,7 @@ func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (s // RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore type RateLimiterMemoryStoreConfig struct { - Rate float64 // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. + Rate rate.Limit // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. Burst int // Burst is maximum number of requests to pass at the same moment. It additionally allows a number of requests to pass when rate limit is reached. ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up } @@ -254,64 +238,32 @@ var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{ // Allow implements RateLimiterStore.Allow func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { - _, allowed := store.allow(identifier) - return allowed, nil -} - -// AllowContext implements RateLimiterStoreContext: it makes the allow/deny decision -// and sets the X-RateLimit-* (and Retry-After when denied) response headers. -func (store *RateLimiterMemoryStore) AllowContext(c *echo.Context, identifier string) (bool, error) { - limiter, allowed := store.allow(identifier) - store.setRateLimitHeaders(c, limiter, allowed) - return allowed, nil -} - -func (store *RateLimiterMemoryStore) allow(identifier string) (*rate.Limiter, bool) { store.mutex.Lock() - defer store.mutex.Unlock() - limiter, exists := store.visitors[identifier] if !exists { limiter = new(Visitor) - limiter.Limiter = rate.NewLimiter(rate.Limit(store.rate), store.burst) + limiter.Limiter = rate.NewLimiter(store.rate, store.burst) store.visitors[identifier] = limiter } now := store.timeNow() limiter.lastSeen = now if now.Sub(store.lastCleanup) > store.expiresIn { - store.cleanupStaleVisitors(now) - } - return limiter.Limiter, limiter.AllowN(now, 1) -} - -func (store *RateLimiterMemoryStore) setRateLimitHeaders(c *echo.Context, limiter *rate.Limiter, allowed bool) { - header := c.Response().Header() - header.Set(HeaderXRateLimitLimit, strconv.Itoa(store.burst)) - - remaining := int(math.Floor(limiter.Tokens())) - if remaining < 0 { - remaining = 0 - } - header.Set(HeaderXRateLimitRemaining, strconv.Itoa(remaining)) - - if !allowed { - reservation := limiter.ReserveN(store.timeNow(), 1) - if delay := reservation.Delay(); delay > 0 { - header.Set(echo.HeaderRetryAfter, strconv.Itoa(int(math.Ceil(delay.Seconds())))) - } - reservation.Cancel() + store.cleanupStaleVisitors() } + allowed := limiter.AllowN(now, 1) + store.mutex.Unlock() + return allowed, nil } /* cleanupStaleVisitors helps manage the size of the visitors map by removing stale records of users who haven't visited again after the configured expiry time has elapsed */ -func (store *RateLimiterMemoryStore) cleanupStaleVisitors(now time.Time) { +func (store *RateLimiterMemoryStore) cleanupStaleVisitors() { for id, visitor := range store.visitors { - if now.Sub(visitor.lastSeen) > store.expiresIn { + if store.timeNow().Sub(visitor.lastSeen) > store.expiresIn { delete(store.visitors, id) } } - store.lastCleanup = now + store.lastCleanup = store.timeNow() } diff --git a/middleware/rate_limiter_context_test.go b/middleware/rate_limiter_context_test.go deleted file mode 100644 index 629c01e47..000000000 --- a/middleware/rate_limiter_context_test.go +++ /dev/null @@ -1,89 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors - -package middleware - -import ( - "net/http" - "net/http/httptest" - "strconv" - "testing" - - "github.com/labstack/echo/v5" - "github.com/stretchr/testify/assert" -) - -// ctxAwareStore implements both Allow and the optional AllowContext. AllowContext -// gives the store the request context so it can set response headers (e.g. -// Retry-After / X-RateLimit-*) β€” see #2961. -type ctxAwareStore struct { - allowCalled bool - ctxAllowCalled bool - allow bool -} - -func (s *ctxAwareStore) Allow(identifier string) (bool, error) { - s.allowCalled = true - return s.allow, nil -} - -func (s *ctxAwareStore) AllowContext(c *echo.Context, identifier string) (bool, error) { - s.ctxAllowCalled = true - c.Response().Header().Set("Retry-After", "42") - return s.allow, nil -} - -// When the store implements AllowContext, the middleware must call it instead of -// Allow, so the store can set rate-limit headers on the response. -func TestRateLimiter_storeAllowContextIsPreferred(t *testing.T) { - e := echo.New() - store := &ctxAwareStore{allow: true} - mw := RateLimiterWithConfig(RateLimiterConfig{ - Store: store, - IdentifierExtractor: func(c *echo.Context) (string, error) { return "id", nil }, - }) - handler := mw(func(c *echo.Context) error { return c.String(http.StatusOK, "ok") }) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - assert.NoError(t, handler(c)) - assert.True(t, store.ctxAllowCalled, "AllowContext should be called when implemented") - assert.False(t, store.allowCalled, "Allow should not be called when AllowContext is implemented") - assert.Equal(t, "42", rec.Header().Get("Retry-After"), "store should be able to set headers via the context") -} - -// The built-in memory store implements AllowContext, so it sets X-RateLimit-Limit / -// X-RateLimit-Remaining on every request and Retry-After when the limit is hit (#2961). -func TestRateLimiterMemoryStore_AllowContextSetsHeaders(t *testing.T) { - store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) - e := echo.New() - e.GET("/", func(c *echo.Context) error { return c.String(http.StatusOK, "ok") }, - RateLimiterWithConfig(RateLimiterConfig{ - Store: store, - IdentifierExtractor: func(c *echo.Context) (string, error) { return "id", nil }, - })) - - do := func() *httptest.ResponseRecorder { - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - return rec - } - - // Burst of 3: each allowed request advertises the limit and decreasing remaining. - for i := 0; i < 3; i++ { - rec := do() - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "3", rec.Header().Get(HeaderXRateLimitLimit)) - assert.Equal(t, strconv.Itoa(2-i), rec.Header().Get(HeaderXRateLimitRemaining)) - assert.Empty(t, rec.Header().Get(echo.HeaderRetryAfter)) - } - - // 4th request is denied: 429, remaining 0, and a Retry-After hint. - rec := do() - assert.Equal(t, http.StatusTooManyRequests, rec.Code) - assert.Equal(t, "0", rec.Header().Get(HeaderXRateLimitRemaining)) - assert.NotEmpty(t, rec.Header().Get(echo.HeaderRetryAfter)) -} diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go index 267e8d08f..655d4731d 100644 --- a/middleware/rate_limiter_test.go +++ b/middleware/rate_limiter_test.go @@ -13,7 +13,7 @@ import ( "testing" "time" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" "golang.org/x/time/rate" ) @@ -21,25 +21,25 @@ import ( func TestRateLimiter(t *testing.T) { e := echo.New() - handler := func(c *echo.Context) error { + handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) - mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore}) + mw := RateLimiter(inMemoryStore) testCases := []struct { - id string - expectErr string + id string + code int }{ - {id: "127.0.0.1"}, - {id: "127.0.0.1"}, - {id: "127.0.0.1"}, - {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, - {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, - {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, - {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, } for _, tc := range testCases { @@ -49,25 +49,20 @@ func TestRateLimiter(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - err := mw(handler)(c) - if tc.expectErr != "" { - assert.EqualError(t, err, tc.expectErr) - } else { - assert.NoError(t, err) - } - assert.Equal(t, http.StatusOK, rec.Code) + _ = mw(handler)(c) + assert.Equal(t, tc.code, rec.Code) } } -func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) { +func TestRateLimiter_panicBehaviour(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) assert.Panics(t, func() { - RateLimiterWithConfig(RateLimiterConfig{}) + RateLimiter(nil) }) assert.NotPanics(t, func() { - RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore}) + RateLimiter(inMemoryStore) }) } @@ -76,27 +71,26 @@ func TestRateLimiterWithConfig(t *testing.T) { e := echo.New() - handler := func(c *echo.Context) error { + handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } - mw, err := RateLimiterConfig{ - IdentifierExtractor: func(c *echo.Context) (string, error) { + mw := RateLimiterWithConfig(RateLimiterConfig{ + IdentifierExtractor: func(c echo.Context) (string, error) { id := c.Request().Header.Get(echo.HeaderXRealIP) if id == "" { return "", errors.New("invalid identifier") } return id, nil }, - DenyHandler: func(ctx *echo.Context, identifier string, err error) error { + DenyHandler: func(ctx echo.Context, identifier string, err error) error { return ctx.JSON(http.StatusForbidden, nil) }, - ErrorHandler: func(ctx *echo.Context, err error) error { + ErrorHandler: func(ctx echo.Context, err error) error { return ctx.JSON(http.StatusBadRequest, nil) }, Store: inMemoryStore, - }.ToMiddleware() - assert.NoError(t, err) + }) testCases := []struct { id string @@ -119,9 +113,8 @@ func TestRateLimiterWithConfig(t *testing.T) { c := e.NewContext(req, rec) - err := mw(handler)(c) + _ = mw(handler)(c) - assert.NoError(t, err) assert.Equal(t, tc.code, rec.Code) } } @@ -131,12 +124,12 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { e := echo.New() - handler := func(c *echo.Context) error { + handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } - mw, err := RateLimiterConfig{ - IdentifierExtractor: func(c *echo.Context) (string, error) { + mw := RateLimiterWithConfig(RateLimiterConfig{ + IdentifierExtractor: func(c echo.Context) (string, error) { id := c.Request().Header.Get(echo.HeaderXRealIP) if id == "" { return "", errors.New("invalid identifier") @@ -144,20 +137,19 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { return id, nil }, Store: inMemoryStore, - }.ToMiddleware() - assert.NoError(t, err) + }) testCases := []struct { - id string - expectErr string + id string + code int }{ - {id: "127.0.0.1"}, - {id: "127.0.0.1"}, - {id: "127.0.0.1"}, - {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, - {expectErr: "code=403, message=error while extracting identifier, err=invalid identifier"}, - {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, - {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"", http.StatusForbidden}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, } for _, tc := range testCases { @@ -168,13 +160,9 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { c := e.NewContext(req, rec) - err := mw(handler)(c) - if tc.expectErr != "" { - assert.EqualError(t, err, tc.expectErr) - } else { - assert.NoError(t, err) - } - assert.Equal(t, http.StatusOK, rec.Code) + _ = mw(handler)(c) + + assert.Equal(t, tc.code, rec.Code) } } @@ -184,26 +172,25 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) { e := echo.New() - handler := func(c *echo.Context) error { + handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } - mw, err := RateLimiterConfig{ + mw := RateLimiterWithConfig(RateLimiterConfig{ Store: inMemoryStore, - }.ToMiddleware() - assert.NoError(t, err) + }) testCases := []struct { - id string - expectErr string + id string + code int }{ - {id: "127.0.0.1"}, - {id: "127.0.0.1"}, - {id: "127.0.0.1"}, - {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, - {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, - {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, - {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, } for _, tc := range testCases { @@ -214,13 +201,9 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) { c := e.NewContext(req, rec) - err := mw(handler)(c) - if tc.expectErr != "" { - assert.EqualError(t, err, tc.expectErr) - } else { - assert.NoError(t, err) - } - assert.Equal(t, http.StatusOK, rec.Code) + _ = mw(handler)(c) + + assert.Equal(t, tc.code, rec.Code) } } } @@ -229,7 +212,7 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) { e := echo.New() var beforeFuncRan bool - handler := func(c *echo.Context) error { + handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } var inMemoryStore = NewRateLimiterMemoryStore(5) @@ -241,23 +224,21 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) { c := e.NewContext(req, rec) - mw, err := RateLimiterConfig{ - Skipper: func(c *echo.Context) bool { + mw := RateLimiterWithConfig(RateLimiterConfig{ + Skipper: func(c echo.Context) bool { return true }, - BeforeFunc: func(c *echo.Context) { + BeforeFunc: func(c echo.Context) { beforeFuncRan = true }, Store: inMemoryStore, - IdentifierExtractor: func(ctx *echo.Context) (string, error) { + IdentifierExtractor: func(ctx echo.Context) (string, error) { return "127.0.0.1", nil }, - }.ToMiddleware() - assert.NoError(t, err) + }) - err = mw(handler)(c) + _ = mw(handler)(c) - assert.NoError(t, err) assert.Equal(t, false, beforeFuncRan) } @@ -265,7 +246,7 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { e := echo.New() var beforeFuncRan bool - handler := func(c *echo.Context) error { + handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } var inMemoryStore = NewRateLimiterMemoryStore(5) @@ -277,19 +258,18 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { c := e.NewContext(req, rec) - mw, err := RateLimiterConfig{ - Skipper: func(c *echo.Context) bool { + mw := RateLimiterWithConfig(RateLimiterConfig{ + Skipper: func(c echo.Context) bool { return false }, - BeforeFunc: func(c *echo.Context) { + BeforeFunc: func(c echo.Context) { beforeFuncRan = true }, Store: inMemoryStore, - IdentifierExtractor: func(ctx *echo.Context) (string, error) { + IdentifierExtractor: func(ctx echo.Context) (string, error) { return "127.0.0.1", nil }, - }.ToMiddleware() - assert.NoError(t, err) + }) _ = mw(handler)(c) @@ -299,7 +279,7 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { func TestRateLimiterWithConfig_beforeFunc(t *testing.T) { e := echo.New() - handler := func(c *echo.Context) error { + handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } @@ -313,20 +293,18 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) { c := e.NewContext(req, rec) - mw, err := RateLimiterConfig{ - BeforeFunc: func(c *echo.Context) { + mw := RateLimiterWithConfig(RateLimiterConfig{ + BeforeFunc: func(c echo.Context) { beforeRan = true }, Store: inMemoryStore, - IdentifierExtractor: func(ctx *echo.Context) (string, error) { + IdentifierExtractor: func(ctx echo.Context) (string, error) { return "127.0.0.1", nil }, - }.ToMiddleware() - assert.NoError(t, err) + }) - err = mw(handler)(c) + _ = mw(handler)(c) - assert.NoError(t, err) assert.Equal(t, true, beforeRan) } @@ -394,7 +372,7 @@ func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) { } inMemoryStore.Allow("D") - inMemoryStore.cleanupStaleVisitors(time.Now()) + inMemoryStore.cleanupStaleVisitors() var exists bool @@ -413,7 +391,7 @@ func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) { func TestNewRateLimiterMemoryStore(t *testing.T) { testCases := []struct { - rate float64 + rate rate.Limit burst int expiresIn time.Duration expectedExpiresIn time.Duration @@ -461,7 +439,7 @@ func TestRateLimiterMemoryStore_FractionalRateDefaultBurst(t *testing.T) { func generateAddressList(count int) []string { addrs := make([]string, count) - for i := range count { + for i := 0; i < count; i++ { addrs[i] = randomString(15) } return addrs @@ -477,7 +455,7 @@ func run(wg *sync.WaitGroup, store RateLimiterStore, addrs []string, max int, b func benchmarkStore(store RateLimiterStore, parallel int, max int, b *testing.B) { addrs := generateAddressList(max) wg := &sync.WaitGroup{} - for range parallel { + for i := 0; i < parallel; i++ { wg.Add(1) go run(wg, store, addrs, max, b) } @@ -553,9 +531,11 @@ func TestRateLimiterMemoryStore_ConcurrentAccess(t *testing.T) { var wg sync.WaitGroup var allowedCount, deniedCount int32 - for range goroutines { - wg.Go(func() { - for range requestsPerGoroutine { + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < requestsPerGoroutine; j++ { allowed, err := store.Allow("test-user") assert.NoError(t, err) if allowed { @@ -565,7 +545,7 @@ func TestRateLimiterMemoryStore_ConcurrentAccess(t *testing.T) { } time.Sleep(time.Millisecond) } - }) + }() } wg.Wait() @@ -596,11 +576,11 @@ func TestRateLimiterMemoryStore_RaceDetection(t *testing.T) { var wg sync.WaitGroup identifiers := []string{"user1", "user2", "user3", "user4", "user5"} - for i := range goroutines { + for i := 0; i < goroutines; i++ { wg.Add(1) go func(routineID int) { defer wg.Done() - for range requestsPerGoroutine { + for j := 0; j < requestsPerGoroutine; j++ { identifier := identifiers[routineID%len(identifiers)] _, err := store.Allow(identifier) assert.NoError(t, err) diff --git a/middleware/recover.go b/middleware/recover.go index 01fde5152..e6a5940e4 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -8,9 +8,13 @@ import ( "net/http" "runtime" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" + "github.com/labstack/gommon/log" ) +// LogErrorFunc defines a function for custom logging in the middleware. +type LogErrorFunc func(c echo.Context, err error, stack []byte) error + // RecoverConfig defines the config for Recover middleware. type RecoverConfig struct { // Skipper defines a function to skip middleware. @@ -18,24 +22,41 @@ type RecoverConfig struct { // Size of the stack to be printed. // Optional. Default value 4KB. - StackSize int + StackSize int `yaml:"stack_size"` // DisableStackAll disables formatting stack traces of all other goroutines // into buffer after the trace for the current goroutine. // Optional. Default value false. - DisableStackAll bool + DisableStackAll bool `yaml:"disable_stack_all"` // DisablePrintStack disables printing stack trace. // Optional. Default value as false. - DisablePrintStack bool + DisablePrintStack bool `yaml:"disable_print_stack"` + + // LogLevel is log level to printing stack trace. + // Optional. Default value 0 (Print). + LogLevel log.Lvl + + // LogErrorFunc defines a function for custom logging in the middleware. + // If it's set you don't need to provide LogLevel for config. + // If this function returns nil, the centralized HTTPErrorHandler will not be called. + LogErrorFunc LogErrorFunc + + // DisableErrorHandler disables the call to centralized HTTPErrorHandler. + // The recovered error is then passed back to upstream middleware, instead of swallowing the error. + // Optional. Default value false. + DisableErrorHandler bool `yaml:"disable_error_handler"` } // DefaultRecoverConfig is the default Recover middleware config. var DefaultRecoverConfig = RecoverConfig{ - Skipper: DefaultSkipper, - StackSize: 4 << 10, // 4 KB - DisableStackAll: false, - DisablePrintStack: false, + Skipper: DefaultSkipper, + StackSize: 4 << 10, // 4 KB + DisableStackAll: false, + DisablePrintStack: false, + LogLevel: 0, + LogErrorFunc: nil, + DisableErrorHandler: false, } // Recover returns a middleware which recovers from panics anywhere in the chain @@ -44,13 +65,9 @@ func Recover() echo.MiddlewareFunc { return RecoverWithConfig(DefaultRecoverConfig) } -// RecoverWithConfig returns a Recovery middleware with config or panics on invalid configuration. +// RecoverWithConfig returns a Recover middleware with config. +// See: `Recover()`. func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { - return toMiddlewareOrPanic(config) -} - -// ToMiddleware converts RecoverConfig to middleware or returns an error for invalid configuration -func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultRecoverConfig.Skipper @@ -60,7 +77,7 @@ func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) { } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) (err error) { + return func(c echo.Context) (returnErr error) { if config.Skipper(c) { return next(c) } @@ -70,34 +87,47 @@ func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if r == http.ErrAbortHandler { panic(r) } - tmpErr, ok := r.(error) + err, ok := r.(error) if !ok { - tmpErr = fmt.Errorf("%v", r) + err = fmt.Errorf("%v", r) } + var stack []byte + var length int + if !config.DisablePrintStack { - stack := make([]byte, config.StackSize) - length := runtime.Stack(stack, !config.DisableStackAll) - tmpErr = &PanicStackError{Stack: stack[:length], Err: tmpErr} + stack = make([]byte, config.StackSize) + length = runtime.Stack(stack, !config.DisableStackAll) + stack = stack[:length] + } + + if config.LogErrorFunc != nil { + err = config.LogErrorFunc(c, err, stack) + } else if !config.DisablePrintStack { + msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length]) + switch config.LogLevel { + case log.DEBUG: + c.Logger().Debug(msg) + case log.INFO: + c.Logger().Info(msg) + case log.WARN: + c.Logger().Warn(msg) + case log.ERROR: + c.Logger().Error(msg) + case log.OFF: + // None. + default: + c.Logger().Print(msg) + } + } + + if err != nil && !config.DisableErrorHandler { + c.Error(err) + } else { + returnErr = err } - err = tmpErr } }() return next(c) } - }, nil -} - -// PanicStackError is an error type that wraps an error along with its stack trace. -// It is returned when config.DisablePrintStack is set to false. -type PanicStackError struct { - Stack []byte - Err error -} - -func (e *PanicStackError) Error() string { - return fmt.Sprintf("[PANIC RECOVER] %s %s", e.Err.Error(), e.Stack) -} - -func (e *PanicStackError) Unwrap() error { - return e.Err + } } diff --git a/middleware/recover_test.go b/middleware/recover_test.go index 719e0cc3d..8fa34fa5c 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -6,72 +6,42 @@ package middleware import ( "bytes" "errors" - "log/slog" + "fmt" "net/http" "net/http/httptest" "testing" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" + "github.com/labstack/gommon/log" "github.com/stretchr/testify/assert" ) func TestRecover(t *testing.T) { e := echo.New() buf := new(bytes.Buffer) - e.Logger = slog.New(&discardHandler{}) + e.Logger.SetOutput(buf) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := Recover()(func(c *echo.Context) error { + h := Recover()(echo.HandlerFunc(func(c echo.Context) error { panic("test") - }) + })) err := h(c) - assert.Contains(t, err.Error(), "[PANIC RECOVER] test goroutine") - - var pse *PanicStackError - if errors.As(err, &pse) { - assert.Contains(t, string(pse.Stack), "middleware/recover.go") - } else { - assert.Fail(t, "not of type PanicStackError") - } - - assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain - assert.Contains(t, buf.String(), "") // nothing is logged -} - -func TestRecover_skipper(t *testing.T) { - e := echo.New() - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - config := RecoverConfig{ - Skipper: func(c *echo.Context) bool { - return true - }, - } - h := RecoverWithConfig(config)(func(c *echo.Context) error { - panic("testPANIC") - }) - - var err error - assert.Panics(t, func() { - err = h(c) - }) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Contains(t, buf.String(), "PANIC RECOVER") } func TestRecoverErrAbortHandler(t *testing.T) { e := echo.New() + buf := new(bytes.Buffer) + e.Logger.SetOutput(buf) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := Recover()(func(c *echo.Context) error { + h := Recover()(echo.HandlerFunc(func(c echo.Context) error { panic(http.ErrAbortHandler) - }) + })) defer func() { r := recover() if r == nil { @@ -85,66 +55,135 @@ func TestRecoverErrAbortHandler(t *testing.T) { } }() - hErr := h(c) + h(c) assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.NotContains(t, hErr.Error(), "PANIC RECOVER") + assert.NotContains(t, buf.String(), "PANIC RECOVER") } -func TestRecoverWithConfig(t *testing.T) { - var testCases = []struct { - name string - givenNoPanic bool - whenConfig RecoverConfig - expectErrContain string - expectErr string - }{ - { - name: "ok, default config", - whenConfig: DefaultRecoverConfig, - expectErrContain: "[PANIC RECOVER] testPANIC goroutine", - }, - { - name: "ok, no panic", - givenNoPanic: true, - whenConfig: DefaultRecoverConfig, - expectErrContain: "", - }, - { - name: "ok, DisablePrintStack", - whenConfig: RecoverConfig{ - DisablePrintStack: true, - }, - expectErr: "testPANIC", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { +func TestRecoverWithConfig_LogLevel(t *testing.T) { + tests := []struct { + logLevel log.Lvl + levelName string + }{{ + logLevel: log.DEBUG, + levelName: "DEBUG", + }, { + logLevel: log.INFO, + levelName: "INFO", + }, { + logLevel: log.WARN, + levelName: "WARN", + }, { + logLevel: log.ERROR, + levelName: "ERROR", + }, { + logLevel: log.OFF, + levelName: "OFF", + }} + + for _, tt := range tests { + tt := tt + t.Run(tt.levelName, func(t *testing.T) { e := echo.New() + e.Logger.SetLevel(log.DEBUG) + + buf := new(bytes.Buffer) + e.Logger.SetOutput(buf) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - config := tc.whenConfig - h := RecoverWithConfig(config)(func(c *echo.Context) error { - if tc.givenNoPanic { - return nil - } - panic("testPANIC") - }) + config := DefaultRecoverConfig + config.LogLevel = tt.logLevel + h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { + panic("test") + })) - err := h(c) + h(c) - if tc.expectErrContain != "" { - assert.Contains(t, err.Error(), tc.expectErrContain) - } else if tc.expectErr != "" { - assert.Contains(t, err.Error(), tc.expectErr) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + + output := buf.String() + if tt.logLevel == log.OFF { + assert.Empty(t, output) } else { - assert.NoError(t, err) + assert.Contains(t, output, "PANIC RECOVER") + assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName)) } - assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain }) } } + +func TestRecoverWithConfig_LogErrorFunc(t *testing.T) { + e := echo.New() + e.Logger.SetLevel(log.DEBUG) + + buf := new(bytes.Buffer) + e.Logger.SetOutput(buf) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + testError := errors.New("test") + config := DefaultRecoverConfig + config.LogErrorFunc = func(c echo.Context, err error, stack []byte) error { + msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack) + if errors.Is(err, testError) { + c.Logger().Debug(msg) + } else { + c.Logger().Error(msg) + } + return err + } + + t.Run("first branch case for LogErrorFunc", func(t *testing.T) { + buf.Reset() + h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { + panic(testError) + })) + + h(c) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + + output := buf.String() + assert.Contains(t, output, "PANIC RECOVER") + assert.Contains(t, output, `"level":"DEBUG"`) + }) + + t.Run("else branch case for LogErrorFunc", func(t *testing.T) { + buf.Reset() + h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { + panic("other") + })) + + h(c) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + + output := buf.String() + assert.Contains(t, output, "PANIC RECOVER") + assert.Contains(t, output, `"level":"ERROR"`) + }) +} + +func TestRecoverWithDisabled_ErrorHandler(t *testing.T) { + e := echo.New() + buf := new(bytes.Buffer) + e.Logger.SetOutput(buf) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + config := DefaultRecoverConfig + config.DisableErrorHandler = true + h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { + panic("test") + })) + err := h(c) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Contains(t, buf.String(), "PANIC RECOVER") + assert.EqualError(t, err, "test") +} diff --git a/middleware/redirect.go b/middleware/redirect.go index bb7045cfe..b772ac131 100644 --- a/middleware/redirect.go +++ b/middleware/redirect.go @@ -4,11 +4,10 @@ package middleware import ( - "errors" "net/http" "strings" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" ) // RedirectConfig defines the config for Redirect middleware. @@ -18,9 +17,7 @@ type RedirectConfig struct { // Status code to be used when redirecting the request. // Optional. Default value http.StatusMovedPermanently. - Code int - - redirect redirectLogic + Code int `yaml:"code"` } // redirectLogic represents a function that given a scheme, host and uri @@ -30,33 +27,29 @@ type redirectLogic func(scheme, host, uri string) (ok bool, url string) const www = "www." -// RedirectHTTPSConfig is the HTTPS Redirect middleware config. -var RedirectHTTPSConfig = RedirectConfig{redirect: redirectHTTPS} - -// RedirectHTTPSWWWConfig is the HTTPS WWW Redirect middleware config. -var RedirectHTTPSWWWConfig = RedirectConfig{redirect: redirectHTTPSWWW} - -// RedirectNonHTTPSWWWConfig is the non HTTPS WWW Redirect middleware config. -var RedirectNonHTTPSWWWConfig = RedirectConfig{redirect: redirectNonHTTPSWWW} - -// RedirectWWWConfig is the WWW Redirect middleware config. -var RedirectWWWConfig = RedirectConfig{redirect: redirectWWW} - -// RedirectNonWWWConfig is the non WWW Redirect middleware config. -var RedirectNonWWWConfig = RedirectConfig{redirect: redirectNonWWW} +// DefaultRedirectConfig is the default Redirect middleware config. +var DefaultRedirectConfig = RedirectConfig{ + Skipper: DefaultSkipper, + Code: http.StatusMovedPermanently, +} // HTTPSRedirect redirects http requests to https. // For example, http://labstack.com will be redirect to https://labstack.com. // // Usage `Echo#Pre(HTTPSRedirect())` func HTTPSRedirect() echo.MiddlewareFunc { - return HTTPSRedirectWithConfig(RedirectHTTPSConfig) + return HTTPSRedirectWithConfig(DefaultRedirectConfig) } -// HTTPSRedirectWithConfig returns a HTTPS redirect middleware with config or panics on invalid configuration. +// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config. +// See `HTTPSRedirect()`. func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - config.redirect = redirectHTTPS - return toMiddlewareOrPanic(config) + return redirect(config, func(scheme, host, uri string) (bool, string) { + if scheme != "https" { + return true, "https://" + host + uri + } + return false, "" + }) } // HTTPSWWWRedirect redirects http requests to https www. @@ -64,13 +57,18 @@ func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(HTTPSWWWRedirect())` func HTTPSWWWRedirect() echo.MiddlewareFunc { - return HTTPSWWWRedirectWithConfig(RedirectHTTPSWWWConfig) + return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig) } -// HTTPSWWWRedirectWithConfig returns a HTTPS WWW redirect middleware with config or panics on invalid configuration. +// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. +// See `HTTPSWWWRedirect()`. func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - config.redirect = redirectHTTPSWWW - return toMiddlewareOrPanic(config) + return redirect(config, func(scheme, host, uri string) (bool, string) { + if scheme != "https" && !strings.HasPrefix(host, www) { + return true, "https://www." + host + uri + } + return false, "" + }) } // HTTPSNonWWWRedirect redirects http requests to https non www. @@ -78,13 +76,19 @@ func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(HTTPSNonWWWRedirect())` func HTTPSNonWWWRedirect() echo.MiddlewareFunc { - return HTTPSNonWWWRedirectWithConfig(RedirectNonHTTPSWWWConfig) + return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig) } -// HTTPSNonWWWRedirectWithConfig returns a HTTPS Non-WWW redirect middleware with config or panics on invalid configuration. +// HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. +// See `HTTPSNonWWWRedirect()`. func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - config.redirect = redirectNonHTTPSWWW - return toMiddlewareOrPanic(config) + return redirect(config, func(scheme, host, uri string) (ok bool, url string) { + if scheme != "https" { + host = strings.TrimPrefix(host, www) + return true, "https://" + host + uri + } + return false, "" + }) } // WWWRedirect redirects non www requests to www. @@ -92,13 +96,18 @@ func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(WWWRedirect())` func WWWRedirect() echo.MiddlewareFunc { - return WWWRedirectWithConfig(RedirectWWWConfig) + return WWWRedirectWithConfig(DefaultRedirectConfig) } -// WWWRedirectWithConfig returns a WWW redirect middleware with config or panics on invalid configuration. +// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config. +// See `WWWRedirect()`. func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - config.redirect = redirectWWW - return toMiddlewareOrPanic(config) + return redirect(config, func(scheme, host, uri string) (bool, string) { + if !strings.HasPrefix(host, www) { + return true, scheme + "://www." + host + uri + } + return false, "" + }) } // NonWWWRedirect redirects www requests to non www. @@ -106,79 +115,41 @@ func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(NonWWWRedirect())` func NonWWWRedirect() echo.MiddlewareFunc { - return NonWWWRedirectWithConfig(RedirectNonWWWConfig) + return NonWWWRedirectWithConfig(DefaultRedirectConfig) } -// NonWWWRedirectWithConfig returns a Non-WWW redirect middleware with config or panics on invalid configuration. +// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. +// See `NonWWWRedirect()`. func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - config.redirect = redirectNonWWW - return toMiddlewareOrPanic(config) + return redirect(config, func(scheme, host, uri string) (bool, string) { + if strings.HasPrefix(host, www) { + return true, scheme + "://" + host[4:] + uri + } + return false, "" + }) } -// ToMiddleware converts RedirectConfig to middleware or returns an error for invalid configuration -func (config RedirectConfig) ToMiddleware() (echo.MiddlewareFunc, error) { +func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc { if config.Skipper == nil { - config.Skipper = DefaultSkipper + config.Skipper = DefaultRedirectConfig.Skipper } if config.Code == 0 { - config.Code = http.StatusMovedPermanently - } - if config.redirect == nil { - return nil, errors.New("redirectConfig is missing redirect function") + config.Code = DefaultRedirectConfig.Code } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) error { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } req, scheme := c.Request(), c.Scheme() host := req.Host - if ok, url := config.redirect(scheme, host, req.RequestURI); ok { + if ok, url := cb(scheme, host, req.RequestURI); ok { return c.Redirect(config.Code, url) } return next(c) } - }, nil -} - -var redirectHTTPS = func(scheme, host, uri string) (bool, string) { - if scheme != "https" { - return true, "https://" + host + uri - } - return false, "" -} - -var redirectHTTPSWWW = func(scheme, host, uri string) (bool, string) { - // Redirect if not HTTPS OR missing www prefix (needs either fix) - if scheme != "https" || !strings.HasPrefix(host, www) { - host = strings.TrimPrefix(host, www) // Remove www if present to avoid duplication - return true, "https://www." + host + uri - } - return false, "" -} - -var redirectNonHTTPSWWW = func(scheme, host, uri string) (ok bool, url string) { - // Redirect if not HTTPS OR has www prefix (needs either fix) - if scheme != "https" || strings.HasPrefix(host, www) { - host = strings.TrimPrefix(host, www) - return true, "https://" + host + uri - } - return false, "" -} - -var redirectWWW = func(scheme, host, uri string) (bool, string) { - if !strings.HasPrefix(host, www) { - return true, scheme + "://www." + host + uri - } - return false, "" -} - -var redirectNonWWW = func(scheme, host, uri string) (bool, string) { - if strings.HasPrefix(host, www) { - return true, scheme + "://" + host[4:] + uri } - return false, "" } diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go index a127ca40c..88068ea2e 100644 --- a/middleware/redirect_test.go +++ b/middleware/redirect_test.go @@ -8,7 +8,7 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) @@ -58,8 +58,8 @@ func TestRedirectHTTPSWWWRedirect(t *testing.T) { }, { whenHost: "www.labstack.com", - expectLocation: "https://www.labstack.com/", - expectStatusCode: http.StatusMovedPermanently, + expectLocation: "", + expectStatusCode: http.StatusOK, }, { whenHost: "a.com", @@ -74,12 +74,6 @@ func TestRedirectHTTPSWWWRedirect(t *testing.T) { { whenHost: "labstack.com", whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, - expectLocation: "https://www.labstack.com/", - expectStatusCode: http.StatusMovedPermanently, - }, - { - whenHost: "www.labstack.com", - whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, expectLocation: "", expectStatusCode: http.StatusOK, }, @@ -120,12 +114,6 @@ func TestRedirectHTTPSNonWWWRedirect(t *testing.T) { { whenHost: "www.labstack.com", whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, - expectLocation: "https://labstack.com/", - expectStatusCode: http.StatusMovedPermanently, - }, - { - whenHost: "labstack.com", - whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, expectLocation: "", expectStatusCode: http.StatusOK, }, @@ -230,7 +218,7 @@ func TestNonWWWRedirectWithConfig(t *testing.T) { var testCases = []struct { name string givenCode int - givenSkipFunc func(c *echo.Context) bool + givenSkipFunc func(c echo.Context) bool whenHost string whenHeader http.Header expectLocation string @@ -244,7 +232,7 @@ func TestNonWWWRedirectWithConfig(t *testing.T) { }, { name: "redirect is skipped", - givenSkipFunc: func(c *echo.Context) bool { + givenSkipFunc: func(c echo.Context) bool { return true // skip always }, whenHost: "www.labstack.com", @@ -278,7 +266,7 @@ func TestNonWWWRedirectWithConfig(t *testing.T) { func redirectTest(fn middlewareGenerator, host string, header http.Header) *httptest.ResponseRecorder { e := echo.New() - next := func(c *echo.Context) (err error) { + next := func(c echo.Context) (err error) { return c.NoContent(http.StatusOK) } req := httptest.NewRequest(http.MethodGet, "/", nil) diff --git a/middleware/request_id.go b/middleware/request_id.go index b3de40d19..14bd4fd15 100644 --- a/middleware/request_id.go +++ b/middleware/request_id.go @@ -4,7 +4,7 @@ package middleware import ( - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" ) // RequestIDConfig defines the config for RequestID middleware. @@ -13,45 +13,43 @@ type RequestIDConfig struct { Skipper Skipper // Generator defines a function to generate an ID. - // Optional. Default value random.String(32). + // Optional. Defaults to generator for random string of length 32. Generator func() string // RequestIDHandler defines a function which is executed for a request id. - RequestIDHandler func(c *echo.Context, requestID string) + RequestIDHandler func(echo.Context, string) - // TargetHeader defines what header to look for to populate the id. - // Optional. Default value is `X-Request-Id` + // TargetHeader defines what header to look for to populate the id TargetHeader string } -// RequestID returns a middleware that reads RequestIDConfig.TargetHeader (`X-Request-ID`) header value or when -// the header value is empty, generates that value and sets request ID to response -// as RequestIDConfig.TargetHeader (`X-Request-Id`) value. -func RequestID() echo.MiddlewareFunc { - return RequestIDWithConfig(RequestIDConfig{}) +// DefaultRequestIDConfig is the default RequestID middleware config. +var DefaultRequestIDConfig = RequestIDConfig{ + Skipper: DefaultSkipper, + Generator: generator, + TargetHeader: echo.HeaderXRequestID, } -// RequestIDWithConfig returns a middleware with given valid config or panics on invalid configuration. -// The middleware reads RequestIDConfig.TargetHeader (`X-Request-ID`) header value or when the header value is empty, -// generates that value and sets request ID to response as RequestIDConfig.TargetHeader (`X-Request-Id`) value. -func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { - return toMiddlewareOrPanic(config) +// RequestID returns a X-Request-ID middleware. +func RequestID() echo.MiddlewareFunc { + return RequestIDWithConfig(DefaultRequestIDConfig) } -// ToMiddleware converts RequestIDConfig to middleware or returns an error for invalid configuration -func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) { +// RequestIDWithConfig returns a X-Request-ID middleware with config. +func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { + // Defaults if config.Skipper == nil { - config.Skipper = DefaultSkipper + config.Skipper = DefaultRequestIDConfig.Skipper } if config.Generator == nil { - config.Generator = createRandomStringGenerator(32) + config.Generator = generator } if config.TargetHeader == "" { config.TargetHeader = echo.HeaderXRequestID } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) error { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } @@ -69,5 +67,9 @@ func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) { return next(c) } - }, nil + } +} + +func generator() string { + return randomString(32) } diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go index 465e6fc42..4e68b126a 100644 --- a/middleware/request_id_test.go +++ b/middleware/request_id_test.go @@ -8,7 +8,7 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) @@ -17,108 +17,29 @@ func TestRequestID(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - handler := func(c *echo.Context) error { + handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } - rid := RequestID() - h := rid(handler) - err := h(c) - assert.NoError(t, err) - assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) -} - -func TestMustRequestIDWithConfig_skipper(t *testing.T) { - e := echo.New() - e.GET("/", func(c *echo.Context) error { - return c.String(http.StatusTeapot, "test") - }) - - generatorCalled := false - e.Use(RequestIDWithConfig(RequestIDConfig{ - Skipper: func(c *echo.Context) bool { - return true - }, - Generator: func() string { - generatorCalled = true - return "customGenerator" - }, - })) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - e.ServeHTTP(res, req) - - assert.Equal(t, http.StatusTeapot, res.Code) - assert.Equal(t, "test", res.Body.String()) - - assert.Equal(t, res.Header().Get(echo.HeaderXRequestID), "") - assert.False(t, generatorCalled) -} - -func TestMustRequestIDWithConfig_customGenerator(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - handler := func(c *echo.Context) error { - return c.String(http.StatusOK, "test") - } - - rid := RequestIDWithConfig(RequestIDConfig{ - Generator: func() string { return "customGenerator" }, - }) - h := rid(handler) - err := h(c) - assert.NoError(t, err) - assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") -} - -func TestMustRequestIDWithConfig_RequestIDHandler(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - handler := func(c *echo.Context) error { - return c.String(http.StatusOK, "test") - } - - called := false - rid := RequestIDWithConfig(RequestIDConfig{ - Generator: func() string { return "customGenerator" }, - RequestIDHandler: func(c *echo.Context, s string) { - called = true - }, - }) - h := rid(handler) - err := h(c) - assert.NoError(t, err) - assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") - assert.True(t, called) -} - -func TestRequestIDWithConfig(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - handler := func(c *echo.Context) error { - return c.String(http.StatusOK, "test") - } - - rid, err := RequestIDConfig{}.ToMiddleware() - assert.NoError(t, err) + rid := RequestIDWithConfig(RequestIDConfig{}) h := rid(handler) h(c) assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) - // Custom generator + // Custom generator and handler + customID := "customGenerator" + calledHandler := false rid = RequestIDWithConfig(RequestIDConfig{ - Generator: func() string { return "customGenerator" }, + Generator: func() string { return customID }, + RequestIDHandler: func(_ echo.Context, id string) { + calledHandler = true + assert.Equal(t, customID, id) + }, }) h = rid(handler) h(c) assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") + assert.True(t, calledHandler) } func TestRequestID_IDNotAltered(t *testing.T) { @@ -128,7 +49,7 @@ func TestRequestID_IDNotAltered(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - handler := func(c *echo.Context) error { + handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } @@ -143,7 +64,7 @@ func TestRequestIDConfigDifferentHeader(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - handler := func(c *echo.Context) error { + handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } @@ -158,7 +79,7 @@ func TestRequestIDConfigDifferentHeader(t *testing.T) { rid = RequestIDWithConfig(RequestIDConfig{ Generator: func() string { return customID }, TargetHeader: echo.HeaderXCorrelationID, - RequestIDHandler: func(_ *echo.Context, id string) { + RequestIDHandler: func(_ echo.Context, id string) { calledHandler = true assert.Equal(t, customID, id) }, diff --git a/middleware/request_logger.go b/middleware/request_logger.go index 827fd3761..211abf464 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -10,7 +10,7 @@ import ( "net/http" "time" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" ) // Example for `slog` https://pkg.go.dev/log/slog @@ -18,8 +18,9 @@ import ( // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // LogStatus: true, // LogURI: true, +// LogError: true, // HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code -// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { +// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // if v.Error == nil { // logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST", // slog.String("uri", v.URI), @@ -40,8 +41,9 @@ import ( // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // LogStatus: true, // LogURI: true, +// LogError: true, // HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code -// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { +// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // if v.Error == nil { // fmt.Printf("REQUEST: uri: %v, status: %v\n", v.URI, v.Status) // } else { @@ -56,8 +58,9 @@ import ( // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // LogURI: true, // LogStatus: true, +// LogError: true, // HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code -// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { +// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // if v.Error == nil { // logger.Info(). // Str("URI", v.URI). @@ -79,8 +82,9 @@ import ( // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // LogURI: true, // LogStatus: true, +// LogError: true, // HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code -// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { +// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // if v.Error == nil { // logger.Info("request", // zap.String("URI", v.URI), @@ -102,8 +106,9 @@ import ( // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // LogURI: true, // LogStatus: true, +// LogError: true, // HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code -// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { +// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // if v.Error == nil { // log.WithFields(logrus.Fields{ // "URI": v.URI, @@ -126,10 +131,10 @@ type RequestLoggerConfig struct { Skipper Skipper // BeforeNextFunc defines a function that is called before next middleware or handler is called in chain. - BeforeNextFunc func(c *echo.Context) + BeforeNextFunc func(c echo.Context) // LogValuesFunc defines a function that is called with values extracted by logger from request/response. // Mandatory. - LogValuesFunc func(c *echo.Context, v RequestLoggerValues) error + LogValuesFunc func(c echo.Context, v RequestLoggerValues) error // HandleError instructs logger to call global error handler when next middleware/handler returns an error. // This is useful when you have custom error handler that can decide to use different status codes. @@ -160,9 +165,11 @@ type RequestLoggerConfig struct { LogReferer bool // LogUserAgent instructs logger to extract request user agent values. LogUserAgent bool - // LogStatus instructs logger to extract response status code. If handler chain returns an error, - // the status code is extracted from the error satisfying echo.StatusCoder interface. + // LogStatus instructs logger to extract response status code. If handler chain returns an echo.HTTPError, + // the status code is extracted from the echo.HTTPError returned LogStatus bool + // LogError instructs logger to extract error returned from executed handler chain. + LogError bool // LogContentLength instructs logger to extract content length header value. Note: this value could be different from // actual request body size as it could be spoofed etc. LogContentLength bool @@ -211,7 +218,7 @@ type RequestLoggerValues struct { Referer string // UserAgent is request user agent values. UserAgent string - // Status is a response status code. When the handler returns an error satisfying echo.StatusCoder interface, then code from it. + // Status is response status code. Then handler returns an echo.HTTPError then code from there. Status int // Error is error returned from executed handler chain. Error error @@ -221,15 +228,15 @@ type RequestLoggerValues struct { // ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct. ResponseSize int64 // Headers are list of headers from request. Note: request can contain more than one header with same value so slice - // of values is what will be returned/logged for each given header. + // of values is been logger for each given header. // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding". Headers map[string][]string // QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter - // with same name so slice of values is what will be returned/logged for each given query param name. + // with same name so slice of values is been logger for each given query param name. QueryParams map[string][]string // FormValues are list of form values from request body+URI. Note: request can contain more than one form value with - // same name so slice of values is what will be returned/logged for each given form value name. + // same name so slice of values is been logger for each given form value name. FormValues map[string][]string } @@ -242,6 +249,72 @@ func RequestLoggerWithConfig(config RequestLoggerConfig) echo.MiddlewareFunc { return mw } +// RequestLogger returns a RequestLogger middleware with default configuration which +// uses default slog.slog logger. +// +// To customize slog output format replace slog default logger: +// For JSON format: `slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, nil)))` +func RequestLogger() echo.MiddlewareFunc { + config := RequestLoggerConfig{ + LogLatency: true, + LogProtocol: false, + LogRemoteIP: true, + LogHost: true, + LogMethod: true, + LogURI: true, + LogURIPath: false, + LogRoutePath: false, + LogRequestID: true, + LogReferer: false, + LogUserAgent: true, + LogStatus: true, + LogError: true, + LogContentLength: true, + LogResponseSize: true, + LogHeaders: nil, + LogQueryParams: nil, + LogFormValues: nil, + HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code + LogValuesFunc: func(c echo.Context, v RequestLoggerValues) error { + if v.Error == nil { + slog.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST", + slog.String("method", v.Method), + slog.String("uri", v.URI), + slog.Int("status", v.Status), + slog.Duration("latency", v.Latency), + slog.String("host", v.Host), + slog.String("bytes_in", v.ContentLength), + slog.Int64("bytes_out", v.ResponseSize), + slog.String("user_agent", v.UserAgent), + slog.String("remote_ip", v.RemoteIP), + slog.String("request_id", v.RequestID), + ) + } else { + slog.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR", + slog.String("method", v.Method), + slog.String("uri", v.URI), + slog.Int("status", v.Status), + slog.Duration("latency", v.Latency), + slog.String("host", v.Host), + slog.String("bytes_in", v.ContentLength), + slog.Int64("bytes_out", v.ResponseSize), + slog.String("user_agent", v.UserAgent), + slog.String("remote_ip", v.RemoteIP), + slog.String("request_id", v.RequestID), + + slog.String("error", v.Error.Error()), + ) + } + return nil + }, + } + mw, err := config.ToMiddleware() + if err != nil { + panic(err) + } + return mw +} + // ToMiddleware converts RequestLoggerConfig into middleware or returns an error for invalid configuration. func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { @@ -266,12 +339,13 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { logFormValues := len(config.LogFormValues) > 0 return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) error { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } req := c.Request() + res := c.Response() start := now() if config.BeforeNextFunc != nil { @@ -279,11 +353,8 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { } err := next(c) if err != nil && config.HandleError { - // When global error handler writes the error to the client the Response gets "committed". This state can be - // checked with `c.Response().Committed` field. - c.Echo().HTTPErrorHandler(c, err) + c.Error(err) } - res := c.Response() v := RequestLoggerValues{ StartTime: start, @@ -329,26 +400,26 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.LogUserAgent { v.UserAgent = req.UserAgent() } - - if config.LogStatus || config.LogResponseSize { - resp, status := echo.ResolveResponseStatus(res, err) - - if config.LogStatus { - v.Status = status - } - if config.LogResponseSize { - v.ResponseSize = -1 - if resp != nil { - v.ResponseSize = resp.Size + if config.LogStatus { + v.Status = res.Status + if err != nil && !config.HandleError { + // this block should not be executed in case of HandleError=true as the global error handler will decide + // the status code. In that case status code could be different from what err contains. + var httpErr *echo.HTTPError + if errors.As(err, &httpErr) { + v.Status = httpErr.Code } } } - if err != nil { + if config.LogError && err != nil { v.Error = err } if config.LogContentLength { v.ContentLength = req.Header.Get(echo.HeaderContentLength) } + if config.LogResponseSize { + v.ResponseSize = res.Size + } if logHeaders { v.Headers = map[string][]string{} for _, header := range headers { @@ -378,69 +449,11 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if errOnLog := config.LogValuesFunc(c, v); errOnLog != nil { return errOnLog } + // in case of HandleError=true we are returning the error that we already have handled with global error handler // this is deliberate as this error could be useful for upstream middlewares and default global error handler // will ignore that error when it bubbles up in middleware chain. - // Committed response can be checked in custom error handler with following logic - // - // if r, _ := echo.UnwrapResponse(c.Response()); r != nil && r.Committed { - // return - // } return err } }, nil } - -// RequestLogger creates Request Logger middleware with Echo default settings that uses Context.Logger() as logger. -func RequestLogger() echo.MiddlewareFunc { - return RequestLoggerWithConfig(RequestLoggerConfig{ - LogLatency: true, - LogRemoteIP: true, - LogHost: true, - LogMethod: true, - LogURI: true, - LogRequestID: true, - LogUserAgent: true, - LogStatus: true, - LogContentLength: true, - LogResponseSize: true, - // forwards error to the global error handler, so it can decide appropriate status code. - // NB: side-effect of that is - request is now "committed" written to the client. Middlewares up in chain can not - // change Response status code or response body. - HandleError: true, - LogValuesFunc: func(c *echo.Context, v RequestLoggerValues) error { - logger := c.Logger() - if v.Error == nil { - logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST", - slog.String("method", v.Method), - slog.String("uri", v.URI), - slog.Int("status", v.Status), - slog.Duration("latency", v.Latency), - slog.String("host", v.Host), - slog.String("bytes_in", v.ContentLength), - slog.Int64("bytes_out", v.ResponseSize), - slog.String("user_agent", v.UserAgent), - slog.String("remote_ip", v.RemoteIP), - slog.String("request_id", v.RequestID), - ) - return nil - } - - logger.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR", - slog.String("method", v.Method), - slog.String("uri", v.URI), - slog.Int("status", v.Status), - slog.Duration("latency", v.Latency), - slog.String("host", v.Host), - slog.String("bytes_in", v.ContentLength), - slog.Int64("bytes_out", v.ResponseSize), - slog.String("user_agent", v.UserAgent), - slog.String("remote_ip", v.RemoteIP), - slog.String("request_id", v.RequestID), - - slog.String("error", v.Error.Error()), - ) - return nil - }, - }) -} diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go index 2232c6f6a..202660287 100644 --- a/middleware/request_logger_test.go +++ b/middleware/request_logger_test.go @@ -16,7 +16,7 @@ import ( "testing" "time" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) @@ -26,13 +26,13 @@ func TestRequestLoggerOK(t *testing.T) { slog.SetDefault(old) }) - e := echo.New() - e.IPExtractor = echo.LegacyIPExtractor() buf := new(bytes.Buffer) - e.Logger = slog.New(slog.NewJSONHandler(buf, nil)) + slog.SetDefault(slog.New(slog.NewJSONHandler(buf, nil))) + + e := echo.New() e.Use(RequestLogger()) - e.POST("/test", func(c *echo.Context) error { + e.POST("/test", func(c echo.Context) error { return c.String(http.StatusTeapot, "OK") }) @@ -76,12 +76,13 @@ func TestRequestLoggerError(t *testing.T) { slog.SetDefault(old) }) - e := echo.New() buf := new(bytes.Buffer) - e.Logger = slog.New(slog.NewJSONHandler(buf, nil)) + slog.SetDefault(slog.New(slog.NewJSONHandler(buf, nil))) + + e := echo.New() e.Use(RequestLogger()) - e.GET("/test", func(c *echo.Context) error { + e.GET("/test", func(c echo.Context) error { return errors.New("nope") }) req := httptest.NewRequest(http.MethodGet, "/test", nil) @@ -120,13 +121,13 @@ func TestRequestLoggerWithConfig(t *testing.T) { e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ LogRoutePath: true, LogURI: true, - LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { expect = values return nil }, })) - e.GET("/test", func(c *echo.Context) error { + e.GET("/test", func(c echo.Context) error { return c.String(http.StatusTeapot, "OK") }) @@ -152,16 +153,16 @@ func TestRequestLogger_skipper(t *testing.T) { loggerCalled := false e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ - Skipper: func(c *echo.Context) bool { + Skipper: func(c echo.Context) bool { return true }, - LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { loggerCalled = true return nil }, })) - e.GET("/test", func(c *echo.Context) error { + e.GET("/test", func(c echo.Context) error { return c.String(http.StatusTeapot, "OK") }) @@ -179,16 +180,16 @@ func TestRequestLogger_beforeNextFunc(t *testing.T) { var myLoggerInstance int e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ - BeforeNextFunc: func(c *echo.Context) { + BeforeNextFunc: func(c echo.Context) { c.Set("myLoggerInstance", 42) }, - LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { myLoggerInstance = c.Get("myLoggerInstance").(int) return nil }, })) - e.GET("/test", func(c *echo.Context) error { + e.GET("/test", func(c echo.Context) error { return c.String(http.StatusTeapot, "OK") }) @@ -206,14 +207,15 @@ func TestRequestLogger_logError(t *testing.T) { var actual RequestLoggerValues e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + LogError: true, LogStatus: true, - LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { actual = values return nil }, })) - e.GET("/test", func(c *echo.Context) error { + e.GET("/test", func(c echo.Context) error { return echo.NewHTTPError(http.StatusNotAcceptable, "nope") }) @@ -236,22 +238,23 @@ func TestRequestLogger_HandleError(t *testing.T) { return time.Unix(1631045377, 0).UTC() }, HandleError: true, + LogError: true, LogStatus: true, - LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { actual = values return nil }, })) // to see if "HandleError" works we create custom error handler that uses its own status codes - e.HTTPErrorHandler = func(c *echo.Context, err error) { - if r, _ := echo.UnwrapResponse(c.Response()); r != nil && r.Committed { + e.HTTPErrorHandler = func(err error, c echo.Context) { + if c.Response().Committed { return } c.JSON(http.StatusTeapot, "custom error handler") } - e.GET("/test", func(c *echo.Context) error { + e.GET("/test", func(c echo.Context) error { return echo.NewHTTPError(http.StatusForbidden, "nope") }) @@ -275,14 +278,15 @@ func TestRequestLogger_LogValuesFuncError(t *testing.T) { var expect RequestLoggerValues e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + LogError: true, LogStatus: true, - LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { expect = values return echo.NewHTTPError(http.StatusNotAcceptable, "LogValuesFuncError") }, })) - e.GET("/test", func(c *echo.Context) error { + e.GET("/test", func(c echo.Context) error { return c.String(http.StatusTeapot, "OK") }) @@ -323,13 +327,13 @@ func TestRequestLogger_ID(t *testing.T) { var expect RequestLoggerValues e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ LogRequestID: true, - LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { expect = values return nil }, })) - e.GET("/test", func(c *echo.Context) error { + e.GET("/test", func(c echo.Context) error { c.Response().Header().Set(echo.HeaderXRequestID, "321") return c.String(http.StatusTeapot, "OK") }) @@ -353,12 +357,12 @@ func TestRequestLogger_headerIsCaseInsensitive(t *testing.T) { var expect RequestLoggerValues mw := RequestLoggerWithConfig(RequestLoggerConfig{ - LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { expect = values return nil }, LogHeaders: []string{"referer", "User-Agent"}, - })(func(c *echo.Context) error { + })(func(c echo.Context) error { c.Request().Header.Set(echo.HeaderXRequestID, "123") c.FormValue("to force parse form") return c.String(http.StatusTeapot, "OK") @@ -383,7 +387,7 @@ func TestRequestLogger_allFields(t *testing.T) { isFirstNowCall := true var expect RequestLoggerValues mw := RequestLoggerWithConfig(RequestLoggerConfig{ - LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { expect = values return nil }, @@ -399,6 +403,7 @@ func TestRequestLogger_allFields(t *testing.T) { LogReferer: true, LogUserAgent: true, LogStatus: true, + LogError: true, LogContentLength: true, LogResponseSize: true, LogHeaders: []string{"accept-encoding", "User-Agent"}, @@ -411,7 +416,7 @@ func TestRequestLogger_allFields(t *testing.T) { } return time.Unix(1631045377+10, 0) }, - })(func(c *echo.Context) error { + })(func(c echo.Context) error { c.Request().Header.Set(echo.HeaderXRequestID, "123") c.FormValue("to force parse form") return c.String(http.StatusTeapot, "OK") @@ -440,7 +445,7 @@ func TestRequestLogger_allFields(t *testing.T) { assert.Equal(t, time.Unix(1631045377, 0), expect.StartTime) assert.Equal(t, 10*time.Second, expect.Latency) assert.Equal(t, "HTTP/1.1", expect.Protocol) - assert.Equal(t, "192.0.2.1", expect.RemoteIP) + assert.Equal(t, "8.8.8.8", expect.RemoteIP) assert.Equal(t, "example.com", expect.Host) assert.Equal(t, http.MethodPost, expect.Method) assert.Equal(t, "/test?lang=en&checked=1&checked=2", expect.URI) @@ -466,86 +471,12 @@ func TestRequestLogger_allFields(t *testing.T) { assert.Equal(t, []string{"1", "2"}, expect.FormValues["multiple"]) } -func TestTestRequestLogger(t *testing.T) { - var testCases = []struct { - name string - whenStatus int - whenError error - expectStatus string - expectError string - }{ - { - name: "ok", - whenStatus: http.StatusTeapot, - expectStatus: "418", - }, - { - name: "error", - whenError: echo.NewHTTPError(http.StatusBadGateway, "bad gw"), - expectStatus: "502", - expectError: `"error":"code=502, message=bad gw"`, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - buf := new(bytes.Buffer) - e.Logger = slog.New(slog.NewJSONHandler(buf, nil)) - - e.Use(RequestLogger()) - e.POST("/test", func(c *echo.Context) error { - if tc.whenError != nil { - return tc.whenError - } - return c.String(tc.whenStatus, "OK") - }) - - f := make(url.Values) - f.Set("csrf", "token") - f.Set("multiple", "1") - f.Add("multiple", "2") - reader := strings.NewReader(f.Encode()) - req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", reader) - req.Header.Set("Referer", "https://echo.labstack.com/") - req.Header.Set("User-Agent", "curl/7.68.0") - req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size()))) - req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) - req.Header.Set(echo.HeaderXRealIP, "8.8.8.8") - req.Header.Set(echo.HeaderXRequestID, "MY_ID") - - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - rawlog := buf.Bytes() - if tc.expectError != "" { - assert.Contains(t, string(rawlog), `"level":"ERROR"`) - assert.Contains(t, string(rawlog), `"msg":"REQUEST_ERROR"`) - assert.Contains(t, string(rawlog), tc.expectError) - } else { - assert.Contains(t, string(rawlog), `"level":"INFO"`) - assert.Contains(t, string(rawlog), `"msg":"REQUEST"`) - } - assert.Contains(t, string(rawlog), `"status":`+tc.expectStatus) - assert.Contains(t, string(rawlog), `"method":"POST"`) - assert.Contains(t, string(rawlog), `"uri":"/test?lang=en&checked=1&checked=2"`) - assert.Contains(t, string(rawlog), `"latency":`) // this value varies - assert.Contains(t, string(rawlog), `"request_id":"MY_ID"`) - assert.Contains(t, string(rawlog), `"remote_ip":"192.0.2.1"`) - assert.Contains(t, string(rawlog), `"host":"example.com"`) - assert.Contains(t, string(rawlog), `"user_agent":"curl/7.68.0"`) - assert.Contains(t, string(rawlog), `"bytes_in":"32"`) - assert.Contains(t, string(rawlog), `"bytes_out":2`) - }) - } -} - func BenchmarkRequestLogger_withoutMapFields(b *testing.B) { e := echo.New() mw := RequestLoggerWithConfig(RequestLoggerConfig{ Skipper: nil, - LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { return nil }, LogLatency: true, @@ -560,9 +491,10 @@ func BenchmarkRequestLogger_withoutMapFields(b *testing.B) { LogReferer: true, LogUserAgent: true, LogStatus: true, + LogError: true, LogContentLength: true, LogResponseSize: true, - })(func(c *echo.Context) error { + })(func(c echo.Context) error { c.Request().Header.Set(echo.HeaderXRequestID, "123") return c.String(http.StatusTeapot, "OK") }) @@ -585,7 +517,7 @@ func BenchmarkRequestLogger_withMapFields(b *testing.B) { e := echo.New() mw := RequestLoggerWithConfig(RequestLoggerConfig{ - LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { return nil }, LogLatency: true, @@ -600,12 +532,13 @@ func BenchmarkRequestLogger_withMapFields(b *testing.B) { LogReferer: true, LogUserAgent: true, LogStatus: true, + LogError: true, LogContentLength: true, LogResponseSize: true, LogHeaders: []string{"accept-encoding", "User-Agent"}, LogQueryParams: []string{"lang", "checked"}, LogFormValues: []string{"csrf", "multiple"}, - })(func(c *echo.Context) error { + })(func(c echo.Context) error { c.Request().Header.Set(echo.HeaderXRequestID, "123") c.FormValue("to force parse form") return c.String(http.StatusTeapot, "OK") diff --git a/middleware/rewrite.go b/middleware/rewrite.go index 02907ca49..4c19cc1cc 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -4,11 +4,9 @@ package middleware import ( - "errors" - "maps" "regexp" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" ) // RewriteConfig defines the config for Rewrite middleware. @@ -24,48 +22,51 @@ type RewriteConfig struct { // "/js/*": "/public/javascripts/$1", // "/users/*/orders/*": "/user/$1/order/$2", // Required. - Rules map[string]string + Rules map[string]string `yaml:"rules"` // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. // Example: // "^/old/[0.9]+/": "/new", // "^/api/.+?/(.*)": "/v2/$1", - RegexRules map[*regexp.Regexp]string + RegexRules map[*regexp.Regexp]string `yaml:"-"` +} + +// DefaultRewriteConfig is the default Rewrite middleware config. +var DefaultRewriteConfig = RewriteConfig{ + Skipper: DefaultSkipper, } // Rewrite returns a Rewrite middleware. // // Rewrite middleware rewrites the URL path based on the provided rules. func Rewrite(rules map[string]string) echo.MiddlewareFunc { - c := RewriteConfig{} + c := DefaultRewriteConfig c.Rules = rules return RewriteWithConfig(c) } -// RewriteWithConfig returns a Rewrite middleware or panics on invalid configuration. -// -// Rewrite middleware rewrites the URL path based on the provided rules. +// RewriteWithConfig returns a Rewrite middleware with config. +// See: `Rewrite()`. func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { - return toMiddlewareOrPanic(config) -} + // Defaults + if config.Rules == nil && config.RegexRules == nil { + panic("echo: rewrite middleware requires url path rewrite rules or regex rules") + } -// ToMiddleware converts RewriteConfig to middleware or returns an error for invalid configuration -func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultSkipper - } - if config.Rules == nil && config.RegexRules == nil { - return nil, errors.New("echo rewrite middleware requires url path rewrite rules or regex rules") + config.Skipper = DefaultBodyDumpConfig.Skipper } if config.RegexRules == nil { config.RegexRules = make(map[*regexp.Regexp]string) } - maps.Copy(config.RegexRules, rewriteRulesRegex(config.Rules)) + for k, v := range rewriteRulesRegex(config.Rules) { + config.RegexRules[k] = v + } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c *echo.Context) (err error) { + return func(c echo.Context) (err error) { if config.Skipper(c) { return next(c) } @@ -75,5 +76,5 @@ func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) { } return next(c) } - }, nil + } } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index adcc8e9f5..d137b2d13 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -11,7 +11,7 @@ import ( "regexp" "testing" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) @@ -26,10 +26,10 @@ func TestRewriteAfterRouting(t *testing.T) { "/users/*/orders/*": "/user/$1/order/$2", }, })) - e.GET("/public/*", func(c *echo.Context) error { + e.GET("/public/*", func(c echo.Context) error { return c.String(http.StatusOK, c.Param("*")) }) - e.GET("/*", func(c *echo.Context) error { + e.GET("/*", func(c echo.Context) error { return c.String(http.StatusOK, c.Param("*")) }) @@ -93,74 +93,20 @@ func TestRewriteAfterRouting(t *testing.T) { } } -func TestMustRewriteWithConfig_emptyRulesPanics(t *testing.T) { - assert.Panics(t, func() { - RewriteWithConfig(RewriteConfig{}) - }) -} - -func TestMustRewriteWithConfig_skipper(t *testing.T) { - var testCases = []struct { - name string - givenSkipper func(c *echo.Context) bool - whenURL string - expectURL string - expectStatus int - }{ - { - name: "not skipped", - whenURL: "/old", - expectURL: "/new", - expectStatus: http.StatusOK, - }, - { - name: "skipped", - givenSkipper: func(c *echo.Context) bool { - return true - }, - whenURL: "/old", - expectURL: "/old", - expectStatus: http.StatusNotFound, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - - e.Pre(RewriteWithConfig( - RewriteConfig{ - Skipper: tc.givenSkipper, - Rules: map[string]string{"/old": "/new"}}, - )) - - e.GET("/new", func(c *echo.Context) error { - return c.NoContent(http.StatusOK) - }) - - req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.Equal(t, tc.expectURL, req.URL.EscapedPath()) - assert.Equal(t, tc.expectStatus, rec.Code) - }) - } -} - // Issue #1086 func TestEchoRewritePreMiddleware(t *testing.T) { e := echo.New() + r := e.Router() // Rewrite old url to new one // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches - e.Pre(RewriteWithConfig(RewriteConfig{ - Rules: map[string]string{"/old": "/new"}}), - ) + e.Pre(Rewrite(map[string]string{ + "/old": "/new", + }, + )) // Route - e.Add(http.MethodGet, "/new", func(c *echo.Context) error { + r.Add(http.MethodGet, "/new", func(c echo.Context) error { return c.NoContent(http.StatusOK) }) @@ -174,6 +120,7 @@ func TestEchoRewritePreMiddleware(t *testing.T) { // Issue #1143 func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { e := echo.New() + r := e.Router() // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches e.Pre(RewriteWithConfig(RewriteConfig{ @@ -183,14 +130,14 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { }, })) - e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c *echo.Context) error { + r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error { return c.String(http.StatusOK, "hosts") }) - e.Add(http.MethodGet, "/api/:version/eng", func(c *echo.Context) error { + r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error { return c.String(http.StatusOK, "eng") }) - for range 100 { + for i := 0; i < 100; i++ { req := httptest.NewRequest(http.MethodGet, "/api/v1/mgmt/proj/test/agt", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) diff --git a/middleware/secure.go b/middleware/secure.go index 022cce4a1..c904abf1a 100644 --- a/middleware/secure.go +++ b/middleware/secure.go @@ -6,7 +6,7 @@ package middleware import ( "fmt" - "github.com/labstack/echo/v5" + "github.com/labstack/echo/v4" ) // SecureConfig defines the config for Secure middleware. @@ -17,12 +17,12 @@ type SecureConfig struct { // XSSProtection provides protection against cross-site scripting attack (XSS) // by setting the `X-XSS-Protection` header. // Optional. Default value "1; mode=block". - XSSProtection string + XSSProtection string `yaml:"xss_protection"` // ContentTypeNosniff provides protection against overriding Content-Type // header by setting the `X-Content-Type-Options` header. // Optional. Default value "nosniff". - ContentTypeNosniff string + ContentTypeNosniff string `yaml:"content_type_nosniff"` // XFrameOptions can be used to indicate whether or not a browser should // be allowed to render a page in a ,