diff --git a/.gitattributes b/.gitattributes
index 28981b84a..49b63e526 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -13,3 +13,8 @@
*.js text eol=lf
*.json text eol=lf
LICENSE text eol=lf
+
+# Exclude `website` and `cookbook` from GitHub's language statistics
+# https://github.com/github/linguist#using-gitattributes
+cookbook/* linguist-documentation
+website/* linguist-documentation
diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md
index 1a76adca7..82220c0a1 100644
--- a/.github/ISSUE_TEMPLATE.md
+++ b/.github/ISSUE_TEMPLATE.md
@@ -6,7 +6,7 @@
package main
import (
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
"net/http"
"net/http/httptest"
"testing"
@@ -15,7 +15,7 @@ import (
func TestExample(t *testing.T) {
e := echo.New()
- e.GET("/", func(c *echo.Context) error {
+ e.GET("/", func(c echo.Context) error {
return c.String(http.StatusOK, "Hello, World!")
})
diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml
index cfe4ca563..89fb3ce85 100644
--- a/.github/workflows/checks.yml
+++ b/.github/workflows/checks.yml
@@ -4,9 +4,11 @@ on:
push:
branches:
- master
+ - v4
pull_request:
branches:
- master
+ - v4
workflow_dispatch:
permissions:
@@ -21,10 +23,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout Code
- uses: actions/checkout@v6
+ uses: actions/checkout@v5
- name: Set up Go ${{ matrix.go }}
- uses: actions/setup-go@v6
+ uses: actions/setup-go@v5
with:
go-version: ${{ env.LATEST_GO_VERSION }}
check-latest: true
@@ -44,4 +46,3 @@ jobs:
go version
go install golang.org/x/vuln/cmd/govulncheck@latest
govulncheck ./...
-
diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml
index e11a029fc..5f4d53f2c 100644
--- a/.github/workflows/echo.yml
+++ b/.github/workflows/echo.yml
@@ -4,9 +4,11 @@ on:
push:
branches:
- master
+ - v4
pull_request:
branches:
- master
+ - v4
workflow_dispatch:
permissions:
@@ -23,17 +25,17 @@ jobs:
os: [ubuntu-latest, macos-latest, windows-latest]
# Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy
# Echo tests with last four major releases (unless there are pressing vulnerabilities)
- # As we depend on `golang.org/x/` libraries which only support the last 2 Go releases, we could have situations when
- # we derive from the last four major releases promise.
+ # As we depend on `golang.org/x/` libraries which only support last 2 Go releases we could have situations when
+ # we derive from last four major releases promise.
go: ["1.25", "1.26"]
name: ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }}
steps:
- name: Checkout Code
- uses: actions/checkout@v6
+ uses: actions/checkout@v5
- name: Set up Go ${{ matrix.go }}
- uses: actions/setup-go@v6
+ uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go }}
@@ -42,7 +44,7 @@ jobs:
- name: Upload coverage to Codecov
if: success() && matrix.go == env.LATEST_GO_VERSION && matrix.os == 'ubuntu-latest'
- uses: codecov/codecov-action@v6
+ uses: codecov/codecov-action@v5
with:
token:
fail_ci_if_error: false
@@ -53,18 +55,18 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout Code (Previous)
- uses: actions/checkout@v6
+ uses: actions/checkout@v5
with:
ref: ${{ github.base_ref }}
path: previous
- name: Checkout Code (New)
- uses: actions/checkout@v6
+ uses: actions/checkout@v5
with:
path: new
- name: Set up Go ${{ matrix.go }}
- uses: actions/setup-go@v6
+ uses: actions/setup-go@v5
with:
go-version: ${{ env.LATEST_GO_VERSION }}
diff --git a/.golangci.yaml b/.golangci.yaml
deleted file mode 100644
index eab6aa2d3..000000000
--- a/.golangci.yaml
+++ /dev/null
@@ -1,24 +0,0 @@
-version: "2"
-linters:
-# default: none
- enable:
- - revive
- disable:
- - errcheck
- settings:
- revive:
- rules:
- - name: exported
- exclusions:
- generated: lax
- presets:
- - common-false-positives
- - legacy
- - std-error-handling
- paths:
- - _test\.go$
-formatters:
- exclusions:
- generated: lax
- paths:
- - _test\.go$
diff --git a/API_CHANGES_V5.md b/API_CHANGES_V5.md
deleted file mode 100644
index d3ca81560..000000000
--- a/API_CHANGES_V5.md
+++ /dev/null
@@ -1,1178 +0,0 @@
-# Echo v5 Public API Changes
-
-**Comparison between `master` (v4.15.0) and `v5` (v5.0.0-alpha) branches**
-
-Generated: 2026-01-01
-
----
-
-## Executive Summary (by authors)
-
-Echo `v5` is maintenance release with **major breaking changes**
-- `Context` is now struct instead of interface and we can add method to it in the future in minor versions.
-- Adds new `Router` interface for possible new routing implementations.
-- Drops old logging interface and uses moderm `log/slog` instead.
-- Rearranges alot of methods/function signatures to make them more consistent.
-
-## Executive Summary (by LLMs)
-
-Echo v5 represents a **major breaking release** with significant architectural changes focused on:
-- **Updated generic helpers** to take `*Context` and rename form helpers to `FormValue*`
-- **Simplified API surface** by moving Context from interface to concrete struct
-- **Modern Go patterns** including slog.Logger integration
-- **Enhanced routing** with explicit RouteInfo and Routes types
-- **Better error handling** with simplified HTTPError
-- **New test helpers** via the `echotest` package
-
-### Change Statistics
-
-- **Major Breaking Changes**: 15+
-- **New Functions Added**: 30+
-- **Type Signature Changes**: 20+
-- **Removed APIs**: 10+
-- **New Packages Added**: 1 (`echotest`)
-- **Version Change**: `4.15.0` β `5.0.0-alpha`
-
----
-
-## Critical Breaking Changes
-
-### 1. **Context: Interface β Concrete Struct**
-
-**v4 (master):**
-```go
-type Context interface {
- Request() *http.Request
- // ... many methods
-}
-
-// Handler signature
-func handler(c echo.Context) error
-```
-
-**v5:**
-```go
-type Context struct {
- // Has unexported fields
-}
-
-// Handler signature - NOW USES POINTER!
-func handler(c *echo.Context) error
-```
-
-**Impact:** π΄ **CRITICAL BREAKING CHANGE**
-- ALL handlers must change from `echo.Context` to `*echo.Context`
-- Context is now a concrete struct, not an interface
-- This affects every single handler function in user code
-
-**Migration:**
-```go
-// Before (v4)
-func MyHandler(c echo.Context) error {
- return c.JSON(200, map[string]string{"hello": "world"})
-}
-
-// After (v5)
-func MyHandler(c *echo.Context) error {
- return c.JSON(200, map[string]string{"hello": "world"})
-}
-```
-
----
-
-### 2. **Logger: Custom Interface β slog.Logger**
-
-**v4:**
-```go
-type Echo struct {
- Logger Logger // Custom interface with Print, Debug, Info, etc.
-}
-
-type Logger interface {
- Output() io.Writer
- SetOutput(w io.Writer)
- Prefix() string
- // ... many custom methods
-}
-
-// Context returns Logger interface
-func (c Context) Logger() Logger
-```
-
-**v5:**
-```go
-type Echo struct {
- Logger *slog.Logger // Standard library structured logger
-}
-
-// Context returns slog.Logger
-func (c *Context) Logger() *slog.Logger
-func (c *Context) SetLogger(logger *slog.Logger)
-```
-
-**Impact:** π΄ **BREAKING CHANGE**
-- Must use Go's standard `log/slog` package
-- Logger interface completely removed
-- All logging code needs updating
-
----
-
-### 3. **Router: From Router to DefaultRouter**
-
-**v4:**
-```go
-type Router struct { ... }
-
-func NewRouter(e *Echo) *Router
-func (e *Echo) Router() *Router
-```
-
-**v5:**
-```go
-type DefaultRouter struct { ... }
-
-func NewRouter(config RouterConfig) *DefaultRouter
-func (e *Echo) Router() Router // Returns interface
-```
-
-**Changes:**
-- New `Router` interface introduced
-- `DefaultRouter` is the concrete implementation
-- `NewRouter()` now takes `RouterConfig` instead of `*Echo`
-- Added `NewConcurrentRouter(r Router) Router` for thread-safe routing
-
----
-
-### 4. **Route Return Types Changed**
-
-**v4:**
-```go
-func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route
-func (e *Echo) Any(path string, h HandlerFunc, m ...MiddlewareFunc) []*Route
-func (e *Echo) Routes() []*Route
-```
-
-**v5:**
-```go
-func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo
-func (e *Echo) Any(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo
-func (e *Echo) Match(...) Routes // Returns Routes type
-func (e *Echo) Router() Router // Returns interface
-```
-
-**New Types:**
-```go
-type RouteInfo struct {
- Name string
- Method string
- Path string
- Parameters []string
-}
-
-type Routes []RouteInfo // Collection with helper methods
-```
-
-**Impact:** π΄ **BREAKING CHANGE**
-- Route registration methods return `RouteInfo` instead of `*Route`
-- New `Routes` collection type with filtering methods
-- `Route` struct still exists but used differently
-
----
-
-### 5. **Response Type Changed**
-
-**v4:**
-```go
-func (c Context) Response() *Response
-type Response struct {
- Writer http.ResponseWriter
- Status int
- Size int64
- Committed bool
-}
-func NewResponse(w http.ResponseWriter, e *Echo) *Response
-```
-
-**v5:**
-```go
-func (c *Context) Response() http.ResponseWriter
-type Response struct {
- http.ResponseWriter // Embedded
- Status int
- Size int64
- Committed bool
-}
-func NewResponse(w http.ResponseWriter, logger *slog.Logger) *Response
-func UnwrapResponse(rw http.ResponseWriter) (*Response, error)
-```
-
-**Changes:**
-- Context.Response() returns `http.ResponseWriter` instead of `*Response`
-- Response now embeds `http.ResponseWriter`
-- NewResponse takes `*slog.Logger` instead of `*Echo`
-- New `UnwrapResponse()` helper function
-
----
-
-### 6. **HTTPError Simplified**
-
-**v4:**
-```go
-type HTTPError struct {
- Internal error
- Message interface{} // Can be any type
- Code int
-}
-
-func NewHTTPError(code int, message ...interface{}) *HTTPError
-```
-
-**v5:**
-```go
-type HTTPError struct {
- Code int
- Message string // Now string only
- // Has unexported fields (Internal moved)
-}
-
-func NewHTTPError(code int, message string) *HTTPError
-func (he HTTPError) Wrap(err error) error // New method
-func (he *HTTPError) StatusCode() int // Implements HTTPStatusCoder
-```
-
-**Changes:**
-- `Message` field changed from `interface{}` to `string`
-- `NewHTTPError()` now takes `string` instead of `...interface{}`
-- Added `HTTPStatusCoder` interface and `StatusCode()` method
-- Added `Wrap(err error)` method for error wrapping
-
----
-
-### 7. **HTTPErrorHandler Signature Changed**
-
-**v4:**
-```go
-type HTTPErrorHandler func(err error, c Context)
-
-func (e *Echo) DefaultHTTPErrorHandler(err error, c Context)
-```
-
-**v5:**
-```go
-type HTTPErrorHandler func(c *Context, err error) // Parameters swapped!
-
-func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler // Now a factory
-```
-
-**Impact:** π΄ **BREAKING CHANGE**
-- Parameter order reversed: `(c *Context, err error)` instead of `(err error, c Context)`
-- DefaultHTTPErrorHandler is now a factory function that returns HTTPErrorHandler
-- Takes `exposeError` bool to control error message exposure
-
----
-
-## Notable API Changes in v5
-
-### 1. **Generic Parameter Extraction Functions (Updated Signatures)**
-
-These helpers keep the same generic API but now accept `*Context`, and the
-form helpers are renamed from `FormParam*` to `FormValue*`:
-
-```go
-// Query Parameters
-func QueryParam[T any](c *Context, key string, opts ...any) (T, error)
-func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error)
-func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error)
-func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error)
-
-// Path Parameters
-func PathParam[T any](c *Context, paramName string, opts ...any) (T, error)
-func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...any) (T, error)
-
-// Form Values
-func FormValue[T any](c *Context, key string, opts ...any) (T, error)
-func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error)
-func FormValues[T any](c *Context, key string, opts ...any) ([]T, error)
-func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error)
-
-// Generic Parsing
-func ParseValue[T any](value string, opts ...any) (T, error)
-func ParseValueOr[T any](value string, defaultValue T, opts ...any) (T, error)
-func ParseValues[T any](values []string, opts ...any) ([]T, error)
-func ParseValuesOr[T any](values []string, defaultValue []T, opts ...any) ([]T, error)
-```
-
-`FormParam*` was renamed to `FormValue*`; the rest keep names but now take `*Context`.
-
-**Supported Types:**
-- bool, string
-- int, int8, int16, int32, int64
-- uint, uint8, uint16, uint32, uint64
-- float32, float64
-- time.Time, time.Duration
-- BindUnmarshaler, encoding.TextUnmarshaler, json.Unmarshaler
-
-**Example Usage:**
-```go
-// v5 - Type-safe parameter binding
-id, err := echo.PathParam[int](c, "id")
-page, err := echo.QueryParamOr[int](c, "page", 1)
-tags, err := echo.QueryParams[string](c, "tags")
-```
-
----
-
-### 2. **Context Store Helpers Now Use `*Context`**
-
-```go
-// Type-safe context value retrieval
-func ContextGet[T any](c *Context, key string) (T, error)
-func ContextGetOr[T any](c *Context, key string, defaultValue T) (T, error)
-
-// Error types
-var ErrNonExistentKey = errors.New("non existent key")
-var ErrInvalidKeyType = errors.New("invalid key type")
-```
-
-These helpers existed in v4 with `Context` and now accept `*Context`.
-
-**Example:**
-```go
-// v5
-user, err := echo.ContextGet[*User](c, "user")
-count, err := echo.ContextGetOr[int](c, "count", 0)
-```
-
----
-
-### 3. **PathValues Type**
-
-New structured path parameter handling:
-
-```go
-type PathValue struct {
- Name string
- Value string
-}
-
-type PathValues []PathValue
-
-func (p PathValues) Get(name string) (string, bool)
-func (p PathValues) GetOr(name string, defaultValue string) string
-
-// Context methods
-func (c *Context) PathValues() PathValues
-func (c *Context) SetPathValues(pathValues PathValues)
-```
-
----
-
-### 4. **Time Parsing Options**
-
-```go
-type TimeLayout string
-
-const (
- TimeLayoutUnixTime = TimeLayout("UnixTime")
- TimeLayoutUnixTimeMilli = TimeLayout("UnixTimeMilli")
- TimeLayoutUnixTimeNano = TimeLayout("UnixTimeNano")
-)
-
-type TimeOpts struct {
- Layout TimeLayout
- ParseInLocation *time.Location
- ToInLocation *time.Location
-}
-```
-
----
-
-### 5. **StartConfig for Server Configuration**
-
-```go
-type StartConfig struct {
- Address string
- HideBanner bool
- HidePort bool
- CertFilesystem fs.FS
- TLSConfig *tls.Config
- ListenerNetwork string
- ListenerAddrFunc func(addr net.Addr)
- GracefulTimeout time.Duration
- OnShutdownError func(err error)
- BeforeServeFunc func(s *http.Server) error
-}
-
-func (sc StartConfig) Start(ctx context.Context, h http.Handler) error
-func (sc StartConfig) StartTLS(ctx context.Context, h http.Handler, certFile, keyFile any) error
-```
-
-**Example:**
-```go
-// v5 - More control over server startup
-ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
-defer cancel()
-
-sc := echo.StartConfig{
- Address: ":8080",
- GracefulTimeout: 10 * time.Second,
-}
-if err := sc.Start(ctx, e); err != nil {
- log.Fatal(err)
-}
-```
-
----
-
-### 6. **Echo Config and Constructors**
-
-```go
-type Config struct {
- // Configuration for Echo (logger, binder, renderer, etc.)
-}
-
-func NewWithConfig(config Config) *Echo
-```
-
-This adds a configuration struct for creating an `Echo` instance without
-mutating fields after `New()`.
-
----
-
-### 7. **Enhanced Routing Features**
-
-```go
-// New route methods
-func (e *Echo) AddRoute(route Route) (RouteInfo, error)
-func (e *Echo) Middlewares() []MiddlewareFunc
-func (e *Echo) PreMiddlewares() []MiddlewareFunc
-type AddRouteError struct{ ... }
-
-// Routes collection with filters
-type Routes []RouteInfo
-
-func (r Routes) Clone() Routes
-func (r Routes) FilterByMethod(method string) (Routes, error)
-func (r Routes) FilterByName(name string) (Routes, error)
-func (r Routes) FilterByPath(path string) (Routes, error)
-func (r Routes) FindByMethodPath(method string, path string) (RouteInfo, error)
-func (r Routes) Reverse(routeName string, pathValues ...any) (string, error)
-
-// RouteInfo operations
-func (r RouteInfo) Clone() RouteInfo
-func (r RouteInfo) Reverse(pathValues ...any) string
-```
-
----
-
-### 8. **Middleware Configuration Interface**
-
-```go
-type MiddlewareConfigurator interface {
- ToMiddleware() (MiddlewareFunc, error)
-}
-```
-
-Allows middleware configs to be converted to middleware without panicking.
-
----
-
-### 9. **New Context Methods**
-
-```go
-// v5 additions
-func (c *Context) FileFS(file string, filesystem fs.FS) error
-func (c *Context) FormValueOr(name, defaultValue string) string
-func (c *Context) InitializeRoute(ri *RouteInfo, pathValues *PathValues)
-func (c *Context) ParamOr(name, defaultValue string) string
-func (c *Context) QueryParamOr(name, defaultValue string) string
-func (c *Context) RouteInfo() RouteInfo
-```
-
----
-
-### 10. **Virtual Host Support**
-
-```go
-func NewVirtualHostHandler(vhosts map[string]*Echo) *Echo
-```
-
-Creates an Echo instance that routes requests to different Echo instances based on host.
-
----
-
-### 11. **New Binder Functions**
-
-```go
-func BindBody(c *Context, target any) error
-func BindHeaders(c *Context, target any) error
-func BindPathValues(c *Context, target any) error // Renamed from BindPathParams
-func BindQueryParams(c *Context, target any) error
-```
-
-Top-level binding functions that work with `*Context`.
-
----
-
-### 12. **New echotest Package**
-
-```go
-package echotest // import "github.com/labstack/echo/v5/echotest"
-
-func LoadBytes(t *testing.T, name string, opts ...loadBytesOpts) []byte
-func TrimNewlineEnd(bytes []byte) []byte
-type ContextConfig struct{ ... }
-type MultipartForm struct{ ... }
-type MultipartFormFile struct{ ... }
-```
-
-Helpers for loading fixtures and constructing test contexts.
-
----
-
-## Removed APIs in v5
-
-### Constants
-
-```go
-// v4 - Removed in v5
-const CONNECT = http.MethodConnect // Use http.MethodConnect directly
-```
-
-**Reason:** Deprecated in v4, use stdlib `http.Method*` constants instead.
-
----
-
-### Constants Added in v5
-
-```go
-// v5 additions
-const (
- NotFoundRouteName = "echo_route_not_found_name"
-)
-```
-
----
-
-### Error Variable Changes
-
-**v4 exports:**
-```go
-ErrBadRequest
-ErrInvalidKeyType
-ErrNonExistentKey
-```
-
-**v5 exports:**
-```go
-ErrBadRequest // Now backed by unexported httpError type
-ErrValidatorNotRegistered // New
-ErrInvalidKeyType
-ErrNonExistentKey
-```
-
-**Reason:** v5 centralizes on `NewHTTPError(code, message)` rather than a broad set
-of predefined HTTP error variables.
-
----
-
-### Functions Removed
-
-```go
-// v4 - Removed in v5
-func GetPath(r *http.Request) string // Use r.URL.Path or r.URL.RawPath
-```
-
-### Variables Removed
-
-```go
-// v4 - Removed in v5
-var MethodNotAllowedHandler = func(c Context) error { ... }
-var NotFoundHandler = func(c Context) error { ... }
-```
-
-### Functions Renamed
-
-```go
-// v4
-func FormParam[T any](c Context, key string, opts ...any) (T, error)
-func FormParamOr[T any](c Context, key string, defaultValue T, opts ...any) (T, error)
-func FormParams[T any](c Context, key string, opts ...any) ([]T, error)
-func FormParamsOr[T any](c Context, key string, defaultValue []T, opts ...any) ([]T, error)
-
-// v5
-func FormValue[T any](c *Context, key string, opts ...any) (T, error)
-func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error)
-func FormValues[T any](c *Context, key string, opts ...any) ([]T, error)
-func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error)
-```
-
----
-
-### Type Methods Removed/Changed
-
-**Echo struct changes:**
-```go
-// v4 fields removed in v5
-type Echo struct {
- StdLogger *stdLog.Logger // Removed
- Server *http.Server // Removed (use StartConfig)
- TLSServer *http.Server // Removed (use StartConfig)
- Listener net.Listener // Removed (use StartConfig)
- TLSListener net.Listener // Removed (use StartConfig)
- AutoTLSManager autocert.Manager // Removed
- ListenerNetwork string // Removed
- OnAddRouteHandler func(...) // Changed to OnAddRoute
- DisableHTTP2 bool // Removed (use StartConfig)
- Debug bool // Removed
- HideBanner bool // Removed (use StartConfig)
- HidePort bool // Removed (use StartConfig)
-}
-
-// v5 Echo struct (simplified)
-type Echo struct {
- Binder Binder
- Filesystem fs.FS // NEW
- Renderer Renderer
- Validator Validator
- JSONSerializer JSONSerializer
- IPExtractor IPExtractor
- OnAddRoute func(route Route) error // Simplified
- HTTPErrorHandler HTTPErrorHandler
- Logger *slog.Logger // Changed from Logger interface
-}
-```
-
----
-
-**Context interface β struct:**
-```go
-// v4
-type Context interface {
- // Had: SetResponse(*Response)
- Response() *Response
-
- // Had: ParamNames(), SetParamNames(), ParamValues(), SetParamValues()
- // These are removed in v5 (use PathValues() instead)
-}
-
-// v5
-type Context struct {
- // Concrete struct with unexported fields
-}
-
-func (c *Context) Response() http.ResponseWriter // Changed return type
-func (c *Context) PathValues() PathValues // Replaces ParamNames/Values
-```
-
----
-
-**Types removed:**
-```go
-// v4
-type Map map[string]interface{}
-```
-
-**Group changes:**
-```go
-// v4
-func (g *Group) File(path, file string) // No return value
-func (g *Group) Static(pathPrefix, fsRoot string) // No return value
-func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) // No return value
-
-// v5
-func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo
-func (g *Group) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo
-func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo
-```
-
-Now return `RouteInfo` and accept middleware.
-
----
-
-### Value Binder Factory Name Changes
-
-```go
-// v4
-func PathParamsBinder(c Context) *ValueBinder
-func QueryParamsBinder(c Context) *ValueBinder
-func FormFieldBinder(c Context) *ValueBinder
-
-// v5
-func PathValuesBinder(c *Context) *ValueBinder // Renamed
-func QueryParamsBinder(c *Context) *ValueBinder
-func FormFieldBinder(c *Context) *ValueBinder
-```
-
----
-
-## Type Signature Changes
-
-### Binder Interface
-
-```go
-// v4
-type Binder interface {
- Bind(i interface{}, c Context) error
-}
-
-// v5
-type Binder interface {
- Bind(c *Context, target any) error // Parameters swapped!
-}
-```
-
----
-
-### DefaultBinder Methods
-
-```go
-// v4
-func (b *DefaultBinder) Bind(i interface{}, c Context) error
-func (b *DefaultBinder) BindBody(c Context, i interface{}) error
-func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error
-
-// v5
-func (b *DefaultBinder) Bind(c *Context, target any) error // Swapped params
-// BindBody, BindPathParams, etc. are now top-level functions
-```
-
----
-
-### JSONSerializer Interface
-
-```go
-// v4
-type JSONSerializer interface {
- Serialize(c Context, i interface{}, indent string) error
- Deserialize(c Context, i interface{}) error
-}
-
-// v5
-type JSONSerializer interface {
- Serialize(c *Context, target any, indent string) error
- Deserialize(c *Context, target any) error
-}
-```
-
----
-
-### Renderer Interface
-
-```go
-// v4
-type Renderer interface {
- Render(io.Writer, string, interface{}, Context) error
-}
-
-// v5
-type Renderer interface {
- Render(c *Context, w io.Writer, templateName string, data any) error
-}
-```
-
-Parameters reordered with Context first.
-
----
-
-### NewBindingError
-
-```go
-// v4
-func NewBindingError(sourceParam string, values []string, message interface{}, internalError error) error
-
-// v5
-func NewBindingError(sourceParam string, values []string, message string, err error) error
-```
-
-Message parameter changed from `interface{}` to `string`.
-
----
-
-### HandlerName
-
-```go
-// v5 only
-func HandlerName(h HandlerFunc) string
-```
-
-New utility function to get handler function name.
-
----
-
-## Middleware Package Changes
-
-### Signature and Type Updates
-
-```go
-// CORS now accepts optional allow-origins
-func CORS(allowOrigins ...string) echo.MiddlewareFunc
-
-// BodyLimit now accepts bytes
-func BodyLimit(limitBytes int64) echo.MiddlewareFunc
-
-// DefaultSkipper now uses *echo.Context
-func DefaultSkipper(c *echo.Context) bool
-
-// Trailing slash configs renamed/split
-func AddTrailingSlashWithConfig(config AddTrailingSlashConfig) echo.MiddlewareFunc
-func RemoveTrailingSlashWithConfig(config RemoveTrailingSlashConfig) echo.MiddlewareFunc
-type AddTrailingSlashConfig struct{ ... }
-type RemoveTrailingSlashConfig struct{ ... }
-
-// Auth + extractor signatures now use *echo.Context and add ExtractorSource
-type BasicAuthValidator func(c *echo.Context, user string, password string) (bool, error)
-type Extractor func(c *echo.Context) (string, error)
-type ExtractorSource string
-type KeyAuthValidator func(c *echo.Context, key string, source ExtractorSource) (bool, error)
-type KeyAuthErrorHandler func(c *echo.Context, err error) error
-
-// BodyDump handler now includes err
-type BodyDumpHandler func(c *echo.Context, reqBody []byte, resBody []byte, err error)
-
-// ValuesExtractor now returns extractor source and CreateExtractors takes a limit
-type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error)
-func CreateExtractors(lookups string, limit uint) ([]ValuesExtractor, error)
-type ValueExtractorError struct{ ... }
-
-// New constants
-const KB = 1024
-
-// Rate limiter store now takes a float64 limit
-func NewRateLimiterMemoryStore(rateLimit float64) (store *RateLimiterMemoryStore)
-```
-
-### Added Middleware Exports
-
-```go
-var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key")
-var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key")
-var RedirectHTTPSConfig = RedirectConfig{ ... }
-var RedirectHTTPSWWWConfig = RedirectConfig{ ... }
-var RedirectNonHTTPSWWWConfig = RedirectConfig{ ... }
-var RedirectNonWWWConfig = RedirectConfig{ ... }
-var RedirectWWWConfig = RedirectConfig{ ... }
-```
-
-### Removed/Consolidated Middleware Exports
-
-```go
-// Removed in v5
-func Logger() echo.MiddlewareFunc
-func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc
-func Timeout() echo.MiddlewareFunc
-func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc
-type ErrKeyAuthMissing struct{ ... }
-type CSRFErrorHandler func(err error, c echo.Context) error
-type LoggerConfig struct{ ... }
-type LogErrorFunc func(c echo.Context, err error, stack []byte) error
-type TargetProvider interface{ ... }
-type TrailingSlashConfig struct{ ... }
-type TimeoutConfig struct{ ... }
-```
-
-Also removed defaults: `DefaultBasicAuthConfig`, `DefaultBodyDumpConfig`, `DefaultBodyLimitConfig`,
-`DefaultCORSConfig`, `DefaultDecompressConfig`, `DefaultGzipConfig`, `DefaultLoggerConfig`,
-`DefaultRedirectConfig`, `DefaultRequestIDConfig`, `DefaultRewriteConfig`, `DefaultTimeoutConfig`,
-`DefaultTrailingSlashConfig`.
-
----
-
-## Router Interface Changes
-
-### v4 Router (Concrete Struct)
-
-```go
-type Router struct { ... }
-
-func NewRouter(e *Echo) *Router
-func (r *Router) Add(method, path string, h HandlerFunc)
-func (r *Router) Find(method, path string, c Context)
-func (r *Router) Reverse(name string, params ...interface{}) string
-func (r *Router) Routes() []*Route
-```
-
-### v5 Router (Interface + DefaultRouter)
-
-```go
-type Router interface {
- Add(routable Route) (RouteInfo, error)
- Remove(method string, path string) error
- Routes() Routes
- Route(c *Context) HandlerFunc
-}
-
-type DefaultRouter struct { ... }
-
-func NewRouter(config RouterConfig) *DefaultRouter
-func NewConcurrentRouter(r Router) Router // NEW
-
-type RouterConfig struct {
- NotFoundHandler HandlerFunc
- MethodNotAllowedHandler HandlerFunc
- OptionsMethodHandler HandlerFunc
- AllowOverwritingRoute bool
- UnescapePathParamValues bool
- UseEscapedPathForMatching bool
-}
-```
-
-**Key Changes:**
-- Router is now an interface
-- DefaultRouter is the concrete implementation
-- Add() returns `(RouteInfo, error)` instead of being void
-- New `Remove()` method
-- New `Route()` method replaces `Find()`
-- Configuration through `RouterConfig`
-
----
-
-## Echo Instance Method Changes
-
-### Route Registration
-
-```go
-// v4
-func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route
-
-// v5
-func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo
-func (e *Echo) AddRoute(route Route) (RouteInfo, error) // NEW
-```
-
-### Static File Serving
-
-```go
-// v4
-func (e *Echo) Static(pathPrefix, fsRoot string) *Route
-func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route
-func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route
-func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route
-
-// v5
-func (e *Echo) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo
-func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo
-func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo
-func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo
-```
-
-Return type changed from `*Route` to `RouteInfo`.
-
-### Server Management
-
-```go
-// v4
-func (e *Echo) Start(address string) error
-func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) error
-func (e *Echo) StartAutoTLS(address string) error
-func (e *Echo) StartH2CServer(address string, h2s *http2.Server) error
-func (e *Echo) StartServer(s *http.Server) error
-func (e *Echo) Shutdown(ctx context.Context) error
-func (e *Echo) Close() error
-func (e *Echo) ListenerAddr() net.Addr
-func (e *Echo) TLSListenerAddr() net.Addr
-func (e *Echo) DefaultHTTPErrorHandler(err error, c Context)
-
-// v5
-func (e *Echo) Start(address string) error // Simplified
-func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request)
-
-// Removed: StartTLS, StartAutoTLS, StartH2CServer, StartServer
-// Use StartConfig instead for advanced server configuration
-// Removed: Shutdown, Close, ListenerAddr, TLSListenerAddr
-// Removed: DefaultHTTPErrorHandler (now a top-level factory function)
-```
-
-**v5 provides** `StartConfig` type for all advanced server configuration.
-
-### Router Access
-
-```go
-// v4
-func (e *Echo) Router() *Router
-func (e *Echo) Routers() map[string]*Router // For multi-host
-func (e *Echo) Routes() []*Route
-func (e *Echo) Reverse(name string, params ...interface{}) string
-func (e *Echo) URI(handler HandlerFunc, params ...interface{}) string
-func (e *Echo) URL(h HandlerFunc, params ...interface{}) string
-func (e *Echo) Host(name string, m ...MiddlewareFunc) *Group
-
-// v5
-func (e *Echo) Router() Router // Returns interface
-// Removed: Routers(), Reverse(), URI(), URL(), Host()
-// Use router.Routes() and Routes.Reverse() instead
-```
-
----
-
-## NewContext Changes
-
-```go
-// v4
-func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context
-func NewResponse(w http.ResponseWriter, e *Echo) *Response
-
-// v5
-func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) *Context
-func NewContext(r *http.Request, w http.ResponseWriter, opts ...any) *Context // Standalone
-func NewResponse(w http.ResponseWriter, logger *slog.Logger) *Response
-```
-
----
-
-## Migration Guide Summary
-
-If you are using Linux you can migrate easier parts like that:
-```bash
-find . -type f -name "*.go" -exec sed -i 's/ echo.Context/ *echo.Context/g' {} +
-find . -type f -name "*.go" -exec sed -i 's/echo\/v4/echo\/v5/g' {} +
-```
-or in your favorite IDE
-
-Replace all:
-1. ` echo.Context` -> ` *echo.Context`
-2. `echo/v4` -> `echo/v5`
-
-
-### 1. Update All Handler Signatures
-
-```go
-// Before
-func MyHandler(c echo.Context) error { ... }
-
-// After
-func MyHandler(c *echo.Context) error { ... }
-```
-
-### 2. Update Logger Usage
-
-```go
-// Before
-e.Logger.Info("Server started")
-c.Logger().Error("Something went wrong")
-
-// After
-e.Logger.Info("Server started")
-c.Logger().Error("Something went wrong") // Same API, different logger
-```
-
-### 3. Use Type-Safe Parameter Extraction
-
-```go
-// Before
-idStr := c.Param("id")
-id, err := strconv.Atoi(idStr)
-
-// After
-id, err := echo.PathParam[int](c, "id")
-```
-
-### 4. Update Error Handler
-
-```go
-// Before
-e.HTTPErrorHandler = func(err error, c echo.Context) {
- // handle error
-}
-
-// After
-e.HTTPErrorHandler = func(c *echo.Context, err error) { // Swapped!
- // handle error
-}
-
-// Or use factory
-e.HTTPErrorHandler = echo.DefaultHTTPErrorHandler(true) // exposeError=true
-```
-
-### 5. Update Server Startup
-
-```go
-// Before
-e.Start(":8080")
-e.StartTLS(":443", "cert.pem", "key.pem")
-
-// After
-// Simple
-e.Start(":8080")
-
-// Advanced with graceful shutdown
-ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
-defer cancel()
-sc := echo.StartConfig{Address: ":8080"}
-sc.Start(ctx, e)
-```
-
-### 6. Update Route Info Access
-
-```go
-// Before
-routes := e.Routes()
-for _, r := range routes {
- fmt.Println(r.Method, r.Path)
-}
-
-// After
-routes := e.Router().Routes()
-for _, r := range routes {
- fmt.Println(r.Method, r.Path)
-}
-```
-
-### 7. Update HTTPError Creation
-
-```go
-// Before
-return echo.NewHTTPError(400, "invalid request", someDetail)
-
-// After
-return echo.NewHTTPError(400, "invalid request")
-```
-
-### 8. Update Custom Binder
-
-```go
-// Before
-type MyBinder struct{}
-func (b *MyBinder) Bind(i interface{}, c echo.Context) error { ... }
-
-// After
-type MyBinder struct{}
-func (b *MyBinder) Bind(c *echo.Context, target any) error { ... } // Swapped!
-```
-
-### 9. Path Parameters
-
-```go
-// Before
-names := c.ParamNames()
-values := c.ParamValues()
-
-// After
-pathValues := c.PathValues()
-for _, pv := range pathValues {
- fmt.Println(pv.Name, pv.Value)
-}
-```
-
-### 10. Response Access
-
-```go
-// Before
-resp := c.Response()
-resp.Header().Set("X-Custom", "value")
-
-// After
-c.Response().Header().Set("X-Custom", "value") // Returns http.ResponseWriter
-
-// To get *echo.Response
-resp, err := echo.UnwrapResponse(c.Response())
-```
-
-### Go Version Requirements
-
-- **v4**: Go 1.24.0 (per `go.mod`)
-- **v5**: Go 1.25.0 (per `go.mod`)
-
----
-
-**Generated by comparing `go doc` output from master (v4.15.0) and v5 (v5.0.0-alpha) branches**
diff --git a/CHANGELOG.md b/CHANGELOG.md
index e129befa2..86feea16b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,169 +1,28 @@
# Changelog
-## v5.2.0 - 2026-06-14
+## v4.15.3 - 2026-06-14
**Security**
-* fix(static): reject encoded path separators that bypass route-level middleware by @vishr in https://github.com/labstack/echo/pull/3009
-* fix(middleware/static): don't double-unescape request path (#2599) by @vishr in https://github.com/labstack/echo/pull/3006
+* fix(static): reject encoded path separators that bypass route-level middleware by @vishr in https://github.com/labstack/echo/pull/3011
-Fixes [GHSA-vfp3-v2gw-7wfq](https://github.com/labstack/echo/security/advisories/GHSA-vfp3-v2gw-7wfq): an encoded path separator (`%2F` or `%5C`) in a static file URL could bypass route-level middleware (e.g. authentication on a sibling route) and disclose static files. Both `StaticDirectoryHandler`/`StaticFS` and the `Static` middleware are affected. Thanks to @a-tt-om and @oran-gugu for reporting.
+Fixes [GHSA-vfp3-v2gw-7wfq](https://github.com/labstack/echo/security/advisories/GHSA-vfp3-v2gw-7wfq): an encoded path separator (`%2F` or `%5C`) in a static file URL could bypass route-level middleware (e.g. authentication on a sibling route) and disclose static files. Both `StaticDirectoryHandler` (used by `Static`/`StaticFS`) and the `Static` middleware are affected. Backport of the v5 fix (#3009). Thanks to @a-tt-om and @oran-gugu for reporting.
-**Enhancements**
-* feat(middleware): optional RateLimiterStoreContext for response headers (#2961) by @vishr in https://github.com/labstack/echo/pull/3007
-* perf: optimize core hot paths (chain, context, binding, responses) by @vishr in https://github.com/labstack/echo/pull/3008
-* fix(binder): include field name in bind conversion errors (#2629) by @vishr in https://github.com/labstack/echo/pull/3005
-* fix(binder): serialize BindingError to structured JSON (#2771) by @vishr in https://github.com/labstack/echo/pull/3004
-* fix(binder): MustUnixTime docs say time.Time, not time.Duration by @c-tonneslan in https://github.com/labstack/echo/pull/2988
-* fix(middleware): reset ContentLength after gzip decompression by @shblue21 in https://github.com/labstack/echo/pull/3000
-* fix(middleware/proxy): append RealIP to X-Forwarded-For for WebSocket requests by @kawaway in https://github.com/labstack/echo/pull/2994
-* Fix proxy panic when balancer has no targets by @shblue21 in https://github.com/labstack/echo/pull/2977
-* fix(middleware): correct documented KeyAuth KeyLookup default by @leestana01 in https://github.com/labstack/echo/pull/2992
-* test: lock in v5 group route method-handling (405 + OPTIONS) by @vishr in https://github.com/labstack/echo/pull/3003
-* docs: liveness signals in README + public ROADMAP by @vishr in https://github.com/labstack/echo/pull/3002
-* Fix typos in CSRFConfig comments by @shblue21 in https://github.com/labstack/echo/pull/2979
-* refactor: modernize code usage using gofix by @kumapower17 in https://github.com/labstack/echo/pull/2970
-* refactor: replace Split in loops with more efficient SplitSeq by @box4wangjing in https://github.com/labstack/echo/pull/2969
-* refactor: use the built-in max/min to simplify the code by @criciss in https://github.com/labstack/echo/pull/2966
-* Update GitHub actions deps versions by @aldas in https://github.com/labstack/echo/pull/2971
-
-**New Contributors**
-
-* @criciss made their first contribution in https://github.com/labstack/echo/pull/2966
-* @box4wangjing made their first contribution in https://github.com/labstack/echo/pull/2969
-* @shblue21 made their first contribution in https://github.com/labstack/echo/pull/2977
-* @c-tonneslan made their first contribution in https://github.com/labstack/echo/pull/2988
-* @leestana01 made their first contribution in https://github.com/labstack/echo/pull/2992
-* @kawaway made their first contribution in https://github.com/labstack/echo/pull/2994
-
-**Full Changelog**: https://github.com/labstack/echo/compare/v5.1.1...v5.2.0
-
-
-## v5.1.1 - 2026-05-01
+## v4.15.2 - 2026-05-01
**Security**
-* `Context.Scheme()` should validate values taken from header by @aldas in https://github.com/labstack/echo/pull/2953
+* `Context.Scheme()` should validate values taken from header by @aldas in https://github.com/labstack/echo/pull/2962
Thanks to @shblue21 for reporting this [issue](https://github.com/labstack/echo/issues/2952).
-**Enhancements**
-
-* Add golangci linter configuration by @aldas in https://github.com/labstack/echo/pull/2930
-* Make StartConfig listener creation context-aware by @EricGusmao in https://github.com/labstack/echo/pull/2936
-* fix(lint): resolve staticcheck issues and improve code quality by @itsllyaz in https://github.com/labstack/echo/pull/2941
-* Context.Scheme should validate values taken from header by @aldas in https://github.com/labstack/echo/pull/2953
-* chore: fix typos in httperror.go by @tisonkun in https://github.com/labstack/echo/pull/2958
-* Context.Json should not unwrap response by @aldas in https://github.com/labstack/echo/pull/2964
-
-
-## v5.1.0 - 2026-03-31
-
-**Security**
-
-This change does not break the API contract, but it does introduce breaking changes in logic/behavior.
-If your application is using `c.RealIP()` beware and read https://echo.labstack.com/docs/ip-address
-
-`v4` behavior can be restored with:
-```go
-e := echo.New()
-e.IPExtractor = echo.LegacyIPExtractor()
-```
-
-* Remove legacy IP extraction logic from context.RealIP method by @aldas in https://github.com/labstack/echo/pull/2933
+## v4.15.1 - 2026-02-22
**Enhancements**
-* Add echo-opentelemetry to the README.md by @aldas in https://github.com/labstack/echo/pull/2908
-* fix: correct spelling mistakes in comments and field name by @crawfordxx in https://github.com/labstack/echo/pull/2916
-* Add https://github.com/labstack/echo-prometheus to the middleware list in README.md by @aldas in https://github.com/labstack/echo/pull/2919
-* Add StartConfig.Listener so server with custom Listener is easier to create by @aldas in https://github.com/labstack/echo/pull/2920
-* Fix rate limiter documentation for default burst value by @karesansui-u in https://github.com/labstack/echo/pull/2925
-* Add doc comments to clarify usage of File related methods and leading slash handling by @aldas in https://github.com/labstack/echo/pull/2928
-* Add NewDefaultFS function to help create filesystem that allows absolute paths by @aldas in https://github.com/labstack/echo/pull/2931
-* Do not set http.Server.WriteTimeout in StartConfig by @aldas in https://github.com/labstack/echo/pull/2932
-
-
-
-## v5.0.4 - 2026-02-15
-
-**Enhancements**
-
-* Remove unused import 'errors' from README example by @kumapower17 in https://github.com/labstack/echo/pull/2889
-* Fix Graceful shutdown: after `http.Server.Serve` returns we need to wait for graceful shutdown goroutine to finish by @aldas in https://github.com/labstack/echo/pull/2898
-* Update location of oapi-codegen in README by @mromaszewicz in https://github.com/labstack/echo/pull/2896
-* Add Go 1.26 to CI flow by @aldas in https://github.com/labstack/echo/pull/2899
-* Add new function `echo.StatusCode` by @suwakei in https://github.com/labstack/echo/pull/2892
-* CSRF: support older token-based CSRF protection handler that want to render token into template by @aldas in https://github.com/labstack/echo/pull/2894
-* Add `echo.ResolveResponseStatus` function to help middleware/handlers determine HTTP status code and echo.Response by @aldas in https://github.com/labstack/echo/pull/2900
-
-
-## v5.0.3 - 2026-02-06
-
-**Security**
-
-* Fix directory traversal vulnerability under Windows in Static middleware when default Echo filesystem is used. Reported by @shblue21.
-
-This applies to cases when:
-- Windows is used as OS
-- `middleware.StaticConfig.Filesystem` is `nil` (default)
-- `echo.Filesystem` is has not been set explicitly (default)
-
-Exposure is restricted to the active process working directory and its subfolders.
-
-
-## v5.0.2 - 2026-02-02
-
-**Security**
-
-* Fix Static middleware with `config.Browse=true` lists all files/subfolders from `config.Filesystem` root and not starting from `config.Root` in https://github.com/labstack/echo/pull/2887
-
-
-## v5.0.1 - 2026-01-28
-
-* Panic MW: will now return a custom PanicStackError with stack trace by @aldas in https://github.com/labstack/echo/pull/2871
-* Docs: add missing err parameter to DenyHandler example by @cgalibern in https://github.com/labstack/echo/pull/2878
-* improve: improve websocket checks in IsWebSocket() [per RFC 6455] by @raju-mechatronics in https://github.com/labstack/echo/pull/2875
-* fix: Context.Json() should not send status code before serialization is complete by @aldas in https://github.com/labstack/echo/pull/2877
-
-
-## v5.0.0 - 2026-01-18
-
-Echo `v5` is maintenance release with **major breaking changes**
-- `Context` is now struct instead of interface and we can add method to it in the future in minor versions.
-- Adds new `Router` interface for possible new routing implementations.
-- Drops old logging interface and uses moderm `log/slog` instead.
-- Rearranges alot of methods/function signatures to make them more consistent.
-
-Upgrade notes and `v4` support:
-- Echo `v4` is supported with **security*** updates and **bug** fixes until **2026-12-31**
-- If you are using Echo in a production environment, it is recommended to wait until after 2026-03-31 before upgrading.
-- Until 2026-03-31, any critical issues requiring breaking `v5` API changes will be addressed, even if this violates semantic versioning.
-
-See [API_CHANGES_V5.md](./API_CHANGES_V5.md) for public API changes between `v4` and `v5`, notes on **upgrading**.
-
-Upgrading TLDR:
-
-If you are using Linux you can migrate easier parts like that:
-```bash
-find . -type f -name "*.go" -exec sed -i 's/ echo.Context/ *echo.Context/g' {} +
-find . -type f -name "*.go" -exec sed -i 's/echo\/v4/echo\/v5/g' {} +
-```
-macOS
-```bash
-find . -type f -name "*.go" -exec sed -i '' 's/ echo.Context/ *echo.Context/g' {} +
-find . -type f -name "*.go" -exec sed -i '' 's/echo\/v4/echo\/v5/g' {} +
-```
-
-or in your favorite IDE
-
-Replace all:
-1. ` echo.Context` -> ` *echo.Context`
-2. `echo/v4` -> `echo/v5`
-
-This should solve most of the issues. Probably the hardest part is updating all the tests.
+* CSRF: support older token-based CSRF protection handler that want to render token into template by @aldas in https://github.com/labstack/echo/pull/2905
## v4.15.0 - 2026-01-01
diff --git a/LICENSE b/LICENSE
index 2f18411bd..c46d0105f 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,6 +1,6 @@
The MIT License (MIT)
-Copyright (c) 2022 LabStack
+Copyright (c) 2021 LabStack
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
diff --git a/Makefile b/Makefile
index bd075bbae..cbd78f1bf 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,10 @@
PKG := "github.com/labstack/echo"
PKG_LIST := $(shell go list ${PKG}/...)
+tag:
+ @git tag `grep -P '^\tversion = ' echo.go|cut -f2 -d'"'`
+ @git tag|grep -v ^v
+
.DEFAULT_GOAL := check
check: lint vet race ## Check project
@@ -22,11 +26,12 @@ race: ## Run tests with data race detector
@go test -race ${PKG_LIST}
benchmark: ## Run benchmarks
- @go test -run="-" -benchmem -bench=".*" ${PKG_LIST}
+ @go test -run="-" -bench=".*" ${PKG_LIST}
help: ## Display this help screen
@grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
-goversion ?= "1.25"
-test_version: ## Run tests inside Docker with given version (defaults to 1.25 oldest supported). Example: make test_version goversion=1.25
- @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check"
+goversion ?= "1.22"
+docker_user ?= "1000"
+test_version: ## Run tests inside Docker with given version (defaults to 1.22 oldest supported). Example: make test_version goversion=1.22
+ @docker run --rm -it --user $(docker_user) -e HOME=/tmp -e GOCACHE=/tmp/go-cache -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "mkdir -p /tmp/go-cache /tmp/.cache && cd /project && make init check"
diff --git a/README.md b/README.md
index cd6abcbf6..5e52d1d4e 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,5 @@
-[](https://github.com/labstack/echo/releases)
-[](https://github.com/labstack/echo/commits/master)
[](https://sourcegraph.com/github.com/labstack/echo?badge)
-[](https://pkg.go.dev/github.com/labstack/echo/v5)
+[](https://pkg.go.dev/github.com/labstack/echo/v4)
[](https://goreportcard.com/report/github.com/labstack/echo)
[](https://github.com/labstack/echo/actions)
[](https://codecov.io/gh/labstack/echo)
@@ -13,14 +11,13 @@
High performance, extensible, minimalist Go web framework.
-Echo is built on Go's standard `net/http` β and interoperates with it via `echo.WrapHandler` / `echo.WrapMiddleware` β adding the parts the standard library leaves to you: a fast radix-tree router, request binding (with a pluggable validator), a deep middleware ecosystem, and centralized error handling. Actively maintained, with `v5` as the current release line (see badges above for the latest version and most recent commit).
-
* [Official website](https://echo.labstack.com)
* [Quick start](https://echo.labstack.com/docs/quick-start)
* [Middlewares](https://echo.labstack.com/docs/category/middleware)
Help and questions: [Github Discussions](https://github.com/labstack/echo/discussions)
+
### Feature Overview
- Optimized HTTP router which smartly prioritize routes
@@ -51,23 +48,13 @@ Click [here](https://github.com/sponsors/labstack) for more information on spons
## [Guide](https://echo.labstack.com/guide)
-### Supported Echo versions
-
-- Latest major version of Echo is `v5` as of 2026-01-18.
- - See [API_CHANGES_V5.md](./API_CHANGES_V5.md) for public API changes between `v4` and `v5`, notes on upgrading.
-- Echo `v4` is supported with **security*** updates and **bug** fixes until **2026-12-31**
-
-See [ROADMAP.md](./ROADMAP.md) for where Echo is heading and the version support policy.
-
### Installation
```sh
// go get github.com/labstack/echo/{version}
-go get github.com/labstack/echo/v5
+go get github.com/labstack/echo/v4
```
-
-Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with
-older versions.
+Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions.
### Example
@@ -75,8 +62,8 @@ older versions.
package main
import (
- "github.com/labstack/echo/v5"
- "github.com/labstack/echo/v5/middleware"
+ "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v4/middleware"
"log/slog"
"net/http"
)
@@ -86,20 +73,20 @@ func main() {
e := echo.New()
// Middleware
- e.Use(middleware.RequestLogger()) // use the RequestLogger middleware with slog logger
- e.Use(middleware.Recover()) // recover panics as errors for proper error handling
+ e.Use(middleware.RequestLogger()) // use the default RequestLogger middleware with slog logger
+ e.Use(middleware.Recover()) // recover panics as errors for proper error handling
// Routes
e.GET("/", hello)
// Start server
- if err := e.Start(":8080"); err != nil {
+ if err := e.Start(":8080"); err != nil && !errors.Is(err, http.ErrServerClosed) {
slog.Error("failed to start server", "error", err)
}
}
// Handler
-func hello(c *echo.Context) error {
+func hello(c echo.Context) error {
return c.String(http.StatusOK, "Hello, World!")
}
```
@@ -108,12 +95,10 @@ func hello(c *echo.Context) error {
Following list of middleware is maintained by Echo team.
-| Repository | Description |
-|------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware |
-| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [pprof](https://pkg.go.dev/net/http/pprof)) middlewares |
-| [github.com/labstack/echo-opentelemetry](https://github.com/labstack/echo-opentelemetry) | [OpenTelemetry](https://opentelemetry.io/) middleware for tracing and metrics |
-| [github.com/labstack/echo-prometheus](https://github.com/labstack/echo-prometheus) | [Prometheus](https://github.com/prometheus/client_golang/) middleware for Echo |
+| Repository | Description |
+|------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware |
+| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](https://github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares |
# Third-party middleware repositories
@@ -122,11 +107,11 @@ of middlewares in this list.
| Repository | Description |
|------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| [oapi-codegen/oapi-codegen](https://github.com/oapi-codegen/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator |
+| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator |
| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. |
| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. |
| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | UberΒ΄s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. |
-| [github.com/samber/slog-echo](https://github.com/samber/slog-echo) | Go [slog](https://pkg.go.dev/golang.org/x/exp/slog) logging library wrapper for Echo logger interface. |
+| [github.com/samber/slog-echo](https://github.com/samber/slog-echo) | Go [slog](https://pkg.go.dev/golang.org/x/exp/slog) logging library wrapper for Echo logger interface. |
| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. |
| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. |
| [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code |
diff --git a/ROADMAP.md b/ROADMAP.md
deleted file mode 100644
index 874d50b0c..000000000
--- a/ROADMAP.md
+++ /dev/null
@@ -1,50 +0,0 @@
-# Echo Roadmap
-
-> **DRAFT** β this is a starting point for maintainers to edit, not a commitment.
-> Dates and priorities are owned by the Echo team. Open a discussion to propose changes.
-
-This document exists so the community can see where Echo is heading. Echo is
-**actively maintained**. We publish releases regularly across two supported
-lines β see [README](./README.md) badges for the latest version and most recent commit.
-
-## Version policy
-
-| Line | Status | Support |
-|------|--------|---------|
-| `v5` | **Current** (since 2026-01-18) | New features, fixes, and improvements. |
-| `v4` | Maintenance / LTS | **Security and bug fixes until 2026-12-31.** No new features. |
-
-Upgrading from v4? See [API_CHANGES_V5.md](./API_CHANGES_V5.md).
-
-Echo supports the **latest four Go major releases** and may work with older versions.
-
-## Now (in progress)
-
-- Stabilizing the `v5` API surface through point releases.
-- Documentation catch-up for v5 behavior changes (e.g. CORS / `RouteNotFound`
- behavior on groups β see #2950).
-- Triaging and reducing the open issue / PR backlog.
-
-## Next (under consideration)
-
-These are frequently-requested items being discussed. Inclusion here is **not** a
-commitment β each still needs design agreement before implementation:
-
-- **Automatic `HEAD` for `GET` routes** (#2895; see #2949) β opt-in, likely via an
- `OnAddRoute` hook so users keep control.
-- **Rate limiter response metadata** β expose `Retry-After` / remaining quota
- through the store interface (#2961).
-- **Real-IP / `Forwarded` header handling** improvements (#2744).
-- **Proxy middleware** authorization-header handling (#2787).
-
-## Later / exploratory
-
-- Continued alignment with the Go standard library (`net/http`, `slog`).
-- Reducing third-party surface where the stdlib now covers the need.
-
-## How to influence the roadmap
-
-- **Discuss before large PRs** β open a [Discussion](https://github.com/labstack/echo/discussions)
- or issue so we can agree on the design first.
-- π reactions on issues help us gauge demand.
-- See [README β Contribute](./README.md#contribute) for contribution guidelines.
diff --git a/SECURITY.md b/SECURITY.md
deleted file mode 100644
index efb618697..000000000
--- a/SECURITY.md
+++ /dev/null
@@ -1,15 +0,0 @@
-# Security Policy
-
-## Supported Versions
-
-| Version | Supported |
-|-----------|-------------------------------------|
-| 5.x.x | :white_check_mark: |
-| >= 4.15.x | :white_check_mark: until 2026.12.31 |
-| < 4.15 | :x: |
-
-## Reporting a Vulnerability
-
-https://github.com/labstack/echo/security/advisories/new
-
-or look for maintainers email(s) in commits and email them.
diff --git a/_fixture/dist/private.txt b/_fixture/dist/private.txt
deleted file mode 100644
index 0f9d2435b..000000000
--- a/_fixture/dist/private.txt
+++ /dev/null
@@ -1 +0,0 @@
-private file
diff --git a/_fixture/dist/public/assets/readme.md b/_fixture/dist/public/assets/readme.md
deleted file mode 100644
index 50590f554..000000000
--- a/_fixture/dist/public/assets/readme.md
+++ /dev/null
@@ -1 +0,0 @@
-readme in assets
diff --git a/_fixture/dist/public/assets/subfolder/subfolder.md b/_fixture/dist/public/assets/subfolder/subfolder.md
deleted file mode 100644
index 74c928b2f..000000000
--- a/_fixture/dist/public/assets/subfolder/subfolder.md
+++ /dev/null
@@ -1 +0,0 @@
-file inside subfolder
diff --git a/_fixture/dist/public/index.html b/_fixture/dist/public/index.html
deleted file mode 100644
index df6d9015a..000000000
--- a/_fixture/dist/public/index.html
+++ /dev/null
@@ -1 +0,0 @@
-
Hello from index
diff --git a/_fixture/dist/public/test.txt b/_fixture/dist/public/test.txt
deleted file mode 100644
index dd937160d..000000000
--- a/_fixture/dist/public/test.txt
+++ /dev/null
@@ -1 +0,0 @@
-test.txt contents
diff --git a/bind.go b/bind.go
index 4713afde2..1e1043c07 100644
--- a/bind.go
+++ b/bind.go
@@ -13,13 +13,12 @@ import (
"reflect"
"strconv"
"strings"
- "sync"
"time"
)
// Binder is the interface that wraps the Bind method.
type Binder interface {
- Bind(c *Context, target any) error
+ Bind(i any, c Context) error
}
// DefaultBinder is the default implementation of the Binder interface.
@@ -40,22 +39,31 @@ type bindMultipleUnmarshaler interface {
UnmarshalParams(params []string) error
}
-// BindPathValues binds path parameter values to bindable object
-func BindPathValues(c *Context, target any) error {
+// BindPathParams binds path params to bindable object
+//
+// Time format support: time.Time fields can use `format` tags to specify custom parsing layouts.
+// Example: `param:"created" format:"2006-01-02T15:04"` for datetime-local format
+// Example: `param:"date" format:"2006-01-02"` for date format
+// Uses Go's standard time format reference time: Mon Jan 2 15:04:05 MST 2006
+// Works with form data, query parameters, and path parameters (not JSON body)
+// Falls back to default time.Time parsing if no format tag is specified
+func (b *DefaultBinder) BindPathParams(c Context, i any) error {
+ names := c.ParamNames()
+ values := c.ParamValues()
params := map[string][]string{}
- for _, param := range c.PathValues() {
- params[param.Name] = []string{param.Value}
+ for i, name := range names {
+ params[name] = []string{values[i]}
}
- if err := bindData(target, params, "param", nil); err != nil {
- return ErrBadRequest.Wrap(err)
+ if err := b.bindData(i, params, "param", nil); err != nil {
+ return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
}
return nil
}
// BindQueryParams binds query params to bindable object
-func BindQueryParams(c *Context, target any) error {
- if err := bindData(target, c.QueryParams(), "query", nil); err != nil {
- return ErrBadRequest.Wrap(err)
+func (b *DefaultBinder) BindQueryParams(c Context, i any) error {
+ if err := b.bindData(i, c.QueryParams(), "query", nil); err != nil {
+ return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
}
return nil
}
@@ -65,7 +73,7 @@ func BindQueryParams(c *Context, target any) error {
// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm
// See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm
// See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm
-func BindBody(c *Context, target any) (err error) {
+func (b *DefaultBinder) BindBody(c Context, i any) (err error) {
req := c.Request()
if req.ContentLength == 0 {
return
@@ -77,52 +85,58 @@ func BindBody(c *Context, target any) (err error) {
switch mediatype {
case MIMEApplicationJSON:
- if err = c.Echo().JSONSerializer.Deserialize(c, target); err != nil {
- var hErr *HTTPError
- if errors.As(err, &hErr) {
+ if err = c.Echo().JSONSerializer.Deserialize(c, i); err != nil {
+ switch err.(type) {
+ case *HTTPError:
return err
+ default:
+ return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
}
- return ErrBadRequest.Wrap(err)
}
case MIMEApplicationXML, MIMETextXML:
- if err = xml.NewDecoder(req.Body).Decode(target); err != nil {
- return ErrBadRequest.Wrap(err)
+ if err = xml.NewDecoder(req.Body).Decode(i); err != nil {
+ if ute, ok := err.(*xml.UnsupportedTypeError); ok {
+ return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err)
+ } else if se, ok := err.(*xml.SyntaxError); ok {
+ return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())).SetInternal(err)
+ }
+ return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
}
case MIMEApplicationForm:
- params, err := c.FormValues()
+ params, err := c.FormParams()
if err != nil {
- return ErrBadRequest.Wrap(err)
+ return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
}
- if err = bindData(target, params, "form", nil); err != nil {
- return ErrBadRequest.Wrap(err)
+ if err = b.bindData(i, params, "form", nil); err != nil {
+ return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
}
case MIMEMultipartForm:
params, err := c.MultipartForm()
if err != nil {
- return ErrBadRequest.Wrap(err)
+ return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
}
- if err = bindData(target, params.Value, "form", params.File); err != nil {
- return ErrBadRequest.Wrap(err)
+ if err = b.bindData(i, params.Value, "form", params.File); err != nil {
+ return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
}
default:
- return &HTTPError{Code: http.StatusUnsupportedMediaType}
+ return ErrUnsupportedMediaType
}
return nil
}
// BindHeaders binds HTTP headers to a bindable object
-func BindHeaders(c *Context, target any) error {
- if err := bindData(target, c.Request().Header, "header", nil); err != nil {
- return ErrBadRequest.Wrap(err)
+func (b *DefaultBinder) BindHeaders(c Context, i any) error {
+ if err := b.bindData(i, c.Request().Header, "header", nil); err != nil {
+ return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
}
return nil
}
// Bind implements the `Binder#Bind` function.
// Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous
-// step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathValues.
-func (b *DefaultBinder) Bind(c *Context, target any) error {
- if err := BindPathValues(c, target); err != nil {
+// step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams.
+func (b *DefaultBinder) Bind(i any, c Context) (err error) {
+ if err := b.BindPathParams(c, i); err != nil {
return err
}
// Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body.
@@ -130,82 +144,15 @@ func (b *DefaultBinder) Bind(c *Context, target any) error {
// The HTTP method check restores pre-v4.1.11 behavior to avoid these problems (see issue #1670)
method := c.Request().Method
if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead {
- if err := BindQueryParams(c, target); err != nil {
+ if err = b.BindQueryParams(c, i); err != nil {
return err
}
}
- return BindBody(c, target)
-}
-
-// bindFieldMeta is the cached, type-level reflection metadata for a single struct field. Reading struct
-// tags (reflect.StructTag.Get) parses the tag string on every call, so for binding-heavy endpoints we
-// compute it once per struct type and reuse it across requests (see bindStructMeta). Only type-level data
-// is cached here; per-request, per-instance reflect.Value operations still happen in bindData.
-type bindFieldMeta struct {
- index int // field index within the struct
- // fieldKind is the DECLARED field kind (typeField.Type.Kind()), used only for unmarshal dispatch.
- // It is intentionally not the post-anonymous-pointer-deref live kind; bindData computes that
- // separately as structFieldKind where needed.
- fieldKind reflect.Kind
- anonymous bool // reflect.StructField.Anonymous
- formatTag string // value of the `format` struct tag
- // binding-source tag values. bindData is only ever called with one of these four tags (see the
- // callers BindPathValues/BindQueryParams/BindBody/BindHeaders). Keep these fields, the four
- // f.Tag.Get(...) lines in bindMetaFor, and the tagName switch in sync if a source is ever added.
- param, query, form, header string
-}
-
-// tagName returns the field's tag value for the given binding source tag.
-// Keep in sync with the tag fields above and the f.Tag.Get calls in bindMetaFor.
-func (m *bindFieldMeta) tagName(tag string) string {
- switch tag {
- case "param":
- return m.param
- case "query":
- return m.query
- case "form":
- return m.form
- case "header":
- return m.header
- default:
- return ""
- }
-}
-
-// bindStructMeta is the cached field metadata for a whole struct type, in declaration order.
-type bindStructMeta struct {
- fields []bindFieldMeta
-}
-
-// bindStructCache memoizes bindStructMeta keyed by struct reflect.Type. Concurrent double-computation is
-// harmless because the result is deterministic and idempotent.
-var bindStructCache sync.Map // map[reflect.Type]*bindStructMeta
-
-func bindMetaFor(typ reflect.Type) *bindStructMeta {
- if cached, ok := bindStructCache.Load(typ); ok {
- return cached.(*bindStructMeta)
- }
- n := typ.NumField()
- meta := &bindStructMeta{fields: make([]bindFieldMeta, n)}
- for i := 0; i < n; i++ {
- f := typ.Field(i)
- meta.fields[i] = bindFieldMeta{
- index: i,
- anonymous: f.Anonymous,
- fieldKind: f.Type.Kind(),
- formatTag: f.Tag.Get("format"),
- param: f.Tag.Get("param"),
- query: f.Tag.Get("query"),
- form: f.Tag.Get("form"),
- header: f.Tag.Get("header"),
- }
- }
- bindStructCache.Store(typ, meta)
- return meta
+ return b.BindBody(c, i)
}
// bindData will bind data ONLY fields in destination struct that have EXPLICIT tag
-func bindData(destination any, data map[string][]string, tag string, dataFiles map[string][]*multipart.FileHeader) error {
+func (b *DefaultBinder) bindData(destination any, data map[string][]string, tag string, dataFiles map[string][]*multipart.FileHeader) error {
if destination == nil || (len(data) == 0 && len(dataFiles) == 0) {
return nil
}
@@ -224,7 +171,7 @@ func bindData(destination any, data map[string][]string, tag string, dataFiles m
isElemInterface := k == reflect.Interface
isElemString := k == reflect.String
isElemSliceOfStrings := k == reflect.Slice && typ.Elem().Elem().Kind() == reflect.String
- if !isElemSliceOfStrings && !isElemString && !isElemInterface {
+ if !(isElemSliceOfStrings || isElemString || isElemInterface) {
return nil
}
if val.IsNil() {
@@ -253,12 +200,11 @@ func bindData(destination any, data map[string][]string, tag string, dataFiles m
return errors.New("binding element must be a struct")
}
- meta := bindMetaFor(typ)
- for fi := range meta.fields { // iterate over all destination fields
- fm := &meta.fields[fi]
- structField := val.Field(fm.index)
- if fm.anonymous {
- if structField.Kind() == reflect.Pointer {
+ for i := 0; i < typ.NumField(); i++ { // iterate over all destination fields
+ typeField := typ.Field(i)
+ structField := val.Field(i)
+ if typeField.Anonymous {
+ if structField.Kind() == reflect.Ptr {
structField = structField.Elem()
}
}
@@ -266,8 +212,8 @@ func bindData(destination any, data map[string][]string, tag string, dataFiles m
continue
}
structFieldKind := structField.Kind()
- inputFieldName := fm.tagName(tag)
- if fm.anonymous && structFieldKind == reflect.Struct && inputFieldName != "" {
+ inputFieldName := typeField.Tag.Get(tag)
+ if typeField.Anonymous && structFieldKind == reflect.Struct && inputFieldName != "" {
// if anonymous struct with query/param/form tags, report an error
return errors.New("query/param/form tags are not allowed with anonymous struct field")
}
@@ -276,7 +222,7 @@ func bindData(destination any, data map[string][]string, tag string, dataFiles m
// If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contain fields with tags).
// structs that implement BindUnmarshaler are bound only when they have explicit tag
if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct {
- if err := bindData(structField.Addr().Interface(), data, tag, dataFiles); err != nil {
+ if err := b.bindData(structField.Addr().Interface(), data, tag, dataFiles); err != nil {
return err
}
}
@@ -317,16 +263,17 @@ func bindData(destination any, data map[string][]string, tag string, dataFiles m
// but it is smart enough to handle niche cases like `*int`,`*[]string`,`[]*int` .
// try unmarshalling first, in case we're dealing with an alias to an array type
- if ok, err := unmarshalInputsToField(fm.fieldKind, inputValue, structField); ok {
+ if ok, err := unmarshalInputsToField(typeField.Type.Kind(), inputValue, structField); ok {
if err != nil {
- return fmt.Errorf("%s: %w", inputFieldName, err)
+ return err
}
continue
}
- if ok, err := unmarshalInputToField(fm.fieldKind, inputValue[0], structField, fm.formatTag); ok {
+ formatTag := typeField.Tag.Get("format")
+ if ok, err := unmarshalInputToField(typeField.Type.Kind(), inputValue[0], structField, formatTag); ok {
if err != nil {
- return fmt.Errorf("%s: %w", inputFieldName, err)
+ return err
}
continue
}
@@ -342,9 +289,9 @@ func bindData(destination any, data map[string][]string, tag string, dataFiles m
sliceOf := structField.Type().Elem().Kind()
numElems := len(inputValue)
slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
- for j := range numElems {
+ for j := 0; j < numElems; j++ {
if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil {
- return fmt.Errorf("%s: %w", inputFieldName, err)
+ return err
}
}
structField.Set(slice)
@@ -352,7 +299,7 @@ func bindData(destination any, data map[string][]string, tag string, dataFiles m
}
if err := setWithProperType(structFieldKind, inputValue[0], structField); err != nil {
- return fmt.Errorf("%s: %w", inputFieldName, err)
+ return err
}
}
return nil
@@ -366,7 +313,7 @@ func setWithProperType(valueKind reflect.Kind, val string, structField reflect.V
}
switch valueKind {
- case reflect.Pointer:
+ case reflect.Ptr:
return setWithProperType(structField.Elem().Kind(), val, structField.Elem())
case reflect.Int:
return setIntField(val, 0, structField)
@@ -403,7 +350,7 @@ func setWithProperType(valueKind reflect.Kind, val string, structField reflect.V
}
func unmarshalInputsToField(valueKind reflect.Kind, values []string, field reflect.Value) (bool, error) {
- if valueKind == reflect.Pointer {
+ if valueKind == reflect.Ptr {
if field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
@@ -419,7 +366,7 @@ func unmarshalInputsToField(valueKind reflect.Kind, values []string, field refle
}
func unmarshalInputToField(valueKind reflect.Kind, val string, field reflect.Value, formatTag string) (bool, error) {
- if valueKind == reflect.Pointer {
+ if valueKind == reflect.Ptr {
if field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
@@ -427,6 +374,7 @@ func unmarshalInputToField(valueKind reflect.Kind, val string, field reflect.Val
}
fieldIValue := field.Addr().Interface()
+
// Handle time.Time with custom format tag
if formatTag != "" {
if _, isTime := fieldIValue.(*time.Time); isTime {
diff --git a/bind_cache_test.go b/bind_cache_test.go
deleted file mode 100644
index f3da8a13e..000000000
--- a/bind_cache_test.go
+++ /dev/null
@@ -1,31 +0,0 @@
-// SPDX-License-Identifier: MIT
-// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
-
-package echo
-
-import (
- "net/http"
- "net/http/httptest"
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-// TestBindCachedMetaPreservesFieldNameError ensures the per-type bind metadata cache preserves the
-// field-name prefix in conversion errors on BOTH the cold (first) and warm (cached) bind of a type.
-// DTO is declared locally so its reflect.Type is independent of suite ordering, making the second
-// bind a deterministic cache hit (the bindMetaFor Load branch).
-func TestBindCachedMetaPreservesFieldNameError(t *testing.T) {
- type DTO struct {
- Number int `query:"number"`
- }
- bind := func() error {
- e := New()
- req := httptest.NewRequest(http.MethodGet, "/?number=10a", nil)
- var dto DTO
- return e.NewContext(req, httptest.NewRecorder()).Bind(&dto)
- }
-
- assert.ErrorContains(t, bind(), "number", "cold cache: error must carry field name")
- assert.ErrorContains(t, bind(), "number", "warm cache: error must still carry field name")
-}
diff --git a/bind_field_error_test.go b/bind_field_error_test.go
deleted file mode 100644
index 8bf63b58d..000000000
--- a/bind_field_error_test.go
+++ /dev/null
@@ -1,32 +0,0 @@
-// SPDX-License-Identifier: MIT
-// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
-
-package echo
-
-import (
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-// Regression test for #2629: when binding form data fails a type conversion, the
-// returned error must identify which field failed (so applications can render a
-// useful message), instead of a bare strconv error with no field context.
-func TestBind_formConversionErrorIncludesFieldName(t *testing.T) {
- e := New()
- req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("number=10a"))
- req.Header.Set(HeaderContentType, MIMEApplicationForm)
- c := e.NewContext(req, httptest.NewRecorder())
-
- type DTO struct {
- Number int `form:"number"`
- }
- var dto DTO
- err := c.Bind(&dto)
-
- assert.Error(t, err)
- assert.ErrorContains(t, err, "number", "bind error must identify the failing field")
-}
diff --git a/bind_test.go b/bind_test.go
index e33298b39..6cc597b33 100644
--- a/bind_test.go
+++ b/bind_test.go
@@ -25,79 +25,79 @@ import (
)
type bindTestStruct struct {
- T Timestamp
- GoT time.Time
+ I int
+ PtrI *int
+ I8 int8
+ PtrI8 *int8
+ I16 int16
PtrI16 *int16
- PtrUI *uint
- Tptr *Timestamp
- PtrF32 *float32
- PtrB *bool
+ I32 int32
PtrI32 *int32
- GoTptr *time.Time
+ I64 int64
PtrI64 *int64
- PtrI *int
- PtrI8 *int8
- PtrF64 *float64
+ UI uint
+ PtrUI *uint
+ UI8 uint8
PtrUI8 *uint8
- PtrUI64 *uint64
+ UI16 uint16
PtrUI16 *uint16
- PtrS *string
+ UI32 uint32
PtrUI32 *uint32
+ UI64 uint64
+ PtrUI64 *uint64
+ B bool
+ PtrB *bool
+ F32 float32
+ PtrF32 *float32
+ F64 float64
+ PtrF64 *float64
S string
+ PtrS *string
cantSet string
DoesntExist string
+ GoT time.Time
+ GoTptr *time.Time
+ T Timestamp
+ Tptr *Timestamp
SA StringArray
- F64 float64
- I int
- UI64 uint64
- UI uint
- I64 int64
- F32 float32
- UI32 uint32
- I32 int32
- UI16 uint16
- I16 int16
- B bool
- UI8 uint8
- I8 int8
}
type bindTestStructWithTags struct {
- T Timestamp `json:"T" form:"T"`
- GoT time.Time `json:"GoT" form:"GoT"`
- PtrI16 *int16 `json:"PtrI16" form:"PtrI16"`
- PtrUI *uint `json:"PtrUI" form:"PtrUI"`
- Tptr *Timestamp `json:"Tptr" form:"Tptr"`
- PtrF32 *float32 `json:"PtrF32" form:"PtrF32"`
- PtrB *bool `json:"PtrB" form:"PtrB"`
- PtrI32 *int32 `json:"PtrI32" form:"PtrI32"`
- GoTptr *time.Time `json:"GoTptr" form:"GoTptr"`
- PtrI64 *int64 `json:"PtrI64" form:"PtrI64"`
- PtrI *int `json:"PtrI" form:"PtrI"`
- PtrI8 *int8 `json:"PtrI8" form:"PtrI8"`
- PtrF64 *float64 `json:"PtrF64" form:"PtrF64"`
- PtrUI8 *uint8 `json:"PtrUI8" form:"PtrUI8"`
- PtrUI64 *uint64 `json:"PtrUI64" form:"PtrUI64"`
- PtrUI16 *uint16 `json:"PtrUI16" form:"PtrUI16"`
- PtrS *string `json:"PtrS" form:"PtrS"`
- PtrUI32 *uint32 `json:"PtrUI32" form:"PtrUI32"`
- S string `json:"S" form:"S"`
+ I int `json:"I" form:"I"`
+ PtrI *int `json:"PtrI" form:"PtrI"`
+ I8 int8 `json:"I8" form:"I8"`
+ PtrI8 *int8 `json:"PtrI8" form:"PtrI8"`
+ I16 int16 `json:"I16" form:"I16"`
+ PtrI16 *int16 `json:"PtrI16" form:"PtrI16"`
+ I32 int32 `json:"I32" form:"I32"`
+ PtrI32 *int32 `json:"PtrI32" form:"PtrI32"`
+ I64 int64 `json:"I64" form:"I64"`
+ PtrI64 *int64 `json:"PtrI64" form:"PtrI64"`
+ UI uint `json:"UI" form:"UI"`
+ PtrUI *uint `json:"PtrUI" form:"PtrUI"`
+ UI8 uint8 `json:"UI8" form:"UI8"`
+ PtrUI8 *uint8 `json:"PtrUI8" form:"PtrUI8"`
+ UI16 uint16 `json:"UI16" form:"UI16"`
+ PtrUI16 *uint16 `json:"PtrUI16" form:"PtrUI16"`
+ UI32 uint32 `json:"UI32" form:"UI32"`
+ PtrUI32 *uint32 `json:"PtrUI32" form:"PtrUI32"`
+ UI64 uint64 `json:"UI64" form:"UI64"`
+ PtrUI64 *uint64 `json:"PtrUI64" form:"PtrUI64"`
+ B bool `json:"B" form:"B"`
+ PtrB *bool `json:"PtrB" form:"PtrB"`
+ F32 float32 `json:"F32" form:"F32"`
+ PtrF32 *float32 `json:"PtrF32" form:"PtrF32"`
+ F64 float64 `json:"F64" form:"F64"`
+ PtrF64 *float64 `json:"PtrF64" form:"PtrF64"`
+ S string `json:"S" form:"S"`
+ PtrS *string `json:"PtrS" form:"PtrS"`
cantSet string
DoesntExist string `json:"DoesntExist" form:"DoesntExist"`
+ GoT time.Time `json:"GoT" form:"GoT"`
+ GoTptr *time.Time `json:"GoTptr" form:"GoTptr"`
+ T Timestamp `json:"T" form:"T"`
+ Tptr *Timestamp `json:"Tptr" form:"Tptr"`
SA StringArray `json:"SA" form:"SA"`
- F64 float64 `json:"F64" form:"F64"`
- I int `json:"I" form:"I"`
- UI64 uint64 `json:"UI64" form:"UI64"`
- UI uint `json:"UI" form:"UI"`
- I64 int64 `json:"I64" form:"I64"`
- F32 float32 `json:"F32" form:"F32"`
- UI32 uint32 `json:"UI32" form:"UI32"`
- I32 int32 `json:"I32" form:"I32"`
- UI16 uint16 `json:"UI16" form:"UI16"`
- I16 int16 `json:"I16" form:"I16"`
- B bool `json:"B" form:"B"`
- UI8 uint8 `json:"UI8" form:"UI8"`
- I8 int8 `json:"I8" form:"I8"`
}
type Timestamp time.Time
@@ -283,7 +283,7 @@ func TestBindHeaderParam(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
u := new(user)
- err := BindHeaders(c, u)
+ err := (&DefaultBinder{}).BindHeaders(c, u)
if assert.NoError(t, err) {
assert.Equal(t, 2, u.ID)
assert.Equal(t, "Jon Doe", u.Name)
@@ -297,7 +297,7 @@ func TestBindHeaderParamBadType(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
u := new(user)
- err := BindHeaders(c, u)
+ err := (&DefaultBinder{}).BindHeaders(c, u)
assert.Error(t, err)
httpErr, ok := err.(*HTTPError)
@@ -312,13 +312,13 @@ func TestBindUnmarshalParam(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
- T Timestamp `query:"ts"`
+ T Timestamp `query:"ts"`
+ TA []Timestamp `query:"ta"`
+ SA StringArray `query:"sa"`
ST Struct
StWithTag struct {
Foo string `query:"st"`
}
- TA []Timestamp `query:"ta"`
- SA StringArray `query:"sa"`
}{}
err := c.Bind(&result)
ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC))
@@ -339,10 +339,10 @@ func TestBindUnmarshalText(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
- T time.Time `query:"ts"`
- ST Struct
+ T time.Time `query:"ts"`
TA []time.Time `query:"ta"`
SA StringArray `query:"sa"`
+ ST Struct
}{}
err := c.Bind(&result)
ts := time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)
@@ -447,7 +447,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) {
t.Run("ok, bind to map[string]string", func(t *testing.T) {
dest := map[string]string{}
- assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil))
assert.Equal(t,
map[string]string{
"multiple": "1",
@@ -459,7 +459,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) {
t.Run("ok, bind to map[string]string with nil map", func(t *testing.T) {
var dest map[string]string
- assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil))
assert.Equal(t,
map[string]string{
"multiple": "1",
@@ -471,7 +471,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) {
t.Run("ok, bind to map[string][]string", func(t *testing.T) {
dest := map[string][]string{}
- assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil))
assert.Equal(t,
map[string][]string{
"multiple": {"1", "2"},
@@ -483,7 +483,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) {
t.Run("ok, bind to map[string][]string with nil map", func(t *testing.T) {
var dest map[string][]string
- assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil))
assert.Equal(t,
map[string][]string{
"multiple": {"1", "2"},
@@ -495,7 +495,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) {
t.Run("ok, bind to map[string]interface", func(t *testing.T) {
dest := map[string]any{}
- assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil))
assert.Equal(t,
map[string]any{
"multiple": "1",
@@ -507,7 +507,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) {
t.Run("ok, bind to map[string]interface with nil map", func(t *testing.T) {
var dest map[string]any
- assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil))
assert.Equal(t,
map[string]any{
"multiple": "1",
@@ -519,32 +519,33 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) {
t.Run("ok, bind to map[string]int skips", func(t *testing.T) {
dest := map[string]int{}
- assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil))
assert.Equal(t, map[string]int{}, dest)
})
t.Run("ok, bind to map[string]int skips with nil map", func(t *testing.T) {
var dest map[string]int
- assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil))
assert.Equal(t, map[string]int(nil), dest)
})
t.Run("ok, bind to map[string][]int skips", func(t *testing.T) {
dest := map[string][]int{}
- assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil))
assert.Equal(t, map[string][]int{}, dest)
})
t.Run("ok, bind to map[string][]int skips with nil map", func(t *testing.T) {
var dest map[string][]int
- assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil))
assert.Equal(t, map[string][]int(nil), dest)
})
}
func TestBindbindData(t *testing.T) {
ts := new(bindTestStruct)
- err := bindData(ts, values, "form", nil)
+ b := new(DefaultBinder)
+ err := b.bindData(ts, values, "form", nil)
assert.NoError(t, err)
assert.Equal(t, 0, ts.I)
@@ -569,13 +570,9 @@ func TestBindParam(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- c.InitializeRoute(
- &RouteInfo{Path: "/users/:id/:name"},
- &PathValues{
- {Name: "id", Value: "1"},
- {Name: "name", Value: "Jon Snow"},
- },
- )
+ c.SetPath("/users/:id/:name")
+ c.SetParamNames("id", "name")
+ c.SetParamValues("1", "Jon Snow")
u := new(user)
err := c.Bind(u)
@@ -586,12 +583,9 @@ func TestBindParam(t *testing.T) {
// Second test for the absence of a param
c2 := e.NewContext(req, rec)
- c2.InitializeRoute(
- &RouteInfo{Path: "/users/:id"},
- &PathValues{
- {Name: "id", Value: "1"},
- },
- )
+ c2.SetPath("/users/:id")
+ c2.SetParamNames("id")
+ c2.SetParamValues("1")
u = new(user)
err = c2.Bind(u)
@@ -609,12 +603,9 @@ func TestBindParam(t *testing.T) {
rec2 := httptest.NewRecorder()
c3 := e2.NewContext(req2, rec2)
- c3.InitializeRoute(
- &RouteInfo{Path: "/users/:id"},
- &PathValues{
- {Name: "id", Value: "1"},
- },
- )
+ c3.SetPath("/users/:id")
+ c3.SetParamNames("id")
+ c3.SetParamValues("1")
u = new(user)
err = c3.Bind(u)
@@ -636,12 +627,14 @@ func TestBindUnmarshalTypeError(t *testing.T) {
err := c.Bind(u)
- assert.EqualError(t, err, `code=400, message=Bad Request, err=json: cannot unmarshal string into Go struct field user.id of type int`)
+ he := &HTTPError{Code: http.StatusBadRequest, Message: "Unmarshal type error: expected=int, got=string, field=id, offset=14", Internal: err.(*HTTPError).Internal}
+
+ assert.Equal(t, he, err)
}
func TestBindSetWithProperType(t *testing.T) {
ts := new(bindTestStruct)
- typ := reflect.TypeFor[bindTestStruct]()
+ typ := reflect.TypeOf(ts).Elem()
val := reflect.ValueOf(ts).Elem()
for i := 0; i < typ.NumField(); i++ {
typeField := typ.Field(i)
@@ -662,7 +655,7 @@ func TestBindSetWithProperType(t *testing.T) {
Bar bytes.Buffer
}
v := &foo{}
- typ = reflect.TypeFor[foo]()
+ typ = reflect.TypeOf(v).Elem()
val = reflect.ValueOf(v).Elem()
assert.Error(t, setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0)))
}
@@ -670,10 +663,11 @@ func TestBindSetWithProperType(t *testing.T) {
func BenchmarkBindbindDataWithTags(b *testing.B) {
b.ReportAllocs()
ts := new(bindTestStructWithTags)
+ binder := new(DefaultBinder)
var err error
b.ResetTimer()
for i := 0; i < b.N; i++ {
- err = bindData(ts, values, "form", nil)
+ err = binder.bindData(ts, values, "form", nil)
}
assert.NoError(b, err)
assertBindTestStruct(b, (*bindTestStruct)(ts))
@@ -748,36 +742,36 @@ func testBindError(t *testing.T, r io.Reader, ctype string, expectedInternal err
strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm):
if assert.IsType(t, new(HTTPError), err) {
assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).Code)
- assert.IsType(t, expectedInternal, err.(*HTTPError).Unwrap())
+ assert.IsType(t, expectedInternal, err.(*HTTPError).Internal)
}
default:
if assert.IsType(t, new(HTTPError), err) {
assert.Equal(t, ErrUnsupportedMediaType, err)
- assert.IsType(t, expectedInternal, err.(*HTTPError).Unwrap())
+ assert.IsType(t, expectedInternal, err.(*HTTPError).Internal)
}
}
}
func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
// tests to check binding behaviour when multiple sources (path params, query params and request body) are in use
- // binding is done in steps and one source could overwrite previous source bound data
+ // binding is done in steps and one source could overwrite previous source binded data
// these tests are to document this behaviour and detect further possible regressions when bind implementation is changed
type Opts struct {
+ ID int `json:"id" form:"id" query:"id"`
Node string `json:"node" form:"node" query:"node" param:"node"`
Lang string
- ID int `json:"id" form:"id" query:"id"`
}
var testCases = []struct {
- givenContent io.Reader
- whenBindTarget any
- expect any
name string
givenURL string
+ givenContent io.Reader
givenMethod string
+ whenBindTarget any
+ whenNoPathParams bool
+ expect any
expectError string
- whenNoPathValues bool
}{
{
name: "ok, POST bind to struct with: path param + query param + body",
@@ -805,14 +799,14 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
givenMethod: http.MethodGet,
givenURL: "/api/real_node/endpoint?node=xxx",
givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
- expect: &Opts{ID: 1, Node: "zzz"}, // body is bound last and overwrites previous (path,query) values
+ expect: &Opts{ID: 1, Node: "zzz"}, // body is binded last and overwrites previous (path,query) values
},
{
name: "ok, DELETE bind to struct with: path param + query param + body",
givenMethod: http.MethodDelete,
givenURL: "/api/real_node/endpoint?node=xxx",
givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
- expect: &Opts{ID: 1, Node: "zzz"}, // for DELETE body is bound after query params
+ expect: &Opts{ID: 1, Node: "zzz"}, // for DELETE body is binded after query params
},
{
name: "ok, POST bind to struct with: path param + body",
@@ -834,7 +828,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
givenURL: "/api/real_node/endpoint?node=xxx",
givenContent: strings.NewReader(`{`),
expect: &Opts{ID: 0, Node: "node_from_path"}, // query binding has already modified bind target
- expectError: "code=400, message=Bad Request, err=unexpected EOF",
+ expectError: "code=400, message=unexpected EOF, internal=unexpected EOF",
},
{
name: "nok, GET with body bind failure when types are not convertible",
@@ -842,7 +836,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
givenURL: "/api/real_node/endpoint?id=nope",
givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
expect: &Opts{ID: 0, Node: "node_from_path"}, // path params binding has already modified bind target
- expectError: `code=400, message=Bad Request, err=id: strconv.ParseInt: parsing "nope": invalid syntax`,
+ expectError: "code=400, message=strconv.ParseInt: parsing \"nope\": invalid syntax, internal=strconv.ParseInt: parsing \"nope\": invalid syntax",
},
{
name: "nok, GET body bind failure - trying to bind json array to struct",
@@ -850,14 +844,14 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
givenURL: "/api/real_node/endpoint?node=xxx",
givenContent: strings.NewReader(`[{"id": 1}]`),
expect: &Opts{ID: 0, Node: "xxx"}, // query binding has already modified bind target
- expectError: `code=400, message=Bad Request, err=json: cannot unmarshal array into Go value of type echo.Opts`,
+ expectError: "code=400, message=Unmarshal type error: expected=echo.Opts, got=array, field=, offset=1, internal=json: cannot unmarshal array into Go value of type echo.Opts",
},
{ // query param is ignored as we do not know where exactly to bind it in slice
name: "ok, GET bind to struct slice, ignore query param",
givenMethod: http.MethodGet,
givenURL: "/api/real_node/endpoint?node=xxx",
givenContent: strings.NewReader(`[{"id": 1}]`),
- whenNoPathValues: true,
+ whenNoPathParams: true,
whenBindTarget: &[]Opts{},
expect: &[]Opts{
{ID: 1, Node: ""},
@@ -868,7 +862,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
givenMethod: http.MethodPost,
givenURL: "/api/real_node/endpoint?id=nope&node=xxx",
givenContent: strings.NewReader(`[{"id": 1}]`),
- whenNoPathValues: true,
+ whenNoPathParams: true,
whenBindTarget: &[]Opts{},
expect: &[]Opts{{ID: 1}},
expectError: "",
@@ -888,7 +882,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
givenMethod: http.MethodGet,
givenURL: "/api/real_node/endpoint",
givenContent: strings.NewReader(`[{"id": 1}]`),
- whenNoPathValues: true,
+ whenNoPathParams: true,
whenBindTarget: &[]Opts{},
expect: &[]Opts{{ID: 1, Node: ""}},
expectError: "",
@@ -904,10 +898,9 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- if !tc.whenNoPathValues {
- c.SetPathValues(PathValues{
- {Name: "node", Value: "node_from_path"},
- })
+ if !tc.whenNoPathParams {
+ c.SetParamNames("node")
+ c.SetParamValues("node_from_path")
}
var bindTarget any
@@ -918,7 +911,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
}
b := new(DefaultBinder)
- err := b.Bind(c, bindTarget)
+ err := b.Bind(bindTarget, c)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
@@ -931,28 +924,28 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
func TestDefaultBinder_BindBody(t *testing.T) {
// tests to check binding behaviour when multiple sources (path params, query params and request body) are in use
- // generally when binding from request body - URL and path params are ignored - unless form is being bound.
+ // generally when binding from request body - URL and path params are ignored - unless form is being binded.
// these tests are to document this behaviour and detect further possible regressions when bind implementation is changed
type Node struct {
- Node string `json:"node" xml:"node" form:"node" query:"node" param:"node"`
ID int `json:"id" xml:"id" form:"id" query:"id"`
+ Node string `json:"node" xml:"node" form:"node" query:"node" param:"node"`
}
type Nodes struct {
Nodes []Node `xml:"node" form:"node"`
}
var testCases = []struct {
- givenContent io.Reader
- whenBindTarget any
- expect any
name string
givenURL string
+ givenContent io.Reader
givenMethod string
givenContentType string
- expectError string
- whenNoPathValues bool
+ whenNoPathParams bool
whenChunkedBody bool
+ whenBindTarget any
+ expect any
+ expectError string
}{
{
name: "ok, JSON POST bind to struct with: path + query + empty field in body",
@@ -976,7 +969,7 @@ func TestDefaultBinder_BindBody(t *testing.T) {
givenMethod: http.MethodPost,
givenContentType: MIMEApplicationJSON,
givenContent: strings.NewReader(`[{"id": 1}]`),
- whenNoPathValues: true,
+ whenNoPathParams: true,
whenBindTarget: &[]Node{},
expect: &[]Node{{ID: 1, Node: ""}},
expectError: "",
@@ -1004,7 +997,7 @@ func TestDefaultBinder_BindBody(t *testing.T) {
givenContentType: MIMEApplicationJSON,
givenContent: strings.NewReader(`{`),
expect: &Node{ID: 0, Node: ""},
- expectError: "code=400, message=Bad Request, err=unexpected EOF",
+ expectError: "code=400, message=unexpected EOF, internal=unexpected EOF",
},
{
name: "ok, XML POST bind to struct with: path + query + empty body",
@@ -1030,7 +1023,7 @@ func TestDefaultBinder_BindBody(t *testing.T) {
givenContentType: MIMEApplicationXML,
givenContent: strings.NewReader(`<`),
expect: &Node{ID: 0, Node: ""},
- expectError: "code=400, message=Bad Request, err=XML syntax error on line 1: unexpected EOF",
+ expectError: "code=400, message=Syntax error: line=1, error=XML syntax error on line 1: unexpected EOF, internal=XML syntax error on line 1: unexpected EOF",
},
{
name: "ok, FORM POST bind to struct with: path + query + body",
@@ -1120,10 +1113,9 @@ func TestDefaultBinder_BindBody(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- if !tc.whenNoPathValues {
- c.SetPathValues(PathValues{
- {Name: "node", Value: "real_node"},
- })
+ if !tc.whenNoPathParams {
+ c.SetParamNames("node")
+ c.SetParamValues("real_node")
}
var bindTarget any
@@ -1132,8 +1124,9 @@ func TestDefaultBinder_BindBody(t *testing.T) {
} else {
bindTarget = &Node{}
}
+ b := new(DefaultBinder)
- err := BindBody(c, bindTarget)
+ err := b.BindBody(c, bindTarget)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
@@ -1196,7 +1189,7 @@ func TestBindUnmarshalParamExtras(t *testing.T) {
}{}
err := testBindURL("/?t=xxxx", &result)
- assert.EqualError(t, err, `code=400, message=Bad Request, err=t: 'xxxx' is not an integer`)
+ assert.EqualError(t, err, "code=400, message='xxxx' is not an integer, internal='xxxx' is not an integer")
})
t.Run("ok, target is struct", func(t *testing.T) {
@@ -1301,7 +1294,7 @@ func TestBindUnmarshalParams(t *testing.T) {
}{}
err := testBindURL("/?t=xxxx", &result)
- assert.EqualError(t, err, "code=400, message=Bad Request, err=t: 'xxxx' is not an integer")
+ assert.EqualError(t, err, "code=400, message='xxxx' is not an integer, internal='xxxx' is not an integer")
})
t.Run("ok, target is struct", func(t *testing.T) {
@@ -1368,7 +1361,7 @@ func TestBindInt8(t *testing.T) {
}
p := target{}
err := testBindURL("/?v=x&v=2", &p)
- assert.EqualError(t, err, `code=400, message=Bad Request, err=v: strconv.ParseInt: parsing "x": invalid syntax`)
+ assert.EqualError(t, err, "code=400, message=strconv.ParseInt: parsing \"x\": invalid syntax, internal=strconv.ParseInt: parsing \"x\": invalid syntax")
})
t.Run("nok, int8 embedded in struct", func(t *testing.T) {
@@ -1476,7 +1469,7 @@ func TestBindMultipartFormFiles(t *testing.T) {
}
err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored
- assert.EqualError(t, err, `code=400, message=Bad Request, err=binding to multipart.FileHeader struct is not supported, use pointer to struct`)
+ assert.EqualError(t, err, "code=400, message=binding to multipart.FileHeader struct is not supported, use pointer to struct, internal=binding to multipart.FileHeader struct is not supported, use pointer to struct")
})
t.Run("ok, bind single multipart file to pointer to multipart file", func(t *testing.T) {
@@ -1584,7 +1577,7 @@ func TestTimeFormatBinding(t *testing.T) {
DateTimeLocal time.Time `form:"datetime_local" format:"2006-01-02T15:04"`
Date time.Time `query:"date" format:"2006-01-02"`
CustomFormat time.Time `form:"custom" format:"01/02/2006 15:04:05"`
- DefaultTime time.Time `form:"default_time"` // No format tag - should use default parsing
+ DefaultTime time.Time `form:"default_time"` // No format tag - should use default parsing
PtrTime *time.Time `query:"ptr_time" format:"2006-01-02"`
}
@@ -1630,7 +1623,7 @@ func TestTimeFormatBinding(t *testing.T) {
{
name: "nok, wrong format should fail",
contentType: MIMEApplicationForm,
- data: "datetime_local=2023-12-25", // Missing time part
+ data: "datetime_local=2023-12-25", // Missing time part
expectError: true,
},
}
diff --git a/binder.go b/binder.go
index 7ae3709fa..300c00934 100644
--- a/binder.go
+++ b/binder.go
@@ -16,7 +16,7 @@ import (
/**
Following functions provide handful of methods for binding to Go native types from request query or path parameters.
* QueryParamsBinder(c) - binds query parameters (source URL)
- * PathValuesBinder(c) - binds path parameters (source URL)
+ * PathParamsBinder(c) - binds path parameters (source URL)
* FormFieldBinder(c) - binds form fields (source URL + body)
Example:
@@ -66,9 +66,6 @@ import (
*/
// BindingError represents an error that occurred while binding request data.
-//
-// Note: JSON serialization is handled by the MarshalJSON method below, not by the
-// struct tags (which are kept for documentation). MarshalJSON emits {"field","message"}.
type BindingError struct {
// Field is the field name where value binding failed
Field string `json:"field"`
@@ -78,11 +75,15 @@ type BindingError struct {
}
// NewBindingError creates new instance of binding error
-func NewBindingError(sourceParam string, values []string, message string, err error) error {
+func NewBindingError(sourceParam string, values []string, message any, internalError error) error {
return &BindingError{
- Field: sourceParam,
- Values: values,
- HTTPError: &HTTPError{Code: http.StatusBadRequest, Message: message, err: err},
+ Field: sourceParam,
+ Values: values,
+ HTTPError: &HTTPError{
+ Code: http.StatusBadRequest,
+ Message: message,
+ Internal: internalError,
+ },
}
}
@@ -91,24 +92,6 @@ func (be *BindingError) Error() string {
return fmt.Sprintf("%s, field=%s", be.HTTPError.Error(), be.Field)
}
-// MarshalJSON implements json.Marshaler so that binding errors are serialized into
-// a structured response (e.g. {"field":"id","message":"..."}) rather than being
-// flattened to a generic message. DefaultHTTPErrorHandler routes errors that
-// implement json.Marshaler through their own encoding.
-func (be *BindingError) MarshalJSON() ([]byte, error) {
- message := be.Message
- if message == "" {
- message = http.StatusText(be.Code)
- }
- return json.Marshal(struct {
- Field string `json:"field"`
- Message string `json:"message"`
- }{
- Field: be.Field,
- Message: message,
- })
-}
-
// ValueBinder provides utility methods for binding query or path parameter to various Go built-in types
type ValueBinder struct {
// ValueFunc is used to get single parameter (first) value from request
@@ -116,14 +99,14 @@ type ValueBinder struct {
// ValuesFunc is used to get all values for parameter from request. i.e. `/api/search?ids=1&ids=2`
ValuesFunc func(sourceParam string) []string
// ErrorFunc is used to create errors. Allows you to use your own error type, that for example marshals to your specific json response
- ErrorFunc func(sourceParam string, values []string, message string, internalError error) error
+ ErrorFunc func(sourceParam string, values []string, message any, internalError error) error
errors []error
// failFast is flag for binding methods to return without attempting to bind when previous binding already failed
failFast bool
}
// QueryParamsBinder creates query parameter value binder
-func QueryParamsBinder(c *Context) *ValueBinder {
+func QueryParamsBinder(c Context) *ValueBinder {
return &ValueBinder{
failFast: true,
ValueFunc: c.QueryParam,
@@ -138,8 +121,8 @@ func QueryParamsBinder(c *Context) *ValueBinder {
}
}
-// PathValuesBinder creates path parameter value binder
-func PathValuesBinder(c *Context) *ValueBinder {
+// PathParamsBinder creates path parameter value binder
+func PathParamsBinder(c Context) *ValueBinder {
return &ValueBinder{
failFast: true,
ValueFunc: c.Param,
@@ -165,7 +148,7 @@ func PathValuesBinder(c *Context) *ValueBinder {
// NB: when binding forms take note that this implementation uses standard library form parsing
// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm
// See https://golang.org/pkg/net/http/#Request.ParseForm
-func FormFieldBinder(c *Context) *ValueBinder {
+func FormFieldBinder(c Context) *ValueBinder {
vb := &ValueBinder{
failFast: true,
ValueFunc: func(sourceParam string) string {
@@ -176,7 +159,7 @@ func FormFieldBinder(c *Context) *ValueBinder {
vb.ValuesFunc = func(sourceParam string) []string {
if c.Request().Form == nil {
// this is same as `Request().FormValue()` does internally
- _, _ = c.MultipartForm() // we want to trigger c.request.ParseMultipartForm(c.formParseMaxMemory)
+ _ = c.Request().ParseMultipartForm(32 << 20)
}
values, ok := c.Request().Form[sourceParam]
if !ok {
@@ -548,11 +531,11 @@ func (b *ValueBinder) int(sourceParam string, value string, dest any, bitSize in
case *int64:
*d = n
case *int32:
- *d = int32(n) // #nosec G115
+ *d = int32(n)
case *int16:
- *d = int16(n) // #nosec G115
+ *d = int16(n)
case *int8:
- *d = int8(n) // #nosec G115
+ *d = int8(n)
case *int:
*d = int(n)
}
@@ -776,13 +759,13 @@ func (b *ValueBinder) uint(sourceParam string, value string, dest any, bitSize i
case *uint64:
*d = n
case *uint32:
- *d = uint32(n) // #nosec G115
+ *d = uint32(n)
case *uint16:
- *d = uint16(n) // #nosec G115
+ *d = uint16(n)
case *uint8: // byte is alias to uint8
- *d = uint8(n) // #nosec G115
+ *d = uint8(n)
case *uint:
- *d = uint(n) // #nosec G115
+ *d = uint(n)
}
return b
}
@@ -1260,7 +1243,7 @@ func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder
return b.unixTime(sourceParam, dest, false, time.Second)
}
-// MustUnixTime requires parameter value to exist to bind to time.Time variable (in local time corresponding
+// MustUnixTime requires parameter value to exist to bind to time.Duration variable (in local time corresponding
// to the given Unix time). Returns error when value does not exist.
//
// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00
@@ -1281,7 +1264,7 @@ func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueB
return b.unixTime(sourceParam, dest, false, time.Millisecond)
}
-// MustUnixTimeMilli requires parameter value to exist to bind to time.Time variable (in local time corresponding
+// MustUnixTimeMilli requires parameter value to exist to bind to time.Duration variable (in local time corresponding
// to the given Unix time in millisecond precision). Returns error when value does not exist.
//
// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
@@ -1305,8 +1288,8 @@ func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBi
return b.unixTime(sourceParam, dest, false, time.Nanosecond)
}
-// MustUnixTimeNano requires parameter value to exist to bind to time.Time variable (in local time corresponding
-// to the given Unix time value in nanosecond precision). Returns error when value does not exist.
+// MustUnixTimeNano requires parameter value to exist to bind to time.Duration variable (in local Time corresponding
+// to the given Unix time value in nano second precision). Returns error when value does not exist.
//
// Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00
// Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00
diff --git a/binder_error_response_test.go b/binder_error_response_test.go
deleted file mode 100644
index 0ed077684..000000000
--- a/binder_error_response_test.go
+++ /dev/null
@@ -1,46 +0,0 @@
-// SPDX-License-Identifier: MIT
-// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
-
-package echo
-
-import (
- "encoding/json"
- "net/http"
- "net/http/httptest"
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-// Regression test for #2771: a BindingError returned from a handler must be
-// serialized by DefaultHTTPErrorHandler into a structured response that retains
-// the field name (and the binder message), not flattened to {"message":"Bad Request"}.
-func TestBindingError_serializesToStructuredJSON(t *testing.T) {
- e := New()
- e.GET("/doc", func(c *Context) error {
- var docNum int
- return QueryParamsBinder(c).MustInt("docNum", &docNum).BindError()
- })
-
- req := httptest.NewRequest(http.MethodGet, "/doc?docNum=abc", nil)
- rec := httptest.NewRecorder()
- e.ServeHTTP(rec, req)
-
- assert.Equal(t, http.StatusBadRequest, rec.Code)
-
- var body map[string]any
- assert.NoError(t, json.Unmarshal(rec.Body.Bytes(), &body))
- assert.Equal(t, "docNum", body["field"], "binding error response must retain the field name")
- assert.Equal(t, "failed to bind field value to int", body["message"], "binding error response must retain the binder message")
-}
-
-// When the binding error carries no message, MarshalJSON falls back to the
-// status text (mirroring DefaultHTTPErrorHandler's *HTTPError branch).
-func TestBindingError_marshalJSON_emptyMessageFallsBackToStatusText(t *testing.T) {
- be := &BindingError{Field: "name", HTTPError: &HTTPError{Code: http.StatusBadRequest}}
-
- b, err := be.MarshalJSON()
-
- assert.NoError(t, err)
- assert.JSONEq(t, `{"field":"name","message":"Bad Request"}`, string(b))
-}
diff --git a/binder_external_test.go b/binder_external_test.go
index d83c891b3..e44055a23 100644
--- a/binder_external_test.go
+++ b/binder_external_test.go
@@ -7,19 +7,18 @@ package echo_test
import (
"encoding/base64"
"fmt"
+ "github.com/labstack/echo/v4"
"log"
"net/http"
"net/http/httptest"
-
- "github.com/labstack/echo/v5"
)
func ExampleValueBinder_BindErrors() {
// example route function that binds query params to different destinations and returns all bind errors in one go
- routeFunc := func(c *echo.Context) error {
+ routeFunc := func(c echo.Context) error {
var opts struct {
- IDs []int64
Active bool
+ IDs []int64
}
length := int64(50) // default length is 50
@@ -54,10 +53,10 @@ func ExampleValueBinder_BindErrors() {
func ExampleValueBinder_BindError() {
// example route function that binds query params to different destinations and stops binding on first bind error
- failFastRouteFunc := func(c *echo.Context) error {
+ failFastRouteFunc := func(c echo.Context) error {
var opts struct {
- IDs []int64
Active bool
+ IDs []int64
}
length := int64(50) // default length is 50
@@ -90,7 +89,7 @@ func ExampleValueBinder_BindError() {
func ExampleValueBinder_CustomFunc() {
// example route function that binds query params using custom function closure
- routeFunc := func(c *echo.Context) error {
+ routeFunc := func(c echo.Context) error {
length := int64(50) // default length is 50
var binary []byte
diff --git a/binder_generic.go b/binder_generic.go
index 62e1da512..f4d45af76 100644
--- a/binder_generic.go
+++ b/binder_generic.go
@@ -56,12 +56,13 @@ const (
// To treat empty values as errors, validate the result separately or check the raw value.
//
// See ParseValue for supported types and options
-func PathParam[T any](c *Context, paramName string, opts ...any) (T, error) {
- for _, pv := range c.PathValues() {
- if pv.Name == paramName {
- v, err := ParseValue[T](pv.Value, opts...)
+func PathParam[T any](c Context, paramName string, opts ...any) (T, error) {
+ for i, name := range c.ParamNames() {
+ if name == paramName {
+ pValues := c.ParamValues()
+ v, err := ParseValue[T](pValues[i], opts...)
if err != nil {
- return v, NewBindingError(paramName, []string{pv.Value}, "path value", err)
+ return v, NewBindingError(paramName, []string{pValues[i]}, "path param", err)
}
return v, nil
}
@@ -82,12 +83,13 @@ func PathParam[T any](c *Context, paramName string, opts ...any) (T, error) {
// // If "id" is "abc": returns (0, BindingError)
//
// See ParseValue for supported types and options
-func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...any) (T, error) {
- for _, pv := range c.PathValues() {
- if pv.Name == paramName {
- v, err := ParseValueOr[T](pv.Value, defaultValue, opts...)
+func PathParamOr[T any](c Context, paramName string, defaultValue T, opts ...any) (T, error) {
+ for i, name := range c.ParamNames() {
+ if name == paramName {
+ pValues := c.ParamValues()
+ v, err := ParseValueOr[T](pValues[i], defaultValue, opts...)
if err != nil {
- return v, NewBindingError(paramName, []string{pv.Value}, "path value", err)
+ return v, NewBindingError(paramName, []string{pValues[i]}, "path param", err)
}
return v, nil
}
@@ -111,7 +113,7 @@ func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...an
// - Invalid value (?key=abc for int): returns (zero, BindingError)
//
// See ParseValue for supported types and options
-func QueryParam[T any](c *Context, key string, opts ...any) (T, error) {
+func QueryParam[T any](c Context, key string, opts ...any) (T, error) {
values, ok := c.QueryParams()[key]
if !ok {
var zero T
@@ -141,7 +143,7 @@ func QueryParam[T any](c *Context, key string, opts ...any) (T, error) {
// // If "page" is "abc": returns (1, BindingError)
//
// See ParseValue for supported types and options
-func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) {
+func QueryParamOr[T any](c Context, key string, defaultValue T, opts ...any) (T, error) {
values, ok := c.QueryParams()[key]
if !ok {
return defaultValue, nil
@@ -161,7 +163,7 @@ func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T
// It returns the typed slice and an error if binding any value fails. Returns ErrNonExistentKey if parameter not found.
//
// See ParseValues for supported types and options
-func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error) {
+func QueryParams[T any](c Context, key string, opts ...any) ([]T, error) {
values, ok := c.QueryParams()[key]
if !ok {
return nil, ErrNonExistentKey
@@ -186,7 +188,7 @@ func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error) {
// // If "ids" contains "abc": returns ([], BindingError)
//
// See ParseValues for supported types and options
-func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) {
+func QueryParamsOr[T any](c Context, key string, defaultValue []T, opts ...any) ([]T, error) {
values, ok := c.QueryParams()[key]
if !ok {
return defaultValue, nil
@@ -199,7 +201,7 @@ func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any)
return result, nil
}
-// FormValue extracts and parses a single form value from the request by key.
+// FormParam extracts and parses a single form value from the request by key.
// It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found.
//
// Empty String Handling:
@@ -210,11 +212,11 @@ func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any)
// To treat empty values as errors, validate the result separately or check the raw value.
//
// See ParseValue for supported types and options
-func FormValue[T any](c *Context, key string, opts ...any) (T, error) {
- formValues, err := c.FormValues()
+func FormParam[T any](c Context, key string, opts ...any) (T, error) {
+ formValues, err := c.FormParams()
if err != nil {
var zero T
- return zero, fmt.Errorf("failed to parse form value, key: %s, err: %w", key, err)
+ return zero, fmt.Errorf("failed to parse form param, key: %s, err: %w", key, err)
}
values, ok := formValues[key]
if !ok {
@@ -228,12 +230,12 @@ func FormValue[T any](c *Context, key string, opts ...any) (T, error) {
value := values[0]
v, err := ParseValue[T](value, opts...)
if err != nil {
- return v, NewBindingError(key, []string{value}, "form value", err)
+ return v, NewBindingError(key, []string{value}, "form param", err)
}
return v, nil
}
-// FormValueOr extracts and parses a single form value from the request by key.
+// FormParamOr extracts and parses a single form value from the request by key.
// Returns defaultValue if the parameter is not found or has an empty value.
// Returns an error only if parsing fails or form parsing errors occur.
//
@@ -245,11 +247,11 @@ func FormValue[T any](c *Context, key string, opts ...any) (T, error) {
// // If "limit" is "abc": returns (100, BindingError)
//
// See ParseValue for supported types and options
-func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) {
- formValues, err := c.FormValues()
+func FormParamOr[T any](c Context, key string, defaultValue T, opts ...any) (T, error) {
+ formValues, err := c.FormParams()
if err != nil {
var zero T
- return zero, fmt.Errorf("failed to parse form value, key: %s, err: %w", key, err)
+ return zero, fmt.Errorf("failed to parse form param, key: %s, err: %w", key, err)
}
values, ok := formValues[key]
if !ok {
@@ -261,19 +263,19 @@ func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T,
value := values[0]
v, err := ParseValueOr[T](value, defaultValue, opts...)
if err != nil {
- return v, NewBindingError(key, []string{value}, "form value", err)
+ return v, NewBindingError(key, []string{value}, "form param", err)
}
return v, nil
}
-// FormValues extracts and parses all values for a form values key as a slice.
+// FormParams extracts and parses all values for a form values key as a slice.
// It returns the typed slice and an error if binding any value fails. Returns ErrNonExistentKey if parameter not found.
//
// See ParseValues for supported types and options
-func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) {
- formValues, err := c.FormValues()
+func FormParams[T any](c Context, key string, opts ...any) ([]T, error) {
+ formValues, err := c.FormParams()
if err != nil {
- return nil, fmt.Errorf("failed to parse form values, key: %s, err: %w", key, err)
+ return nil, fmt.Errorf("failed to parse form params, key: %s, err: %w", key, err)
}
values, ok := formValues[key]
if !ok {
@@ -281,26 +283,26 @@ func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) {
}
result, err := ParseValues[T](values, opts...)
if err != nil {
- return nil, NewBindingError(key, values, "form values", err)
+ return nil, NewBindingError(key, values, "form params", err)
}
return result, nil
}
-// FormValuesOr extracts and parses all values for a form values key as a slice.
+// FormParamsOr extracts and parses all values for a form values key as a slice.
// Returns defaultValue if the parameter is not found.
// Returns an error only if parsing any value fails or form parsing errors occur.
//
// Example:
//
-// tags, err := echo.FormValuesOr[string](c, "tags", []string{})
+// tags, err := echo.FormParamsOr[string](c, "tags", []string{})
// // If "tags" is missing: returns ([], nil)
// // If form parsing fails: returns (nil, error)
//
// See ParseValues for supported types and options
-func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) {
- formValues, err := c.FormValues()
+func FormParamsOr[T any](c Context, key string, defaultValue []T, opts ...any) ([]T, error) {
+ formValues, err := c.FormParams()
if err != nil {
- return nil, fmt.Errorf("failed to parse form values, key: %s, err: %w", key, err)
+ return nil, fmt.Errorf("failed to parse form params, key: %s, err: %w", key, err)
}
values, ok := formValues[key]
if !ok {
@@ -308,7 +310,7 @@ func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any)
}
result, err := ParseValuesOr[T](values, defaultValue, opts...)
if err != nil {
- return nil, NewBindingError(key, values, "form values", err)
+ return nil, NewBindingError(key, values, "form params", err)
}
return result, nil
}
diff --git a/binder_generic_test.go b/binder_generic_test.go
index 849d75962..96dfc5ed8 100644
--- a/binder_generic_test.go
+++ b/binder_generic_test.go
@@ -64,16 +64,15 @@ func TestPathParam(t *testing.T) {
name: "nok, invalid value",
givenValue: "can_parse_me",
expect: false,
- expectErr: `code=400, message=path value, err=failed to parse value, err: strconv.ParseBool: parsing "can_parse_me": invalid syntax, field=key`,
+ expectErr: `code=400, message=path param, internal=failed to parse value, err: strconv.ParseBool: parsing "can_parse_me": invalid syntax, field=key`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- c := NewContext(nil, nil)
- c.SetPathValues(PathValues{{
- Name: cmp.Or(tc.givenKey, "key"),
- Value: tc.givenValue,
- }})
+ e := New()
+ c := e.NewContext(nil, nil)
+ c.SetParamNames(cmp.Or(tc.givenKey, "key"))
+ c.SetParamValues(tc.givenValue)
v, err := PathParam[bool](c, "key")
if tc.expectErr != "" {
@@ -87,12 +86,14 @@ func TestPathParam(t *testing.T) {
}
func TestPathParam_UnsupportedType(t *testing.T) {
- c := NewContext(nil, nil)
- c.SetPathValues(PathValues{{Name: "key", Value: "true"}})
+ e := New()
+ c := e.NewContext(nil, nil)
+ c.SetParamNames("key")
+ c.SetParamValues("true")
v, err := PathParam[[]bool](c, "key")
- expectErr := "code=400, message=path value, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
+ expectErr := "code=400, message=path param, internal=failed to parse value, err: unsupported value type: *[]bool, field=key"
assert.EqualError(t, err, expectErr)
assert.Equal(t, []bool(nil), v)
}
@@ -119,13 +120,14 @@ func TestQueryParam(t *testing.T) {
name: "nok, invalid value",
givenURL: "/?key=invalidbool",
expect: false,
- expectErr: `code=400, message=query param, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
+ expectErr: `code=400, message=query param, internal=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
- c := NewContext(req, nil)
+ e := New()
+ c := e.NewContext(req, nil)
v, err := QueryParam[bool](c, "key")
if tc.expectErr != "" {
@@ -140,11 +142,12 @@ func TestQueryParam(t *testing.T) {
func TestQueryParam_UnsupportedType(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil)
- c := NewContext(req, nil)
+ e := New()
+ c := e.NewContext(req, nil)
v, err := QueryParam[[]bool](c, "key")
- expectErr := "code=400, message=query param, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
+ expectErr := "code=400, message=query param, internal=failed to parse value, err: unsupported value type: *[]bool, field=key"
assert.EqualError(t, err, expectErr)
assert.Equal(t, []bool(nil), v)
}
@@ -171,13 +174,14 @@ func TestQueryParams(t *testing.T) {
name: "nok, invalid value",
givenURL: "/?key=true&key=invalidbool",
expect: []bool(nil),
- expectErr: `code=400, message=query params, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
+ expectErr: `code=400, message=query params, internal=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
- c := NewContext(req, nil)
+ e := New()
+ c := e.NewContext(req, nil)
v, err := QueryParams[bool](c, "key")
if tc.expectErr != "" {
@@ -192,11 +196,12 @@ func TestQueryParams(t *testing.T) {
func TestQueryParams_UnsupportedType(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil)
- c := NewContext(req, nil)
+ e := New()
+ c := e.NewContext(req, nil)
v, err := QueryParams[[]bool](c, "key")
- expectErr := "code=400, message=query params, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
+ expectErr := "code=400, message=query params, internal=failed to parse value, err: unsupported value type: *[]bool, field=key"
assert.EqualError(t, err, expectErr)
assert.Equal(t, [][]bool(nil), v)
}
@@ -223,15 +228,16 @@ func TestFormValue(t *testing.T) {
name: "nok, invalid value",
givenURL: "/?key=invalidbool",
expect: false,
- expectErr: `code=400, message=form value, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
+ expectErr: `code=400, message=form param, internal=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
- c := NewContext(req, nil)
+ e := New()
+ c := e.NewContext(req, nil)
- v, err := FormValue[bool](c, "key")
+ v, err := FormParam[bool](c, "key")
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
@@ -244,11 +250,12 @@ func TestFormValue(t *testing.T) {
func TestFormValue_UnsupportedType(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil)
- c := NewContext(req, nil)
+ e := New()
+ c := e.NewContext(req, nil)
- v, err := FormValue[[]bool](c, "key")
+ v, err := FormParam[[]bool](c, "key")
- expectErr := "code=400, message=form value, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
+ expectErr := "code=400, message=form param, internal=failed to parse value, err: unsupported value type: *[]bool, field=key"
assert.EqualError(t, err, expectErr)
assert.Equal(t, []bool(nil), v)
}
@@ -275,15 +282,16 @@ func TestFormValues(t *testing.T) {
name: "nok, invalid value",
givenURL: "/?key=true&key=invalidbool",
expect: []bool(nil),
- expectErr: `code=400, message=form values, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
+ expectErr: `code=400, message=form params, internal=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
- c := NewContext(req, nil)
+ e := New()
+ c := e.NewContext(req, nil)
- v, err := FormValues[bool](c, "key")
+ v, err := FormParams[bool](c, "key")
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
@@ -296,11 +304,12 @@ func TestFormValues(t *testing.T) {
func TestFormValues_UnsupportedType(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil)
- c := NewContext(req, nil)
+ e := New()
+ c := e.NewContext(req, nil)
- v, err := FormValues[[]bool](c, "key")
+ v, err := FormParams[[]bool](c, "key")
- expectErr := "code=400, message=form values, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
+ expectErr := "code=400, message=form params, internal=failed to parse value, err: unsupported value type: *[]bool, field=key"
assert.EqualError(t, err, expectErr)
assert.Equal(t, [][]bool(nil), v)
}
@@ -1424,13 +1433,15 @@ func TestPathParamOr(t *testing.T) {
givenKey: "id",
givenValue: "invalid",
defaultValue: 999,
- expectErr: "code=400, message=path value, err=failed to parse value",
+ expectErr: "code=400, message=path param, internal=failed to parse value",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- c := NewContext(nil, nil)
- c.SetPathValues(PathValues{{Name: tc.givenKey, Value: tc.givenValue}})
+ e := New()
+ c := e.NewContext(nil, nil)
+ c.SetParamNames(tc.givenKey)
+ c.SetParamValues(tc.givenValue)
v, err := PathParamOr[int](c, "id", tc.defaultValue)
if tc.expectErr != "" {
@@ -1479,7 +1490,8 @@ func TestQueryParamOr(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
- c := NewContext(req, nil)
+ e := New()
+ c := e.NewContext(req, nil)
v, err := QueryParamOr[int](c, "key", tc.defaultValue)
if tc.expectErr != "" {
@@ -1522,7 +1534,8 @@ func TestQueryParamsOr(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
- c := NewContext(req, nil)
+ e := New()
+ c := e.NewContext(req, nil)
v, err := QueryParamsOr[int](c, "key", tc.defaultValue)
if tc.expectErr != "" {
@@ -1565,9 +1578,10 @@ func TestFormValueOr(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
- c := NewContext(req, nil)
+ e := New()
+ c := e.NewContext(req, nil)
- v, err := FormValueOr[string](c, "name", tc.defaultValue)
+ v, err := FormParamOr[string](c, "name", tc.defaultValue)
if tc.expectErr != "" {
assert.ErrorContains(t, err, tc.expectErr)
} else {
@@ -1602,9 +1616,10 @@ func TestFormValuesOr(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
- c := NewContext(req, nil)
+ e := New()
+ c := e.NewContext(req, nil)
- v, err := FormValuesOr[string](c, "tags", tc.defaultValue)
+ v, err := FormParamsOr[string](c, "tags", tc.defaultValue)
if tc.expectErr != "" {
assert.ErrorContains(t, err, tc.expectErr)
} else {
diff --git a/binder_test.go b/binder_test.go
index 8eced8208..31d13a02c 100644
--- a/binder_test.go
+++ b/binder_test.go
@@ -18,7 +18,7 @@ import (
"time"
)
-func createTestContext(URL string, body io.Reader, pathValues map[string]string) *Context {
+func createTestContext(URL string, body io.Reader, pathParams map[string]string) Context {
e := New()
req := httptest.NewRequest(http.MethodGet, URL, body)
if body != nil {
@@ -27,15 +27,15 @@ func createTestContext(URL string, body io.Reader, pathValues map[string]string)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- if len(pathValues) > 0 {
- params := make(PathValues, 0)
- for name, value := range pathValues {
- params = append(params, PathValue{
- Name: name,
- Value: value,
- })
+ if len(pathParams) > 0 {
+ names := make([]string, 0)
+ values := make([]string, 0)
+ for name, value := range pathParams {
+ names = append(names, name)
+ values = append(values, value)
}
- c.SetPathValues(params)
+ c.SetParamNames(names...)
+ c.SetParamValues(values...)
}
return c
@@ -43,12 +43,12 @@ func createTestContext(URL string, body io.Reader, pathValues map[string]string)
func TestBindingError_Error(t *testing.T) {
err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error"))
- assert.EqualError(t, err, `code=400, message=bind failed, err=internal error, field=id`)
+ assert.EqualError(t, err, `code=400, message=bind failed, internal=internal error, field=id`)
bErr := err.(*BindingError)
assert.Equal(t, 400, bErr.Code)
assert.Equal(t, "bind failed", bErr.Message)
- assert.Equal(t, errors.New("internal error"), bErr.err)
+ assert.Equal(t, errors.New("internal error"), bErr.Internal)
assert.Equal(t, "id", bErr.Field)
assert.Equal(t, []string{"1", "nope"}, bErr.Values)
@@ -62,13 +62,13 @@ func TestBindingError_ErrorJSON(t *testing.T) {
assert.Equal(t, `{"field":"id","message":"bind failed"}`, string(resp))
}
-func TestPathValuesBinder(t *testing.T) {
+func TestPathParamsBinder(t *testing.T) {
c := createTestContext("/api/user/999", nil, map[string]string{
"id": "1",
"nr": "2",
"slice": "3",
})
- b := PathValuesBinder(c)
+ b := PathParamsBinder(c)
id := int64(99)
nr := int64(88)
@@ -91,15 +91,15 @@ func TestQueryParamsBinder_FailFast(t *testing.T) {
var testCases = []struct {
name string
whenURL string
- expectError []string
givenFailFast bool
+ expectError []string
}{
{
name: "ok, FailFast=true stops at first error",
whenURL: "/api/user/999?nr=en&id=nope",
givenFailFast: true,
expectError: []string{
- `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "nope": invalid syntax, field=id`,
+ `code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing "nope": invalid syntax, field=id`,
},
},
{
@@ -107,8 +107,8 @@ func TestQueryParamsBinder_FailFast(t *testing.T) {
whenURL: "/api/user/999?nr=en&id=nope",
givenFailFast: false,
expectError: []string{
- `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "nope": invalid syntax, field=id`,
- `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "en": invalid syntax, field=nr`,
+ `code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing "nope": invalid syntax, field=id`,
+ `code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing "en": invalid syntax, field=nr`,
},
},
}
@@ -165,7 +165,7 @@ func TestFormFieldBinder(t *testing.T) {
}
func TestValueBinder_errorStopsBinding(t *testing.T) {
- // this test documents "feature" that binding multiple params can change destination if it was bound before
+ // this test documents "feature" that binding multiple params can change destination if it was binded before
// failing parameter binding
c := createTestContext("/api/user/999?id=1&nr=nope", nil, nil)
@@ -177,7 +177,7 @@ func TestValueBinder_errorStopsBinding(t *testing.T) {
Int64("nr", &nr).
BindError()
- assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=nr")
+ assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=nr")
assert.Equal(t, int64(1), id)
assert.Equal(t, int64(88), nr)
}
@@ -192,17 +192,17 @@ func TestValueBinder_BindError(t *testing.T) {
Int64("nr", &nr).
BindError()
- assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=id")
+ assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=id")
assert.Nil(t, b.errors)
assert.Nil(t, b.BindError())
}
func TestValueBinder_GetValues(t *testing.T) {
var testCases = []struct {
- whenValuesFunc func(sourceParam string) []string
name string
- expectError string
+ whenValuesFunc func(sourceParam string) []string
expect []int64
+ expectError string
}{
{
name: "ok, default implementation",
@@ -266,13 +266,13 @@ func TestValueBinder_CustomFuncWithError(t *testing.T) {
func TestValueBinder_CustomFunc(t *testing.T) {
var testCases = []struct {
- expectValue any
name string
- whenURL string
+ givenFailFast bool
givenFuncErrors []error
+ whenURL string
expectParamValues []string
+ expectValue any
expectErrors []string
- givenFailFast bool
}{
{
name: "ok, binds value",
@@ -341,13 +341,13 @@ func TestValueBinder_CustomFunc(t *testing.T) {
func TestValueBinder_MustCustomFunc(t *testing.T) {
var testCases = []struct {
- expectValue any
name string
- whenURL string
+ givenFailFast bool
givenFuncErrors []error
+ whenURL string
expectParamValues []string
+ expectValue any
expectErrors []string
- givenFailFast bool
}{
{
name: "ok, binds value",
@@ -418,12 +418,12 @@ func TestValueBinder_MustCustomFunc(t *testing.T) {
func TestValueBinder_String(t *testing.T) {
var testCases = []struct {
name string
+ givenFailFast bool
+ givenBindErrors []error
whenURL string
+ whenMust bool
expectValue string
expectError string
- givenBindErrors []error
- givenFailFast bool
- whenMust bool
}{
{
name: "ok, binds value",
@@ -494,12 +494,12 @@ func TestValueBinder_String(t *testing.T) {
func TestValueBinder_Strings(t *testing.T) {
var testCases = []struct {
name string
- whenURL string
- expectError string
- givenBindErrors []error
- expectValue []string
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue []string
+ expectError string
}{
{
name: "ok, binds value",
@@ -570,12 +570,12 @@ func TestValueBinder_Strings(t *testing.T) {
func TestValueBinder_Int64_intValue(t *testing.T) {
var testCases = []struct {
name string
- whenURL string
- expectError string
- givenBindErrors []error
- expectValue int64
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue int64
+ expectError string
}{
{
name: "ok, binds value",
@@ -598,7 +598,7 @@ func TestValueBinder_Int64_intValue(t *testing.T) {
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: 99,
- expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
@@ -626,7 +626,7 @@ func TestValueBinder_Int64_intValue(t *testing.T) {
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: 99,
- expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
}
@@ -667,19 +667,19 @@ func TestValueBinder_Int_errorMessage(t *testing.T) {
assert.Equal(t, 99, destInt)
assert.Equal(t, uint(98), destUint)
- assert.EqualError(t, errs[0], `code=400, message=failed to bind field value to int, err=strconv.ParseInt: parsing "nope": invalid syntax, field=param`)
- assert.EqualError(t, errs[1], `code=400, message=failed to bind field value to uint, err=strconv.ParseUint: parsing "nope": invalid syntax, field=param`)
+ assert.EqualError(t, errs[0], `code=400, message=failed to bind field value to int, internal=strconv.ParseInt: parsing "nope": invalid syntax, field=param`)
+ assert.EqualError(t, errs[1], `code=400, message=failed to bind field value to uint, internal=strconv.ParseUint: parsing "nope": invalid syntax, field=param`)
}
func TestValueBinder_Uint64_uintValue(t *testing.T) {
var testCases = []struct {
name string
- whenURL string
- expectError string
- givenBindErrors []error
- expectValue uint64
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue uint64
+ expectError string
}{
{
name: "ok, binds value",
@@ -702,7 +702,7 @@ func TestValueBinder_Uint64_uintValue(t *testing.T) {
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: 99,
- expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to uint64, internal=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
@@ -730,7 +730,7 @@ func TestValueBinder_Uint64_uintValue(t *testing.T) {
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: 99,
- expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to uint64, internal=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
},
}
@@ -881,12 +881,12 @@ func TestValueBinder_Int_Types(t *testing.T) {
func TestValueBinder_Int64s_intsValue(t *testing.T) {
var testCases = []struct {
name string
- whenURL string
- expectError string
- givenBindErrors []error
- expectValue []int64
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue []int64
+ expectError string
}{
{
name: "ok, binds value",
@@ -909,7 +909,7 @@ func TestValueBinder_Int64s_intsValue(t *testing.T) {
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: []int64{99},
- expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
@@ -937,7 +937,7 @@ func TestValueBinder_Int64s_intsValue(t *testing.T) {
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: []int64{99},
- expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
}
@@ -970,12 +970,12 @@ func TestValueBinder_Int64s_intsValue(t *testing.T) {
func TestValueBinder_Uint64s_uintsValue(t *testing.T) {
var testCases = []struct {
name string
- whenURL string
- expectError string
- givenBindErrors []error
- expectValue []uint64
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue []uint64
+ expectError string
}{
{
name: "ok, binds value",
@@ -998,7 +998,7 @@ func TestValueBinder_Uint64s_uintsValue(t *testing.T) {
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: []uint64{99},
- expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to uint64, internal=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
@@ -1026,7 +1026,7 @@ func TestValueBinder_Uint64s_uintsValue(t *testing.T) {
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: []uint64{99},
- expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to uint64, internal=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
},
}
@@ -1169,7 +1169,7 @@ func TestValueBinder_Ints_Types(t *testing.T) {
func TestValueBinder_Ints_Types_FailFast(t *testing.T) {
// FailFast() should stop parsing and return early
- errTmpl := "code=400, message=failed to bind field value to %v, err=strconv.Parse%v: parsing \"nope\": invalid syntax, field=param"
+ errTmpl := "code=400, message=failed to bind field value to %v, internal=strconv.Parse%v: parsing \"nope\": invalid syntax, field=param"
c := createTestContext("/search?param=1¶m=nope¶m=2", nil, nil)
var dest64 []int64
@@ -1226,12 +1226,12 @@ func TestValueBinder_Ints_Types_FailFast(t *testing.T) {
func TestValueBinder_Bool(t *testing.T) {
var testCases = []struct {
name string
- whenURL string
- expectError string
- givenBindErrors []error
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
expectValue bool
+ expectError string
}{
{
name: "ok, binds value",
@@ -1254,7 +1254,7 @@ func TestValueBinder_Bool(t *testing.T) {
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: false,
- expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
@@ -1282,7 +1282,7 @@ func TestValueBinder_Bool(t *testing.T) {
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: false,
- expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
},
}
@@ -1315,12 +1315,12 @@ func TestValueBinder_Bool(t *testing.T) {
func TestValueBinder_Bools(t *testing.T) {
var testCases = []struct {
name string
- whenURL string
- expectError string
- givenBindErrors []error
- expectValue []bool
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue []bool
+ expectError string
}{
{
name: "ok, binds value",
@@ -1344,14 +1344,14 @@ func TestValueBinder_Bools(t *testing.T) {
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=true¶m=nope¶m=100",
expectValue: []bool(nil),
- expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
},
{
name: "nok, conversion fails fast, value is not changed",
givenFailFast: true,
whenURL: "/search?param=true¶m=nope¶m=100",
expectValue: []bool(nil),
- expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
@@ -1380,7 +1380,7 @@ func TestValueBinder_Bools(t *testing.T) {
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: []bool(nil),
- expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
},
}
@@ -1411,12 +1411,12 @@ func TestValueBinder_Bools(t *testing.T) {
func TestValueBinder_Float64(t *testing.T) {
var testCases = []struct {
name string
- whenURL string
- expectError string
- givenBindErrors []error
- expectValue float64
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue float64
+ expectError string
}{
{
name: "ok, binds value",
@@ -1439,7 +1439,7 @@ func TestValueBinder_Float64(t *testing.T) {
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: 1.123,
- expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
@@ -1467,7 +1467,7 @@ func TestValueBinder_Float64(t *testing.T) {
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: 1.123,
- expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
}
@@ -1500,12 +1500,12 @@ func TestValueBinder_Float64(t *testing.T) {
func TestValueBinder_Float64s(t *testing.T) {
var testCases = []struct {
name string
- whenURL string
- expectError string
- givenBindErrors []error
- expectValue []float64
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue []float64
+ expectError string
}{
{
name: "ok, binds value",
@@ -1529,14 +1529,14 @@ func TestValueBinder_Float64s(t *testing.T) {
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: []float64(nil),
- expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
{
name: "nok, conversion fails fast, value is not changed",
givenFailFast: true,
whenURL: "/search?param=0¶m=nope¶m=100",
expectValue: []float64(nil),
- expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
@@ -1565,7 +1565,7 @@ func TestValueBinder_Float64s(t *testing.T) {
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: []float64(nil),
- expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
}
@@ -1596,12 +1596,12 @@ func TestValueBinder_Float64s(t *testing.T) {
func TestValueBinder_Float32(t *testing.T) {
var testCases = []struct {
name string
- whenURL string
- expectError string
- givenBindErrors []error
- expectValue float32
givenNoFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue float32
+ expectError string
}{
{
name: "ok, binds value",
@@ -1624,7 +1624,7 @@ func TestValueBinder_Float32(t *testing.T) {
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: 1.123,
- expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
@@ -1652,7 +1652,7 @@ func TestValueBinder_Float32(t *testing.T) {
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: 1.123,
- expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
}
@@ -1685,12 +1685,12 @@ func TestValueBinder_Float32(t *testing.T) {
func TestValueBinder_Float32s(t *testing.T) {
var testCases = []struct {
name string
- whenURL string
- expectError string
- givenBindErrors []error
- expectValue []float32
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue []float32
+ expectError string
}{
{
name: "ok, binds value",
@@ -1714,14 +1714,14 @@ func TestValueBinder_Float32s(t *testing.T) {
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: []float32(nil),
- expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
{
name: "nok, conversion fails fast, value is not changed",
givenFailFast: true,
whenURL: "/search?param=0¶m=nope¶m=100",
expectValue: []float32(nil),
- expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
@@ -1750,7 +1750,7 @@ func TestValueBinder_Float32s(t *testing.T) {
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: []float32(nil),
- expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
}
@@ -1781,14 +1781,14 @@ func TestValueBinder_Float32s(t *testing.T) {
func TestValueBinder_Time(t *testing.T) {
exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00")
var testCases = []struct {
- expectValue time.Time
name string
+ givenFailFast bool
+ givenBindErrors []error
whenURL string
+ whenMust bool
whenLayout string
+ expectValue time.Time
expectError string
- givenBindErrors []error
- givenFailFast bool
- whenMust bool
}{
{
name: "ok, binds value",
@@ -1863,13 +1863,13 @@ func TestValueBinder_Times(t *testing.T) {
exampleTime2, _ := time.Parse(time.RFC3339, "2000-01-02T09:45:31+00:00")
var testCases = []struct {
name string
+ givenFailFast bool
+ givenBindErrors []error
whenURL string
+ whenMust bool
whenLayout string
- expectError string
- givenBindErrors []error
expectValue []time.Time
- givenFailFast bool
- whenMust bool
+ expectError string
}{
{
name: "ok, binds value",
@@ -1948,12 +1948,12 @@ func TestValueBinder_Duration(t *testing.T) {
example := 42 * time.Second
var testCases = []struct {
name string
- whenURL string
- expectError string
- givenBindErrors []error
- expectValue time.Duration
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue time.Duration
+ expectError string
}{
{
name: "ok, binds value",
@@ -2026,12 +2026,12 @@ func TestValueBinder_Durations(t *testing.T) {
exampleDuration2 := 1 * time.Millisecond
var testCases = []struct {
name string
- whenURL string
- expectError string
- givenBindErrors []error
- expectValue []time.Duration
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue []time.Duration
+ expectError string
}{
{
name: "ok, binds value",
@@ -2103,13 +2103,13 @@ func TestValueBinder_BindUnmarshaler(t *testing.T) {
exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00")
var testCases = []struct {
- expectValue Timestamp
name string
- whenURL string
- expectError string
- givenBindErrors []error
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue Timestamp
+ expectError string
}{
{
name: "ok, binds value",
@@ -2132,7 +2132,7 @@ func TestValueBinder_BindUnmarshaler(t *testing.T) {
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: Timestamp{},
- expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
+ expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
},
{
name: "ok (must), binds value",
@@ -2160,7 +2160,7 @@ func TestValueBinder_BindUnmarshaler(t *testing.T) {
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: Timestamp{},
- expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
+ expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
},
}
@@ -2195,12 +2195,12 @@ func TestValueBinder_JSONUnmarshaler(t *testing.T) {
var testCases = []struct {
name string
- whenURL string
- expectError string
- expectValue big.Int
- givenBindErrors []error
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue big.Int
+ expectError string
}{
{
name: "ok, binds value",
@@ -2223,7 +2223,7 @@ func TestValueBinder_JSONUnmarshaler(t *testing.T) {
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=xxx",
expectValue: big.Int{},
- expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
+ expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
},
{
name: "ok (must), binds value",
@@ -2251,7 +2251,7 @@ func TestValueBinder_JSONUnmarshaler(t *testing.T) {
whenMust: true,
whenURL: "/search?param=nope¶m=xxx",
expectValue: big.Int{},
- expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
+ expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
},
}
@@ -2286,12 +2286,12 @@ func TestValueBinder_TextUnmarshaler(t *testing.T) {
var testCases = []struct {
name string
- whenURL string
- expectError string
- expectValue big.Int
- givenBindErrors []error
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue big.Int
+ expectError string
}{
{
name: "ok, binds value",
@@ -2314,7 +2314,7 @@ func TestValueBinder_TextUnmarshaler(t *testing.T) {
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=xxx",
expectValue: big.Int{},
- expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
+ expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
},
{
name: "ok (must), binds value",
@@ -2342,7 +2342,7 @@ func TestValueBinder_TextUnmarshaler(t *testing.T) {
whenMust: true,
whenURL: "/search?param=nope¶m=xxx",
expectValue: big.Int{},
- expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
+ expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
},
}
@@ -2374,9 +2374,9 @@ func TestValueBinder_TextUnmarshaler(t *testing.T) {
func TestValueBinder_BindWithDelimiter_types(t *testing.T) {
var testCases = []struct {
- expect any
name string
whenURL string
+ expect any
}{
{
name: "ok, strings",
@@ -2522,12 +2522,12 @@ func TestValueBinder_BindWithDelimiter_types(t *testing.T) {
func TestValueBinder_BindWithDelimiter(t *testing.T) {
var testCases = []struct {
name string
- whenURL string
- expectError string
- givenBindErrors []error
- expectValue []int64
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue []int64
+ expectError string
}{
{
name: "ok, binds value",
@@ -2550,7 +2550,7 @@ func TestValueBinder_BindWithDelimiter(t *testing.T) {
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: []int64(nil),
- expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
@@ -2578,7 +2578,7 @@ func TestValueBinder_BindWithDelimiter(t *testing.T) {
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: []int64(nil),
- expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
}
@@ -2621,13 +2621,13 @@ func TestBindWithDelimiter_invalidType(t *testing.T) {
func TestValueBinder_UnixTime(t *testing.T) {
exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43+00:00") // => 1609180603
var testCases = []struct {
- expectValue time.Time
name string
- whenURL string
- expectError string
- givenBindErrors []error
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue time.Time
+ expectError string
}{
{
name: "ok, binds value, unix time in seconds",
@@ -2655,7 +2655,7 @@ func TestValueBinder_UnixTime(t *testing.T) {
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: time.Time{},
- expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
@@ -2683,7 +2683,7 @@ func TestValueBinder_UnixTime(t *testing.T) {
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: time.Time{},
- expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
}
@@ -2717,13 +2717,13 @@ func TestValueBinder_UnixTime(t *testing.T) {
func TestValueBinder_UnixTimeMilli(t *testing.T) {
exampleTime, _ := time.Parse(time.RFC3339Nano, "2022-03-13T15:13:30.140000000+00:00") // => 1647184410140
var testCases = []struct {
- expectValue time.Time
name string
- whenURL string
- expectError string
- givenBindErrors []error
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue time.Time
+ expectError string
}{
{
name: "ok, binds value, unix time in milliseconds",
@@ -2746,7 +2746,7 @@ func TestValueBinder_UnixTimeMilli(t *testing.T) {
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: time.Time{},
- expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
@@ -2774,7 +2774,7 @@ func TestValueBinder_UnixTimeMilli(t *testing.T) {
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: time.Time{},
- expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
}
@@ -2810,13 +2810,13 @@ func TestValueBinder_UnixTimeNano(t *testing.T) {
exampleTimeNano, _ := time.Parse(time.RFC3339Nano, "2020-12-28T18:36:43.123456789+00:00") // => 1609180603123456789
exampleTimeNanoBelowSec, _ := time.Parse(time.RFC3339Nano, "1970-01-01T00:00:00.999999999+00:00")
var testCases = []struct {
- expectValue time.Time
name string
- whenURL string
- expectError string
- givenBindErrors []error
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue time.Time
+ expectError string
}{
{
name: "ok, binds value, unix time in nano seconds (sec precision)",
@@ -2849,7 +2849,7 @@ func TestValueBinder_UnixTimeNano(t *testing.T) {
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: time.Time{},
- expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
@@ -2877,7 +2877,7 @@ func TestValueBinder_UnixTimeNano(t *testing.T) {
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: time.Time{},
- expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
}
@@ -2919,7 +2919,7 @@ func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) {
binder := new(DefaultBinder)
for i := 0; i < b.N; i++ {
var dest Opts
- _ = binder.Bind(c, &dest)
+ _ = binder.Bind(&dest, c)
}
}
@@ -2967,16 +2967,17 @@ func BenchmarkRawFunc_Int64_single(b *testing.B) {
func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) {
type Opts struct {
- String string `query:"string"`
- Strings []string `query:"strings"`
- Int64 int64 `query:"int64"`
+ Int64 int64 `query:"int64"`
+ Int32 int32 `query:"int32"`
+ Int16 int16 `query:"int16"`
+ Int8 int8 `query:"int8"`
+ String string `query:"string"`
+
Uint64 uint64 `query:"uint64"`
- Int32 int32 `query:"int32"`
Uint32 uint32 `query:"uint32"`
- Int16 int16 `query:"int16"`
Uint16 uint16 `query:"uint16"`
- Int8 int8 `query:"int8"`
Uint8 uint8 `query:"uint8"`
+ Strings []string `query:"strings"`
}
c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil)
@@ -2985,7 +2986,7 @@ func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) {
binder := new(DefaultBinder)
for i := 0; i < b.N; i++ {
var dest Opts
- _ = binder.Bind(c, &dest)
+ _ = binder.Bind(&dest, c)
if dest.Int64 != 1 {
b.Fatalf("int64!=1")
}
@@ -2994,16 +2995,17 @@ func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) {
func BenchmarkValueBinder_BindInt64_10_fields(b *testing.B) {
type Opts struct {
- String string `query:"string"`
- Strings []string `query:"strings"`
- Int64 int64 `query:"int64"`
+ Int64 int64 `query:"int64"`
+ Int32 int32 `query:"int32"`
+ Int16 int16 `query:"int16"`
+ Int8 int8 `query:"int8"`
+ String string `query:"string"`
+
Uint64 uint64 `query:"uint64"`
- Int32 int32 `query:"int32"`
Uint32 uint32 `query:"uint32"`
- Int16 int16 `query:"int16"`
Uint16 uint16 `query:"uint16"`
- Int8 int8 `query:"int8"`
Uint8 uint8 `query:"uint8"`
+ Strings []string `query:"strings"`
}
c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil)
@@ -3032,27 +3034,27 @@ func BenchmarkValueBinder_BindInt64_10_fields(b *testing.B) {
func TestValueBinder_TimeError(t *testing.T) {
var testCases = []struct {
- expectValue time.Time
name string
+ givenFailFast bool
+ givenBindErrors []error
whenURL string
+ whenMust bool
whenLayout string
+ expectValue time.Time
expectError string
- givenBindErrors []error
- givenFailFast bool
- whenMust bool
}{
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: time.Time{},
- expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\": extra text: \"nope\", field=param",
+ expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: time.Time{},
- expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\": extra text: \"nope\", field=param",
+ expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param",
},
}
@@ -3085,33 +3087,33 @@ func TestValueBinder_TimeError(t *testing.T) {
func TestValueBinder_TimesError(t *testing.T) {
var testCases = []struct {
name string
+ givenFailFast bool
+ givenBindErrors []error
whenURL string
+ whenMust bool
whenLayout string
- expectError string
- givenBindErrors []error
expectValue []time.Time
- givenFailFast bool
- whenMust bool
+ expectError string
}{
{
name: "nok, fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: []time.Time(nil),
- expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param",
+ expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: []time.Time(nil),
- expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
+ expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: []time.Time(nil),
- expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
+ expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
},
}
@@ -3147,25 +3149,25 @@ func TestValueBinder_TimesError(t *testing.T) {
func TestValueBinder_DurationError(t *testing.T) {
var testCases = []struct {
name string
- whenURL string
- expectError string
- givenBindErrors []error
- expectValue time.Duration
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue time.Duration
+ expectError string
}{
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: 0,
- expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param",
+ expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: 0,
- expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param",
+ expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param",
},
}
@@ -3198,32 +3200,32 @@ func TestValueBinder_DurationError(t *testing.T) {
func TestValueBinder_DurationsError(t *testing.T) {
var testCases = []struct {
name string
- whenURL string
- expectError string
- givenBindErrors []error
- expectValue []time.Duration
givenFailFast bool
+ givenBindErrors []error
+ whenURL string
whenMust bool
+ expectValue []time.Duration
+ expectError string
}{
{
name: "nok, fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: []time.Duration(nil),
- expectError: "code=400, message=failed to bind field value to Duration, err=time: missing unit in duration \"1\", field=param",
+ expectError: "code=400, message=failed to bind field value to Duration, internal=time: missing unit in duration \"1\", field=param",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: []time.Duration(nil),
- expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param",
+ expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: []time.Duration(nil),
- expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param",
+ expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param",
},
}
diff --git a/context.go b/context.go
index 9a7429f39..6500e5eef 100644
--- a/context.go
+++ b/context.go
@@ -6,194 +6,270 @@ package echo
import (
"bytes"
"encoding/xml"
- "errors"
"fmt"
"io"
- "io/fs"
- "log/slog"
"mime/multipart"
"net"
"net/http"
"net/url"
- "path"
- "path/filepath"
"strings"
"sync"
- "unsafe"
)
-// stringToBytes returns a []byte view over s without copying, avoiding the allocation+copy of []byte(s)
-// on the response write path (the zero-copy technique used by fasthttp/fiber).
-//
-// Contract β all of the following must hold at every call site:
-// - The result is read-only: writing through it is undefined behaviour.
-// - The callee must NOT retain the slice beyond the call. It is only passed to the response
-// Writer's Write, whose io.Writer contract forbids retaining/mutating the argument. Note the
-// concrete writer may be a wrapping ResponseWriter (e.g. gzip); such writers must copy, not alias.
-// - s must stay reachable for as long as the slice is used: the slice aliases s's backing array.
-func stringToBytes(s string) []byte {
- if s == "" {
- return nil
- }
- return unsafe.Slice(unsafe.StringData(s), len(s))
-}
-
-// jsonpOpen and jsonpClose are the constant byte wrappers for JSONP payloads, kept as package-level
-// slices to avoid allocating them on every JSONP response.
-var (
- jsonpOpen = []byte("(")
- jsonpClose = []byte(");")
-)
+// Context represents the context of the current HTTP request. It holds request and
+// response objects, path, path parameters, data and registered handler.
+type Context interface {
+ // Request returns `*http.Request`.
+ Request() *http.Request
-const (
- // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain.
- // Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests.
- // It is added to context only when Router does not find matching method handler for request.
- ContextKeyHeaderAllow = "echo_header_allow"
-)
+ // SetRequest sets `*http.Request`.
+ SetRequest(r *http.Request)
-const (
- // defaultMemory is default value for memory limit that is used when
- // parsing multipart forms (See (*http.Request).ParseMultipartForm)
- defaultMemory int64 = 32 << 20 // 32 MB
- indexPage = "index.html"
-)
+ // SetResponse sets `*Response`.
+ SetResponse(r *Response)
-// Context represents the context of the current HTTP request. It holds request and
-// response objects, path, path parameters, data and registered handler.
-type Context struct {
- request *http.Request
- orgResponse *Response
- response http.ResponseWriter
- query url.Values
+ // Response returns `*Response`.
+ Response() *Response
- // formParseMaxMemory is used for http.Request.ParseMultipartForm
- formParseMaxMemory int64
+ // IsTLS returns true if HTTP connection is TLS otherwise false.
+ IsTLS() bool
- route *RouteInfo
- pathValues *PathValues
+ // IsWebSocket returns true if HTTP connection is WebSocket otherwise false.
+ IsWebSocket() bool
- // handler is the route handler resolved during routing. It is invoked by the terminal of the global
- // middleware chain (see Echo.buildRouterChains) so that the chain can be compiled once and reused.
- handler HandlerFunc
+ // Scheme returns the HTTP protocol scheme, `http` or `https`.
+ Scheme() string
- // dsw is reused by json() so that each JSON response does not heap-allocate a delayedStatusWriter.
- // It lives on the pooled Context; &c.dsw is a stable, allocation-free pointer. Only json() may point
- // the response at &c.dsw, and only via the nested-call guard there β aliasing it to itself (wrapping
- // &c.dsw around &c.dsw) would make the response writer reference itself.
- dsw delayedStatusWriter
+ // RealIP returns the client's network address based on `X-Forwarded-For`
+ // or `X-Real-IP` request header.
+ // The behavior can be configured using `Echo#IPExtractor`.
+ RealIP() string
- store map[string]any
- echo *Echo
- logger *slog.Logger
+ // Path returns the registered path for the handler.
+ Path() string
- path string
- lock sync.RWMutex
-}
-
-// NewContext returns a new Context instance.
-//
-// Note: request,response and e can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway
-// these arguments are useful when creating context for tests and cases like that.
-func NewContext(r *http.Request, w http.ResponseWriter, opts ...any) *Context {
- var e *Echo
- for _, opt := range opts {
- switch v := opt.(type) {
- case *Echo:
- e = v
- }
- }
- return newContext(r, w, e)
-}
+ // SetPath sets the registered path for the handler.
+ SetPath(p string)
-func newContext(r *http.Request, w http.ResponseWriter, e *Echo) *Context {
- // store is created lazily by Set and cleared (not freed) by Reset, so we deliberately do not allocate a map here.
- c := &Context{
- pathValues: nil,
- echo: e,
- logger: nil,
- }
- var logger *slog.Logger
- paramLen := int32(0)
- formParseMaxMemory := defaultMemory
- if e != nil {
- paramLen = e.contextPathParamAllocSize.Load()
- logger = e.Logger
- formParseMaxMemory = e.formParseMaxMemory
- }
- if logger == nil {
- logger = slog.Default()
- }
- c.logger = logger
- p := make(PathValues, 0, paramLen)
- c.pathValues = &p
+ // Param returns path parameter by name.
+ Param(name string) string
+
+ // ParamNames returns path parameter names.
+ ParamNames() []string
+
+ // SetParamNames sets path parameter names.
+ SetParamNames(names ...string)
+
+ // ParamValues returns path parameter values.
+ ParamValues() []string
+
+ // SetParamValues sets path parameter values.
+ SetParamValues(values ...string)
+
+ // QueryParam returns the query param for the provided name.
+ QueryParam(name string) string
+
+ // QueryParams returns the query parameters as `url.Values`.
+ QueryParams() url.Values
+
+ // QueryString returns the URL query string.
+ QueryString() string
+
+ // FormValue returns the form field value for the provided name.
+ FormValue(name string) string
+
+ // FormParams returns the form parameters as `url.Values`.
+ FormParams() (url.Values, error)
+
+ // FormFile returns the multipart form file for the provided name.
+ FormFile(name string) (*multipart.FileHeader, error)
+
+ // MultipartForm returns the multipart form.
+ MultipartForm() (*multipart.Form, error)
+
+ // Cookie returns the named cookie provided in the request.
+ Cookie(name string) (*http.Cookie, error)
+
+ // SetCookie adds a `Set-Cookie` header in HTTP response.
+ SetCookie(cookie *http.Cookie)
+
+ // Cookies returns the HTTP cookies sent with the request.
+ Cookies() []*http.Cookie
+
+ // Get retrieves data from the context.
+ Get(key string) any
+
+ // Set saves data in the context.
+ Set(key string, val any)
+
+ // Bind binds path params, query params and the request body into provided type `i`. The default binder
+ // binds body based on Content-Type header.
+ Bind(i any) error
- c.SetRequest(r)
- c.orgResponse = NewResponse(w, logger)
- c.response = c.orgResponse
- c.formParseMaxMemory = formParseMaxMemory
- return c
+ // Validate validates provided `i`. It is usually called after `Context#Bind()`.
+ // Validator must be registered using `Echo#Validator`.
+ Validate(i any) error
+
+ // Render renders a template with data and sends a text/html response with status
+ // code. Renderer must be registered using `Echo.Renderer`.
+ Render(code int, name string, data any) error
+
+ // HTML sends an HTTP response with status code.
+ HTML(code int, html string) error
+
+ // HTMLBlob sends an HTTP blob response with status code.
+ HTMLBlob(code int, b []byte) error
+
+ // String sends a string response with status code.
+ String(code int, s string) error
+
+ // JSON sends a JSON response with status code.
+ JSON(code int, i any) error
+
+ // JSONPretty sends a pretty-print JSON with status code.
+ JSONPretty(code int, i any, indent string) error
+
+ // JSONBlob sends a JSON blob response with status code.
+ JSONBlob(code int, b []byte) error
+
+ // JSONP sends a JSONP response with status code. It uses `callback` to construct
+ // the JSONP payload.
+ JSONP(code int, callback string, i any) error
+
+ // JSONPBlob sends a JSONP blob response with status code. It uses `callback`
+ // to construct the JSONP payload.
+ JSONPBlob(code int, callback string, b []byte) error
+
+ // XML sends an XML response with status code.
+ XML(code int, i any) error
+
+ // XMLPretty sends a pretty-print XML with status code.
+ XMLPretty(code int, i any, indent string) error
+
+ // XMLBlob sends an XML blob response with status code.
+ XMLBlob(code int, b []byte) error
+
+ // Blob sends a blob response with status code and content type.
+ Blob(code int, contentType string, b []byte) error
+
+ // Stream sends a streaming response with status code and content type.
+ Stream(code int, contentType string, r io.Reader) error
+
+ // File sends a response with the content of the file.
+ File(file string) error
+
+ // Attachment sends a response as attachment, prompting client to save the
+ // file.
+ Attachment(file string, name string) error
+
+ // Inline sends a response as inline, opening the file in the browser.
+ Inline(file string, name string) error
+
+ // NoContent sends a response with no body and a status code.
+ NoContent(code int) error
+
+ // Redirect redirects the request to a provided URL with status code.
+ Redirect(code int, url string) error
+
+ // Error invokes the registered global HTTP error handler. Generally used by middleware.
+ // A side-effect of calling global error handler is that now Response has been committed (sent to the client) and
+ // middlewares up in chain can not change Response status code or Response body anymore.
+ //
+ // Avoid using this method in handlers as no middleware will be able to effectively handle errors after that.
+ Error(err error)
+
+ // Handler returns the matched handler by router.
+ Handler() HandlerFunc
+
+ // SetHandler sets the matched handler by router.
+ SetHandler(h HandlerFunc)
+
+ // Logger returns the `Logger` instance.
+ Logger() Logger
+
+ // SetLogger Set the logger
+ SetLogger(l Logger)
+
+ // Echo returns the `Echo` instance.
+ Echo() *Echo
+
+ // Reset resets the context after request completes. It must be called along
+ // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`.
+ // See `Echo#ServeHTTP()`
+ Reset(r *http.Request, w http.ResponseWriter)
}
-// Reset resets the context after request completes. It must be called along
-// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`.
-// See `Echo#ServeHTTP()`
-func (c *Context) Reset(r *http.Request, w http.ResponseWriter) {
- c.request = r
- c.orgResponse.reset(w)
- c.response = c.orgResponse
- c.query = nil
- // clear (rather than nil) keeps the map allocated on the pooled Context so that requests using Set
- // do not allocate a fresh map each time. clear(nil) is a no-op.
- clear(c.store)
- c.logger = c.echo.Logger
-
- c.route = nil
- c.handler = nil
- c.dsw = delayedStatusWriter{}
- c.path = ""
- // NOTE: empty by setting length to 0. PathValues has to have capacity of c.echo.contextPathParamAllocSize at all times
- *c.pathValues = (*c.pathValues)[:0]
+type context struct {
+ logger Logger
+ request *http.Request
+ response *Response
+ query url.Values
+ echo *Echo
+
+ store Map
+ lock sync.RWMutex
+
+ // following fields are set by Router
+ handler HandlerFunc
+
+ // path is route path that Router matched. It is empty string where there is no route match.
+ // Route registered with RouteNotFound is considered as a match and path therefore is not empty.
+ path string
+
+ // Usually echo.Echo is sizing pvalues but there could be user created middlewares that decide to
+ // overwrite parameter by calling SetParamNames + SetParamValues.
+ // When echo.Echo allocated that slice it length/capacity is tied to echo.Echo.maxParam value.
+ //
+ // It is important that pvalues size is always equal or bigger to pnames length.
+ pvalues []string
+
+ // pnames length is tied to param count for the matched route
+ pnames []string
}
-func (c *Context) writeContentType(value string) {
- header := c.response.Header()
+const (
+ // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain.
+ // Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests.
+ // It is added to context only when Router does not find matching method handler for request.
+ ContextKeyHeaderAllow = "echo_header_allow"
+)
+
+const (
+ defaultMemory = 32 << 20 // 32 MB
+ indexPage = "index.html"
+ defaultIndent = " "
+)
+
+func (c *context) writeContentType(value string) {
+ header := c.Response().Header()
if header.Get(HeaderContentType) == "" {
header.Set(HeaderContentType, value)
}
}
-// Request returns `*http.Request`.
-func (c *Context) Request() *http.Request {
+func (c *context) Request() *http.Request {
return c.request
}
-// SetRequest sets `*http.Request`.
-func (c *Context) SetRequest(r *http.Request) {
+func (c *context) SetRequest(r *http.Request) {
c.request = r
}
-// Response returns `*Response`.
-func (c *Context) Response() http.ResponseWriter {
+func (c *context) Response() *Response {
return c.response
}
-// SetResponse sets `*http.ResponseWriter`. Some context methods and/or middleware require that given ResponseWriter implements following
-// method `Unwrap() http.ResponseWriter` which eventually should return *echo.Response instance.
-func (c *Context) SetResponse(r http.ResponseWriter) {
+func (c *context) SetResponse(r *Response) {
c.response = r
}
-// IsTLS returns true if HTTP connection is TLS otherwise false.
-func (c *Context) IsTLS() bool {
+func (c *context) IsTLS() bool {
return c.request.TLS != nil
}
-// IsWebSocket returns true if HTTP connection is WebSocket otherwise false.
-func (c *Context) IsWebSocket() bool {
+func (c *context) IsWebSocket() bool {
upgrade := c.request.Header.Get(HeaderUpgrade)
- connection := c.request.Header.Get(HeaderConnection)
- return strings.EqualFold(upgrade, "websocket") && strings.Contains(strings.ToLower(connection), "upgrade")
+ return strings.EqualFold(upgrade, "websocket")
}
func isValidProto(proto string) bool {
@@ -209,7 +285,7 @@ func isValidProto(proto string) bool {
}
// Scheme returns the HTTP protocol scheme, `http` or `https`.
-func (c *Context) Scheme() string {
+func (c *context) Scheme() string {
// Can't use `r.Request.URL.Scheme`
// See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0
if c.IsTLS() {
@@ -230,188 +306,107 @@ func (c *Context) Scheme() string {
return "http"
}
-// RealIP returns the client IP address using the configured extraction strategy.
-//
-// If Echo#IPExtractor is set, it is used to resolve the client IP from the incoming request (typically via proxy
-// headers such as X-Forwarded-For or X-Real-IP).
-// Look into the `ip.go` file for comments and examples.
-//
-// See:
-// - Echo#ExtractIPFromXFFHeader for `X-Forwarded-For` handling with trust checks
-// - Echo#ExtractIPFromRealIPHeader for `X-Real-IP` handling with trust checks
-// - Echo#LegacyIPExtractor for `v4` compatibility (spoofable, no trust checks built in)
-//
-// If no extractor is configured, RealIP falls back to the request RemoteAddr, returning only the host portion.
-//
-// Notes:
-// - No validation or trust enforcement is performed unless implemented by the configured IPExtractor.
-// - When relying on proxy headers, ensure the application is deployed behind trusted intermediaries to avoid spoofing.
-func (c *Context) RealIP() string {
+func (c *context) RealIP() string {
if c.echo != nil && c.echo.IPExtractor != nil {
return c.echo.IPExtractor(c.request)
}
- // req.RemoteAddr is the IP address of the remote end of the connection, which may be a proxy. It is populated by the
- // http.conn.readRequest() method and uses net.Conn.RemoteAddr().String() which we trust.
+ // Fall back to legacy behavior
+ if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" {
+ i := strings.IndexAny(ip, ",")
+ if i > 0 {
+ xffip := strings.TrimSpace(ip[:i])
+ xffip = strings.TrimPrefix(xffip, "[")
+ xffip = strings.TrimSuffix(xffip, "]")
+ return xffip
+ }
+ return ip
+ }
+ if ip := c.request.Header.Get(HeaderXRealIP); ip != "" {
+ ip = strings.TrimPrefix(ip, "[")
+ ip = strings.TrimSuffix(ip, "]")
+ return ip
+ }
ra, _, _ := net.SplitHostPort(c.request.RemoteAddr)
return ra
}
-// Path returns the registered path for the handler.
-func (c *Context) Path() string {
+func (c *context) Path() string {
return c.path
}
-// SetPath sets the registered path for the handler.
-func (c *Context) SetPath(p string) {
+func (c *context) SetPath(p string) {
c.path = p
}
-// RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route.
-//
-// RouteInfo returns generic "empty" struct for these cases:
-// * Context is accessed before Routing is done. For example inside Pre middlewares (`e.Pre()`)
-// * Router did not find matching route - 404 (route not found)
-// * Router did not find matching route with same method - 405 (method not allowed)
-func (c *Context) RouteInfo() RouteInfo {
- if c.route != nil {
- return c.route.Clone()
+func (c *context) Param(name string) string {
+ for i, n := range c.pnames {
+ if i < len(c.pvalues) {
+ if n == name {
+ return c.pvalues[i]
+ }
+ }
}
- return RouteInfo{}
-}
-
-// Param returns path parameter by name.
-func (c *Context) Param(name string) string {
- return c.pathValues.GetOr(name, "")
+ return ""
}
-// ParamOr returns the path parameter or default value for the provided name.
-//
-// Notes for DefaultRouter implementation:
-// Path parameter could be empty for cases like that:
-// * route `/release-:version/bin` and request URL is `/release-/bin`
-// * route `/api/:version/image.jpg` and request URL is `/api//image.jpg`
-// but not when path parameter is last part of route path
-// * route `/download/file.:ext` will not match request `/download/file.`
-func (c *Context) ParamOr(name, defaultValue string) string {
- return c.pathValues.GetOr(name, defaultValue)
+func (c *context) ParamNames() []string {
+ return c.pnames
}
-// PathValues returns path parameter values.
-func (c *Context) PathValues() PathValues {
- return *c.pathValues
-}
+func (c *context) SetParamNames(names ...string) {
+ c.pnames = names
-// SetPathValues sets path parameters for current request.
-func (c *Context) SetPathValues(pathValues PathValues) {
- if pathValues == nil {
- panic("context SetPathValues called with nil PathValues")
+ l := len(names)
+ if len(c.pvalues) < l {
+ // Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them,
+ // probably those values will be overridden in a Context#SetParamValues
+ newPvalues := make([]string, l)
+ copy(newPvalues, c.pvalues)
+ c.pvalues = newPvalues
}
- c.setPathValues(&pathValues)
-}
-
-// InitializeRoute sets the route related variables of this request to the context.
-func (c *Context) InitializeRoute(ri *RouteInfo, pathValues *PathValues) {
- c.route = ri
- c.path = ri.Path
- c.setPathValues(pathValues)
}
-func (c *Context) setPathValues(pv *PathValues) {
- // Router accesses c.pathValues by index and may resize it to full capacity during routing
- // for that to work without going out-of-bounds we must make sure that c.pathValues slice is not replaced with smaller
- // slice than Router can set when routing Route with maximum amount of parameters.
- pathValues := c.pathValues
- if cap(*c.pathValues) < len(*pv) {
- // normally we should not end up here. pathValues is normally sized to Echo.contextPathParamAllocSize which should not
- // be smaller than anything router knows as maximum path parameter count to be.
- tmp := make(PathValues, len(*pv))
- c.pathValues = &tmp
- pathValues = c.pathValues
- } else if len(*c.pathValues) != len(*pv) {
- *pathValues = (*pathValues)[0:len(*pv)] // resize slice to given params length for copy to work
- }
- copy(*pathValues, *pv)
+func (c *context) ParamValues() []string {
+ return c.pvalues[:len(c.pnames)]
}
-// QueryParam returns the query param for the provided name.
-func (c *Context) QueryParam(name string) string {
- // If the full query map was already built (e.g. by QueryParams), use it. Otherwise look the single
- // key up directly from the raw query, avoiding the url.Values map allocation for the common case of
- // reading only a few params. The result is identical to url.Values.Get on the parsed query.
- if c.query != nil {
- return c.query.Get(name)
+func (c *context) SetParamValues(values ...string) {
+ // NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam (or bigger) at all times
+ // It will break the Router#Find code
+ limit := len(values)
+ if limit > len(c.pvalues) {
+ c.pvalues = make([]string, limit)
}
- return getRawQueryParam(c.request.URL.RawQuery, name)
-}
-
-// getRawQueryParam returns the first value for name parsed directly from a raw URL query string. It
-// matches url.Values.Get over url.ParseQuery output: first match wins, '+' decodes to space, percent
-// escapes are decoded, segments containing ';' are skipped, and pairs whose key or value fail to
-// unescape are skipped. It avoids allocating the full url.Values map for single-key lookups.
-func getRawQueryParam(query, name string) string {
- for query != "" {
- var seg string
- seg, query, _ = strings.Cut(query, "&")
- if seg == "" || strings.Contains(seg, ";") {
- continue
- }
- key, value, _ := strings.Cut(seg, "=")
- k, err := url.QueryUnescape(key)
- if err != nil || k != name {
- continue
- }
- v, err := url.QueryUnescape(value)
- if err != nil {
- continue
- }
- return v
+ for i := 0; i < limit; i++ {
+ c.pvalues[i] = values[i]
}
- return ""
}
-// QueryParamOr returns the query param or default value for the provided name.
-// Note: QueryParamOr does not distinguish if query had no value by that name or value was empty string
-// This means URLs `/test?search=` and `/test` would both return `1` for `c.QueryParamOr("search", "1")`
-func (c *Context) QueryParamOr(name, defaultValue string) string {
- value := c.QueryParam(name)
- if value == "" {
- value = defaultValue
+func (c *context) QueryParam(name string) string {
+ if c.query == nil {
+ c.query = c.request.URL.Query()
}
- return value
+ return c.query.Get(name)
}
-// QueryParams returns the query parameters as `url.Values`.
-func (c *Context) QueryParams() url.Values {
+func (c *context) QueryParams() url.Values {
if c.query == nil {
c.query = c.request.URL.Query()
}
return c.query
}
-// QueryString returns the URL query string.
-func (c *Context) QueryString() string {
+func (c *context) QueryString() string {
return c.request.URL.RawQuery
}
-// FormValue returns the form field value for the provided name.
-func (c *Context) FormValue(name string) string {
+func (c *context) FormValue(name string) string {
return c.request.FormValue(name)
}
-// FormValueOr returns the form field value or default value for the provided name.
-// Note: FormValueOr does not distinguish if form had no value by that name or value was empty string
-func (c *Context) FormValueOr(name, defaultValue string) string {
- value := c.FormValue(name)
- if value == "" {
- value = defaultValue
- }
- return value
-}
-
-// FormValues returns the form field values as `url.Values`.
-func (c *Context) FormValues() (url.Values, error) {
+func (c *context) FormParams() (url.Values, error) {
if strings.HasPrefix(c.request.Header.Get(HeaderContentType), MIMEMultipartForm) {
- if err := c.request.ParseMultipartForm(c.formParseMaxMemory); err != nil {
+ if err := c.request.ParseMultipartForm(defaultMemory); err != nil {
return nil, err
}
} else {
@@ -422,189 +417,141 @@ func (c *Context) FormValues() (url.Values, error) {
return c.request.Form, nil
}
-// FormFile returns the multipart form file for the provided name.
-func (c *Context) FormFile(name string) (*multipart.FileHeader, error) {
+func (c *context) FormFile(name string) (*multipart.FileHeader, error) {
f, fh, err := c.request.FormFile(name)
if err != nil {
return nil, err
}
- _ = f.Close()
+ f.Close()
return fh, nil
}
-// MultipartForm returns the multipart form.
-func (c *Context) MultipartForm() (*multipart.Form, error) {
- err := c.request.ParseMultipartForm(c.formParseMaxMemory)
+func (c *context) MultipartForm() (*multipart.Form, error) {
+ err := c.request.ParseMultipartForm(defaultMemory)
return c.request.MultipartForm, err
}
-// Cookie returns the named cookie provided in the request.
-func (c *Context) Cookie(name string) (*http.Cookie, error) {
+func (c *context) Cookie(name string) (*http.Cookie, error) {
return c.request.Cookie(name)
}
-// SetCookie adds a `Set-Cookie` header in HTTP response.
-func (c *Context) SetCookie(cookie *http.Cookie) {
+func (c *context) SetCookie(cookie *http.Cookie) {
http.SetCookie(c.Response(), cookie)
}
-// Cookies returns the HTTP cookies sent with the request.
-func (c *Context) Cookies() []*http.Cookie {
+func (c *context) Cookies() []*http.Cookie {
return c.request.Cookies()
}
-// Get retrieves data from the context.
-// Method returns any(nil) when key does not exist which is different from typed nil (eg. []byte(nil)).
-func (c *Context) Get(key string) any {
- // Unlock without defer to avoid the deferred-call overhead on this hot path.
+func (c *context) Get(key string) any {
c.lock.RLock()
- v := c.store[key]
- c.lock.RUnlock()
- return v
+ defer c.lock.RUnlock()
+ return c.store[key]
}
-// Set saves data in the context.
-func (c *Context) Set(key string, val any) {
+func (c *context) Set(key string, val any) {
c.lock.Lock()
+ defer c.lock.Unlock()
+
if c.store == nil {
- c.store = make(map[string]any)
+ c.store = make(Map)
}
c.store[key] = val
- c.lock.Unlock()
}
-// Bind binds path params, query params and the request body into provided type `i`. The default binder
-// binds body based on Content-Type header.
-func (c *Context) Bind(i any) error {
- return c.echo.Binder.Bind(c, i)
+func (c *context) Bind(i any) error {
+ return c.echo.Binder.Bind(i, c)
}
-// Validate validates provided `i`. It is usually called after `Context#Bind()`.
-// Validator must be registered using `Echo#Validator`.
-func (c *Context) Validate(i any) error {
+func (c *context) Validate(i any) error {
if c.echo.Validator == nil {
return ErrValidatorNotRegistered
}
return c.echo.Validator.Validate(i)
}
-// Render renders a template with data and sends a text/html response with status
-// code. Renderer must be registered using `Echo.Renderer`.
-func (c *Context) Render(code int, name string, data any) (err error) {
+func (c *context) Render(code int, name string, data any) (err error) {
if c.echo.Renderer == nil {
return ErrRendererNotRegistered
}
- // as Renderer.Render can fail, and in that case we need to delay sending status code to the client until
- // (global) error handler decides the correct status code for the error to be sent to the client, so we need to write
- // the rendered template to the buffer first.
- //
- // html.Template.ExecuteTemplate() documentations writes:
- // > If an error occurs executing the template or writing its output,
- // > execution stops, but partial results may already have been written to
- // > the output writer.
-
buf := new(bytes.Buffer)
- if err = c.echo.Renderer.Render(c, buf, name, data); err != nil {
+ if err = c.echo.Renderer.Render(buf, name, data, c); err != nil {
return
}
return c.HTMLBlob(code, buf.Bytes())
}
-// HTML sends an HTTP response with status code.
-func (c *Context) HTML(code int, html string) (err error) {
- return c.HTMLBlob(code, stringToBytes(html))
+func (c *context) HTML(code int, html string) (err error) {
+ return c.HTMLBlob(code, []byte(html))
}
-// HTMLBlob sends an HTTP blob response with status code.
-func (c *Context) HTMLBlob(code int, b []byte) (err error) {
+func (c *context) HTMLBlob(code int, b []byte) (err error) {
return c.Blob(code, MIMETextHTMLCharsetUTF8, b)
}
-// String sends a string response with status code.
-func (c *Context) String(code int, s string) (err error) {
- return c.Blob(code, MIMETextPlainCharsetUTF8, stringToBytes(s))
+func (c *context) String(code int, s string) (err error) {
+ return c.Blob(code, MIMETextPlainCharsetUTF8, []byte(s))
}
-func (c *Context) jsonPBlob(code int, callback string, i any) (err error) {
+func (c *context) jsonPBlob(code int, callback string, i any) (err error) {
+ indent := ""
+ if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty {
+ indent = defaultIndent
+ }
c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8)
c.response.WriteHeader(code)
- if _, err = c.response.Write(stringToBytes(callback)); err != nil {
- return
- }
- if _, err = c.response.Write(jsonpOpen); err != nil {
+ if _, err = c.response.Write([]byte(callback + "(")); err != nil {
return
}
- if err = c.echo.JSONSerializer.Serialize(c, i, ""); err != nil {
+ if err = c.echo.JSONSerializer.Serialize(c, i, indent); err != nil {
return
}
- if _, err = c.response.Write(jsonpClose); err != nil {
+ if _, err = c.response.Write([]byte(");")); err != nil {
return
}
return
}
-func (c *Context) json(code int, i any, indent string) error {
+func (c *context) json(code int, i any, indent string) error {
c.writeContentType(MIMEApplicationJSON)
-
- // as JSONSerializer.Serialize can fail, and in that case we need to delay sending status code to the client until
- // (global) error handler decides correct status code for the error to be sent to the client.
- // For that we need to use writer that can store the proposed status code until the first Write is called.
- resp := c.Response()
- // Reuse the Context-owned delayedStatusWriter to avoid heap-allocating one per JSON response.
- // If we are already nested inside a delayed write (rare: a serializer or handler calling c.JSON
- // re-entrantly), allocate a fresh writer so the outer call's writer (which is &c.dsw) is not
- // clobbered β reusing c.dsw here would make it reference itself.
- if _, nested := resp.(*delayedStatusWriter); nested {
- c.SetResponse(&delayedStatusWriter{ResponseWriter: resp, status: code})
- } else {
- c.dsw = delayedStatusWriter{ResponseWriter: resp, status: code}
- c.SetResponse(&c.dsw)
- }
- defer c.SetResponse(resp)
-
+ c.response.Status = code
return c.echo.JSONSerializer.Serialize(c, i, indent)
}
-// JSON sends a JSON response with status code.
-func (c *Context) JSON(code int, i any) (err error) {
- return c.json(code, i, "")
+func (c *context) JSON(code int, i any) (err error) {
+ indent := ""
+ if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty {
+ indent = defaultIndent
+ }
+ return c.json(code, i, indent)
}
-// JSONPretty sends a pretty-print JSON with status code.
-func (c *Context) JSONPretty(code int, i any, indent string) (err error) {
+func (c *context) JSONPretty(code int, i any, indent string) (err error) {
return c.json(code, i, indent)
}
-// JSONBlob sends a JSON blob response with status code.
-func (c *Context) JSONBlob(code int, b []byte) (err error) {
+func (c *context) JSONBlob(code int, b []byte) (err error) {
return c.Blob(code, MIMEApplicationJSON, b)
}
-// JSONP sends a JSONP response with status code. It uses `callback` to construct
-// the JSONP payload.
-func (c *Context) JSONP(code int, callback string, i any) (err error) {
+func (c *context) JSONP(code int, callback string, i any) (err error) {
return c.jsonPBlob(code, callback, i)
}
-// JSONPBlob sends a JSONP blob response with status code. It uses `callback`
-// to construct the JSONP payload.
-func (c *Context) JSONPBlob(code int, callback string, b []byte) (err error) {
+func (c *context) JSONPBlob(code int, callback string, b []byte) (err error) {
c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8)
c.response.WriteHeader(code)
- if _, err = c.response.Write(stringToBytes(callback)); err != nil {
- return
- }
- if _, err = c.response.Write(jsonpOpen); err != nil {
+ if _, err = c.response.Write([]byte(callback + "(")); err != nil {
return
}
if _, err = c.response.Write(b); err != nil {
return
}
- _, err = c.response.Write(jsonpClose)
+ _, err = c.response.Write([]byte(");"))
return
}
-func (c *Context) xml(code int, i any, indent string) (err error) {
+func (c *context) xml(code int, i any, indent string) (err error) {
c.writeContentType(MIMEApplicationXMLCharsetUTF8)
c.response.WriteHeader(code)
enc := xml.NewEncoder(c.response)
@@ -617,18 +564,19 @@ func (c *Context) xml(code int, i any, indent string) (err error) {
return enc.Encode(i)
}
-// XML sends an XML response with status code.
-func (c *Context) XML(code int, i any) (err error) {
- return c.xml(code, i, "")
+func (c *context) XML(code int, i any) (err error) {
+ indent := ""
+ if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty {
+ indent = defaultIndent
+ }
+ return c.xml(code, i, indent)
}
-// XMLPretty sends a pretty-print XML with status code.
-func (c *Context) XMLPretty(code int, i any, indent string) (err error) {
+func (c *context) XMLPretty(code int, i any, indent string) (err error) {
return c.xml(code, i, indent)
}
-// XMLBlob sends an XML blob response with status code.
-func (c *Context) XMLBlob(code int, b []byte) (err error) {
+func (c *context) XMLBlob(code int, b []byte) (err error) {
c.writeContentType(MIMEApplicationXMLCharsetUTF8)
c.response.WriteHeader(code)
if _, err = c.response.Write([]byte(xml.Header)); err != nil {
@@ -638,98 +586,41 @@ func (c *Context) XMLBlob(code int, b []byte) (err error) {
return
}
-// Blob sends a blob response with status code and content type.
-func (c *Context) Blob(code int, contentType string, b []byte) (err error) {
+func (c *context) Blob(code int, contentType string, b []byte) (err error) {
c.writeContentType(contentType)
c.response.WriteHeader(code)
_, err = c.response.Write(b)
return
}
-// Stream sends a streaming response with status code and content type.
-func (c *Context) Stream(code int, contentType string, r io.Reader) (err error) {
+func (c *context) Stream(code int, contentType string, r io.Reader) (err error) {
c.writeContentType(contentType)
c.response.WriteHeader(code)
_, err = io.Copy(c.response, r)
return
}
-// File sends a response with the content of the file.
-//
-// Avoid using the leading `/` slash as most of the Go standard library fs.FS implementations require relative paths for
-// file operations.
-func (c *Context) File(file string) error {
- return fsFile(c, file, c.echo.Filesystem)
-}
-
-// FileFS serves file from given file system.
-//
-// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
-// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
-// including `assets/images` as their prefix.
-func (c *Context) FileFS(file string, filesystem fs.FS) error {
- return fsFile(c, file, filesystem)
-}
-
-func fsFile(c *Context, file string, filesystem fs.FS) error {
- file = path.Clean(file) // `os.Open` and `os.DirFs.Open()` behave differently, later does not like ``, `.`, `..` at all, but we allowed those now need to clean
- f, err := filesystem.Open(file)
- if err != nil {
- return ErrNotFound
- }
- defer f.Close()
-
- fi, _ := f.Stat()
- if fi.IsDir() {
- file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect.
- f, err = filesystem.Open(file)
- if err != nil {
- return ErrNotFound
- }
- defer f.Close()
- if fi, err = f.Stat(); err != nil {
- return err
- }
- }
- ff, ok := f.(io.ReadSeeker)
- if !ok {
- return errors.New("file does not implement io.ReadSeeker")
- }
- http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff)
- return nil
-}
-
-// Attachment sends a response as attachment, prompting client to save the file.
-//
-// Avoid using the leading `/` slash as most of the Go standard library fs.FS implementations require relative paths for
-// file operations.
-func (c *Context) Attachment(file, name string) error {
+func (c *context) Attachment(file, name string) error {
return c.contentDisposition(file, name, "attachment")
}
-// Inline sends a response as inline, opening the file in the browser.
-//
-// Avoid using the leading `/` slash as most of the Go standard library fs.FS implementations require relative paths for
-// file operations.
-func (c *Context) Inline(file, name string) error {
+func (c *context) Inline(file, name string) error {
return c.contentDisposition(file, name, "inline")
}
var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
-func (c *Context) contentDisposition(file, name, dispositionType string) error {
+func (c *context) contentDisposition(file, name, dispositionType string) error {
c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf(`%s; filename="%s"`, dispositionType, quoteEscaper.Replace(name)))
return c.File(file)
}
-// NoContent sends a response with no body and a status code.
-func (c *Context) NoContent(code int) error {
+func (c *context) NoContent(code int) error {
c.response.WriteHeader(code)
return nil
}
-// Redirect redirects the request to a provided URL with status code.
-func (c *Context) Redirect(code int, url string) error {
+func (c *context) Redirect(code int, url string) error {
if code < 300 || code > 308 {
return ErrInvalidRedirectCode
}
@@ -738,20 +629,45 @@ func (c *Context) Redirect(code int, url string) error {
return nil
}
-// Logger returns logger in Context
-func (c *Context) Logger() *slog.Logger {
- if c.logger != nil {
- return c.logger
+func (c *context) Error(err error) {
+ c.echo.HTTPErrorHandler(err, c)
+}
+
+func (c *context) Echo() *Echo {
+ return c.echo
+}
+
+func (c *context) Handler() HandlerFunc {
+ return c.handler
+}
+
+func (c *context) SetHandler(h HandlerFunc) {
+ c.handler = h
+}
+
+func (c *context) Logger() Logger {
+ res := c.logger
+ if res != nil {
+ return res
}
return c.echo.Logger
}
-// SetLogger sets logger in Context
-func (c *Context) SetLogger(logger *slog.Logger) {
- c.logger = logger
+func (c *context) SetLogger(l Logger) {
+ c.logger = l
}
-// Echo returns the `Echo` instance.
-func (c *Context) Echo() *Echo {
- return c.echo
+func (c *context) Reset(r *http.Request, w http.ResponseWriter) {
+ c.request = r
+ c.response.reset(w)
+ c.query = nil
+ c.handler = NotFoundHandler
+ c.store = nil
+ c.path = ""
+ c.pnames = nil
+ c.logger = nil
+ // NOTE: Don't reset because it has to have length c.echo.maxParam (or bigger) at all times
+ for i := 0; i < len(c.pvalues); i++ {
+ c.pvalues[i] = ""
+ }
}
diff --git a/context_fs.go b/context_fs.go
new file mode 100644
index 000000000..1c25baf12
--- /dev/null
+++ b/context_fs.go
@@ -0,0 +1,52 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "errors"
+ "io"
+ "io/fs"
+ "net/http"
+ "path/filepath"
+)
+
+func (c *context) File(file string) error {
+ return fsFile(c, file, c.echo.Filesystem)
+}
+
+// FileFS serves file from given file system.
+//
+// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
+// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
+// including `assets/images` as their prefix.
+func (c *context) FileFS(file string, filesystem fs.FS) error {
+ return fsFile(c, file, filesystem)
+}
+
+func fsFile(c Context, file string, filesystem fs.FS) error {
+ f, err := filesystem.Open(file)
+ if err != nil {
+ return ErrNotFound
+ }
+ defer f.Close()
+
+ fi, _ := f.Stat()
+ if fi.IsDir() {
+ file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect.
+ f, err = filesystem.Open(file)
+ if err != nil {
+ return ErrNotFound
+ }
+ defer f.Close()
+ if fi, err = f.Stat(); err != nil {
+ return err
+ }
+ }
+ ff, ok := f.(io.ReadSeeker)
+ if !ok {
+ return errors.New("file does not implement io.ReadSeeker")
+ }
+ http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff)
+ return nil
+}
diff --git a/context_fs_test.go b/context_fs_test.go
new file mode 100644
index 000000000..83232ea45
--- /dev/null
+++ b/context_fs_test.go
@@ -0,0 +1,135 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "github.com/stretchr/testify/assert"
+ "io/fs"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "testing"
+)
+
+func TestContext_File(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenFile string
+ whenFS fs.FS
+ expectStatus int
+ expectStartsWith []byte
+ expectError string
+ }{
+ {
+ name: "ok, from default file system",
+ whenFile: "_fixture/images/walle.png",
+ whenFS: nil,
+ expectStatus: http.StatusOK,
+ expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ },
+ {
+ name: "ok, from custom file system",
+ whenFile: "walle.png",
+ whenFS: os.DirFS("_fixture/images"),
+ expectStatus: http.StatusOK,
+ expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ },
+ {
+ name: "nok, not existent file",
+ whenFile: "not.png",
+ whenFS: os.DirFS("_fixture/images"),
+ expectStatus: http.StatusOK,
+ expectStartsWith: nil,
+ expectError: "code=404, message=Not Found",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ if tc.whenFS != nil {
+ e.Filesystem = tc.whenFS
+ }
+
+ handler := func(ec Context) error {
+ return ec.(*context).File(tc.whenFile)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := handler(c)
+
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ body := rec.Body.Bytes()
+ if len(body) > len(tc.expectStartsWith) {
+ body = body[:len(tc.expectStartsWith)]
+ }
+ assert.Equal(t, tc.expectStartsWith, body)
+ })
+ }
+}
+
+func TestContext_FileFS(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenFile string
+ whenFS fs.FS
+ expectStatus int
+ expectStartsWith []byte
+ expectError string
+ }{
+ {
+ name: "ok",
+ whenFile: "walle.png",
+ whenFS: os.DirFS("_fixture/images"),
+ expectStatus: http.StatusOK,
+ expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ },
+ {
+ name: "nok, not existent file",
+ whenFile: "not.png",
+ whenFS: os.DirFS("_fixture/images"),
+ expectStatus: http.StatusOK,
+ expectStartsWith: nil,
+ expectError: "code=404, message=Not Found",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+
+ handler := func(ec Context) error {
+ return ec.(*context).FileFS(tc.whenFile, tc.whenFS)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := handler(c)
+
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ body := rec.Body.Bytes()
+ if len(body) > len(tc.expectStartsWith) {
+ body = body[:len(tc.expectStartsWith)]
+ }
+ assert.Equal(t, tc.expectStartsWith, body)
+ })
+ }
+}
diff --git a/context_generic.go b/context_generic.go
index 7cf8b296c..f06041bbf 100644
--- a/context_generic.go
+++ b/context_generic.go
@@ -13,12 +13,9 @@ var ErrInvalidKeyType = errors.New("invalid key type")
// ContextGet retrieves a value from the context store or ErrNonExistentKey error the key is missing.
// Returns ErrInvalidKeyType error if the value is not castable to type T.
-func ContextGet[T any](c *Context, key string) (T, error) {
- c.lock.RLock()
- defer c.lock.RUnlock()
-
- val, ok := c.store[key]
- if !ok {
+func ContextGet[T any](c Context, key string) (T, error) {
+ val := c.Get(key)
+ if val == any(nil) {
var zero T
return zero, ErrNonExistentKey
}
@@ -34,7 +31,7 @@ func ContextGet[T any](c *Context, key string) (T, error) {
// ContextGetOr retrieves a value from the context store or returns a default value when the key
// is missing. Returns ErrInvalidKeyType error if the value is not castable to type T.
-func ContextGetOr[T any](c *Context, key string, defaultValue T) (T, error) {
+func ContextGetOr[T any](c Context, key string, defaultValue T) (T, error) {
typed, err := ContextGet[T](c, key)
if err == ErrNonExistentKey {
return defaultValue, nil
diff --git a/context_generic_test.go b/context_generic_test.go
index ce468ac3e..9b6d2d04e 100644
--- a/context_generic_test.go
+++ b/context_generic_test.go
@@ -10,7 +10,8 @@ import (
)
func TestContextGetOK(t *testing.T) {
- c := NewContext(nil, nil)
+ e := New()
+ c := e.NewContext(nil, nil)
c.Set("key", int64(123))
@@ -20,7 +21,8 @@ func TestContextGetOK(t *testing.T) {
}
func TestContextGetNonExistentKey(t *testing.T) {
- c := NewContext(nil, nil)
+ e := New()
+ c := e.NewContext(nil, nil)
c.Set("key", int64(123))
@@ -30,7 +32,8 @@ func TestContextGetNonExistentKey(t *testing.T) {
}
func TestContextGetInvalidCast(t *testing.T) {
- c := NewContext(nil, nil)
+ e := New()
+ c := e.NewContext(nil, nil)
c.Set("key", int64(123))
@@ -40,7 +43,8 @@ func TestContextGetInvalidCast(t *testing.T) {
}
func TestContextGetOrOK(t *testing.T) {
- c := NewContext(nil, nil)
+ e := New()
+ c := e.NewContext(nil, nil)
c.Set("key", int64(123))
@@ -50,7 +54,8 @@ func TestContextGetOrOK(t *testing.T) {
}
func TestContextGetOrNonExistentKey(t *testing.T) {
- c := NewContext(nil, nil)
+ e := New()
+ c := e.NewContext(nil, nil)
c.Set("key", int64(123))
@@ -60,7 +65,8 @@ func TestContextGetOrNonExistentKey(t *testing.T) {
}
func TestContextGetOrInvalidCast(t *testing.T) {
- c := NewContext(nil, nil)
+ e := New()
+ c := e.NewContext(nil, nil)
c.Set("key", int64(123))
diff --git a/context_test.go b/context_test.go
index 9b820c6e3..c848d9044 100644
--- a/context_test.go
+++ b/context_test.go
@@ -8,22 +8,20 @@ import (
"crypto/tls"
"encoding/json"
"encoding/xml"
+ "errors"
"fmt"
"io"
- "io/fs"
- "log/slog"
"math"
"mime/multipart"
- "net"
"net/http"
"net/http/httptest"
"net/url"
- "os"
"strings"
"testing"
"text/template"
"time"
+ "github.com/labstack/gommon/log"
"github.com/stretchr/testify/assert"
)
@@ -31,14 +29,13 @@ type Template struct {
templates *template.Template
}
-var testUser = user{ID: 1, Name: "Jon Snow"}
+var testUser = user{1, "Jon Snow"}
func BenchmarkAllocJSONP(b *testing.B) {
e := New()
- e.Logger = slog.New(slog.DiscardHandler)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
b.ResetTimer()
b.ReportAllocs()
@@ -50,10 +47,9 @@ func BenchmarkAllocJSONP(b *testing.B) {
func BenchmarkAllocJSON(b *testing.B) {
e := New()
- e.Logger = slog.New(slog.DiscardHandler)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
b.ResetTimer()
b.ReportAllocs()
@@ -65,10 +61,9 @@ func BenchmarkAllocJSON(b *testing.B) {
func BenchmarkAllocXML(b *testing.B) {
e := New()
- e.Logger = slog.New(slog.DiscardHandler)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
b.ResetTimer()
b.ReportAllocs()
@@ -79,7 +74,7 @@ func BenchmarkAllocXML(b *testing.B) {
}
func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) {
- c := Context{request: &http.Request{
+ c := context{request: &http.Request{
Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}},
}}
for i := 0; i < b.N; i++ {
@@ -87,7 +82,7 @@ func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) {
}
}
-func (t *Template) Render(c *Context, w io.Writer, name string, data any) error {
+func (t *Template) Render(w io.Writer, name string, data any, c Context) error {
return t.templates.ExecuteTemplate(w, name, data)
}
@@ -96,7 +91,7 @@ func TestContextEcho(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
assert.Equal(t, e, c.Echo())
}
@@ -106,7 +101,7 @@ func TestContextRequest(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
assert.NotNil(t, c.Request())
assert.Equal(t, req, c.Request())
@@ -117,7 +112,7 @@ func TestContextResponse(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
assert.NotNil(t, c.Response())
}
@@ -127,12 +122,12 @@ func TestContextRenderTemplate(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
tmpl := &Template{
templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")),
}
- c.Echo().Renderer = tmpl
+ c.echo.Renderer = tmpl
err := c.Render(http.StatusOK, "hello", "Jon Snow")
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
@@ -140,91 +135,24 @@ func TestContextRenderTemplate(t *testing.T) {
}
}
-func TestContextRenderTemplateError(t *testing.T) {
- // we test that when template rendering fails, no response is sent to the client yet, so the global error handler can decide what to do
- e := New()
- req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- tmpl := &Template{
- templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")),
- }
- c.Echo().Renderer = tmpl
- err := c.Render(http.StatusOK, "not_existing", "Jon Snow")
-
- assert.EqualError(t, err, `template: no template "not_existing" associated with template "hello"`)
- assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
- assert.Empty(t, rec.Body.String()) // body must not be sent to the client
-}
-
func TestContextRenderErrorsOnNoRenderer(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
- c.Echo().Renderer = nil
+ c.echo.Renderer = nil
assert.Error(t, c.Render(http.StatusOK, "hello", "Jon Snow"))
}
-func TestContextStream(t *testing.T) {
- e := New()
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := e.NewContext(req, rec)
-
- r, w := io.Pipe()
- go func() {
- defer w.Close()
- for i := range 3 {
- fmt.Fprintf(w, "data: index %v\n\n", i)
- time.Sleep(5 * time.Millisecond)
- }
- }()
-
- err := c.Stream(http.StatusOK, "text/event-stream", r)
- if assert.NoError(t, err) {
- assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, "text/event-stream", rec.Header().Get(HeaderContentType))
- assert.Equal(t, "data: index 0\n\ndata: index 1\n\ndata: index 2\n\n", rec.Body.String())
- }
-}
-
-func TestContextHTML(t *testing.T) {
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := NewContext(req, rec)
-
- err := c.HTML(http.StatusOK, "Hi, Jon Snow")
- if assert.NoError(t, err) {
- assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(t, "Hi, Jon Snow", rec.Body.String())
- }
-}
-
-func TestContextHTMLBlob(t *testing.T) {
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := NewContext(req, rec)
-
- err := c.HTMLBlob(http.StatusOK, []byte("Hi, Jon Snow"))
- if assert.NoError(t, err) {
- assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(t, "Hi, Jon Snow", rec.Body.String())
- }
-}
-
func TestContextJSON(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
- err := c.JSON(http.StatusOK, user{ID: 1, Name: "Jon Snow"})
+ err := c.JSON(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
@@ -236,37 +164,33 @@ func TestContextJSONErrorsOut(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
err := c.JSON(http.StatusOK, make(chan bool))
assert.EqualError(t, err, "json: unsupported type: chan bool")
-
- assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
- assert.Empty(t, rec.Body.String()) // body must not be sent to the client
}
-func TestContextJSONWithNotEchoResponse(t *testing.T) {
+func TestContextJSONPrettyURL(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
- c := e.NewContext(req, rec)
-
- c.SetResponse(rec)
-
- err := c.JSON(http.StatusCreated, map[string]float64{"foo": math.NaN()})
- assert.EqualError(t, err, "json: unsupported value: NaN")
+ req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
+ c := e.NewContext(req, rec).(*context)
- assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
- assert.Empty(t, rec.Body.String()) // body must not be sent to the client
+ err := c.JSON(http.StatusOK, user{1, "Jon Snow"})
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
+ }
}
func TestContextJSONPretty(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
- err := c.JSONPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, " ")
+ err := c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
@@ -278,16 +202,16 @@ func TestContextJSONWithEmptyIntent(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
- u := user{ID: 1, Name: "Jon Snow"}
+ u := user{1, "Jon Snow"}
emptyIndent := ""
buf := new(bytes.Buffer)
enc := json.NewEncoder(buf)
enc.SetIndent(emptyIndent, emptyIndent)
_ = enc.Encode(u)
- err := c.JSONPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, emptyIndent)
+ err := c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
@@ -299,10 +223,10 @@ func TestContextJSONP(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
callback := "callback"
- err := c.JSONP(http.StatusOK, callback, user{ID: 1, Name: "Jon Snow"})
+ err := c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"})
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
@@ -314,9 +238,9 @@ func TestContextJSONBlob(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
- data, err := json.Marshal(user{ID: 1, Name: "Jon Snow"})
+ data, err := json.Marshal(user{1, "Jon Snow"})
assert.NoError(t, err)
err = c.JSONBlob(http.StatusOK, data)
if assert.NoError(t, err) {
@@ -330,10 +254,10 @@ func TestContextJSONPBlob(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
callback := "callback"
- data, err := json.Marshal(user{ID: 1, Name: "Jon Snow"})
+ data, err := json.Marshal(user{1, "Jon Snow"})
assert.NoError(t, err)
err = c.JSONPBlob(http.StatusOK, callback, data)
if assert.NoError(t, err) {
@@ -347,9 +271,9 @@ func TestContextXML(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
- err := c.XML(http.StatusOK, user{ID: 1, Name: "Jon Snow"})
+ err := c.XML(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
@@ -357,13 +281,27 @@ func TestContextXML(t *testing.T) {
}
}
+func TestContextXMLPrettyURL(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ err := c.XML(http.StatusOK, user{1, "Jon Snow"})
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String())
+ }
+}
+
func TestContextXMLPretty(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
- err := c.XMLPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, " ")
+ err := c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
@@ -375,9 +313,9 @@ func TestContextXMLBlob(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
- data, err := xml.Marshal(user{ID: 1, Name: "Jon Snow"})
+ data, err := xml.Marshal(user{1, "Jon Snow"})
assert.NoError(t, err)
err = c.XMLBlob(http.StatusOK, data)
if assert.NoError(t, err) {
@@ -391,16 +329,16 @@ func TestContextXMLWithEmptyIntent(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
- u := user{ID: 1, Name: "Jon Snow"}
+ u := user{1, "Jon Snow"}
emptyIndent := ""
buf := new(bytes.Buffer)
enc := xml.NewEncoder(buf)
enc.Indent(emptyIndent, emptyIndent)
_ = enc.Encode(u)
- err := c.XMLPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, emptyIndent)
+ err := c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
@@ -408,17 +346,71 @@ func TestContextXMLWithEmptyIntent(t *testing.T) {
}
}
-func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
+type responseWriterErr struct {
+}
+
+func (responseWriterErr) Header() http.Header {
+ return http.Header{}
+}
+
+func (responseWriterErr) Write([]byte) (int, error) {
+ return 0, errors.New("responseWriterErr")
+}
+
+func (responseWriterErr) WriteHeader(statusCode int) {
+}
+
+func TestContextXMLError(t *testing.T) {
e := New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
- err := c.JSON(http.StatusCreated, user{ID: 1, Name: "Jon Snow"})
+ req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
+ c := e.NewContext(req, rec).(*context)
+ c.response.Writer = responseWriterErr{}
+
+ err := c.XML(http.StatusOK, make(chan bool))
+ assert.EqualError(t, err, "responseWriterErr")
+}
+
+func TestContextString(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
+ c := e.NewContext(req, rec).(*context)
+ err := c.String(http.StatusOK, "Hello, World!")
if assert.NoError(t, err) {
- assert.Equal(t, http.StatusCreated, rec.Code)
- assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
- assert.Equal(t, userJSON+"\n", rec.Body.String())
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, "Hello, World!", rec.Body.String())
+ }
+}
+
+func TestContextHTML(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ err := c.HTML(http.StatusOK, "Hello, World!")
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, "Hello, World!", rec.Body.String())
+ }
+}
+
+func TestContextStream(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ r := strings.NewReader("response from a stream")
+ err := c.Stream(http.StatusOK, "application/octet-stream", r)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType))
+ assert.Equal(t, "response from a stream", rec.Body.String())
}
}
@@ -444,7 +436,7 @@ func TestContextAttachment(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
err := c.Attachment("_fixture/images/walle.png", tc.whenName)
if assert.NoError(t, err) {
@@ -479,7 +471,7 @@ func TestContextInline(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
err := c.Inline("_fixture/images/walle.png", tc.whenName)
if assert.NoError(t, err) {
@@ -496,12 +488,69 @@ func TestContextNoContent(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
c.NoContent(http.StatusOK)
assert.Equal(t, http.StatusOK, rec.Code)
}
+func TestContextError(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ c.Error(errors.New("error"))
+ assert.Equal(t, http.StatusInternalServerError, rec.Code)
+ assert.True(t, c.Response().Committed)
+}
+
+func TestContextReset(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec).(*context)
+
+ c.SetParamNames("foo")
+ c.SetParamValues("bar")
+ c.Set("foe", "ban")
+ c.query = url.Values(map[string][]string{"fon": {"baz"}})
+
+ c.Reset(req, httptest.NewRecorder())
+
+ assert.Len(t, c.ParamValues(), 0)
+ assert.Len(t, c.ParamNames(), 0)
+ assert.Len(t, c.Path(), 0)
+ assert.Len(t, c.QueryParams(), 0)
+ assert.Len(t, c.store, 0)
+}
+
+func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec).(*context)
+ err := c.JSON(http.StatusCreated, user{1, "Jon Snow"})
+
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusCreated, rec.Code)
+ assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSON+"\n", rec.Body.String())
+ }
+}
+
+func TestContext_JSON_DoesntCommitResponseCodePrematurely(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec).(*context)
+ err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()})
+
+ if assert.Error(t, err) {
+ assert.False(t, c.response.Committed)
+ }
+}
+
func TestContextCookie(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
@@ -510,7 +559,7 @@ func TestContextCookie(t *testing.T) {
req.Header.Add(HeaderCookie, theme)
req.Header.Add(HeaderCookie, user)
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
// Read single
cookie, err := c.Cookie("theme")
@@ -547,237 +596,107 @@ func TestContextCookie(t *testing.T) {
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly")
}
-func TestContext_PathValues(t *testing.T) {
- var testCases = []struct {
- name string
- given PathValues
- expect PathValues
- }{
- {
- name: "param exists",
- given: PathValues{
- {Name: "uid", Value: "101"},
- {Name: "fid", Value: "501"},
- },
- expect: PathValues{
- {Name: "uid", Value: "101"},
- {Name: "fid", Value: "501"},
- },
- },
- {
- name: "params is empty",
- given: PathValues{},
- expect: PathValues{},
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := e.NewContext(req, nil)
-
- c.SetPathValues(tc.given)
+func TestContextPath(t *testing.T) {
+ e := New()
+ r := e.Router()
- assert.EqualValues(t, tc.expect, c.PathValues())
- })
- }
-}
+ handler := func(c Context) error { return c.String(http.StatusOK, "OK") }
-func TestContext_PathParam(t *testing.T) {
- var testCases = []struct {
- name string
- given PathValues
- whenParamName string
- expect string
- }{
- {
- name: "param exists",
- given: PathValues{
- {Name: "uid", Value: "101"},
- {Name: "fid", Value: "501"},
- },
- whenParamName: "uid",
- expect: "101",
- },
- {
- name: "multiple same param values exists - return first",
- given: PathValues{
- {Name: "uid", Value: "101"},
- {Name: "uid", Value: "202"},
- {Name: "fid", Value: "501"},
- },
- whenParamName: "uid",
- expect: "101",
- },
- {
- name: "param does not exists",
- given: PathValues{
- {Name: "uid", Value: "101"},
- },
- whenParamName: "nope",
- expect: "",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := e.NewContext(req, nil)
+ r.Add(http.MethodGet, "/users/:id", handler)
+ c := e.NewContext(nil, nil)
+ r.Find(http.MethodGet, "/users/1", c)
- c.SetPathValues(tc.given)
+ assert.Equal(t, "/users/:id", c.Path())
- assert.EqualValues(t, tc.expect, c.Param(tc.whenParamName))
- })
- }
+ r.Add(http.MethodGet, "/users/:uid/files/:fid", handler)
+ c = e.NewContext(nil, nil)
+ r.Find(http.MethodGet, "/users/1/files/1", c)
+ assert.Equal(t, "/users/:uid/files/:fid", c.Path())
}
-func TestContext_PathParamDefault(t *testing.T) {
- var testCases = []struct {
- name string
- given PathValues
- whenParamName string
- whenDefaultValue string
- expect string
- }{
- {
- name: "param exists",
- given: PathValues{
- {Name: "uid", Value: "101"},
- {Name: "fid", Value: "501"},
- },
- whenParamName: "uid",
- whenDefaultValue: "999",
- expect: "101",
- },
- {
- name: "param exists and is empty",
- given: PathValues{
- {Name: "uid", Value: ""},
- {Name: "fid", Value: "501"},
- },
- whenParamName: "uid",
- whenDefaultValue: "999",
- expect: "", // <-- this is different from QueryParamOr behaviour
- },
- {
- name: "param does not exists",
- given: PathValues{
- {Name: "uid", Value: "101"},
- },
- whenParamName: "nope",
- whenDefaultValue: "999",
- expect: "999",
- },
- }
+func TestContextPathParam(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, nil)
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := e.NewContext(req, nil)
+ // ParamNames
+ c.SetParamNames("uid", "fid")
+ assert.EqualValues(t, []string{"uid", "fid"}, c.ParamNames())
- c.SetPathValues(tc.given)
+ // ParamValues
+ c.SetParamValues("101", "501")
+ assert.EqualValues(t, []string{"101", "501"}, c.ParamValues())
- assert.EqualValues(t, tc.expect, c.ParamOr(tc.whenParamName, tc.whenDefaultValue))
- })
- }
+ // Param
+ assert.Equal(t, "501", c.Param("fid"))
+ assert.Equal(t, "", c.Param("undefined"))
}
-func TestContextGetAndSetPathValuesMutability(t *testing.T) {
- t.Run("c.PathValues() does not return copy and modifying raw slice mutates value in context", func(t *testing.T) {
- e := New()
- e.contextPathParamAllocSize.Store(1)
-
- req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
- c := e.NewContext(req, nil)
-
- params := PathValues{{Name: "foo", Value: "101"}}
- c.SetPathValues(params)
-
- // round-trip param values with modification
- paramVals := c.PathValues()
- assert.Equal(t, params, c.PathValues())
-
- // PathValues() does not return copy and modifying raw slice mutates value in context
- paramVals[0] = PathValue{Name: "xxx", Value: "yyy"}
- assert.Equal(t, PathValues{PathValue{Name: "xxx", Value: "yyy"}}, c.PathValues())
- })
-
- t.Run("calling SetPathValues with bigger size changes capacity in context", func(t *testing.T) {
- e := New()
- e.contextPathParamAllocSize.Store(1)
-
- req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
- c := e.NewContext(req, nil)
- // increase path param capacity in context
- pathValues := PathValues{
- {Name: "aaa", Value: "bbb"},
- {Name: "ccc", Value: "ddd"},
- }
- c.SetPathValues(pathValues)
- assert.Equal(t, pathValues, c.PathValues())
-
- // shouldn't explode during Reset() afterwards!
- assert.NotPanics(t, func() {
- c.Reset(nil, nil)
- })
- assert.Equal(t, PathValues{}, c.PathValues())
- assert.Len(t, *c.pathValues, 0)
- assert.Equal(t, 2, cap(*c.pathValues))
- })
-
- t.Run("calling SetPathValues with smaller size slice does not change capacity in context", func(t *testing.T) {
- e := New()
-
- req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
- c := e.NewContext(req, nil)
- c.pathValues = &PathValues{
- {Name: "aaa", Value: "bbb"},
- {Name: "ccc", Value: "ddd"},
- }
-
- pathValues := PathValues{
- {Name: "aaa", Value: "bbb"},
- }
- // given pathValues slice is smaller. this should not decrease c.pathValues capacity
- c.SetPathValues(pathValues)
- assert.Equal(t, pathValues, c.PathValues())
-
- // shouldn't explode during Reset() afterwards!
- assert.NotPanics(t, func() {
- c.Reset(nil, nil)
- })
- assert.Equal(t, PathValues{}, c.PathValues())
- assert.Len(t, *c.pathValues, 0)
- assert.Equal(t, 2, cap(*c.pathValues))
+func TestContextGetAndSetParam(t *testing.T) {
+ e := New()
+ r := e.Router()
+ r.Add(http.MethodGet, "/:foo", func(Context) error { return nil })
+ req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
+ c := e.NewContext(req, nil)
+ c.SetParamNames("foo")
+
+ // round-trip param values with modification
+ paramVals := c.ParamValues()
+ assert.EqualValues(t, []string{""}, c.ParamValues())
+ paramVals[0] = "bar"
+ c.SetParamValues(paramVals...)
+ assert.EqualValues(t, []string{"bar"}, c.ParamValues())
+
+ // shouldn't explode during Reset() afterwards!
+ assert.NotPanics(t, func() {
+ c.Reset(nil, nil)
})
-
}
-// Issue #1655
-func TestContext_SetParamNamesShouldNotModifyPathValuesCapacity(t *testing.T) {
+func TestContextSetParamNamesEchoMaxParam(t *testing.T) {
e := New()
- c := e.NewContext(nil, nil)
+ assert.Equal(t, 0, *e.maxParam)
+
+ expectedOneParam := []string{"one"}
+ expectedTwoParams := []string{"one", "two"}
+ expectedThreeParams := []string{"one", "two", ""}
+
+ {
+ c := e.AcquireContext()
+ c.SetParamNames("1", "2")
+ c.SetParamValues(expectedTwoParams...)
+ assert.Equal(t, 0, *e.maxParam) // has not been changed
+ assert.EqualValues(t, expectedTwoParams, c.ParamValues())
+ e.ReleaseContext(c)
+ }
- assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load())
- expectedTwoParams := PathValues{
- {Name: "1", Value: "one"},
- {Name: "2", Value: "two"},
+ {
+ c := e.AcquireContext()
+ c.SetParamNames("1", "2", "3")
+ c.SetParamValues(expectedThreeParams...)
+ assert.Equal(t, 0, *e.maxParam) // has not been changed
+ assert.EqualValues(t, expectedThreeParams, c.ParamValues())
+ e.ReleaseContext(c)
}
- c.SetPathValues(expectedTwoParams)
- assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load())
- assert.Equal(t, expectedTwoParams, c.PathValues())
-
- expectedThreeParams := PathValues{
- {Name: "1", Value: "one"},
- {Name: "2", Value: "two"},
- {Name: "3", Value: "three"},
+
+ { // values is always same size as names length
+ c := e.NewContext(nil, nil)
+ c.SetParamValues([]string{"one", "two"}...) // more values than names should be ok
+ c.SetParamNames("1")
+ assert.Equal(t, 0, *e.maxParam) // has not been changed
+ assert.EqualValues(t, expectedOneParam, c.ParamValues())
+ }
+
+ e.GET("/:id", handlerFunc)
+ assert.Equal(t, 1, *e.maxParam) // has not been changed
+
+ {
+ c := e.NewContext(nil, nil)
+ c.SetParamValues([]string{"one", "two"}...)
+ c.SetParamNames("1")
+ assert.Equal(t, 1, *e.maxParam) // has not been changed
+ assert.EqualValues(t, expectedOneParam, c.ParamValues())
}
- c.SetPathValues(expectedThreeParams)
- assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load())
- assert.Equal(t, expectedThreeParams, c.PathValues())
}
func TestContextFormValue(t *testing.T) {
@@ -794,151 +713,41 @@ func TestContextFormValue(t *testing.T) {
assert.Equal(t, "Jon Snow", c.FormValue("name"))
assert.Equal(t, "jon@labstack.com", c.FormValue("email"))
- // FormValueOr
- assert.Equal(t, "Jon Snow", c.FormValueOr("name", "nope"))
- assert.Equal(t, "default", c.FormValueOr("missing", "default"))
-
- // FormValues
- values, err := c.FormValues()
+ // FormParams
+ params, err := c.FormParams()
if assert.NoError(t, err) {
assert.Equal(t, url.Values{
"name": []string{"Jon Snow"},
"email": []string{"jon@labstack.com"},
- }, values)
+ }, params)
}
// Multipart FormParams error
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(HeaderContentType, MIMEMultipartForm)
c = e.NewContext(req, nil)
- values, err = c.FormValues()
- assert.Nil(t, values)
+ params, err = c.FormParams()
+ assert.Nil(t, params)
assert.Error(t, err)
}
-func TestContext_QueryParams(t *testing.T) {
- var testCases = []struct {
- expect url.Values
- name string
- givenURL string
- }{
- {
- name: "multiple values in url",
- givenURL: "/?test=1&test=2&email=jon%40labstack.com",
- expect: url.Values{
- "test": []string{"1", "2"},
- "email": []string{"jon@labstack.com"},
- },
- },
- {
- name: "single value in url",
- givenURL: "/?nope=1",
- expect: url.Values{
- "nope": []string{"1"},
- },
- },
- {
- name: "no query params in url",
- givenURL: "/?",
- expect: url.Values{},
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
- e := New()
- c := e.NewContext(req, nil)
-
- assert.Equal(t, tc.expect, c.QueryParams())
- })
- }
-}
-
-func TestContext_QueryParam(t *testing.T) {
- var testCases = []struct {
- name string
- givenURL string
- whenParamName string
- expect string
- }{
- {
- name: "value exists in url",
- givenURL: "/?test=1",
- whenParamName: "test",
- expect: "1",
- },
- {
- name: "multiple values exists in url",
- givenURL: "/?test=9&test=8",
- whenParamName: "test",
- expect: "9", // <-- first value in returned
- },
- {
- name: "value does not exists in url",
- givenURL: "/?nope=1",
- whenParamName: "test",
- expect: "",
- },
- {
- name: "value is empty in url",
- givenURL: "/?test=",
- whenParamName: "test",
- expect: "",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
- e := New()
- c := e.NewContext(req, nil)
-
- assert.Equal(t, tc.expect, c.QueryParam(tc.whenParamName))
- })
- }
-}
-
-func TestContext_QueryParamDefault(t *testing.T) {
- var testCases = []struct {
- name string
- givenURL string
- whenParamName string
- whenDefaultValue string
- expect string
- }{
- {
- name: "value exists in url",
- givenURL: "/?test=1",
- whenParamName: "test",
- whenDefaultValue: "999",
- expect: "1",
- },
- {
- name: "value does not exists in url",
- givenURL: "/?nope=1",
- whenParamName: "test",
- whenDefaultValue: "999",
- expect: "999",
- },
- {
- name: "value is empty in url",
- givenURL: "/?test=",
- whenParamName: "test",
- whenDefaultValue: "999",
- expect: "999",
- },
- }
+func TestContextQueryParam(t *testing.T) {
+ q := make(url.Values)
+ q.Set("name", "Jon Snow")
+ q.Set("email", "jon@labstack.com")
+ req := httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil)
+ e := New()
+ c := e.NewContext(req, nil)
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
- e := New()
- c := e.NewContext(req, nil)
+ // QueryParam
+ assert.Equal(t, "Jon Snow", c.QueryParam("name"))
+ assert.Equal(t, "jon@labstack.com", c.QueryParam("email"))
- assert.Equal(t, tc.expect, c.QueryParamOr(tc.whenParamName, tc.whenDefaultValue))
- })
- }
+ // QueryParams
+ assert.Equal(t, url.Values{
+ "name": []string{"Jon Snow"},
+ "email": []string{"jon@labstack.com"},
+ }, c.QueryParams())
}
func TestContextFormFile(t *testing.T) {
@@ -999,47 +808,16 @@ func TestContextRedirect(t *testing.T) {
assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo"))
}
-func TestContextGet(t *testing.T) {
- var testCases = []struct {
- name string
- given any
- whenKey string
- expect any
- }{
- {
- name: "ok, value exist",
- given: "Jon Snow",
- whenKey: "key",
- expect: "Jon Snow",
- },
- {
- name: "ok, value does not exist",
- given: "Jon Snow",
- whenKey: "nope",
- expect: nil,
- },
- {
- name: "ok, value is nil value",
- given: []byte(nil),
- whenKey: "key",
- expect: []byte(nil),
- },
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- var c = new(Context)
- c.Set("key", tc.given)
-
- v := c.Get(tc.whenKey)
- assert.Equal(t, tc.expect, v)
- })
- }
+func TestContextStore(t *testing.T) {
+ var c Context = new(context)
+ c.Set("name", "Jon Snow")
+ assert.Equal(t, "Jon Snow", c.Get("name"))
}
func BenchmarkContext_Store(b *testing.B) {
e := &Echo{}
- c := &Context{
+ c := &context{
echo: e,
}
@@ -1051,6 +829,42 @@ func BenchmarkContext_Store(b *testing.B) {
}
}
+func TestContextHandler(t *testing.T) {
+ e := New()
+ r := e.Router()
+ b := new(bytes.Buffer)
+
+ r.Add(http.MethodGet, "/handler", func(Context) error {
+ _, err := b.Write([]byte("handler"))
+ return err
+ })
+ c := e.NewContext(nil, nil)
+ r.Find(http.MethodGet, "/handler", c)
+ err := c.Handler()(c)
+ assert.Equal(t, "handler", b.String())
+ assert.NoError(t, err)
+}
+
+func TestContext_SetHandler(t *testing.T) {
+ var c Context = new(context)
+
+ assert.Nil(t, c.Handler())
+
+ c.SetHandler(func(c Context) error {
+ return nil
+ })
+ assert.NotNil(t, c.Handler())
+}
+
+func TestContext_Path(t *testing.T) {
+ path := "/pa/th"
+
+ var c Context = new(context)
+
+ c.SetPath(path)
+ assert.Equal(t, path, c.Path())
+}
+
type validator struct{}
func (*validator) Validate(i any) error {
@@ -1079,7 +893,7 @@ func TestContext_QueryString(t *testing.T) {
}
func TestContext_Request(t *testing.T) {
- var c = new(Context)
+ var c Context = new(context)
assert.Nil(t, c.Request())
@@ -1252,9 +1066,10 @@ func TestContext_Scheme(t *testing.T) {
if tc.givenHeaders != nil {
req.Header = tc.givenHeaders
}
- c := NewContext(req, nil)
+ e := New()
+ c := e.NewContext(req, nil)
if tc.givenIsTLS {
- c.request.TLS = &tls.ConnectionState{}
+ c.Request().TLS = &tls.ConnectionState{}
}
assert.Equal(t, tc.expect, c.Scheme())
@@ -1264,56 +1079,39 @@ func TestContext_Scheme(t *testing.T) {
func TestContext_IsWebSocket(t *testing.T) {
tests := []struct {
- c *Context
+ c Context
ws assert.BoolAssertionFunc
}{
{
- &Context{
+ &context{
request: &http.Request{
- Header: http.Header{
- HeaderUpgrade: []string{"websocket"},
- HeaderConnection: []string{"upgrade"},
- },
+ Header: http.Header{HeaderUpgrade: []string{"websocket"}},
},
},
assert.True,
},
{
- &Context{
+ &context{
request: &http.Request{
- Header: http.Header{
- HeaderUpgrade: []string{"Websocket"},
- HeaderConnection: []string{"Upgrade"},
- },
+ Header: http.Header{HeaderUpgrade: []string{"Websocket"}},
},
},
assert.True,
},
{
- &Context{
+ &context{
request: &http.Request{},
},
assert.False,
},
{
- &Context{
+ &context{
request: &http.Request{
Header: http.Header{HeaderUpgrade: []string{"other"}},
},
},
assert.False,
},
- {
- &Context{
- request: &http.Request{
- Header: http.Header{
- HeaderUpgrade: []string{"websocket"},
- HeaderConnection: []string{"close"},
- },
- },
- },
- assert.False,
- },
}
for i, tt := range tests {
@@ -1332,212 +1130,110 @@ func TestContext_Bind(t *testing.T) {
req.Header.Add(HeaderContentType, MIMEApplicationJSON)
err := c.Bind(u)
assert.NoError(t, err)
- assert.Equal(t, &user{ID: 1, Name: "Jon Snow"}, u)
+ assert.Equal(t, &user{1, "Jon Snow"}, u)
}
-func TestContext_RealIP(t *testing.T) {
- _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64")
+func TestContext_Logger(t *testing.T) {
+ e := New()
+ c := e.NewContext(nil, nil)
- var testCases = []struct {
- name string
- givenIPExtrator IPExtractor
- whenReq *http.Request
- expect string
+ log1 := c.Logger()
+ assert.NotNil(t, log1)
+
+ log2 := log.New("echo2")
+ c.SetLogger(log2)
+ assert.Equal(t, log2, c.Logger())
+
+ // Resetting the context returns the initial logger
+ c.Reset(nil, nil)
+ assert.Equal(t, log1, c.Logger())
+}
+
+func TestContext_RealIP(t *testing.T) {
+ tests := []struct {
+ c Context
+ s string
}{
{
- name: "ip from remote addr",
- givenIPExtrator: nil,
- whenReq: &http.Request{RemoteAddr: "89.89.89.89:1654"},
- expect: "89.89.89.89",
+ &context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}},
+ },
+ },
+ "127.0.0.1",
},
{
- name: "ip from ip extractor",
- givenIPExtrator: ExtractIPFromRealIPHeader(TrustIPRange(ipv6ForRemoteAddrExternalRange)),
- whenReq: &http.Request{
- Header: http.Header{
- HeaderXRealIP: []string{"[2001:db8::113:199]"},
- HeaderXForwardedFor: []string{"[2001:db8::113:198], [2001:db8::113:197]"}, // <-- should not affect anything
+ &context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1,127.0.1.1"}},
},
- RemoteAddr: "[2001:db8::113:1]:8080",
},
- expect: "2001:db8::113:199",
+ "127.0.0.1",
},
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := New()
- c := e.NewContext(tc.whenReq, nil)
- if tc.givenIPExtrator != nil {
- e.IPExtractor = tc.givenIPExtrator
- }
- assert.Equal(t, tc.expect, c.RealIP())
- })
- }
-}
-
-func TestContext_File(t *testing.T) {
- var testCases = []struct {
- whenFS fs.FS
- name string
- whenFile string
- expectError string
- expectStartsWith []byte
- expectStatus int
- }{
{
- name: "ok, from default file system",
- whenFile: "_fixture/images/walle.png",
- whenFS: nil,
- expectStatus: http.StatusOK,
- expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ &context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1"}},
+ },
+ },
+ "127.0.0.1",
},
{
- name: "ok, from custom file system",
- whenFile: "walle.png",
- whenFS: os.DirFS("_fixture/images"),
- expectStatus: http.StatusOK,
- expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ &context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348], 2001:db8::1, "}},
+ },
+ },
+ "2001:db8:85a3:8d3:1319:8a2e:370:7348",
},
{
- name: "nok, not existent file",
- whenFile: "not.png",
- whenFS: os.DirFS("_fixture/images"),
- expectStatus: http.StatusOK,
- expectStartsWith: nil,
- expectError: "Not Found",
+ &context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348],[2001:db8::1]"}},
+ },
+ },
+ "2001:db8:85a3:8d3:1319:8a2e:370:7348",
},
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := New()
- if tc.whenFS != nil {
- e.Filesystem = tc.whenFS
- }
-
- handler := func(ec *Context) error {
- return ec.File(tc.whenFile)
- }
-
- req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- err := handler(c)
-
- assert.Equal(t, tc.expectStatus, rec.Code)
- if tc.expectError != "" {
- assert.EqualError(t, err, tc.expectError)
- } else {
- assert.NoError(t, err)
- }
-
- body := rec.Body.Bytes()
- if len(body) > len(tc.expectStartsWith) {
- body = body[:len(tc.expectStartsWith)]
- }
- assert.Equal(t, tc.expectStartsWith, body)
- })
- }
-}
-
-func TestContext_FileFS(t *testing.T) {
- var testCases = []struct {
- whenFS fs.FS
- name string
- whenFile string
- expectError string
- expectStartsWith []byte
- expectStatus int
- }{
{
- name: "ok",
- whenFile: "walle.png",
- whenFS: os.DirFS("_fixture/images"),
- expectStatus: http.StatusOK,
- expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ &context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"2001:db8:85a3:8d3:1319:8a2e:370:7348"}},
+ },
+ },
+ "2001:db8:85a3:8d3:1319:8a2e:370:7348",
},
{
- name: "nok, not existent file",
- whenFile: "not.png",
- whenFS: os.DirFS("_fixture/images"),
- expectStatus: http.StatusOK,
- expectStartsWith: nil,
- expectError: "Not Found",
+ &context{
+ request: &http.Request{
+ Header: http.Header{
+ "X-Real-Ip": []string{"192.168.0.1"},
+ },
+ },
+ },
+ "192.168.0.1",
+ },
+ {
+ &context{
+ request: &http.Request{
+ Header: http.Header{
+ "X-Real-Ip": []string{"[2001:db8::1]"},
+ },
+ },
+ },
+ "2001:db8::1",
},
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := New()
-
- handler := func(ec *Context) error {
- return ec.FileFS(tc.whenFile, tc.whenFS)
- }
-
- req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- err := handler(c)
-
- assert.Equal(t, tc.expectStatus, rec.Code)
- if tc.expectError != "" {
- assert.EqualError(t, err, tc.expectError)
- } else {
- assert.NoError(t, err)
- }
- body := rec.Body.Bytes()
- if len(body) > len(tc.expectStartsWith) {
- body = body[:len(tc.expectStartsWith)]
- }
- assert.Equal(t, tc.expectStartsWith, body)
- })
+ {
+ &context{
+ request: &http.Request{
+ RemoteAddr: "89.89.89.89:1654",
+ },
+ },
+ "89.89.89.89",
+ },
}
-}
-
-func TestLogger(t *testing.T) {
- e := New()
- c := e.NewContext(nil, nil)
-
- log1 := c.Logger()
- assert.NotNil(t, log1)
- assert.Equal(t, e.Logger, log1)
-
- customLogger := slog.New(slog.NewTextHandler(os.Stdout, nil))
- c.SetLogger(customLogger)
- assert.Equal(t, customLogger, c.Logger())
-
- // Resetting the context returns the initial Echo logger
- c.Reset(nil, nil)
- assert.Equal(t, e.Logger, c.Logger())
-}
-
-func TestRouteInfo(t *testing.T) {
- e := New()
- c := e.NewContext(nil, nil)
- orgRI := RouteInfo{
- Name: "root",
- Method: http.MethodGet,
- Path: "/*",
- Parameters: []string{"*"},
+ for _, tt := range tests {
+ assert.Equal(t, tt.s, tt.c.RealIP())
}
- c.route = &orgRI
- ri := c.RouteInfo()
- assert.Equal(t, orgRI, ri)
-
- // Test mutability when middlewares start to change things
-
- // RouteInfo inside context will not be affected when returned instance is changed
- expect := orgRI.Clone()
- ri.Path = "changed"
- ri.Parameters[0] = "changed"
- assert.Equal(t, expect, c.RouteInfo())
-
- // RouteInfo inside context will not be affected when returned instance is changed
- expect = c.RouteInfo()
- orgRI.Name = "changed"
- assert.NotEqual(t, expect, c.RouteInfo())
}
diff --git a/dispatch_pool_test.go b/dispatch_pool_test.go
deleted file mode 100644
index b2912139f..000000000
--- a/dispatch_pool_test.go
+++ /dev/null
@@ -1,101 +0,0 @@
-// SPDX-License-Identifier: MIT
-// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
-
-package echo
-
-import (
- "net/http"
- "net/http/httptest"
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-// TestContextResetClearsStore guards the clear(c.store) reuse: a pooled Context must not leak store
-// values from a previous request into the next one, and Set must still work after a clear-based Reset.
-func TestContextResetClearsStore(t *testing.T) {
- e := New()
- c := e.NewContext(httptest.NewRequest(http.MethodGet, "/", nil), httptest.NewRecorder())
- c.Set("secret", "req1")
- assert.Equal(t, "req1", c.Get("secret"))
-
- c.Reset(httptest.NewRequest(http.MethodGet, "/", nil), httptest.NewRecorder())
- assert.Nil(t, c.Get("secret"), "store must not leak across Reset")
-
- c.Set("k", "req2") // Set must still work after clear-based reset
- assert.Equal(t, "req2", c.Get("k"))
-}
-
-// TestContextJSONStatusAcrossReset guards the reused delayedStatusWriter (c.dsw): a second JSON
-// response on a pooled+Reset Context must use the new status, not inherit the previous one.
-func TestContextJSONStatusAcrossReset(t *testing.T) {
- e := New()
- c := e.NewContext(httptest.NewRequest(http.MethodGet, "/", nil), httptest.NewRecorder())
- assert.NoError(t, c.JSON(http.StatusTeapot, map[string]int{"a": 1}))
-
- rec2 := httptest.NewRecorder()
- c.Reset(httptest.NewRequest(http.MethodGet, "/", nil), rec2)
- assert.NoError(t, c.JSON(http.StatusCreated, map[string]int{"b": 2}))
- assert.Equal(t, http.StatusCreated, rec2.Code)
- assert.JSONEq(t, `{"b":2}`, rec2.Body.String())
-}
-
-// TestNestedJSONUsesFreshDelayedWriter guards the nested c.JSON case: a serializer that calls c.JSON
-// re-entrantly must not corrupt the outer delayedStatusWriter (regression test for c.dsw self-reference).
-func TestNestedJSONUsesFreshDelayedWriter(t *testing.T) {
- e := New()
- e.JSONSerializer = nestedJSONSerializer{}
- rec := httptest.NewRecorder()
- c := e.NewContext(httptest.NewRequest(http.MethodGet, "/", nil), rec)
- assert.NoError(t, c.JSON(http.StatusOK, map[string]int{"outer": 1}))
- assert.Equal(t, http.StatusOK, rec.Code)
-}
-
-type nestedJSONSerializer struct{}
-
-func (nestedJSONSerializer) Serialize(c *Context, i any, indent string) error {
- if m, ok := i.(map[string]int); ok && m["outer"] == 1 {
- // re-enter c.JSON once before encoding the outer payload
- if err := c.JSON(http.StatusOK, map[string]int{"inner": 2}); err != nil {
- return err
- }
- }
- return (DefaultJSONSerializer{}).Serialize(c, i, indent)
-}
-
-func (nestedJSONSerializer) Deserialize(c *Context, i any) error {
- return (DefaultJSONSerializer{}).Deserialize(c, i)
-}
-
-// TestGlobalMiddlewareRunsOnNotFoundAndMethodNotAllowed pins the dispatch contract: global (Use) and
-// pre (Pre) middleware must execute even when the router returns 404 / 405 / OPTIONS handlers.
-func TestGlobalMiddlewareRunsOnNotFoundAndMethodNotAllowed(t *testing.T) {
- cases := []struct {
- name, method, path string
- code int
- }{
- {"404", http.MethodGet, "/missing", http.StatusNotFound},
- {"405", http.MethodPost, "/", http.StatusMethodNotAllowed},
- {"OPTIONS", http.MethodOptions, "/", http.StatusNoContent},
- }
- for _, tc := range cases {
- t.Run(tc.name, func(t *testing.T) {
- e := New()
- var pre, use bool
- e.Pre(func(n HandlerFunc) HandlerFunc {
- return func(c *Context) error { pre = true; return n(c) }
- })
- e.Use(func(n HandlerFunc) HandlerFunc {
- return func(c *Context) error { use = true; return n(c) }
- })
- e.GET("/", func(c *Context) error { return c.String(http.StatusOK, "ok") })
-
- rec := httptest.NewRecorder()
- e.ServeHTTP(rec, httptest.NewRequest(tc.method, tc.path, nil))
-
- assert.True(t, pre, "pre-middleware must run on %s", tc.name)
- assert.True(t, use, "global middleware must run on %s", tc.name)
- assert.Equal(t, tc.code, rec.Code)
- })
- }
-}
diff --git a/echo.go b/echo.go
index ee3a53cd0..489e16c23 100644
--- a/echo.go
+++ b/echo.go
@@ -9,33 +9,30 @@ Example:
package main
import (
- "log/slog"
- "net/http"
+ "net/http"
- "github.com/labstack/echo/v5"
- "github.com/labstack/echo/v5/middleware"
+ "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v4/middleware"
)
// Handler
- func hello(c *echo.Context) error {
- return c.String(http.StatusOK, "Hello, World!")
+ func hello(c echo.Context) error {
+ return c.String(http.StatusOK, "Hello, World!")
}
func main() {
- // Echo instance
- e := echo.New()
+ // Echo instance
+ e := echo.New()
- // Middleware
- e.Use(middleware.RequestLogger())
- e.Use(middleware.Recover())
+ // Middleware
+ e.Use(middleware.Logger())
+ e.Use(middleware.Recover())
- // Routes
- e.GET("/", hello)
+ // Routes
+ e.GET("/", hello)
- // Start server
- if err := e.Start(":8080"); err != nil {
- slog.Error("failed to start server", "error", err)
- }
+ // Start server
+ e.Logger.Fatal(e.Start(":1323"))
}
Learn more at https://echo.labstack.com
@@ -44,99 +41,126 @@ package echo
import (
stdContext "context"
+ "crypto/tls"
"encoding/json"
"errors"
"fmt"
- "io/fs"
- "log/slog"
+ stdLog "log"
+ "net"
"net/http"
- "net/url"
"os"
- "os/signal"
- "path"
- "path/filepath"
- "strings"
+ "reflect"
+ "runtime"
"sync"
- "sync/atomic"
- "syscall"
-
- "github.com/labstack/echo/v5/internal/pathutil"
+ "time"
+
+ "github.com/labstack/gommon/color"
+ "github.com/labstack/gommon/log"
+ "golang.org/x/crypto/acme"
+ "golang.org/x/crypto/acme/autocert"
+ "golang.org/x/net/http2"
+ "golang.org/x/net/http2/h2c"
)
// Echo is the top-level framework instance.
//
// Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these
// fields from handlers/middlewares and changing field values at the same time leads to data-races.
-// Same rule applies to adding new routes after server has been started - Adding a route is not Goroutine safe action.
+// Adding new routes after the server has been started is also not safe!
type Echo struct {
- serveHTTPFunc func(http.ResponseWriter, *http.Request)
-
- Binder Binder
-
- // Filesystem is the file system used for serving static files. Defaults to the current working directory (os.Getwd()).
- //
- // Note: fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid
- // so if you have `fs := os.DirFS("/tmp")` and you try to `fs.Open("/tmp/file.txt")` it will fail, but "file.txt"
- // would succeed. `echo.NewDefaultFS("/tmp")` overwrites this behavior and allows you to use Open with a matching
- // absolute path prefix.
- Filesystem fs.FS
-
- Renderer Renderer
- Validator Validator
+ filesystem
+ common
+ // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get
+ // listener address info (on which interface/port was listener bound) without having data races.
+ startupMutex sync.RWMutex
+ colorer *color.Color
+
+ // premiddleware are middlewares that are run before routing is done. In case a pre-middleware returns
+ // an error the router is not executed and the request will end up in the global error handler.
+ premiddleware []MiddlewareFunc
+ middleware []MiddlewareFunc
+ maxParam *int
+ router *Router
+ routers map[string]*Router
+ pool sync.Pool
+
+ StdLogger *stdLog.Logger
+ Server *http.Server
+ TLSServer *http.Server
+ Listener net.Listener
+ TLSListener net.Listener
+ AutoTLSManager autocert.Manager
+ HTTPErrorHandler HTTPErrorHandler
+ Binder Binder
JSONSerializer JSONSerializer
+ Validator Validator
+ Renderer Renderer
+ Logger Logger
IPExtractor IPExtractor
- OnAddRoute func(route Route) error
- HTTPErrorHandler HTTPErrorHandler
- Logger *slog.Logger
+ ListenerNetwork string
- contextPool sync.Pool
+ // OnAddRouteHandler is called when Echo adds new route to specific host router.
+ OnAddRouteHandler func(host string, route Route, handler HandlerFunc, middleware []MiddlewareFunc)
+ DisableHTTP2 bool
+ Debug bool
+ HideBanner bool
+ HidePort bool
+}
- router Router
+// Route contains a handler and information for matching against requests.
+type Route struct {
+ Method string `json:"method"`
+ Path string `json:"path"`
+ Name string `json:"name"`
+}
- // premiddleware are middlewares that are called before routing is done
- premiddleware []MiddlewareFunc
+// HTTPError represents an error that occurred while handling a request.
+type HTTPError struct {
+ Internal error `json:"-"` // Stores the error returned by an external dependency
+ Message any `json:"message"`
+ Code int `json:"-"`
+}
- // middleware are middlewares that are called after routing is done and before handler is called
- middleware []MiddlewareFunc
+// MiddlewareFunc defines a function to process middleware.
+type MiddlewareFunc func(next HandlerFunc) HandlerFunc
- // chain is the global middleware chain (e.middleware) compiled once and reused for every request.
- // It terminates in a dispatcher that invokes the route handler stored on the Context during routing.
- // Rebuilt by Use(). See buildRouterChains.
- chain HandlerFunc
- // preChain is the pre-middleware chain (e.premiddleware) compiled once. It performs routing and then
- // invokes chain. Rebuilt by Pre()/Use(). Only used when premiddleware is registered.
- preChain HandlerFunc
+// HandlerFunc defines a function to serve HTTP requests.
+type HandlerFunc func(c Context) error
- contextPathParamAllocSize atomic.Int32
+// HTTPErrorHandler is a centralized HTTP error handler.
+type HTTPErrorHandler func(err error, c Context)
- // formParseMaxMemory is passed to Context for multipart form parsing (See http.Request.ParseMultipartForm)
- formParseMaxMemory int64
+// Validator is the interface that wraps the Validate function.
+type Validator interface {
+ Validate(i any) error
}
// JSONSerializer is the interface that encodes and decodes JSON to and from interfaces.
type JSONSerializer interface {
- Serialize(c *Context, target any, indent string) error
- Deserialize(c *Context, target any) error
+ Serialize(c Context, i any, indent string) error
+ Deserialize(c Context, i any) error
}
-// HTTPErrorHandler is a centralized HTTP error handler.
-type HTTPErrorHandler func(c *Context, err error)
+// Map defines a generic map of type `map[string]any`.
+type Map map[string]any
-// HandlerFunc defines a function to serve HTTP requests.
-type HandlerFunc func(c *Context) error
-
-// MiddlewareFunc defines a function to process middleware.
-type MiddlewareFunc func(next HandlerFunc) HandlerFunc
+// Common struct for Echo & Group.
+type common struct{}
-// MiddlewareConfigurator defines interface for creating middleware handlers with possibility to return configuration errors instead of panicking.
-type MiddlewareConfigurator interface {
- ToMiddleware() (MiddlewareFunc, error)
-}
-
-// Validator is the interface that wraps the Validate function.
-type Validator interface {
- Validate(i any) error
-}
+// HTTP methods
+// NOTE: Deprecated, please use the stdlib constants directly instead.
+const (
+ CONNECT = http.MethodConnect
+ DELETE = http.MethodDelete
+ GET = http.MethodGet
+ HEAD = http.MethodHead
+ OPTIONS = http.MethodOptions
+ PATCH = http.MethodPatch
+ POST = http.MethodPost
+ // PROPFIND = "PROPFIND"
+ PUT = http.MethodPut
+ TRACE = http.MethodTrace
+)
// MIME types
const (
@@ -145,7 +169,7 @@ const (
// Deprecated: Please use MIMEApplicationJSON instead. JSON should be encoded using UTF-8 by default.
// No "charset" parameter is defined for this registration.
// Adding one really has no effect on compliant recipients.
- // See RFC 8259, section 8.1. https://datatracker.ietf.org/doc/html/rfc8259#section-8.1n"
+ // See RFC 8259, section 8.1. https://datatracker.ietf.org/doc/html/rfc8259#section-8.1
MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8
MIMEApplicationJavaScript = "application/javascript"
MIMEApplicationJavaScriptCharsetUTF8 = MIMEApplicationJavaScript + "; " + charsetUTF8
@@ -172,9 +196,6 @@ const (
REPORT = "REPORT"
// RouteNotFound is special method type for routes handling "route not found" (404) cases
RouteNotFound = "echo_route_not_found"
- // RouteAny is special method type that matches any HTTP method in request. Any has lower
- // priority that other methods that have been registered with Router to that path.
- RouteAny = "echo_route_any"
)
// Headers
@@ -235,7 +256,7 @@ const (
HeaderXFrameOptions = "X-Frame-Options"
HeaderContentSecurityPolicy = "Content-Security-Policy"
HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only"
- HeaderXCSRFToken = "X-CSRF-Token" // #nosec G101
+ HeaderXCSRFToken = "X-CSRF-Token"
HeaderReferrerPolicy = "Referrer-Policy"
// HeaderSecFetchSite fetch metadata request header indicates the relationship between a request initiator's
@@ -244,277 +265,273 @@ const (
HeaderSecFetchSite = "Sec-Fetch-Site"
)
-// Config is configuration for NewWithConfig function
-type Config struct {
- // Logger is the slog logger instance used for application-wide structured logging.
- // If not set, a default TextHandler writing to stdout is created.
- Logger *slog.Logger
-
- // HTTPErrorHandler is the centralized error handler that processes errors returned
- // by handlers and middleware, converting them to appropriate HTTP responses.
- // If not set, DefaultHTTPErrorHandler(false) is used.
- HTTPErrorHandler HTTPErrorHandler
-
- // Router is the HTTP request router responsible for matching URLs to handlers
- // using a radix tree-based algorithm.
- // If not set, NewRouter(RouterConfig{}) is used.
- Router Router
-
- // OnAddRoute is an optional callback hook executed when routes are registered.
- // Useful for route validation, logging, or custom route processing.
- // If not set, no callback is executed.
- OnAddRoute func(route Route) error
-
- // Filesystem is the fs.FS implementation used for serving static files.
- // Supports os.DirFS, embed.FS, and custom implementations.
- // If not set, defaults to current working directory.
- Filesystem fs.FS
-
- // Binder handles automatic data binding from HTTP requests to Go structs.
- // Supports JSON, XML, form data, query parameters, and path parameters.
- // If not set, DefaultBinder is used.
- Binder Binder
-
- // Validator provides optional struct validation after data binding.
- // Commonly used with third-party validation libraries.
- // If not set, Context.Validate() returns ErrValidatorNotRegistered.
- Validator Validator
-
- // Renderer provides template rendering for generating HTML responses.
- // Requires integration with a template engine like html/template.
- // If not set, Context.Render() returns ErrRendererNotRegistered.
- Renderer Renderer
-
- // JSONSerializer handles JSON encoding and decoding for HTTP requests/responses.
- // Can be replaced with faster alternatives like jsoniter or sonic.
- // If not set, DefaultJSONSerializer using encoding/json is used.
- JSONSerializer JSONSerializer
+const (
+ // Version of Echo
+ Version = "4.15.3"
+ website = "https://echo.labstack.com"
+ // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo
+ banner = `
+ ____ __
+ / __/___/ / ___
+ / _// __/ _ \/ _ \
+/___/\__/_//_/\___/ %s
+High performance, minimalist Go web framework
+%s
+____________________________________O/_______
+ O\
+`
+)
- // IPExtractor defines the strategy for extracting the real client IP address
- // from requests, particularly important when behind proxies or load balancers.
- // Used for rate limiting, access control, and logging.
- // If not set, falls back to checking X-Forwarded-For and X-Real-IP headers.
- IPExtractor IPExtractor
+var methods = [...]string{
+ http.MethodConnect,
+ http.MethodDelete,
+ http.MethodGet,
+ http.MethodHead,
+ http.MethodOptions,
+ http.MethodPatch,
+ http.MethodPost,
+ PROPFIND,
+ http.MethodPut,
+ http.MethodTrace,
+ REPORT,
+}
+
+// Errors
+var (
+ ErrBadRequest = NewHTTPError(http.StatusBadRequest) // HTTP 400 Bad Request
+ ErrUnauthorized = NewHTTPError(http.StatusUnauthorized) // HTTP 401 Unauthorized
+ ErrPaymentRequired = NewHTTPError(http.StatusPaymentRequired) // HTTP 402 Payment Required
+ ErrForbidden = NewHTTPError(http.StatusForbidden) // HTTP 403 Forbidden
+ ErrNotFound = NewHTTPError(http.StatusNotFound) // HTTP 404 Not Found
+ ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) // HTTP 405 Method Not Allowed
+ ErrNotAcceptable = NewHTTPError(http.StatusNotAcceptable) // HTTP 406 Not Acceptable
+ ErrProxyAuthRequired = NewHTTPError(http.StatusProxyAuthRequired) // HTTP 407 Proxy AuthRequired
+ ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout) // HTTP 408 Request Timeout
+ ErrConflict = NewHTTPError(http.StatusConflict) // HTTP 409 Conflict
+ ErrGone = NewHTTPError(http.StatusGone) // HTTP 410 Gone
+ ErrLengthRequired = NewHTTPError(http.StatusLengthRequired) // HTTP 411 Length Required
+ ErrPreconditionFailed = NewHTTPError(http.StatusPreconditionFailed) // HTTP 412 Precondition Failed
+ ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) // HTTP 413 Payload Too Large
+ ErrRequestURITooLong = NewHTTPError(http.StatusRequestURITooLong) // HTTP 414 URI Too Long
+ ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) // HTTP 415 Unsupported Media Type
+ ErrRequestedRangeNotSatisfiable = NewHTTPError(http.StatusRequestedRangeNotSatisfiable) // HTTP 416 Range Not Satisfiable
+ ErrExpectationFailed = NewHTTPError(http.StatusExpectationFailed) // HTTP 417 Expectation Failed
+ ErrTeapot = NewHTTPError(http.StatusTeapot) // HTTP 418 I'm a teapot
+ ErrMisdirectedRequest = NewHTTPError(http.StatusMisdirectedRequest) // HTTP 421 Misdirected Request
+ ErrUnprocessableEntity = NewHTTPError(http.StatusUnprocessableEntity) // HTTP 422 Unprocessable Entity
+ ErrLocked = NewHTTPError(http.StatusLocked) // HTTP 423 Locked
+ ErrFailedDependency = NewHTTPError(http.StatusFailedDependency) // HTTP 424 Failed Dependency
+ ErrTooEarly = NewHTTPError(http.StatusTooEarly) // HTTP 425 Too Early
+ ErrUpgradeRequired = NewHTTPError(http.StatusUpgradeRequired) // HTTP 426 Upgrade Required
+ ErrPreconditionRequired = NewHTTPError(http.StatusPreconditionRequired) // HTTP 428 Precondition Required
+ ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests) // HTTP 429 Too Many Requests
+ ErrRequestHeaderFieldsTooLarge = NewHTTPError(http.StatusRequestHeaderFieldsTooLarge) // HTTP 431 Request Header Fields Too Large
+ ErrUnavailableForLegalReasons = NewHTTPError(http.StatusUnavailableForLegalReasons) // HTTP 451 Unavailable For Legal Reasons
+ ErrInternalServerError = NewHTTPError(http.StatusInternalServerError) // HTTP 500 Internal Server Error
+ ErrNotImplemented = NewHTTPError(http.StatusNotImplemented) // HTTP 501 Not Implemented
+ ErrBadGateway = NewHTTPError(http.StatusBadGateway) // HTTP 502 Bad Gateway
+ ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable) // HTTP 503 Service Unavailable
+ ErrGatewayTimeout = NewHTTPError(http.StatusGatewayTimeout) // HTTP 504 Gateway Timeout
+ ErrHTTPVersionNotSupported = NewHTTPError(http.StatusHTTPVersionNotSupported) // HTTP 505 HTTP Version Not Supported
+ ErrVariantAlsoNegotiates = NewHTTPError(http.StatusVariantAlsoNegotiates) // HTTP 506 Variant Also Negotiates
+ ErrInsufficientStorage = NewHTTPError(http.StatusInsufficientStorage) // HTTP 507 Insufficient Storage
+ ErrLoopDetected = NewHTTPError(http.StatusLoopDetected) // HTTP 508 Loop Detected
+ ErrNotExtended = NewHTTPError(http.StatusNotExtended) // HTTP 510 Not Extended
+ ErrNetworkAuthenticationRequired = NewHTTPError(http.StatusNetworkAuthenticationRequired) // HTTP 511 Network Authentication Required
+
+ ErrValidatorNotRegistered = errors.New("validator not registered")
+ ErrRendererNotRegistered = errors.New("renderer not registered")
+ ErrInvalidRedirectCode = errors.New("invalid redirect status code")
+ ErrCookieNotFound = errors.New("cookie not found")
+ ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
+ ErrInvalidListenerNetwork = errors.New("invalid listener network")
+)
- // FormParseMaxMemory is default value for memory limit that is used
- // when parsing multipart forms (See (*http.Request).ParseMultipartForm)
- FormParseMaxMemory int64
+// NotFoundHandler is the handler that router uses in case there was no matching route found. Returns an error that results
+// HTTP 404 status code.
+var NotFoundHandler = func(c Context) error {
+ return ErrNotFound
}
-// NewWithConfig creates an instance of Echo with given configuration.
-func NewWithConfig(config Config) *Echo {
- e := New()
- if config.Logger != nil {
- e.Logger = config.Logger
- }
- if config.HTTPErrorHandler != nil {
- e.HTTPErrorHandler = config.HTTPErrorHandler
- }
- if config.Router != nil {
- e.router = config.Router
- }
- if config.OnAddRoute != nil {
- e.OnAddRoute = config.OnAddRoute
- }
- if config.Filesystem != nil {
- e.Filesystem = config.Filesystem
- }
- if config.Binder != nil {
- e.Binder = config.Binder
- }
- if config.Validator != nil {
- e.Validator = config.Validator
+// MethodNotAllowedHandler is the handler thar router uses in case there was no matching route found but there was
+// another matching routes for that requested URL. Returns an error that results HTTP 405 Method Not Allowed status code.
+var MethodNotAllowedHandler = func(c Context) error {
+ // See RFC 7231 section 7.4.1: An origin server MUST generate an Allow field in a 405 (Method Not Allowed)
+ // response and MAY do so in any other response. For disabled resources an empty Allow header may be returned
+ routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string)
+ if ok && routerAllowMethods != "" {
+ c.Response().Header().Set(HeaderAllow, routerAllowMethods)
}
- if config.Renderer != nil {
- e.Renderer = config.Renderer
- }
- if config.JSONSerializer != nil {
- e.JSONSerializer = config.JSONSerializer
- }
- if config.IPExtractor != nil {
- e.IPExtractor = config.IPExtractor
- }
- if config.FormParseMaxMemory > 0 {
- e.formParseMaxMemory = config.FormParseMaxMemory
- }
- return e
+ return ErrMethodNotAllowed
}
// New creates an instance of Echo.
-func New() *Echo {
- dir, _ := os.Getwd()
- logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
- e := &Echo{
- Logger: logger,
- Filesystem: NewDefaultFS(dir),
- Binder: &DefaultBinder{},
- JSONSerializer: &DefaultJSONSerializer{},
- formParseMaxMemory: defaultMemory,
- }
-
- e.serveHTTPFunc = e.serveHTTP
- e.router = NewRouter(RouterConfig{})
- e.HTTPErrorHandler = DefaultHTTPErrorHandler(false)
- e.contextPool.New = func() any {
- return newContext(nil, nil, e)
- }
- e.buildRouterChains()
- return e
+func New() (e *Echo) {
+ e = &Echo{
+ filesystem: createFilesystem(),
+ Server: new(http.Server),
+ TLSServer: new(http.Server),
+ AutoTLSManager: autocert.Manager{
+ Prompt: autocert.AcceptTOS,
+ },
+ Logger: log.New("echo"),
+ colorer: color.New(),
+ maxParam: new(int),
+ ListenerNetwork: "tcp",
+ }
+ e.Server.Handler = e
+ e.TLSServer.Handler = e
+ e.HTTPErrorHandler = e.DefaultHTTPErrorHandler
+ e.Binder = &DefaultBinder{}
+ e.JSONSerializer = &DefaultJSONSerializer{}
+ e.Logger.SetLevel(log.ERROR)
+ e.StdLogger = stdLog.New(e.Logger.Output(), e.Logger.Prefix()+": ", 0)
+ e.pool.New = func() any {
+ return e.NewContext(nil, nil)
+ }
+ e.router = NewRouter(e)
+ e.routers = map[string]*Router{}
+ return
}
-// buildRouterChains compiles the global and pre-middleware chains once so that ServeHTTP does not have to
-// re-wrap middleware closures on every request. It must be called whenever e.middleware or e.premiddleware
-// changes (i.e. from Use/Pre). This is safe because middleware must not be mutated after the server starts.
-func (e *Echo) buildRouterChains() {
- // dispatch is the terminal of the global chain: it invokes the handler resolved during routing.
- dispatch := func(c *Context) error {
- return c.handler(c)
- }
- e.chain = applyMiddleware(dispatch, e.middleware...)
-
- // route performs routing (storing the matched handler on the Context) and then runs the global chain.
- route := func(c *Context) error {
- c.handler = e.router.Route(c)
- return e.chain(c)
+// NewContext returns a Context instance.
+func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context {
+ return &context{
+ request: r,
+ response: NewResponse(w, e),
+ store: make(Map),
+ echo: e,
+ pvalues: make([]string, *e.maxParam),
+ handler: NotFoundHandler,
}
- e.preChain = applyMiddleware(route, e.premiddleware...)
-}
-
-// NewContext returns a new Context instance.
-//
-// Note: both request and response can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway
-// these arguments are useful when creating context for tests and cases like that.
-func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) *Context {
- return newContext(r, w, e)
}
// Router returns the default router.
-func (e *Echo) Router() Router {
+func (e *Echo) Router() *Router {
return e.router
}
-// DefaultHTTPErrorHandler creates new default HTTP error handler implementation. It sends a JSON response
-// with status code. `exposeError` parameter decides if returned message will contain also error message or not
+// Routers returns the map of host => router.
+func (e *Echo) Routers() map[string]*Router {
+ return e.routers
+}
+
+// DefaultHTTPErrorHandler is the default HTTP error handler. It sends a JSON response
+// with status code.
//
-// Note: DefaultHTTPErrorHandler does not log errors. Use middleware for it if errors need to be logged (separately)
-// Note: In case errors happens in middleware call-chain that is returning from handler (which did not return an error).
+// NOTE: In case errors happens in middleware call-chain that is returning from handler (which did not return an error).
// When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from
// handler. Then the error that global error handler received will be ignored because we have already "committed" the
// response and status code header has been sent to the client.
-func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler {
- return func(c *Context, err error) {
- if r, _ := UnwrapResponse(c.response); r != nil && r.Committed {
- return
- }
+func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) {
- code := http.StatusInternalServerError
- var sc HTTPStatusCoder
- if errors.As(err, &sc) {
- if tmp := sc.StatusCode(); tmp != 0 {
- code = tmp
- }
- }
+ if c.Response().Committed {
+ return
+ }
- var result any
- switch m := sc.(type) {
- case json.Marshaler: // this type knows how to format itself to JSON
- result = m
- case *HTTPError:
- sText := m.Message
- if sText == "" {
- sText = http.StatusText(code)
- }
- msg := map[string]any{"message": sText}
- if exposeError {
- if wrappedErr := m.Unwrap(); wrappedErr != nil {
- msg["error"] = wrappedErr.Error()
- }
- }
- result = msg
- default:
- msg := map[string]any{"message": http.StatusText(code)}
- if exposeError {
- msg["error"] = err.Error()
+ he, ok := err.(*HTTPError)
+ if ok {
+ if he.Internal != nil {
+ if herr, ok := he.Internal.(*HTTPError); ok {
+ he = herr
}
- result = msg
}
+ } else {
+ he = &HTTPError{
+ Code: http.StatusInternalServerError,
+ Message: http.StatusText(http.StatusInternalServerError),
+ }
+ }
+
+ // Issue #1426
+ code := he.Code
+ message := he.Message
- var cErr error
- if c.Request().Method == http.MethodHead { // Issue #608
- cErr = c.NoContent(code)
+ switch m := he.Message.(type) {
+ case string:
+ if e.Debug {
+ message = Map{"message": m, "error": err.Error()}
} else {
- cErr = c.JSON(code, result)
- }
- if cErr != nil {
- c.Logger().Error("echo default error handler failed to send error to client", "error", cErr) // truly rare case. ala client already disconnected
+ message = Map{"message": m}
}
+ case json.Marshaler:
+ // do nothing - this type knows how to format itself to JSON
+ case error:
+ message = Map{"message": m.Error()}
+ }
+
+ // Send response
+ if c.Request().Method == http.MethodHead { // Issue #608
+ err = c.NoContent(he.Code)
+ } else {
+ err = c.JSON(code, message)
+ }
+ if err != nil {
+ e.Logger.Error(err)
}
}
-// Pre adds middleware to the chain which is run before router tries to find matching route.
-// Meaning middleware is executed even for 404 (not found) cases.
+// Pre adds middleware to the chain which is run before router.
func (e *Echo) Pre(middleware ...MiddlewareFunc) {
e.premiddleware = append(e.premiddleware, middleware...)
- e.buildRouterChains()
}
-// Use adds middleware to the chain which is run after router has found matching route and before route/request handler method is executed.
+// Use adds middleware to the chain which is run after router.
func (e *Echo) Use(middleware ...MiddlewareFunc) {
e.middleware = append(e.middleware, middleware...)
- e.buildRouterChains()
}
// CONNECT registers a new CONNECT route for a path with matching handler in the
-// router with optional route-level middleware. Panics on error.
-func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// router with optional route-level middleware.
+func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return e.Add(http.MethodConnect, path, h, m...)
}
// DELETE registers a new DELETE route for a path with matching handler in the router
-// with optional route-level middleware. Panics on error.
-func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// with optional route-level middleware.
+func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return e.Add(http.MethodDelete, path, h, m...)
}
// GET registers a new GET route for a path with matching handler in the router
-// with optional route-level middleware. Panics on error.
-func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// with optional route-level middleware.
+func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return e.Add(http.MethodGet, path, h, m...)
}
// HEAD registers a new HEAD route for a path with matching handler in the
-// router with optional route-level middleware. Panics on error.
-func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// router with optional route-level middleware.
+func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return e.Add(http.MethodHead, path, h, m...)
}
// OPTIONS registers a new OPTIONS route for a path with matching handler in the
-// router with optional route-level middleware. Panics on error.
-func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// router with optional route-level middleware.
+func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return e.Add(http.MethodOptions, path, h, m...)
}
// PATCH registers a new PATCH route for a path with matching handler in the
-// router with optional route-level middleware. Panics on error.
-func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// router with optional route-level middleware.
+func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return e.Add(http.MethodPatch, path, h, m...)
}
// POST registers a new POST route for a path with matching handler in the
-// router with optional route-level middleware. Panics on error.
-func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// router with optional route-level middleware.
+func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return e.Add(http.MethodPost, path, h, m...)
}
// PUT registers a new PUT route for a path with matching handler in the
-// router with optional route-level middleware. Panics on error.
-func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// router with optional route-level middleware.
+func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return e.Add(http.MethodPut, path, h, m...)
}
// TRACE registers a new TRACE route for a path with matching handler in the
-// router with optional route-level middleware. Panics on error.
-func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// router with optional route-level middleware.
+func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return e.Add(http.MethodTrace, path, h, m...)
}
@@ -523,8 +540,8 @@ func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo
// Path supports static and named/any parameters just like other http method is defined. Generally path is ended with
// wildcard/match-any character (`/*`, `/download/*` etc).
//
-// Example: `e.RouteNotFound("/*", func(c *echo.Context) error { return c.NoContent(http.StatusNotFound) })`
-func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// Example: `e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })`
+func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return e.Add(RouteNotFound, path, h, m...)
}
@@ -533,274 +550,388 @@ func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) Ro
//
// Note: this method only adds specific set of supported HTTP methods as handler and is not true
// "catch-any-arbitrary-method" way of matching requests.
-func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
- return e.Add(RouteAny, path, handler, middleware...)
+func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
+ routes := make([]*Route, len(methods))
+ for i, m := range methods {
+ routes[i] = e.Add(m, path, handler, middleware...)
+ }
+ return routes
}
// Match registers a new route for multiple HTTP methods and path with matching
-// handler in the router with optional route-level middleware. Panics on error.
-func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
- errs := make([]error, 0)
- ris := make(Routes, 0)
- for _, m := range methods {
- ri, err := e.AddRoute(Route{
- Method: m,
- Path: path,
- Handler: handler,
- Middlewares: middleware,
- })
- if err != nil {
- errs = append(errs, err)
- continue
- }
- ris = append(ris, ri)
+// handler in the router with optional route-level middleware.
+func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
+ routes := make([]*Route, len(methods))
+ for i, m := range methods {
+ routes[i] = e.Add(m, path, handler, middleware...)
}
- if len(errs) > 0 {
- panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
+ return routes
+}
+
+func (common) file(path, file string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route,
+ m ...MiddlewareFunc) *Route {
+ return get(path, func(c Context) error {
+ return c.File(file)
+ }, m...)
+}
+
+// File registers a new route with path to serve a static file with optional route-level middleware.
+func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route {
+ return e.file(path, file, e.GET, m...)
+}
+
+func (e *Echo) add(host, method, path string, handler HandlerFunc, middlewares ...MiddlewareFunc) *Route {
+ router := e.findRouter(host)
+ //FIXME: when handler+middleware are both nil ... make it behave like handler removal
+ name := handlerName(handler)
+ route := router.add(method, path, name, func(c Context) error {
+ h := applyMiddleware(handler, middlewares...)
+ return h(c)
+ })
+
+ if e.OnAddRouteHandler != nil {
+ e.OnAddRouteHandler(host, *route, handler, middlewares)
}
- return ris
+
+ return route
}
-// Static registers a new route with path prefix to serve static files from the provided root directory.
-func (e *Echo) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo {
- subFs := MustSubFS(e.Filesystem, fsRoot)
- return e.Add(
- http.MethodGet,
- pathPrefix+"*",
- StaticDirectoryHandler(subFs, false),
- middleware...,
- )
+// Add registers a new route for an HTTP method and path with matching handler
+// in the router with optional route-level middleware.
+func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
+ return e.add("", method, path, handler, middleware...)
}
-// StaticFS registers a new route with path prefix to serve static files from the provided file system.
-//
-// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
-// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
-// including `assets/images` as their prefix.
-func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo {
- return e.Add(
- http.MethodGet,
- pathPrefix+"*",
- StaticDirectoryHandler(filesystem, false),
- middleware...,
- )
+// Host creates a new router group for the provided host and optional host-level middleware.
+func (e *Echo) Host(name string, m ...MiddlewareFunc) (g *Group) {
+ e.routers[name] = NewRouter(e)
+ g = &Group{host: name, echo: e}
+ g.Use(m...)
+ return
}
-// StaticDirectoryHandler creates handler function to serve files from provided file system
-// When disablePathUnescaping is set then file name from path is not unescaped and is served as is.
-func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc {
- return func(c *Context) error {
- p := c.Param("*")
- if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice
- // By default the router matches routes against the raw, still-encoded request path
- // (unless UseEscapedPathForMatching is enabled), so an encoded path separator is not
- // treated as a segment boundary during routing. Unescaping it here would let it act as
- // a separator and resolve a file outside the path the router authorized, bypassing
- // route-level middleware (e.g. auth on a sibling route). No real filename contains a
- // separator, so reject it as not found, carrying the reason internally for operators.
- if pathutil.HasEncodedPathSeparator(p) {
- return NewHTTPError(http.StatusNotFound, http.StatusText(http.StatusNotFound)).
- Wrap(fmt.Errorf("rejected encoded path separator in static path %q", p))
- }
- tmpPath, err := url.PathUnescape(p)
- if err != nil {
- return fmt.Errorf("failed to unescape path variable: %w", err)
- }
- p = tmpPath
- }
+// Group creates a new router group with prefix and optional group-level middleware.
+func (e *Echo) Group(prefix string, m ...MiddlewareFunc) (g *Group) {
+ g = &Group{prefix: prefix, echo: e}
+ g.Use(m...)
+ return
+}
- // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid.
- // Use path.Clean (not filepath.Clean): fs.FS paths are always forward-slash, so a backslash must stay a literal
- // character rather than being interpreted as a separator on Windows (which would resolve a file across a boundary
- // the router never matched on, the same Windows backslash traversal class as GHSA-pgvm-wxw2-hrv9).
- name := path.Clean(strings.TrimPrefix(p, "/"))
- fi, err := fs.Stat(fileSystem, name)
- if err != nil {
- return ErrNotFound
- }
+// URI generates an URI from handler.
+func (e *Echo) URI(handler HandlerFunc, params ...any) string {
+ name := handlerName(handler)
+ return e.Reverse(name, params...)
+}
- // If the request is for a directory and does not end with "/"
- p = c.Request().URL.Path // path must not be empty.
- if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' {
- // Redirect to ends with "/"
- return c.Redirect(http.StatusMovedPermanently, sanitizeURI(p+"/"))
- }
- return fsFile(c, name, fileSystem)
- }
+// URL is an alias for `URI` function.
+func (e *Echo) URL(h HandlerFunc, params ...any) string {
+ return e.URI(h, params...)
}
-// FileFS registers a new route with path to serve file from the provided file system.
-//
-// Avoid using the leading `/` slash as most of the Go standard library fs.FS implementations require relative paths for
-// file operations.
-func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo {
- return e.GET(path, StaticFileHandler(file, filesystem), m...)
+// Reverse generates a URL from route name and provided parameters.
+func (e *Echo) Reverse(name string, params ...any) string {
+ return e.router.Reverse(name, params...)
}
-// StaticFileHandler creates handler function to serve file from provided file system.
-//
-// Avoid using the leading `/` slash as most of the Go standard library fs.FS implementations require relative paths for
-// file operations.
-func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc {
- return func(c *Context) error {
- return fsFile(c, file, filesystem)
- }
+// Routes returns the registered routes for default router.
+// In case when Echo serves multiple hosts/domains use `e.Routers()["domain2.site"].Routes()` to get specific host routes.
+func (e *Echo) Routes() []*Route {
+ return e.router.Routes()
}
-// File registers a new route with path to serve a static file with optional route-level middleware. Panics on error.
-//
-// Avoid using the leading `/` slash as most of the Go standard library fs.FS implementations require relative paths for
-// file operations.
-func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo {
- handler := func(c *Context) error {
- return c.File(file)
- }
- return e.Add(http.MethodGet, path, handler, middleware...)
+// AcquireContext returns an empty `Context` instance from the pool.
+// You must return the context by calling `ReleaseContext()`.
+func (e *Echo) AcquireContext() Context {
+ return e.pool.Get().(Context)
}
-// AddRoute registers a new Route with default host Router
-func (e *Echo) AddRoute(route Route) (RouteInfo, error) {
- return e.add(route)
+// ReleaseContext returns the `Context` instance back to the pool.
+// You must call it after `AcquireContext()`.
+func (e *Echo) ReleaseContext(c Context) {
+ e.pool.Put(c)
}
-func (e *Echo) add(route Route) (RouteInfo, error) {
- if e.OnAddRoute != nil {
- if err := e.OnAddRoute(route); err != nil {
- return RouteInfo{}, err
+// ServeHTTP implements `http.Handler` interface, which serves HTTP requests.
+func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ // Acquire context
+ c := e.pool.Get().(*context)
+ c.Reset(r, w)
+ var h HandlerFunc
+
+ if e.premiddleware == nil {
+ e.findRouter(r.Host).Find(r.Method, GetPath(r), c)
+ h = c.Handler()
+ h = applyMiddleware(h, e.middleware...)
+ } else {
+ h = func(c Context) error {
+ e.findRouter(r.Host).Find(r.Method, GetPath(r), c)
+ h := c.Handler()
+ h = applyMiddleware(h, e.middleware...)
+ return h(c)
}
+ h = applyMiddleware(h, e.premiddleware...)
}
- ri, err := e.router.Add(route)
- if err != nil {
- return RouteInfo{}, err
+ // Execute chain
+ if err := h(c); err != nil {
+ e.HTTPErrorHandler(err, c)
}
- paramsCount := int32(len(ri.Parameters)) // #nosec G115
- if paramsCount > e.contextPathParamAllocSize.Load() {
- e.contextPathParamAllocSize.Store(paramsCount)
+ // Release context
+ e.pool.Put(c)
+}
+
+// Start starts an HTTP server.
+func (e *Echo) Start(address string) error {
+ e.startupMutex.Lock()
+ e.Server.Addr = address
+ if err := e.configureServer(e.Server); err != nil {
+ e.startupMutex.Unlock()
+ return err
}
- return ri, nil
+ e.startupMutex.Unlock()
+ return e.Server.Serve(e.Listener)
}
-// Add registers a new route for an HTTP method and path with matching handler
-// in the router with optional route-level middleware.
-func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
- ri, err := e.add(
- Route{
- Method: method,
- Path: path,
- Handler: handler,
- Middlewares: middleware,
- Name: "",
- },
- )
- if err != nil {
- panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
+// StartTLS starts an HTTPS server.
+// If `certFile` or `keyFile` is `string` the values are treated as file paths.
+// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is.
+func (e *Echo) StartTLS(address string, certFile, keyFile any) (err error) {
+ e.startupMutex.Lock()
+ var cert []byte
+ if cert, err = filepathOrContent(certFile); err != nil {
+ e.startupMutex.Unlock()
+ return
}
- return ri
+
+ var key []byte
+ if key, err = filepathOrContent(keyFile); err != nil {
+ e.startupMutex.Unlock()
+ return
+ }
+
+ s := e.TLSServer
+ s.TLSConfig = new(tls.Config)
+ s.TLSConfig.Certificates = make([]tls.Certificate, 1)
+ if s.TLSConfig.Certificates[0], err = tls.X509KeyPair(cert, key); err != nil {
+ e.startupMutex.Unlock()
+ return
+ }
+
+ e.configureTLS(address)
+ if err := e.configureServer(s); err != nil {
+ e.startupMutex.Unlock()
+ return err
+ }
+ e.startupMutex.Unlock()
+ return s.Serve(e.TLSListener)
}
-// Group creates a new router group with prefix and optional group-level middleware.
-func (e *Echo) Group(prefix string, m ...MiddlewareFunc) (g *Group) {
- g = &Group{prefix: prefix, echo: e}
- g.Use(m...)
- return
+func filepathOrContent(fileOrContent any) (content []byte, err error) {
+ switch v := fileOrContent.(type) {
+ case string:
+ return os.ReadFile(v)
+ case []byte:
+ return v, nil
+ default:
+ return nil, ErrInvalidCertOrKeyType
+ }
}
-// PreMiddlewares returns registered pre middlewares. These are middleware to the chain
-// which are run before router tries to find matching route.
-// Use this method to build your own ServeHTTP method.
-//
-// NOTE: returned slice is not a copy. Do not mutate.
-func (e *Echo) PreMiddlewares() []MiddlewareFunc {
- return e.premiddleware
+// StartAutoTLS starts an HTTPS server using certificates automatically installed from https://letsencrypt.org.
+func (e *Echo) StartAutoTLS(address string) error {
+ e.startupMutex.Lock()
+ s := e.TLSServer
+ s.TLSConfig = new(tls.Config)
+ s.TLSConfig.GetCertificate = e.AutoTLSManager.GetCertificate
+ s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, acme.ALPNProto)
+
+ e.configureTLS(address)
+ if err := e.configureServer(s); err != nil {
+ e.startupMutex.Unlock()
+ return err
+ }
+ e.startupMutex.Unlock()
+ return s.Serve(e.TLSListener)
}
-// Middlewares returns registered route level middlewares. Does not contain any group level
-// middlewares. Use this method to build your own ServeHTTP method.
-//
-// NOTE: returned slice is not a copy. Do not mutate.
-func (e *Echo) Middlewares() []MiddlewareFunc {
- return e.middleware
+func (e *Echo) configureTLS(address string) {
+ s := e.TLSServer
+ s.Addr = address
+ if !e.DisableHTTP2 {
+ s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2")
+ }
}
-// AcquireContext returns an empty `Context` instance from the pool.
-// You must return the context by calling `ReleaseContext()`.
-func (e *Echo) AcquireContext() *Context {
- return e.contextPool.Get().(*Context)
+// StartServer starts a custom http server.
+func (e *Echo) StartServer(s *http.Server) (err error) {
+ e.startupMutex.Lock()
+ if err := e.configureServer(s); err != nil {
+ e.startupMutex.Unlock()
+ return err
+ }
+ if s.TLSConfig != nil {
+ e.startupMutex.Unlock()
+ return s.Serve(e.TLSListener)
+ }
+ e.startupMutex.Unlock()
+ return s.Serve(e.Listener)
}
-// ReleaseContext returns the `Context` instance back to the pool.
-// You must call it after `AcquireContext()`.
-func (e *Echo) ReleaseContext(c *Context) {
- e.contextPool.Put(c)
+func (e *Echo) configureServer(s *http.Server) error {
+ // Setup
+ e.colorer.SetOutput(e.Logger.Output())
+ s.ErrorLog = e.StdLogger
+ s.Handler = e
+ if e.Debug {
+ e.Logger.SetLevel(log.DEBUG)
+ }
+
+ if !e.HideBanner {
+ e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website))
+ }
+
+ if s.TLSConfig == nil {
+ if e.Listener == nil {
+ l, err := newListener(s.Addr, e.ListenerNetwork)
+ if err != nil {
+ return err
+ }
+ e.Listener = l
+ }
+ if !e.HidePort {
+ e.colorer.Printf("β¨ http server started on %s\n", e.colorer.Green(e.Listener.Addr()))
+ }
+ return nil
+ }
+ if e.TLSListener == nil {
+ l, err := newListener(s.Addr, e.ListenerNetwork)
+ if err != nil {
+ return err
+ }
+ e.TLSListener = tls.NewListener(l, s.TLSConfig)
+ }
+ if !e.HidePort {
+ e.colorer.Printf("β¨ https server started on %s\n", e.colorer.Green(e.TLSListener.Addr()))
+ }
+ return nil
}
-// ServeHTTP implements `http.Handler` interface, which serves HTTP requests.
-func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- e.serveHTTPFunc(w, r)
+// ListenerAddr returns net.Addr for Listener
+func (e *Echo) ListenerAddr() net.Addr {
+ e.startupMutex.RLock()
+ defer e.startupMutex.RUnlock()
+ if e.Listener == nil {
+ return nil
+ }
+ return e.Listener.Addr()
}
-// serveHTTP implements `http.Handler` interface, which serves HTTP requests.
-func (e *Echo) serveHTTP(w http.ResponseWriter, r *http.Request) {
- c := e.contextPool.Get().(*Context)
- defer e.contextPool.Put(c)
+// TLSListenerAddr returns net.Addr for TLSListener
+func (e *Echo) TLSListenerAddr() net.Addr {
+ e.startupMutex.RLock()
+ defer e.startupMutex.RUnlock()
+ if e.TLSListener == nil {
+ return nil
+ }
+ return e.TLSListener.Addr()
+}
- c.Reset(r, w)
+// StartH2CServer starts a custom http/2 server with h2c (HTTP/2 Cleartext).
+func (e *Echo) StartH2CServer(address string, h2s *http2.Server) error {
+ e.startupMutex.Lock()
+ // Setup
+ s := e.Server
+ s.Addr = address
+ e.colorer.SetOutput(e.Logger.Output())
+ s.ErrorLog = e.StdLogger
+ s.Handler = h2c.NewHandler(e, h2s)
+ if e.Debug {
+ e.Logger.SetLevel(log.DEBUG)
+ }
- // The global (e.chain) and pre-middleware (e.preChain) chains are compiled once in buildRouterChains and
- // reused here, so no middleware closures are allocated per request.
- var err error
- if e.premiddleware == nil {
- c.handler = e.router.Route(c)
- err = e.chain(c)
- } else {
- err = e.preChain(c)
+ if !e.HideBanner {
+ e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website))
}
- if err != nil {
- e.HTTPErrorHandler(c, err)
+ if e.Listener == nil {
+ l, err := newListener(s.Addr, e.ListenerNetwork)
+ if err != nil {
+ e.startupMutex.Unlock()
+ return err
+ }
+ e.Listener = l
}
+ if !e.HidePort {
+ e.colorer.Printf("β¨ http server started on %s\n", e.colorer.Green(e.Listener.Addr()))
+ }
+ e.startupMutex.Unlock()
+ return s.Serve(e.Listener)
}
-// Start stars HTTP server on given address with Echo as a handler serving requests. The server can be shutdown by
-// sending os.Interrupt signal with `ctrl+c`. Method returns only errors that are not http.ErrServerClosed.
-//
-// Note: this method is created for use in examples/demos and is deliberately simple without providing configuration
-// options.
-//
-// In need of customization use:
-//
-// ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
-// defer cancel()
-// sc := echo.StartConfig{Address: ":8080"}
-// if err := sc.Start(ctx, e); err != nil && !errors.Is(err, http.ErrServerClosed) {
-// slog.Error(err.Error())
-// }
-//
-// // or standard library `http.Server`
-//
-// s := http.Server{Addr: ":8080", Handler: e}
-// if err := s.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
-// slog.Error(err.Error())
-// }
-func (e *Echo) Start(address string) error {
- sc := StartConfig{Address: address}
- ctx, cancel := signal.NotifyContext(stdContext.Background(), os.Interrupt, syscall.SIGTERM) // start shutdown process on ctrl+c
- defer cancel()
- return sc.Start(ctx, e)
+// Close immediately stops the server.
+// It internally calls `http.Server#Close()`.
+func (e *Echo) Close() error {
+ e.startupMutex.Lock()
+ defer e.startupMutex.Unlock()
+ if err := e.TLSServer.Close(); err != nil {
+ return err
+ }
+ return e.Server.Close()
+}
+
+// Shutdown stops the server gracefully.
+// It internally calls `http.Server#Shutdown()`.
+func (e *Echo) Shutdown(ctx stdContext.Context) error {
+ e.startupMutex.Lock()
+ defer e.startupMutex.Unlock()
+ if err := e.TLSServer.Shutdown(ctx); err != nil {
+ return err
+ }
+ return e.Server.Shutdown(ctx)
+}
+
+// NewHTTPError creates a new HTTPError instance.
+func NewHTTPError(code int, message ...any) *HTTPError {
+ he := &HTTPError{Code: code, Message: http.StatusText(code)}
+ if len(message) > 0 {
+ he.Message = message[0]
+ }
+ return he
+}
+
+// Error makes it compatible with `error` interface.
+func (he *HTTPError) Error() string {
+ if he.Internal == nil {
+ return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message)
+ }
+ return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal)
+}
+
+// SetInternal sets error to HTTPError.Internal
+func (he *HTTPError) SetInternal(err error) *HTTPError {
+ he.Internal = err
+ return he
+}
+
+// WithInternal returns clone of HTTPError with err set to HTTPError.Internal field
+func (he *HTTPError) WithInternal(err error) *HTTPError {
+ return &HTTPError{
+ Code: he.Code,
+ Message: he.Message,
+ Internal: err,
+ }
+}
+
+// Unwrap satisfies the Go 1.13 error wrapper interface.
+func (he *HTTPError) Unwrap() error {
+ return he.Internal
}
// WrapHandler wraps `http.Handler` into `echo.HandlerFunc`.
func WrapHandler(h http.Handler) HandlerFunc {
- return func(c *Context) error {
- req := c.Request()
- req.Pattern = c.Path()
- for _, p := range c.PathValues() {
- req.SetPathValue(p.Name, p.Value)
- }
-
- h.ServeHTTP(c.Response(), req)
+ return func(c Context) error {
+ h.ServeHTTP(c.Response(), c.Request())
return nil
}
}
@@ -808,101 +939,85 @@ func WrapHandler(h http.Handler) HandlerFunc {
// WrapMiddleware wraps `func(http.Handler) http.Handler` into `echo.MiddlewareFunc`
func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc {
return func(next HandlerFunc) HandlerFunc {
- return func(c *Context) (err error) {
- req := c.Request()
- req.Pattern = c.Path()
- for _, p := range c.PathValues() {
- req.SetPathValue(p.Name, p.Value)
- }
-
+ return func(c Context) (err error) {
m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c.SetRequest(r)
- c.SetResponse(NewResponse(w, c.echo.Logger))
+ c.SetResponse(NewResponse(w, c.Echo()))
err = next(c)
- })).ServeHTTP(c.Response(), req)
+ })).ServeHTTP(c.Response(), c.Request())
return
}
}
}
-func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc {
- for i := len(middleware) - 1; i >= 0; i-- {
- h = middleware[i](h)
+// GetPath returns RawPath, if it's empty returns Path from URL
+// Difference between RawPath and Path is:
+// - Path is where request path is stored. Value is stored in decoded form: /%47%6f%2f becomes /Go/.
+// - RawPath is an optional field which only gets set if the default encoding is different from Path.
+func GetPath(r *http.Request) string {
+ path := r.URL.RawPath
+ if path == "" {
+ path = r.URL.Path
}
- return h
+ return path
}
-// defaultFS emulates os.Open behavior with filesystem opened by `os.DirFs`. Difference between `os.Open` and `fs.Open`
-// is that FS does not allow to open path that start with `..` or `/` etc. For example previously you could have `../images`
-// in your application but `fs := os.DirFS("./")` would not allow you to use `fs.Open("../images")` and this would break
-// all old applications that rely on being able to traverse up from current executable run path.
-// NB: private because you really should use fs.FS implementation instances
-type defaultFS struct {
- fs fs.FS
- prefix string
+func (e *Echo) findRouter(host string) *Router {
+ if len(e.routers) > 0 {
+ if r, ok := e.routers[host]; ok {
+ return r
+ }
+ }
+ return e.router
}
-// NewDefaultFS returns a new defaultFS instance which allows `fs.FS.Open` to have absolute paths as input if it matches
-// then given dir as prefix.
-func NewDefaultFS(dir string) fs.FS {
- return &defaultFS{
- prefix: dir,
- fs: os.DirFS(dir),
+func handlerName(h HandlerFunc) string {
+ t := reflect.ValueOf(h).Type()
+ if t.Kind() == reflect.Func {
+ return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name()
}
+ return t.String()
}
-func (fs defaultFS) Open(name string) (fs.File, error) {
- // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid
- // For example `f.Name()` returns file names as absolute paths (e.g. `/tmp/data.csv`) so in case user wants to open
- // a file with an absolute path we need to remove prefix and then call fs.FS.Open().
- // not to force users to cut prefix from file name we do it here.
- if filepath.IsAbs(name) {
- if strings.HasPrefix(name, fs.prefix) {
- name = name[len(fs.prefix):]
- if len(name) > 1 && os.IsPathSeparator(name[0]) {
- name = name[1:]
- }
- }
- }
- return fs.fs.Open(name)
+// // PathUnescape is wraps `url.PathUnescape`
+// func PathUnescape(s string) (string, error) {
+// return url.PathUnescape(s)
+// }
+
+// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
+// connections. It's used by ListenAndServe and ListenAndServeTLS so
+// dead TCP connections (e.g. closing laptop mid-download) eventually
+// go away.
+type tcpKeepAliveListener struct {
+ *net.TCPListener
}
-func subFS(currentFs fs.FS, root string) (fs.FS, error) {
- root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows
- if dFS, ok := currentFs.(*defaultFS); ok {
- // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS.
- // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we
- // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs
- if !filepath.IsAbs(root) {
- root = filepath.Join(dFS.prefix, root)
- }
- return &defaultFS{
- prefix: root,
- fs: os.DirFS(root),
- }, nil
+func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
+ if c, err = ln.AcceptTCP(); err != nil {
+ return
+ } else if err = c.(*net.TCPConn).SetKeepAlive(true); err != nil {
+ return
}
- return fs.Sub(currentFs, root)
+ // Ignore error from setting the KeepAlivePeriod as some systems, such as
+ // OpenBSD, do not support setting TCP_USER_TIMEOUT on IPPROTO_TCP
+ _ = c.(*net.TCPConn).SetKeepAlivePeriod(3 * time.Minute)
+ return
}
-// MustSubFS creates sub FS from current filesystem or panic on failure.
-// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules.
-//
-// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with
-// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to
-// create sub fs which uses necessary prefix for directory path.
-func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS {
- subFs, err := subFS(currentFs, fsRoot)
+func newListener(address, network string) (*tcpKeepAliveListener, error) {
+ if network != "tcp" && network != "tcp4" && network != "tcp6" {
+ return nil, ErrInvalidListenerNetwork
+ }
+ l, err := net.Listen(network, address)
if err != nil {
- panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err))
+ return nil, err
}
- return subFs
+ return &tcpKeepAliveListener{l.(*net.TCPListener)}, nil
}
-func sanitizeURI(uri string) string {
- // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri
- // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash
- if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') {
- uri = "/" + strings.TrimLeft(uri, `/\`)
+func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc {
+ for i := len(middleware) - 1; i >= 0; i-- {
+ h = middleware[i](h)
}
- return uri
+ return h
}
diff --git a/echo_fs.go b/echo_fs.go
new file mode 100644
index 000000000..f111095f1
--- /dev/null
+++ b/echo_fs.go
@@ -0,0 +1,176 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "fmt"
+ "io/fs"
+ "net/http"
+ "net/url"
+ "os"
+ "path"
+ "path/filepath"
+ "strings"
+
+ "github.com/labstack/echo/v4/internal/pathutil"
+)
+
+type filesystem struct {
+ // Filesystem is file system used by Static and File handlers to access files.
+ // Defaults to os.DirFS(".")
+ //
+ // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
+ // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
+ // including `assets/images` as their prefix.
+ Filesystem fs.FS
+}
+
+func createFilesystem() filesystem {
+ return filesystem{
+ Filesystem: newDefaultFS(),
+ }
+}
+
+// Static registers a new route with path prefix to serve static files from the provided root directory.
+func (e *Echo) Static(pathPrefix, fsRoot string) *Route {
+ subFs := MustSubFS(e.Filesystem, fsRoot)
+ return e.Add(
+ http.MethodGet,
+ pathPrefix+"*",
+ StaticDirectoryHandler(subFs, false),
+ )
+}
+
+// StaticFS registers a new route with path prefix to serve static files from the provided file system.
+//
+// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
+// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
+// including `assets/images` as their prefix.
+func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route {
+ return e.Add(
+ http.MethodGet,
+ pathPrefix+"*",
+ StaticDirectoryHandler(filesystem, false),
+ )
+}
+
+// StaticDirectoryHandler creates handler function to serve files from provided file system
+// When disablePathUnescaping is set then file name from path is not unescaped and is served as is.
+func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc {
+ return func(c Context) error {
+ p := c.Param("*")
+ if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice
+ // The router matches routes against the raw, still-encoded request path, so an
+ // encoded path separator (%2F or %5C) is not treated as a segment boundary during
+ // routing. Unescaping it here would let it act as a separator and resolve a file
+ // outside the path the router authorized, bypassing route-level middleware (e.g. auth
+ // on a sibling route). No real filename contains a separator, so reject it as not found.
+ if pathutil.HasEncodedPathSeparator(p) {
+ return ErrNotFound
+ }
+ tmpPath, err := url.PathUnescape(p)
+ if err != nil {
+ return fmt.Errorf("failed to unescape path variable: %w", err)
+ }
+ p = tmpPath
+ }
+
+ // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid.
+ // Use path.Clean (not filepath.Clean): fs.FS paths are always forward-slash, so a backslash must stay a literal
+ // character rather than being interpreted as a separator on Windows (which would resolve a file across a boundary
+ // the router never matched on).
+ name := path.Clean(strings.TrimPrefix(p, "/"))
+ fi, err := fs.Stat(fileSystem, name)
+ if err != nil {
+ return ErrNotFound
+ }
+
+ // If the request is for a directory and does not end with "/"
+ p = c.Request().URL.Path // path must not be empty.
+ if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' {
+ // Redirect to ends with "/"
+ return c.Redirect(http.StatusMovedPermanently, sanitizeURI(p+"/"))
+ }
+ return fsFile(c, name, fileSystem)
+ }
+}
+
+// FileFS registers a new route with path to serve file from the provided file system.
+func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route {
+ return e.GET(path, StaticFileHandler(file, filesystem), m...)
+}
+
+// StaticFileHandler creates handler function to serve file from provided file system
+func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc {
+ return func(c Context) error {
+ return fsFile(c, file, filesystem)
+ }
+}
+
+// defaultFS exists to preserve pre v4.7.0 behaviour where files were open by `os.Open`.
+// v4.7 introduced `echo.Filesystem` field which is Go1.16+ `fs.Fs` interface.
+// Difference between `os.Open` and `fs.Open` is that FS does not allow opening path that start with `.`, `..` or `/`
+// etc. For example previously you could have `../images` in your application but `fs := os.DirFS("./")` would not
+// allow you to use `fs.Open("../images")` and this would break all old applications that rely on being able to
+// traverse up from current executable run path.
+// NB: private because you really should use fs.FS implementation instances
+type defaultFS struct {
+ fs fs.FS
+ prefix string
+}
+
+func newDefaultFS() *defaultFS {
+ dir, _ := os.Getwd()
+ return &defaultFS{
+ prefix: dir,
+ fs: nil,
+ }
+}
+
+func (fs defaultFS) Open(name string) (fs.File, error) {
+ if fs.fs == nil {
+ return os.Open(name)
+ }
+ return fs.fs.Open(name)
+}
+
+func subFS(currentFs fs.FS, root string) (fs.FS, error) {
+ root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows
+ if dFS, ok := currentFs.(*defaultFS); ok {
+ // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS.
+ // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we
+ // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs
+ if !filepath.IsAbs(root) {
+ root = filepath.Join(dFS.prefix, root)
+ }
+ return &defaultFS{
+ prefix: root,
+ fs: os.DirFS(root),
+ }, nil
+ }
+ return fs.Sub(currentFs, root)
+}
+
+// MustSubFS creates sub FS from current filesystem or panic on failure.
+// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules.
+//
+// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with
+// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to
+// create sub fs which uses necessary prefix for directory path.
+func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS {
+ subFs, err := subFS(currentFs, fsRoot)
+ if err != nil {
+ panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err))
+ }
+ return subFs
+}
+
+func sanitizeURI(uri string) string {
+ // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri
+ // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash
+ if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') {
+ uri = "/" + strings.TrimLeft(uri, `/\`)
+ }
+ return uri
+}
diff --git a/static_encoded_separator_test.go b/echo_fs_encoded_separator_test.go
similarity index 81%
rename from static_encoded_separator_test.go
rename to echo_fs_encoded_separator_test.go
index a10eafabc..16752e41d 100644
--- a/static_encoded_separator_test.go
+++ b/echo_fs_encoded_separator_test.go
@@ -1,3 +1,6 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
+
package echo
import (
@@ -9,8 +12,8 @@ import (
"github.com/stretchr/testify/assert"
)
-// Regression for GHSA-vfp3-v2gw-7wfq: an encoded slash (%2F) must not let a static
-// file request resolve across a path separator and bypass route-level middleware.
+// Regression for GHSA-vfp3-v2gw-7wfq (v4 backport): an encoded path separator (%2F or %5C)
+// must not let a static file request resolve across a separator and bypass route-level middleware.
func TestStaticDirectoryHandler_EncodedSeparatorDoesNotBypassRoute(t *testing.T) {
fsys := fstest.MapFS{
"admin/secret.txt": {Data: []byte("TOP-SECRET")},
@@ -18,9 +21,9 @@ func TestStaticDirectoryHandler_EncodedSeparatorDoesNotBypassRoute(t *testing.T)
}
e := New()
g := e.Group("/admin", func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error { return c.String(http.StatusForbidden, "denied") }
+ return func(c Context) error { return c.String(http.StatusForbidden, "denied") }
})
- g.GET("/*", func(c *Context) error { return c.String(http.StatusOK, "reached-protected-handler") })
+ g.GET("/*", func(c Context) error { return c.String(http.StatusOK, "reached-protected-handler") })
e.StaticFS("/", fsys)
cases := []struct {
@@ -31,7 +34,7 @@ func TestStaticDirectoryHandler_EncodedSeparatorDoesNotBypassRoute(t *testing.T)
{"/admin/secret.txt", http.StatusForbidden, "denied"}, // protected route fires
{"/admin%2Fsecret.txt", http.StatusNotFound, ""}, // encoded slash rejected, no disclosure
{"/admin%2fsecret.txt", http.StatusNotFound, ""}, // lower-case hex variant
- {"/admin%5Csecret.txt", http.StatusNotFound, ""}, // encoded backslash variant
+ {"/admin%5Csecret.txt", http.StatusNotFound, ""}, // encoded backslash (Windows separator) neutralized by path.Clean
{"/admin%252Fsecret.txt", http.StatusNotFound, ""}, // double-encoded: single unescape -> literal filename, not a separator
{"/index.html", http.StatusOK, "public"}, // legitimate static file still served
}
diff --git a/echo_fs_test.go b/echo_fs_test.go
new file mode 100644
index 000000000..75f32dfb0
--- /dev/null
+++ b/echo_fs_test.go
@@ -0,0 +1,274 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "github.com/stretchr/testify/assert"
+ "io/fs"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "testing"
+)
+
+func TestEcho_StaticFS(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenPrefix string
+ givenFs fs.FS
+ givenFsRoot string
+ whenURL string
+ expectStatus int
+ expectHeaderLocation string
+ expectBodyStartsWith string
+ }{
+ {
+ name: "ok",
+ givenPrefix: "/images",
+ givenFs: os.DirFS("./_fixture/images"),
+ whenURL: "/images/walle.png",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
+ },
+ {
+ name: "ok, from sub fs",
+ givenPrefix: "/images",
+ givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"),
+ whenURL: "/images/walle.png",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
+ },
+ {
+ name: "No file",
+ givenPrefix: "/images",
+ givenFs: os.DirFS("_fixture/scripts"),
+ whenURL: "/images/bolt.png",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Directory",
+ givenPrefix: "/images",
+ givenFs: os.DirFS("_fixture/images"),
+ whenURL: "/images/",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Directory Redirect",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture/"),
+ whenURL: "/folder",
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/folder/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Directory Redirect with non-root path",
+ givenPrefix: "/static",
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/static",
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/static/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory 404 (request URL without slash)",
+ givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/folder", // no trailing slash
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Prefixed directory redirect (without slash redirect to slash)",
+ givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/folder", // no trailing slash
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/folder/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Directory with index.html",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory with index.html (prefix ending with slash)",
+ givenPrefix: "/assets/",
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/assets/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory with index.html (prefix ending without slash)",
+ givenPrefix: "/assets",
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/assets/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Sub-directory with index.html",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/folder/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "do not allow directory traversal (backslash - windows separator)",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture/"),
+ whenURL: `/..\\middleware/basic_auth.go`,
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "do not allow directory traversal (slash - unix separator)",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture/"),
+ whenURL: `/../middleware/basic_auth.go`,
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ // An encoded slash (%2f) is rejected outright (GHSA-vfp3-v2gw-7wfq): the router matches
+ // on the raw path so %2f is not a separator, and unescaping it here would let it act as
+ // one. No redirect is emitted, closing the open-redirect vector.
+ name: "encoded slash is rejected, not redirected",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture/"),
+ whenURL: "/open.redirect.hackercom%2f..",
+ expectStatus: http.StatusNotFound,
+ expectHeaderLocation: "",
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+
+ tmpFs := tc.givenFs
+ if tc.givenFsRoot != "" {
+ tmpFs = MustSubFS(tmpFs, tc.givenFsRoot)
+ }
+ e.StaticFS(tc.givenPrefix, tmpFs)
+
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ body := rec.Body.String()
+ if tc.expectBodyStartsWith != "" {
+ assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
+ } else {
+ assert.Equal(t, "", body)
+ }
+
+ if tc.expectHeaderLocation != "" {
+ assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
+ } else {
+ _, ok := rec.Result().Header["Location"]
+ assert.False(t, ok)
+ }
+ })
+ }
+}
+
+func TestEcho_FileFS(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenPath string
+ whenFile string
+ whenFS fs.FS
+ givenURL string
+ expectCode int
+ expectStartsWith []byte
+ }{
+ {
+ name: "ok",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "walle.png",
+ givenURL: "/walle",
+ expectCode: http.StatusOK,
+ expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ },
+ {
+ name: "nok, requesting invalid path",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "walle.png",
+ givenURL: "/walle.png",
+ expectCode: http.StatusNotFound,
+ expectStartsWith: []byte(`{"message":"Not Found"}`),
+ },
+ {
+ name: "nok, serving not existent file from filesystem",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "not-existent.png",
+ givenURL: "/walle",
+ expectCode: http.StatusNotFound,
+ expectStartsWith: []byte(`{"message":"Not Found"}`),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS)
+
+ req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectCode, rec.Code)
+
+ body := rec.Body.Bytes()
+ if len(body) > len(tc.expectStartsWith) {
+ body = body[:len(tc.expectStartsWith)]
+ }
+ assert.Equal(t, tc.expectStartsWith, body)
+ })
+ }
+}
+
+func TestEcho_StaticPanic(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenRoot string
+ }{
+ {
+ name: "panics for ../",
+ givenRoot: "../assets",
+ },
+ {
+ name: "panics for /",
+ givenRoot: "/assets",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ e.Filesystem = os.DirFS("./")
+
+ assert.Panics(t, func() {
+ e.Static("../assets", tc.givenRoot)
+ })
+ })
+ }
+}
diff --git a/echo_test.go b/echo_test.go
index 0af0b96ec..08cc7162b 100644
--- a/echo_test.go
+++ b/echo_test.go
@@ -6,23 +6,23 @@ package echo
import (
"bytes"
stdContext "context"
+ "crypto/tls"
"errors"
"fmt"
"io"
- "io/fs"
- "log/slog"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
- "path/filepath"
- "runtime"
+ "reflect"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "golang.org/x/net/http2"
)
type user struct {
@@ -62,99 +62,33 @@ func TestEcho(t *testing.T) {
// Router
assert.NotNil(t, e.Router())
- e.HTTPErrorHandler(c, errors.New("error"))
-
+ // DefaultHTTPErrorHandler
+ e.DefaultHTTPErrorHandler(errors.New("error"), c)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
-func TestNewWithConfig(t *testing.T) {
- e := NewWithConfig(Config{})
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
-
- e.GET("/", func(c *Context) error {
- return c.String(http.StatusTeapot, "Hello, World!")
- })
- e.ServeHTTP(rec, req)
-
- assert.Equal(t, http.StatusTeapot, rec.Code)
- assert.Equal(t, `Hello, World!`, rec.Body.String())
-}
-
-func TestNewDefaultFS(t *testing.T) {
- tempDir := t.TempDir()
- filename := filepath.Join(tempDir, "file.txt")
- if err := os.WriteFile(filename, []byte("hello"), 0644); err != nil {
- t.Fatalf("failed to write file: %v", err)
- }
-
- var testCases = []struct {
- name string
- givenDir string
- whenName string
- expectedError string
- }{
- {
- name: "ok, can open absolute path",
- givenDir: tempDir,
- whenName: filename,
- },
- {
- name: "ok, can open path to fs",
- givenDir: tempDir,
- whenName: "file.txt",
- },
- {
- name: "nok, can not use ./ in path",
- givenDir: tempDir,
- whenName: "./file.txt",
- expectedError: `open ./file.txt: invalid argument`,
- },
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- myFs := NewDefaultFS(tc.givenDir)
-
- f, err := myFs.Open(tc.whenName)
- if tc.expectedError != "" {
- assert.EqualError(t, err, tc.expectedError)
- return
- }
- if err != nil {
- t.Fatalf("failed to read file: %v", err)
- }
- defer f.Close()
-
- contents, err := io.ReadAll(f)
- assert.NoError(t, err)
- assert.Equal(t, []byte("hello"), contents)
- })
- }
-}
-
-func TestEcho_StaticFS(t *testing.T) {
+func TestEchoStatic(t *testing.T) {
var testCases = []struct {
- givenFs fs.FS
name string
givenPrefix string
- givenFsRoot string
+ givenRoot string
whenURL string
+ expectStatus int
expectHeaderLocation string
expectBodyStartsWith string
- expectStatus int
}{
{
name: "ok",
givenPrefix: "/images",
- givenFs: os.DirFS("./_fixture/images"),
+ givenRoot: "_fixture/images",
whenURL: "/images/walle.png",
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
{
- name: "ok, from sub fs",
+ name: "ok with relative path for root points to directory",
givenPrefix: "/images",
- givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"),
+ givenRoot: "./_fixture/images",
whenURL: "/images/walle.png",
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
@@ -162,7 +96,7 @@ func TestEcho_StaticFS(t *testing.T) {
{
name: "No file",
givenPrefix: "/images",
- givenFs: os.DirFS("_fixture/scripts"),
+ givenRoot: "_fixture/scripts",
whenURL: "/images/bolt.png",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@@ -170,7 +104,7 @@ func TestEcho_StaticFS(t *testing.T) {
{
name: "Directory",
givenPrefix: "/images",
- givenFs: os.DirFS("_fixture/images"),
+ givenRoot: "_fixture/images",
whenURL: "/images/",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@@ -178,7 +112,7 @@ func TestEcho_StaticFS(t *testing.T) {
{
name: "Directory Redirect",
givenPrefix: "/",
- givenFs: os.DirFS("_fixture/"),
+ givenRoot: "_fixture",
whenURL: "/folder",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/folder/",
@@ -187,7 +121,7 @@ func TestEcho_StaticFS(t *testing.T) {
{
name: "Directory Redirect with non-root path",
givenPrefix: "/static",
- givenFs: os.DirFS("_fixture"),
+ givenRoot: "_fixture",
whenURL: "/static",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/static/",
@@ -196,7 +130,7 @@ func TestEcho_StaticFS(t *testing.T) {
{
name: "Prefixed directory 404 (request URL without slash)",
givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
- givenFs: os.DirFS("_fixture"),
+ givenRoot: "_fixture",
whenURL: "/folder", // no trailing slash
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@@ -204,7 +138,7 @@ func TestEcho_StaticFS(t *testing.T) {
{
name: "Prefixed directory redirect (without slash redirect to slash)",
givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
- givenFs: os.DirFS("_fixture"),
+ givenRoot: "_fixture",
whenURL: "/folder", // no trailing slash
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/folder/",
@@ -213,7 +147,7 @@ func TestEcho_StaticFS(t *testing.T) {
{
name: "Directory with index.html",
givenPrefix: "/",
- givenFs: os.DirFS("_fixture"),
+ givenRoot: "_fixture",
whenURL: "/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "",
@@ -221,7 +155,7 @@ func TestEcho_StaticFS(t *testing.T) {
{
name: "Prefixed directory with index.html (prefix ending with slash)",
givenPrefix: "/assets/",
- givenFs: os.DirFS("_fixture"),
+ givenRoot: "_fixture",
whenURL: "/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "",
@@ -229,7 +163,7 @@ func TestEcho_StaticFS(t *testing.T) {
{
name: "Prefixed directory with index.html (prefix ending without slash)",
givenPrefix: "/assets",
- givenFs: os.DirFS("_fixture"),
+ givenRoot: "_fixture",
whenURL: "/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "",
@@ -237,7 +171,7 @@ func TestEcho_StaticFS(t *testing.T) {
{
name: "Sub-directory with index.html",
givenPrefix: "/",
- givenFs: os.DirFS("_fixture"),
+ givenRoot: "_fixture",
whenURL: "/folder/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "",
@@ -245,7 +179,7 @@ func TestEcho_StaticFS(t *testing.T) {
{
name: "do not allow directory traversal (backslash - windows separator)",
givenPrefix: "/",
- givenFs: os.DirFS("_fixture/"),
+ givenRoot: "_fixture/",
whenURL: `/..\\middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@@ -253,40 +187,20 @@ func TestEcho_StaticFS(t *testing.T) {
{
name: "do not allow directory traversal (slash - unix separator)",
givenPrefix: "/",
- givenFs: os.DirFS("_fixture/"),
+ givenRoot: "_fixture/",
whenURL: `/../middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
- {
- // An encoded slash (%2f) is rejected outright (GHSA-vfp3-v2gw-7wfq): by default the
- // router matches on the raw path so %2f is not a separator, and unescaping it here
- // would let it act as one. No redirect is emitted, closing the open-redirect vector.
- name: "encoded slash is rejected, not redirected",
- givenPrefix: "/",
- givenFs: os.DirFS("_fixture/"),
- whenURL: "/open.redirect.hackercom%2f..",
- expectStatus: http.StatusNotFound,
- expectHeaderLocation: "",
- expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
- },
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
-
- tmpFs := tc.givenFs
- if tc.givenFsRoot != "" {
- tmpFs = MustSubFS(tmpFs, tc.givenFsRoot)
- }
- e.StaticFS(tc.givenPrefix, tmpFs)
-
+ e.Static(tc.givenPrefix, tc.givenRoot)
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
-
e.ServeHTTP(rec, req)
-
assert.Equal(t, tc.expectStatus, rec.Code)
body := rec.Body.String()
if tc.expectBodyStartsWith != "" {
@@ -305,114 +219,44 @@ func TestEcho_StaticFS(t *testing.T) {
}
}
-func TestEcho_FileFS(t *testing.T) {
- var testCases = []struct {
- whenFS fs.FS
- name string
- whenPath string
- whenFile string
- givenURL string
- expectStartsWith []byte
- expectCode int
- }{
- {
- name: "ok",
- whenPath: "/walle",
- whenFS: os.DirFS("_fixture/images"),
- whenFile: "walle.png",
- givenURL: "/walle",
- expectCode: http.StatusOK,
- expectStartsWith: []byte{0x89, 0x50, 0x4e},
- },
- {
- name: "nok, requesting invalid path",
- whenPath: "/walle",
- whenFS: os.DirFS("_fixture/images"),
- whenFile: "walle.png",
- givenURL: "/walle.png",
- expectCode: http.StatusNotFound,
- expectStartsWith: []byte(`{"message":"Not Found"}`),
- },
- {
- name: "nok, serving not existent file from filesystem",
- whenPath: "/walle",
- whenFS: os.DirFS("_fixture/images"),
- whenFile: "not-existent.png",
- givenURL: "/walle",
- expectCode: http.StatusNotFound,
- expectStartsWith: []byte(`{"message":"Not Found"}`),
- },
- }
+func TestEchoStaticRedirectIndex(t *testing.T) {
+ e := New()
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := New()
- e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS)
+ // HandlerFunc
+ e.Static("/static", "_fixture")
- req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
- rec := httptest.NewRecorder()
+ errCh := make(chan error)
- e.ServeHTTP(rec, req)
+ go func() {
+ errCh <- e.Start(":0")
+ }()
- assert.Equal(t, tc.expectCode, rec.Code)
+ err := waitForServerStart(e, errCh, false)
+ assert.NoError(t, err)
- body := rec.Body.Bytes()
- if len(body) > len(tc.expectStartsWith) {
- body = body[:len(tc.expectStartsWith)]
+ addr := e.ListenerAddr().String()
+ if resp, err := http.Get("http://" + addr + "/static"); err == nil { // http.Get follows redirects by default
+ defer func(Body io.ReadCloser) {
+ err := Body.Close()
+ if err != nil {
+ assert.Fail(t, err.Error())
}
- assert.Equal(t, tc.expectStartsWith, body)
- })
- }
-}
+ }(resp.Body)
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
-func TestEcho_StaticPanic(t *testing.T) {
- var testCases = []struct {
- name string
- givenRoot string
- }{
- {
- name: "panics for ../",
- givenRoot: "../assets",
- },
- {
- name: "panics for /",
- givenRoot: "/assets",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := New()
- e.Filesystem = os.DirFS("./")
+ if body, err := io.ReadAll(resp.Body); err == nil {
+ assert.Equal(t, true, strings.HasPrefix(string(body), ""))
+ } else {
+ assert.Fail(t, err.Error())
+ }
- assert.Panics(t, func() {
- e.Static("../assets", tc.givenRoot)
- })
- })
+ } else {
+ assert.NoError(t, err)
}
-}
-
-func TestEchoStaticRedirectIndex(t *testing.T) {
- e := New()
-
- // HandlerFunc
- ri := e.Static("/static", "_fixture")
- assert.Equal(t, http.MethodGet, ri.Method)
- assert.Equal(t, "/static*", ri.Path)
- assert.Equal(t, "GET:/static*", ri.Name)
- assert.Equal(t, []string{"*"}, ri.Parameters)
- ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
- defer cancel()
- addr, err := startOnRandomPort(ctx, e)
- if err != nil {
- assert.Fail(t, err.Error())
+ if err := e.Close(); err != nil {
+ t.Fatal(err)
}
-
- code, body, err := doGet(fmt.Sprintf("http://%v/static", addr))
- assert.NoError(t, err)
- assert.True(t, strings.HasPrefix(body, ""))
- assert.Equal(t, http.StatusOK, code)
}
func TestEchoFile(t *testing.T) {
@@ -421,8 +265,8 @@ func TestEchoFile(t *testing.T) {
givenPath string
givenFile string
whenPath string
- expectStartsWith string
expectCode int
+ expectStartsWith string
}{
{
name: "ok",
@@ -471,37 +315,36 @@ func TestEchoMiddleware(t *testing.T) {
buf := new(bytes.Buffer)
e.Pre(func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
- // before route match is found RouteInfo does not exist
- assert.Equal(t, RouteInfo{}, c.RouteInfo())
+ return func(c Context) error {
+ assert.Empty(t, c.Path())
buf.WriteString("-1")
return next(c)
}
})
e.Use(func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
+ return func(c Context) error {
buf.WriteString("1")
return next(c)
}
})
e.Use(func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
+ return func(c Context) error {
buf.WriteString("2")
return next(c)
}
})
e.Use(func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
+ return func(c Context) error {
buf.WriteString("3")
return next(c)
}
})
// Route
- e.GET("/", func(c *Context) error {
+ e.GET("/", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
@@ -514,11 +357,11 @@ func TestEchoMiddleware(t *testing.T) {
func TestEchoMiddlewareError(t *testing.T) {
e := New()
e.Use(func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
+ return func(c Context) error {
return errors.New("error")
}
})
- e.GET("/", notFoundHandler)
+ e.GET("/", NotFoundHandler)
c, _ := request(http.MethodGet, "/", e)
assert.Equal(t, http.StatusInternalServerError, c)
}
@@ -527,7 +370,7 @@ func TestEchoHandler(t *testing.T) {
e := New()
// HandlerFunc
- e.GET("/ok", func(c *Context) error {
+ e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
@@ -538,256 +381,230 @@ func TestEchoHandler(t *testing.T) {
func TestEchoWrapHandler(t *testing.T) {
e := New()
-
- var actualID string
- var actualPattern string
- e.GET("/:id", WrapHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- w.Write([]byte("test"))
- actualID = r.PathValue("id")
- actualPattern = r.Pattern
- })))
-
- req := httptest.NewRequest(http.MethodGet, "/123", nil)
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
- e.ServeHTTP(rec, req)
-
- assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, "test", rec.Body.String())
- assert.Equal(t, "123", actualID)
- assert.Equal(t, "/:id", actualPattern)
+ c := e.NewContext(req, rec)
+ h := WrapHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, err := w.Write([]byte("test"))
+ if err != nil {
+ assert.Fail(t, err.Error())
+ }
+ }))
+ if assert.NoError(t, h(c)) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "test", rec.Body.String())
+ }
}
func TestEchoWrapMiddleware(t *testing.T) {
e := New()
-
- var actualID string
- var actualPattern string
- e.Use(WrapMiddleware(func(h http.Handler) http.Handler {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ buf := new(bytes.Buffer)
+ mw := WrapMiddleware(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- actualID = r.PathValue("id")
- actualPattern = r.Pattern
+ buf.Write([]byte("mw"))
h.ServeHTTP(w, r)
})
- }))
-
- e.GET("/:id", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
})
-
- req := httptest.NewRequest(http.MethodGet, "/123", nil)
- rec := httptest.NewRecorder()
- e.ServeHTTP(rec, req)
-
- assert.Equal(t, http.StatusTeapot, rec.Code)
- assert.Equal(t, "OK", rec.Body.String())
- assert.Equal(t, "123", actualID)
- assert.Equal(t, "/:id", actualPattern)
+ h := mw(func(c Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+ if assert.NoError(t, h(c)) {
+ assert.Equal(t, "mw", buf.String())
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "OK", rec.Body.String())
+ }
}
func TestEchoConnect(t *testing.T) {
e := New()
-
- ri := e.CONNECT("/", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
-
- assert.Equal(t, http.MethodConnect, ri.Method)
- assert.Equal(t, "/", ri.Path)
- assert.Equal(t, http.MethodConnect+":/", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodConnect, "/", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, "OK", body)
+ testMethod(t, http.MethodConnect, "/", e)
}
func TestEchoDelete(t *testing.T) {
e := New()
-
- ri := e.DELETE("/", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
-
- assert.Equal(t, http.MethodDelete, ri.Method)
- assert.Equal(t, "/", ri.Path)
- assert.Equal(t, http.MethodDelete+":/", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodDelete, "/", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, "OK", body)
+ testMethod(t, http.MethodDelete, "/", e)
}
func TestEchoGet(t *testing.T) {
e := New()
-
- ri := e.GET("/", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
-
- assert.Equal(t, http.MethodGet, ri.Method)
- assert.Equal(t, "/", ri.Path)
- assert.Equal(t, http.MethodGet+":/", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodGet, "/", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, "OK", body)
+ testMethod(t, http.MethodGet, "/", e)
}
func TestEchoHead(t *testing.T) {
e := New()
-
- ri := e.HEAD("/", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
-
- assert.Equal(t, http.MethodHead, ri.Method)
- assert.Equal(t, "/", ri.Path)
- assert.Equal(t, http.MethodHead+":/", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodHead, "/", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, "OK", body)
+ testMethod(t, http.MethodHead, "/", e)
}
func TestEchoOptions(t *testing.T) {
e := New()
-
- ri := e.OPTIONS("/", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
-
- assert.Equal(t, http.MethodOptions, ri.Method)
- assert.Equal(t, "/", ri.Path)
- assert.Equal(t, http.MethodOptions+":/", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodOptions, "/", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, "OK", body)
+ testMethod(t, http.MethodOptions, "/", e)
}
func TestEchoPatch(t *testing.T) {
e := New()
-
- ri := e.PATCH("/", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
-
- assert.Equal(t, http.MethodPatch, ri.Method)
- assert.Equal(t, "/", ri.Path)
- assert.Equal(t, http.MethodPatch+":/", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodPatch, "/", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, "OK", body)
+ testMethod(t, http.MethodPatch, "/", e)
}
func TestEchoPost(t *testing.T) {
e := New()
-
- ri := e.POST("/", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
-
- assert.Equal(t, http.MethodPost, ri.Method)
- assert.Equal(t, "/", ri.Path)
- assert.Equal(t, http.MethodPost+":/", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodPost, "/", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, "OK", body)
+ testMethod(t, http.MethodPost, "/", e)
}
func TestEchoPut(t *testing.T) {
e := New()
-
- ri := e.PUT("/", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
-
- assert.Equal(t, http.MethodPut, ri.Method)
- assert.Equal(t, "/", ri.Path)
- assert.Equal(t, http.MethodPut+":/", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodPut, "/", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, "OK", body)
+ testMethod(t, http.MethodPut, "/", e)
}
func TestEchoTrace(t *testing.T) {
e := New()
+ testMethod(t, http.MethodTrace, "/", e)
+}
- ri := e.TRACE("/", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
+func TestEchoAny(t *testing.T) { // JFC
+ e := New()
+ e.Any("/", func(c Context) error {
+ return c.String(http.StatusOK, "Any")
})
-
- assert.Equal(t, http.MethodTrace, ri.Method)
- assert.Equal(t, "/", ri.Path)
- assert.Equal(t, http.MethodTrace+":/", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodTrace, "/", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, "OK", body)
}
-func TestEcho_Any(t *testing.T) {
+func TestEchoMatch(t *testing.T) { // JFC
e := New()
-
- ri := e.Any("/activate", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK from ANY")
+ e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c Context) error {
+ return c.String(http.StatusOK, "Match")
})
+}
- assert.Equal(t, RouteAny, ri.Method)
- assert.Equal(t, "/activate", ri.Path)
- assert.Equal(t, RouteAny+":/activate", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodTrace, "/activate", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, `OK from ANY`, body)
+func TestEchoURL(t *testing.T) {
+ e := New()
+ static := func(Context) error { return nil }
+ getUser := func(Context) error { return nil }
+ getAny := func(Context) error { return nil }
+ getFile := func(Context) error { return nil }
+
+ e.GET("/static/file", static)
+ e.GET("/users/:id", getUser)
+ e.GET("/documents/*", getAny)
+ g := e.Group("/group")
+ g.GET("/users/:uid/files/:fid", getFile)
+
+ assert.Equal(t, "/static/file", e.URL(static))
+ assert.Equal(t, "/users/:id", e.URL(getUser))
+ assert.Equal(t, "/users/1", e.URL(getUser, "1"))
+ assert.Equal(t, "/users/1", e.URL(getUser, "1"))
+ assert.Equal(t, "/documents/foo.txt", e.URL(getAny, "foo.txt"))
+ assert.Equal(t, "/documents/*", e.URL(getAny))
+ assert.Equal(t, "/group/users/1/files/:fid", e.URL(getFile, "1"))
+ assert.Equal(t, "/group/users/1/files/1", e.URL(getFile, "1", "1"))
}
-func TestEcho_Any_hasLowerPriority(t *testing.T) {
+func TestEchoRoutes(t *testing.T) {
e := New()
+ routes := []*Route{
+ {http.MethodGet, "/users/:user/events", ""},
+ {http.MethodGet, "/users/:user/events/public", ""},
+ {http.MethodPost, "/repos/:owner/:repo/git/refs", ""},
+ {http.MethodPost, "/repos/:owner/:repo/git/tags", ""},
+ }
+ for _, r := range routes {
+ e.Add(r.Method, r.Path, func(c Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+ }
- e.Any("/activate", func(c *Context) error {
- return c.String(http.StatusTeapot, "ANY")
- })
- e.GET("/activate", func(c *Context) error {
- return c.String(http.StatusLocked, "GET")
+ if assert.Equal(t, len(routes), len(e.Routes())) {
+ for _, r := range e.Routes() {
+ found := false
+ for _, rr := range routes {
+ if r.Method == rr.Method && r.Path == rr.Path {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("Route %s %s not found", r.Method, r.Path)
+ }
+ }
+ }
+}
+
+func TestEchoRoutesHandleAdditionalHosts(t *testing.T) {
+ e := New()
+ domain2Router := e.Host("domain2.router.com")
+ routes := []*Route{
+ {http.MethodGet, "/users/:user/events", ""},
+ {http.MethodGet, "/users/:user/events/public", ""},
+ {http.MethodPost, "/repos/:owner/:repo/git/refs", ""},
+ {http.MethodPost, "/repos/:owner/:repo/git/tags", ""},
+ }
+ for _, r := range routes {
+ domain2Router.Add(r.Method, r.Path, func(c Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+ }
+ e.Add(http.MethodGet, "/api", func(c Context) error {
+ return c.String(http.StatusOK, "OK")
})
- status, body := request(http.MethodTrace, "/activate", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, `ANY`, body)
+ domain2Routes := e.Routers()["domain2.router.com"].Routes()
- status, body = request(http.MethodGet, "/activate", e)
- assert.Equal(t, http.StatusLocked, status)
- assert.Equal(t, `GET`, body)
+ assert.Len(t, domain2Routes, len(routes))
+ for _, r := range domain2Routes {
+ found := false
+ for _, rr := range routes {
+ if r.Method == rr.Method && r.Path == rr.Path {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("Route %s %s not found", r.Method, r.Path)
+ }
+ }
}
-func TestEchoMatch(t *testing.T) { // JFC
+func TestEchoRoutesHandleDefaultHost(t *testing.T) {
e := New()
- ris := e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c *Context) error {
- return c.String(http.StatusOK, "Match")
+ routes := []*Route{
+ {http.MethodGet, "/users/:user/events", ""},
+ {http.MethodGet, "/users/:user/events/public", ""},
+ {http.MethodPost, "/repos/:owner/:repo/git/refs", ""},
+ {http.MethodPost, "/repos/:owner/:repo/git/tags", ""},
+ }
+ for _, r := range routes {
+ e.Add(r.Method, r.Path, func(c Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+ }
+ e.Host("subdomain.mysite.site").Add(http.MethodGet, "/api", func(c Context) error {
+ return c.String(http.StatusOK, "OK")
})
- assert.Len(t, ris, 2)
+
+ defaultRouterRoutes := e.Routes()
+ assert.Len(t, defaultRouterRoutes, len(routes))
+ for _, r := range defaultRouterRoutes {
+ found := false
+ for _, rr := range routes {
+ if r.Method == rr.Method && r.Path == rr.Path {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("Route %s %s not found", r.Method, r.Path)
+ }
+ }
}
func TestEchoServeHTTPPathEncoding(t *testing.T) {
e := New()
- e.GET("/with/slash", func(c *Context) error {
+ e.GET("/with/slash", func(c Context) error {
return c.String(http.StatusOK, "/with/slash")
})
- e.GET("/:id", func(c *Context) error {
+ e.GET("/:id", func(c Context) error {
return c.String(http.StatusOK, c.Param("id"))
})
@@ -824,16 +641,117 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) {
}
}
+func TestEchoHost(t *testing.T) {
+ okHandler := func(c Context) error { return c.String(http.StatusOK, http.StatusText(http.StatusOK)) }
+ teapotHandler := func(c Context) error { return c.String(http.StatusTeapot, http.StatusText(http.StatusTeapot)) }
+ acceptHandler := func(c Context) error { return c.String(http.StatusAccepted, http.StatusText(http.StatusAccepted)) }
+ teapotMiddleware := MiddlewareFunc(func(next HandlerFunc) HandlerFunc { return teapotHandler })
+
+ e := New()
+ e.GET("/", acceptHandler)
+ e.GET("/foo", acceptHandler)
+
+ ok := e.Host("ok.com")
+ ok.GET("/", okHandler)
+ ok.GET("/foo", okHandler)
+
+ teapot := e.Host("teapot.com")
+ teapot.GET("/", teapotHandler)
+ teapot.GET("/foo", teapotHandler)
+
+ middle := e.Host("middleware.com", teapotMiddleware)
+ middle.GET("/", okHandler)
+ middle.GET("/foo", okHandler)
+
+ var testCases = []struct {
+ name string
+ whenHost string
+ whenPath string
+ expectBody string
+ expectStatus int
+ }{
+ {
+ name: "No Host Root",
+ whenHost: "",
+ whenPath: "/",
+ expectBody: http.StatusText(http.StatusAccepted),
+ expectStatus: http.StatusAccepted,
+ },
+ {
+ name: "No Host Foo",
+ whenHost: "",
+ whenPath: "/foo",
+ expectBody: http.StatusText(http.StatusAccepted),
+ expectStatus: http.StatusAccepted,
+ },
+ {
+ name: "OK Host Root",
+ whenHost: "ok.com",
+ whenPath: "/",
+ expectBody: http.StatusText(http.StatusOK),
+ expectStatus: http.StatusOK,
+ },
+ {
+ name: "OK Host Foo",
+ whenHost: "ok.com",
+ whenPath: "/foo",
+ expectBody: http.StatusText(http.StatusOK),
+ expectStatus: http.StatusOK,
+ },
+ {
+ name: "Teapot Host Root",
+ whenHost: "teapot.com",
+ whenPath: "/",
+ expectBody: http.StatusText(http.StatusTeapot),
+ expectStatus: http.StatusTeapot,
+ },
+ {
+ name: "Teapot Host Foo",
+ whenHost: "teapot.com",
+ whenPath: "/foo",
+ expectBody: http.StatusText(http.StatusTeapot),
+ expectStatus: http.StatusTeapot,
+ },
+ {
+ name: "Middleware Host",
+ whenHost: "middleware.com",
+ whenPath: "/",
+ expectBody: http.StatusText(http.StatusTeapot),
+ expectStatus: http.StatusTeapot,
+ },
+ {
+ name: "Middleware Host Foo",
+ whenHost: "middleware.com",
+ whenPath: "/foo",
+ expectBody: http.StatusText(http.StatusTeapot),
+ expectStatus: http.StatusTeapot,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, tc.whenPath, nil)
+ req.Host = tc.whenHost
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ assert.Equal(t, tc.expectBody, rec.Body.String())
+ })
+ }
+}
+
func TestEchoGroup(t *testing.T) {
e := New()
buf := new(bytes.Buffer)
e.Use(MiddlewareFunc(func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
+ return func(c Context) error {
buf.WriteString("0")
return next(c)
}
}))
- h := func(c *Context) error {
+ h := func(c Context) error {
return c.NoContent(http.StatusOK)
}
@@ -846,7 +764,7 @@ func TestEchoGroup(t *testing.T) {
// Group
g1 := e.Group("/group1")
g1.Use(func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
+ return func(c Context) error {
buf.WriteString("1")
return next(c)
}
@@ -856,14 +774,14 @@ func TestEchoGroup(t *testing.T) {
// Nested groups with middleware
g2 := e.Group("/group2")
g2.Use(func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
+ return func(c Context) error {
buf.WriteString("2")
return next(c)
}
})
g3 := g2.Group("/group3")
g3.Use(func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
+ return func(c Context) error {
buf.WriteString("3")
return next(c)
}
@@ -882,12 +800,20 @@ func TestEchoGroup(t *testing.T) {
assert.Equal(t, "023", buf.String())
}
-func TestEcho_RouteNotFound(t *testing.T) {
- var testCases = []struct {
- expectRoute any
- name string
- whenURL string
- expectCode int
+func TestEchoNotFound(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/files", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestEcho_RouteNotFound(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectRoute any
+ expectCode int
}{
{
name: "404, route to static not found handler /a/c/xx",
@@ -919,10 +845,10 @@ func TestEcho_RouteNotFound(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
e := New()
- okHandler := func(c *Context) error {
+ okHandler := func(c Context) error {
return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
}
- notFoundHandler := func(c *Context) error {
+ notFoundHandler := func(c Context) error {
return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path())
}
@@ -946,18 +872,10 @@ func TestEcho_RouteNotFound(t *testing.T) {
}
}
-func TestEchoNotFound(t *testing.T) {
- e := New()
- req := httptest.NewRequest(http.MethodGet, "/files", nil)
- rec := httptest.NewRecorder()
- e.ServeHTTP(rec, req)
- assert.Equal(t, http.StatusNotFound, rec.Code)
-}
-
func TestEchoMethodNotAllowed(t *testing.T) {
e := New()
- e.GET("/", func(c *Context) error {
+ e.GET("/", func(c Context) error {
return c.String(http.StatusOK, "Echo!")
})
req := httptest.NewRequest(http.MethodPost, "/", nil)
@@ -968,133 +886,348 @@ func TestEchoMethodNotAllowed(t *testing.T) {
assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow))
}
-func TestEcho_OnAddRoute(t *testing.T) {
- exampleRoute := Route{
- Method: http.MethodGet,
- Path: "/api/files/:id",
- Handler: notFoundHandler,
- Middlewares: nil,
- Name: "x",
+func TestEchoContext(t *testing.T) {
+ e := New()
+ c := e.AcquireContext()
+ assert.IsType(t, new(context), c)
+ e.ReleaseContext(c)
+}
+
+func waitForServerStart(e *Echo, errChan <-chan error, isTLS bool) error {
+ ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
+ defer cancel()
+
+ ticker := time.NewTicker(5 * time.Millisecond)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-ticker.C:
+ var addr net.Addr
+ if isTLS {
+ addr = e.TLSListenerAddr()
+ } else {
+ addr = e.ListenerAddr()
+ }
+ if addr != nil && strings.Contains(addr.String(), ":") {
+ return nil // was started
+ }
+ case err := <-errChan:
+ if err == http.ErrServerClosed {
+ return nil
+ }
+ return err
+ }
}
+}
+
+func TestEchoStart(t *testing.T) {
+ e := New()
+ errChan := make(chan error)
+
+ go func() {
+ err := e.Start(":0")
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ err := waitForServerStart(e, errChan, false)
+ assert.NoError(t, err)
+
+ assert.NoError(t, e.Close())
+}
+func TestEcho_StartTLS(t *testing.T) {
var testCases = []struct {
- whenRoute Route
- whenError error
name string
+ addr string
+ certFile string
+ keyFile string
expectError string
- expectAdded []string
- expectLen int
}{
{
- name: "ok",
- whenRoute: exampleRoute,
- whenError: nil,
- expectAdded: []string{"/static", "/api/files/:id"},
- expectError: "",
- expectLen: 2,
+ name: "ok",
+ addr: ":0",
},
{
- name: "nok, error is returned",
- whenRoute: exampleRoute,
- whenError: errors.New("nope"),
- expectAdded: []string{"/static"},
- expectError: "nope",
- expectLen: 1,
+ name: "nok, invalid certFile",
+ addr: ":0",
+ certFile: "not existing",
+ expectError: "open not existing: no such file or directory",
+ },
+ {
+ name: "nok, invalid keyFile",
+ addr: ":0",
+ keyFile: "not existing",
+ expectError: "open not existing: no such file or directory",
+ },
+ {
+ name: "nok, failed to create cert out of certFile and keyFile",
+ addr: ":0",
+ keyFile: "_fixture/certs/cert.pem", // we are passing cert instead of key
+ expectError: "tls: found a certificate rather than a key in the PEM for the private key",
+ },
+ {
+ name: "nok, invalid tls address",
+ addr: "nope",
+ expectError: "listen tcp: address nope: missing port in address",
},
}
+
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
-
e := New()
+ errChan := make(chan error)
- added := make([]string, 0)
- cnt := 0
- e.OnAddRoute = func(route Route) error {
- if cnt > 0 && tc.whenError != nil { // we want to GET /static to succeed for nok tests
- return tc.whenError
+ go func() {
+ certFile := "_fixture/certs/cert.pem"
+ if tc.certFile != "" {
+ certFile = tc.certFile
+ }
+ keyFile := "_fixture/certs/key.pem"
+ if tc.keyFile != "" {
+ keyFile = tc.keyFile
}
- cnt++
- added = append(added, route.Path)
- return nil
- }
-
- e.GET("/static", notFoundHandler)
- var err error
- _, err = e.AddRoute(tc.whenRoute)
+ err := e.StartTLS(tc.addr, certFile, keyFile)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+ err := waitForServerStart(e, errChan, true)
if tc.expectError != "" {
- assert.EqualError(t, err, tc.expectError)
+ if _, ok := err.(*os.PathError); ok {
+ assert.Error(t, err) // error messages for unix and windows are different. so test only error type here
+ } else {
+ assert.EqualError(t, err, tc.expectError)
+ }
} else {
assert.NoError(t, err)
}
- assert.Len(t, e.Router().Routes(), tc.expectLen)
- assert.Equal(t, tc.expectAdded, added)
+ assert.NoError(t, e.Close())
})
}
}
-func TestEchoContext(t *testing.T) {
+func TestEchoStartTLSAndStart(t *testing.T) {
+ // We test if Echo and listeners work correctly when Echo is simultaneously attached to HTTP and HTTPS server
e := New()
- c := e.AcquireContext()
- assert.IsType(t, new(Context), c)
- e.ReleaseContext(c)
-}
+ e.GET("/", func(c Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
-func TestPreMiddlewares(t *testing.T) {
- e := New()
- assert.Equal(t, 0, len(e.PreMiddlewares()))
+ errTLSChan := make(chan error)
+ go func() {
+ certFile := "_fixture/certs/cert.pem"
+ keyFile := "_fixture/certs/key.pem"
+ err := e.StartTLS("localhost:", certFile, keyFile)
+ if err != nil {
+ errTLSChan <- err
+ }
+ }()
- e.Pre(func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
- return next(c)
+ err := waitForServerStart(e, errTLSChan, true)
+ assert.NoError(t, err)
+ defer func() {
+ if err := e.Shutdown(stdContext.Background()); err != nil {
+ t.Error(err)
}
- })
+ }()
+
+ // check if HTTPS works (note: we are using self signed certs so InsecureSkipVerify=true)
+ client := &http.Client{Transport: &http.Transport{
+ TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
+ }}
+ res, err := client.Get("https://" + e.TLSListenerAddr().String())
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, res.StatusCode)
+
+ errChan := make(chan error)
+ go func() {
+ err := e.Start("localhost:")
+ if err != nil {
+ errChan <- err
+ }
+ }()
+ err = waitForServerStart(e, errChan, false)
+ assert.NoError(t, err)
+
+ // now we are serving both HTTPS and HTTP listeners. see if HTTP works in addition to HTTPS
+ res, err = http.Get("http://" + e.ListenerAddr().String())
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, res.StatusCode)
- assert.Equal(t, 1, len(e.PreMiddlewares()))
+ // see if HTTPS works after HTTP listener is also added
+ res, err = client.Get("https://" + e.TLSListenerAddr().String())
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, res.StatusCode)
}
-func TestMiddlewares(t *testing.T) {
- e := New()
- assert.Equal(t, 0, len(e.Middlewares()))
+func TestEchoStartTLSByteString(t *testing.T) {
+ cert, err := os.ReadFile("_fixture/certs/cert.pem")
+ require.NoError(t, err)
+ key, err := os.ReadFile("_fixture/certs/key.pem")
+ require.NoError(t, err)
- e.Use(func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
- return next(c)
- }
- })
+ testCases := []struct {
+ cert any
+ key any
+ expectedErr error
+ name string
+ }{
+ {
+ cert: "_fixture/certs/cert.pem",
+ key: "_fixture/certs/key.pem",
+ expectedErr: nil,
+ name: `ValidCertAndKeyFilePath`,
+ },
+ {
+ cert: cert,
+ key: key,
+ expectedErr: nil,
+ name: `ValidCertAndKeyByteString`,
+ },
+ {
+ cert: cert,
+ key: 1,
+ expectedErr: ErrInvalidCertOrKeyType,
+ name: `InvalidKeyType`,
+ },
+ {
+ cert: 0,
+ key: key,
+ expectedErr: ErrInvalidCertOrKeyType,
+ name: `InvalidCertType`,
+ },
+ {
+ cert: 0,
+ key: 1,
+ expectedErr: ErrInvalidCertOrKeyType,
+ name: `InvalidCertAndKeyTypes`,
+ },
+ }
+
+ for _, test := range testCases {
+ test := test
+ t.Run(test.name, func(t *testing.T) {
+ e := New()
+ e.HideBanner = true
- assert.Equal(t, 1, len(e.Middlewares()))
+ errChan := make(chan error)
+
+ go func() {
+ errChan <- e.StartTLS(":0", test.cert, test.key)
+ }()
+
+ err := waitForServerStart(e, errChan, true)
+ if test.expectedErr != nil {
+ assert.EqualError(t, err, test.expectedErr.Error())
+ } else {
+ assert.NoError(t, err)
+ }
+
+ assert.NoError(t, e.Close())
+ })
+ }
}
-func TestEcho_Start(t *testing.T) {
- e := New()
- e.GET("/", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
- rndPort, err := net.Listen("tcp", ":0")
- if err != nil {
- t.Fatal(err)
+func TestEcho_StartAutoTLS(t *testing.T) {
+ var testCases = []struct {
+ name string
+ addr string
+ expectError string
+ }{
+ {
+ name: "ok",
+ addr: ":0",
+ },
+ {
+ name: "nok, invalid address",
+ addr: "nope",
+ expectError: "listen tcp: address nope: missing port in address",
+ },
}
- defer rndPort.Close()
- errChan := make(chan error, 1)
- go func() {
- errChan <- e.Start(rndPort.Addr().String())
- }()
- select {
- case <-time.After(250 * time.Millisecond):
- t.Fatal("start did not error out")
- case err := <-errChan:
- expectContains := "bind: address already in use"
- if runtime.GOOS == "windows" {
- expectContains = "bind: Only one usage of each socket address"
- }
- assert.Contains(t, err.Error(), expectContains)
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ errChan := make(chan error)
+
+ go func() {
+ errChan <- e.StartAutoTLS(tc.addr)
+ }()
+
+ err := waitForServerStart(e, errChan, true)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ assert.NoError(t, e.Close())
+ })
}
}
+func TestEcho_StartH2CServer(t *testing.T) {
+ var testCases = []struct {
+ name string
+ addr string
+ expectError string
+ }{
+ {
+ name: "ok",
+ addr: ":0",
+ },
+ {
+ name: "nok, invalid address",
+ addr: "nope",
+ expectError: "listen tcp: address nope: missing port in address",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ e.Debug = true
+ h2s := &http2.Server{}
+
+ errChan := make(chan error)
+ go func() {
+ err := e.StartH2CServer(tc.addr, h2s)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ err := waitForServerStart(e, errChan, false)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ assert.NoError(t, e.Close())
+ })
+ }
+}
+
+func testMethod(t *testing.T, method, path string, e *Echo) {
+ p := reflect.ValueOf(path)
+ h := reflect.ValueOf(func(c Context) error {
+ return c.String(http.StatusOK, method)
+ })
+ i := any(e)
+ reflect.ValueOf(i).MethodByName(method).Call([]reflect.Value{p, h})
+ _, body := request(method, path, e)
+ assert.Equal(t, method, body)
+}
+
func request(method, path string, e *Echo) (int, string) {
req := httptest.NewRequest(method, path, nil)
rec := httptest.NewRecorder()
@@ -1102,143 +1235,589 @@ func request(method, path string, e *Echo) (int, string) {
return rec.Code, rec.Body.String()
}
-type customError struct {
- Code int
- Message string
+func TestHTTPError(t *testing.T) {
+ t.Run("non-internal", func(t *testing.T) {
+ err := NewHTTPError(http.StatusBadRequest, map[string]any{
+ "code": 12,
+ })
+
+ assert.Equal(t, "code=400, message=map[code:12]", err.Error())
+ })
+
+ t.Run("internal and SetInternal", func(t *testing.T) {
+ err := NewHTTPError(http.StatusBadRequest, map[string]any{
+ "code": 12,
+ })
+ err.SetInternal(errors.New("internal error"))
+ assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error())
+ })
+
+ t.Run("internal and WithInternal", func(t *testing.T) {
+ err := NewHTTPError(http.StatusBadRequest, map[string]any{
+ "code": 12,
+ })
+ err = err.WithInternal(errors.New("internal error"))
+ assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error())
+ })
+}
+
+func TestHTTPError_Unwrap(t *testing.T) {
+ t.Run("non-internal", func(t *testing.T) {
+ err := NewHTTPError(http.StatusBadRequest, map[string]any{
+ "code": 12,
+ })
+
+ assert.Nil(t, errors.Unwrap(err))
+ })
+
+ t.Run("unwrap internal and SetInternal", func(t *testing.T) {
+ err := NewHTTPError(http.StatusBadRequest, map[string]any{
+ "code": 12,
+ })
+ err.SetInternal(errors.New("internal error"))
+ assert.Equal(t, "internal error", errors.Unwrap(err).Error())
+ })
+
+ t.Run("unwrap internal and WithInternal", func(t *testing.T) {
+ err := NewHTTPError(http.StatusBadRequest, map[string]any{
+ "code": 12,
+ })
+ err = err.WithInternal(errors.New("internal error"))
+ assert.Equal(t, "internal error", errors.Unwrap(err).Error())
+ })
}
-func (ce *customError) StatusCode() int {
- return ce.Code
+type customError struct {
+ s string
}
func (ce *customError) MarshalJSON() ([]byte, error) {
- return fmt.Appendf(nil, `{"x":"%v"}`, ce.Message), nil
+ return []byte(fmt.Sprintf(`{"x":"%v"}`, ce.s)), nil
}
func (ce *customError) Error() string {
- return ce.Message
+ return ce.s
}
func TestDefaultHTTPErrorHandler(t *testing.T) {
var testCases = []struct {
- whenError error
- name string
- whenMethod string
- expectBody string
- expectLogged string
- expectStatus int
- givenExposeError bool
- givenLoggerFunc bool
+ name string
+ givenDebug bool
+ whenPath string
+ expectCode int
+ expectBody string
}{
{
- name: "ok, expose error = true, HTTPError, no wrapped err",
- givenExposeError: true,
- whenError: &HTTPError{Code: http.StatusTeapot, Message: "my_error"},
- expectStatus: http.StatusTeapot,
- expectBody: `{"message":"my_error"}` + "\n",
+ name: "with Debug=true plain response contains error message",
+ givenDebug: true,
+ whenPath: "/plain",
+ expectCode: http.StatusInternalServerError,
+ expectBody: "{\n \"error\": \"an error occurred\",\n \"message\": \"Internal Server Error\"\n}\n",
},
{
- name: "ok, expose error = true, HTTPError + wrapped error",
- givenExposeError: true,
- whenError: HTTPError{Code: http.StatusTeapot, Message: "my_error"}.Wrap(errors.New("internal_error")),
- expectStatus: http.StatusTeapot,
- expectBody: `{"error":"internal_error","message":"my_error"}` + "\n",
+ name: "with Debug=true special handling for HTTPError",
+ givenDebug: true,
+ whenPath: "/badrequest",
+ expectCode: http.StatusBadRequest,
+ expectBody: "{\n \"error\": \"code=400, message=Invalid request\",\n \"message\": \"Invalid request\"\n}\n",
},
{
- name: "ok, expose error = true, HTTPError + wrapped HTTPError",
- givenExposeError: true,
- whenError: HTTPError{Code: http.StatusTeapot, Message: "my_error"}.Wrap(&HTTPError{Code: http.StatusTeapot, Message: "early_error"}),
- expectStatus: http.StatusTeapot,
- expectBody: `{"error":"code=418, message=early_error","message":"my_error"}` + "\n",
+ name: "with Debug=true complex errors are serialized to pretty JSON",
+ givenDebug: true,
+ whenPath: "/servererror",
+ expectCode: http.StatusInternalServerError,
+ expectBody: "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n",
},
{
- name: "ok, expose error = false, HTTPError",
- whenError: &HTTPError{Code: http.StatusTeapot, Message: "my_error"},
- expectStatus: http.StatusTeapot,
- expectBody: `{"message":"my_error"}` + "\n",
+ name: "with Debug=true if the body is already set HTTPErrorHandler should not add anything to response body",
+ givenDebug: true,
+ whenPath: "/early-return",
+ expectCode: http.StatusOK,
+ expectBody: "OK",
},
{
- name: "ok, expose error = false, HTTPError, no message",
- whenError: &HTTPError{Code: http.StatusTeapot, Message: ""},
- expectStatus: http.StatusTeapot,
- expectBody: `{"message":"I'm a teapot"}` + "\n",
+ name: "with Debug=true internal error should be reflected in the message",
+ givenDebug: true,
+ whenPath: "/internal-error",
+ expectCode: http.StatusBadRequest,
+ expectBody: "{\n \"error\": \"code=400, message=Bad Request, internal=internal error message body\",\n \"message\": \"Bad Request\"\n}\n",
},
{
- name: "ok, expose error = false, HTTPError + internal HTTPError",
- whenError: HTTPError{Code: http.StatusTooEarly, Message: "my_error"}.Wrap(&HTTPError{Code: http.StatusTeapot, Message: "early_error"}),
- expectStatus: http.StatusTooEarly,
- expectBody: `{"message":"my_error"}` + "\n",
+ name: "with Debug=false the error response is shortened",
+ whenPath: "/plain",
+ expectCode: http.StatusInternalServerError,
+ expectBody: "{\"message\":\"Internal Server Error\"}\n",
},
{
- name: "ok, expose error = true, Error",
- givenExposeError: true,
- whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")),
- expectStatus: http.StatusInternalServerError,
- expectBody: `{"error":"my errors wraps: internal_error","message":"Internal Server Error"}` + "\n",
+ name: "with Debug=false the error response is shortened",
+ whenPath: "/badrequest",
+ expectCode: http.StatusBadRequest,
+ expectBody: "{\"message\":\"Invalid request\"}\n",
},
{
- name: "ok, expose error = false, Error",
- whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")),
- expectStatus: http.StatusInternalServerError,
- expectBody: `{"message":"Internal Server Error"}` + "\n",
+ name: "with Debug=false No difference for error response with non plain string errors",
+ whenPath: "/servererror",
+ expectCode: http.StatusInternalServerError,
+ expectBody: "{\"code\":33,\"error\":\"stackinfo\",\"message\":\"Something bad happened\"}\n",
},
{
- name: "ok, http.HEAD, expose error = true, Error",
- givenExposeError: true,
- whenMethod: http.MethodHead,
- whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")),
- expectStatus: http.StatusInternalServerError,
- expectBody: ``,
+ name: "with Debug=false when httpError contains an error",
+ whenPath: "/error-in-httperror",
+ expectCode: http.StatusBadRequest,
+ expectBody: "{\"message\":\"error in httperror\"}\n",
},
{
- name: "ok, custom error implement MarshalJSON + HTTPStatusCoder",
- whenMethod: http.MethodGet,
- whenError: &customError{Code: http.StatusTeapot, Message: "custom error msg"},
- expectStatus: http.StatusTeapot,
- expectBody: `{"x":"custom error msg"}` + "\n",
+ name: "with Debug=false when httpError contains an error",
+ whenPath: "/customerror-in-httperror",
+ expectCode: http.StatusBadRequest,
+ expectBody: "{\"x\":\"custom error msg\"}\n",
},
}
-
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- buf := new(bytes.Buffer)
e := New()
- e.Logger = slog.New(slog.DiscardHandler)
- e.Any("/path", func(c *Context) error {
- return tc.whenError
+ e.Debug = tc.givenDebug // With Debug=true plain response contains error message
+
+ e.Any("/plain", func(c Context) error {
+ return errors.New("an error occurred")
})
- e.HTTPErrorHandler = DefaultHTTPErrorHandler(tc.givenExposeError)
+ e.Any("/badrequest", func(c Context) error { // and special handling for HTTPError
+ return NewHTTPError(http.StatusBadRequest, "Invalid request")
+ })
- method := http.MethodGet
- if tc.whenMethod != "" {
- method = tc.whenMethod
- }
- c, b := request(method, "/path", e)
+ e.Any("/servererror", func(c Context) error { // complex errors are serialized to pretty JSON
+ return NewHTTPError(http.StatusInternalServerError, map[string]any{
+ "code": 33,
+ "message": "Something bad happened",
+ "error": "stackinfo",
+ })
+ })
+
+ // if the body is already set HTTPErrorHandler should not add anything to response body
+ e.Any("/early-return", func(c Context) error {
+ err := c.String(http.StatusOK, "OK")
+ if err != nil {
+ assert.Fail(t, err.Error())
+ }
+ return errors.New("ERROR")
+ })
+
+ // internal error should be reflected in the message
+ e.GET("/internal-error", func(c Context) error {
+ err := errors.New("internal error message body")
+ return NewHTTPError(http.StatusBadRequest).SetInternal(err)
+ })
+
+ e.GET("/error-in-httperror", func(c Context) error {
+ return NewHTTPError(http.StatusBadRequest, errors.New("error in httperror"))
+ })
+
+ e.GET("/customerror-in-httperror", func(c Context) error {
+ return NewHTTPError(http.StatusBadRequest, &customError{s: "custom error msg"})
+ })
- assert.Equal(t, tc.expectStatus, c)
+ c, b := request(http.MethodGet, tc.whenPath, e)
+ assert.Equal(t, tc.expectCode, c)
assert.Equal(t, tc.expectBody, b)
- assert.Equal(t, tc.expectLogged, buf.String())
})
}
}
-func TestDefaultHTTPErrorHandler_CommitedResponse(t *testing.T) {
+func TestEchoClose(t *testing.T) {
e := New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- resp := httptest.NewRecorder()
- c := e.NewContext(req, resp)
+ errCh := make(chan error)
+
+ go func() {
+ errCh <- e.Start(":0")
+ }()
+
+ err := waitForServerStart(e, errCh, false)
+ assert.NoError(t, err)
+
+ if err := e.Close(); err != nil {
+ t.Fatal(err)
+ }
- c.orgResponse.Committed = true
- errHandler := DefaultHTTPErrorHandler(false)
+ assert.NoError(t, e.Close())
- errHandler(c, errors.New("my_error"))
- assert.Equal(t, http.StatusOK, resp.Code)
+ err = <-errCh
+ assert.Equal(t, err.Error(), "http: Server closed")
}
-func benchmarkEchoRoutes(b *testing.B, routes []testRoute) {
+func TestEchoShutdown(t *testing.T) {
e := New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
+ errCh := make(chan error)
+
+ go func() {
+ errCh <- e.Start(":0")
+ }()
+
+ err := waitForServerStart(e, errCh, false)
+ assert.NoError(t, err)
+
+ if err := e.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 10*time.Second)
+ defer cancel()
+ assert.NoError(t, e.Shutdown(ctx))
+
+ err = <-errCh
+ assert.Equal(t, err.Error(), "http: Server closed")
+}
+
+var listenerNetworkTests = []struct {
+ test string
+ network string
+ address string
+}{
+ {"tcp ipv4 address", "tcp", "127.0.0.1:1323"},
+ {"tcp ipv6 address", "tcp", "[::1]:1323"},
+ {"tcp4 ipv4 address", "tcp4", "127.0.0.1:1323"},
+ {"tcp6 ipv6 address", "tcp6", "[::1]:1323"},
+}
+
+func supportsIPv6() bool {
+ addrs, _ := net.InterfaceAddrs()
+ for _, addr := range addrs {
+ // Check if any interface has local IPv6 assigned
+ if strings.Contains(addr.String(), "::1") {
+ return true
+ }
+ }
+ return false
+}
+
+func TestEchoListenerNetwork(t *testing.T) {
+ hasIPv6 := supportsIPv6()
+ for _, tt := range listenerNetworkTests {
+ if !hasIPv6 && strings.Contains(tt.address, "::") {
+ t.Skip("Skipping testing IPv6 for " + tt.address + ", not available")
+ continue
+ }
+ t.Run(tt.test, func(t *testing.T) {
+ e := New()
+ e.ListenerNetwork = tt.network
+
+ // HandlerFunc
+ e.GET("/ok", func(c Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+
+ errCh := make(chan error)
+
+ go func() {
+ errCh <- e.Start(tt.address)
+ }()
+
+ err := waitForServerStart(e, errCh, false)
+ assert.NoError(t, err)
+
+ if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil {
+ defer func(Body io.ReadCloser) {
+ err := Body.Close()
+ if err != nil {
+ assert.Fail(t, err.Error())
+ }
+ }(resp.Body)
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ if body, err := io.ReadAll(resp.Body); err == nil {
+ assert.Equal(t, "OK", string(body))
+ } else {
+ assert.Fail(t, err.Error())
+ }
+
+ } else {
+ assert.Fail(t, err.Error())
+ }
+
+ if err := e.Close(); err != nil {
+ t.Fatal(err)
+ }
+ })
+ }
+}
+
+func TestEchoListenerNetworkInvalid(t *testing.T) {
+ e := New()
+ e.ListenerNetwork = "unix"
+
+ // HandlerFunc
+ e.GET("/ok", func(c Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+
+ assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323"))
+}
+
+func TestEcho_OnAddRouteHandler(t *testing.T) {
+ type rr struct {
+ host string
+ route Route
+ handler HandlerFunc
+ middleware []MiddlewareFunc
+ }
+ dummyHandler := func(Context) error { return nil }
+ e := New()
+
+ added := make([]rr, 0)
+ e.OnAddRouteHandler = func(host string, route Route, handler HandlerFunc, middleware []MiddlewareFunc) {
+ added = append(added, rr{
+ host: host,
+ route: route,
+ handler: handler,
+ middleware: middleware,
+ })
+ }
+
+ e.GET("/static", dummyHandler)
+ e.Host("domain.site").GET("/static/*", dummyHandler, func(next HandlerFunc) HandlerFunc {
+ return func(c Context) error {
+ return next(c)
+ }
+ })
+
+ assert.Len(t, added, 2)
+
+ assert.Equal(t, "", added[0].host)
+ assert.Equal(t, Route{Method: http.MethodGet, Path: "/static", Name: "github.com/labstack/echo/v4.TestEcho_OnAddRouteHandler.func1"}, added[0].route)
+ assert.Len(t, added[0].middleware, 0)
+
+ assert.Equal(t, "domain.site", added[1].host)
+ assert.Equal(t, Route{Method: http.MethodGet, Path: "/static/*", Name: "github.com/labstack/echo/v4.TestEcho_OnAddRouteHandler.func1"}, added[1].route)
+ assert.Len(t, added[1].middleware, 1)
+}
+
+func TestEchoReverse(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenRouteName string
+ whenParams []any
+ expect string
+ }{
+ {
+ name: "ok, not existing path returns empty url",
+ whenRouteName: "not-existing",
+ expect: "",
+ },
+ {
+ name: "ok,static with no params",
+ whenRouteName: "/static",
+ expect: "/static",
+ },
+ {
+ name: "ok,static with non existent param",
+ whenRouteName: "/static",
+ whenParams: []any{"missing param"},
+ expect: "/static",
+ },
+ {
+ name: "ok, wildcard with no params",
+ whenRouteName: "/static/*",
+ expect: "/static/*",
+ },
+ {
+ name: "ok, wildcard with params",
+ whenRouteName: "/static/*",
+ whenParams: []any{"foo.txt"},
+ expect: "/static/foo.txt",
+ },
+ {
+ name: "ok, single param without param",
+ whenRouteName: "/params/:foo",
+ expect: "/params/:foo",
+ },
+ {
+ name: "ok, single param with param",
+ whenRouteName: "/params/:foo",
+ whenParams: []any{"one"},
+ expect: "/params/one",
+ },
+ {
+ name: "ok, multi param without params",
+ whenRouteName: "/params/:foo/bar/:qux",
+ expect: "/params/:foo/bar/:qux",
+ },
+ {
+ name: "ok, multi param with one param",
+ whenRouteName: "/params/:foo/bar/:qux",
+ whenParams: []any{"one"},
+ expect: "/params/one/bar/:qux",
+ },
+ {
+ name: "ok, multi param with all params",
+ whenRouteName: "/params/:foo/bar/:qux",
+ whenParams: []any{"one", "two"},
+ expect: "/params/one/bar/two",
+ },
+ {
+ name: "ok, multi param + wildcard with all params",
+ whenRouteName: "/params/:foo/bar/:qux/*",
+ whenParams: []any{"one", "two", "three"},
+ expect: "/params/one/bar/two/three",
+ },
+ {
+ name: "ok, backslash is not escaped",
+ whenRouteName: "/backslash",
+ whenParams: []any{"test"},
+ expect: `/a\b/test`,
+ },
+ {
+ name: "ok, escaped colon verbs",
+ whenRouteName: "/params:customVerb",
+ whenParams: []any{"PATCH"},
+ expect: `/params:PATCH`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ dummyHandler := func(Context) error { return nil }
+
+ e.GET("/static", dummyHandler).Name = "/static"
+ e.GET("/static/*", dummyHandler).Name = "/static/*"
+ e.GET("/params/:foo", dummyHandler).Name = "/params/:foo"
+ e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux"
+ e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*"
+ e.GET("/a\\b/:x", dummyHandler).Name = "/backslash"
+ e.GET("/params\\::customVerb", dummyHandler).Name = "/params:customVerb"
+
+ assert.Equal(t, tc.expect, e.Reverse(tc.whenRouteName, tc.whenParams...))
+ })
+ }
+}
+
+func TestEchoReverseHandleHostProperly(t *testing.T) {
+ dummyHandler := func(Context) error { return nil }
+
+ e := New()
+
+ // routes added to the default router are different form different hosts
+ e.GET("/static", dummyHandler).Name = "default-host /static"
+ e.GET("/static/*", dummyHandler).Name = "xxx"
+
+ // different host
+ h := e.Host("the_host")
+ h.GET("/static", dummyHandler).Name = "host2 /static"
+ h.GET("/static/v2/*", dummyHandler).Name = "xxx"
+
+ assert.Equal(t, "/static", e.Reverse("default-host /static"))
+ // when actual route does not have params and we provide some to Reverse we should get that route url back
+ assert.Equal(t, "/static", e.Reverse("default-host /static", "missing param"))
+
+ host2Router := e.Routers()["the_host"]
+ assert.Equal(t, "/static", host2Router.Reverse("host2 /static"))
+ assert.Equal(t, "/static", host2Router.Reverse("host2 /static", "missing param"))
+
+ assert.Equal(t, "/static/v2/*", host2Router.Reverse("xxx"))
+ assert.Equal(t, "/static/v2/foo.txt", host2Router.Reverse("xxx", "foo.txt"))
+
+}
+
+func TestEcho_ListenerAddr(t *testing.T) {
+ e := New()
+
+ addr := e.ListenerAddr()
+ assert.Nil(t, addr)
+
+ errCh := make(chan error)
+ go func() {
+ errCh <- e.Start(":0")
+ }()
+
+ err := waitForServerStart(e, errCh, false)
+ assert.NoError(t, err)
+}
+
+func TestEcho_TLSListenerAddr(t *testing.T) {
+ cert, err := os.ReadFile("_fixture/certs/cert.pem")
+ require.NoError(t, err)
+ key, err := os.ReadFile("_fixture/certs/key.pem")
+ require.NoError(t, err)
+
+ e := New()
+
+ addr := e.TLSListenerAddr()
+ assert.Nil(t, addr)
+
+ errCh := make(chan error)
+ go func() {
+ errCh <- e.StartTLS(":0", cert, key)
+ }()
+
+ err = waitForServerStart(e, errCh, true)
+ assert.NoError(t, err)
+}
+
+func TestEcho_StartServer(t *testing.T) {
+ cert, err := os.ReadFile("_fixture/certs/cert.pem")
+ require.NoError(t, err)
+ key, err := os.ReadFile("_fixture/certs/key.pem")
+ require.NoError(t, err)
+ certs, err := tls.X509KeyPair(cert, key)
+ require.NoError(t, err)
+
+ var testCases = []struct {
+ name string
+ addr string
+ TLSConfig *tls.Config
+ expectError string
+ }{
+ {
+ name: "ok",
+ addr: ":0",
+ },
+ {
+ name: "ok, start with TLS",
+ addr: ":0",
+ TLSConfig: &tls.Config{Certificates: []tls.Certificate{certs}},
+ },
+ {
+ name: "nok, invalid address",
+ addr: "nope",
+ expectError: "listen tcp: address nope: missing port in address",
+ },
+ {
+ name: "nok, invalid tls address",
+ addr: "nope",
+ TLSConfig: &tls.Config{InsecureSkipVerify: true},
+ expectError: "listen tcp: address nope: missing port in address",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ e.Debug = true
+
+ server := new(http.Server)
+ server.Addr = tc.addr
+ if tc.TLSConfig != nil {
+ server.TLSConfig = tc.TLSConfig
+ }
+
+ errCh := make(chan error)
+ go func() {
+ errCh <- e.StartServer(server)
+ }()
+
+ err := waitForServerStart(e, errCh, tc.TLSConfig != nil)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.NoError(t, e.Close())
+ })
+ }
+}
+
+func benchmarkEchoRoutes(b *testing.B, routes []*Route) {
+ e := New()
+ req := httptest.NewRequest("GET", "/", nil)
u := req.URL
w := httptest.NewRecorder()
@@ -1246,7 +1825,7 @@ func benchmarkEchoRoutes(b *testing.B, routes []testRoute) {
// Add routes
for _, route := range routes {
- e.Add(route.Method, route.Path, func(c *Context) error {
+ e.Add(route.Method, route.Path, func(c Context) error {
return nil
})
}
diff --git a/echotest/context.go b/echotest/context.go
deleted file mode 100644
index ca3bd1056..000000000
--- a/echotest/context.go
+++ /dev/null
@@ -1,183 +0,0 @@
-// SPDX-License-Identifier: MIT
-// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
-
-package echotest
-
-import (
- "bytes"
- "io"
- "mime/multipart"
- "net/http"
- "net/http/httptest"
- "net/url"
- "strings"
- "testing"
-
- "github.com/labstack/echo/v5"
-)
-
-// ContextConfig is configuration for creating echo.Context for testing purposes.
-type ContextConfig struct {
- // Request will be used instead of default `httptest.NewRequest(http.MethodGet, "/", nil)`
- Request *http.Request
-
- // Response will be used instead of default `httptest.NewRecorder()`
- Response *httptest.ResponseRecorder
-
- // QueryValues will be set as Request.URL.RawQuery value
- QueryValues url.Values
-
- // Headers will be set as Request.Header value
- Headers http.Header
-
- // PathValues initializes context.PathValues with given value.
- PathValues echo.PathValues
-
- // RouteInfo initializes context.RouteInfo() with given value
- RouteInfo *echo.RouteInfo
-
- // FormValues creates form-urlencoded form out of given values. If there is no
- // `content-type` header it will be set to `application/x-www-form-urlencoded`
- // In case Request was not set the Request.Method is set to `POST`
- //
- // FormValues, MultipartForm and JSONBody are mutually exclusive.
- FormValues url.Values
-
- // MultipartForm creates multipart form out of given value. If there is no
- // `content-type` header it will be set to `multipart/form-data`
- // In case Request was not set the Request.Method is set to `POST`
- //
- // FormValues, MultipartForm and JSONBody are mutually exclusive.
- MultipartForm *MultipartForm
-
- // JSONBody creates JSON body out of given bytes. If there is no
- // `content-type` header it will be set to `application/json`
- // In case Request was not set the Request.Method is set to `POST`
- //
- // FormValues, MultipartForm and JSONBody are mutually exclusive.
- JSONBody []byte
-}
-
-// MultipartForm is used to create multipart form out of given value
-type MultipartForm struct {
- Fields map[string]string
- Files []MultipartFormFile
-}
-
-// MultipartFormFile is used to create file in multipart form out of given value
-type MultipartFormFile struct {
- Fieldname string
- Filename string
- Content []byte
-}
-
-// ToContext converts ContextConfig to echo.Context
-func (conf ContextConfig) ToContext(t *testing.T) *echo.Context {
- c, _ := conf.ToContextRecorder(t)
- return c
-}
-
-// ToContextRecorder converts ContextConfig to echo.Context and httptest.ResponseRecorder
-func (conf ContextConfig) ToContextRecorder(t *testing.T) (*echo.Context, *httptest.ResponseRecorder) {
- if conf.Response == nil {
- conf.Response = httptest.NewRecorder()
- }
- isDefaultRequest := false
- if conf.Request == nil {
- isDefaultRequest = true
- conf.Request = httptest.NewRequest(http.MethodGet, "/", nil)
- }
-
- if len(conf.QueryValues) > 0 {
- conf.Request.URL.RawQuery = conf.QueryValues.Encode()
- }
- if len(conf.Headers) > 0 {
- conf.Request.Header = conf.Headers
- }
- if len(conf.FormValues) > 0 {
- body := strings.NewReader(url.Values(conf.FormValues).Encode())
- conf.Request.Body = io.NopCloser(body)
- conf.Request.ContentLength = int64(body.Len())
-
- if conf.Request.Header.Get(echo.HeaderContentType) == "" {
- conf.Request.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
- }
- if isDefaultRequest {
- conf.Request.Method = http.MethodPost
- }
- } else if conf.MultipartForm != nil {
- var body bytes.Buffer
- mw := multipart.NewWriter(&body)
- for field, value := range conf.MultipartForm.Fields {
- if err := mw.WriteField(field, value); err != nil {
- t.Fatal(err)
- }
- }
- for _, file := range conf.MultipartForm.Files {
- fw, err := mw.CreateFormFile(file.Fieldname, file.Filename)
- if err != nil {
- t.Fatal(err)
- }
- if _, err = fw.Write(file.Content); err != nil {
- t.Fatal(err)
- }
- }
- if err := mw.Close(); err != nil {
- t.Fatal(err)
- }
-
- conf.Request.Body = io.NopCloser(&body)
- conf.Request.ContentLength = int64(body.Len())
- if conf.Request.Header.Get(echo.HeaderContentType) == "" {
- conf.Request.Header.Set(echo.HeaderContentType, mw.FormDataContentType())
- }
- if isDefaultRequest {
- conf.Request.Method = http.MethodPost
- }
- } else if conf.JSONBody != nil {
- body := bytes.NewReader(conf.JSONBody)
- conf.Request.Body = io.NopCloser(body)
- conf.Request.ContentLength = int64(body.Len())
-
- if conf.Request.Header.Get(echo.HeaderContentType) == "" {
- conf.Request.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
- }
- if isDefaultRequest {
- conf.Request.Method = http.MethodPost
- }
- }
-
- ec := echo.NewContext(conf.Request, conf.Response, echo.New())
- if conf.RouteInfo == nil {
- conf.RouteInfo = &echo.RouteInfo{
- Name: "",
- Method: conf.Request.Method,
- Path: "/test",
- Parameters: []string{},
- }
- for _, p := range conf.PathValues {
- conf.RouteInfo.Parameters = append(conf.RouteInfo.Parameters, p.Name)
- }
- }
- ec.InitializeRoute(conf.RouteInfo, &conf.PathValues)
- return ec, conf.Response
-}
-
-// ServeWithHandler serves ContextConfig with given handler and returns httptest.ResponseRecorder for response checking
-func (conf ContextConfig) ServeWithHandler(t *testing.T, handler echo.HandlerFunc, opts ...any) *httptest.ResponseRecorder {
- c, rec := conf.ToContextRecorder(t)
-
- errHandler := echo.DefaultHTTPErrorHandler(false)
- for _, opt := range opts {
- switch o := opt.(type) {
- case echo.HTTPErrorHandler:
- errHandler = o
- }
- }
-
- err := handler(c)
- if err != nil {
- errHandler(c, err)
- }
- return rec
-}
diff --git a/echotest/context_external_test.go b/echotest/context_external_test.go
deleted file mode 100644
index d98257148..000000000
--- a/echotest/context_external_test.go
+++ /dev/null
@@ -1,27 +0,0 @@
-package echotest_test
-
-import (
- "net/http"
- "testing"
-
- "github.com/labstack/echo/v5"
- "github.com/labstack/echo/v5/echotest"
- "github.com/stretchr/testify/assert"
-)
-
-func TestToContext_JSONBody(t *testing.T) {
- c := echotest.ContextConfig{
- JSONBody: echotest.LoadBytes(t, "testdata/test.json"),
- }.ToContext(t)
-
- payload := struct {
- Field string `json:"field"`
- }{}
- if err := c.Bind(&payload); err != nil {
- t.Fatal(err)
- }
-
- assert.Equal(t, "value", payload.Field)
- assert.Equal(t, http.MethodPost, c.Request().Method)
- assert.Equal(t, echo.MIMEApplicationJSON, c.Request().Header.Get(echo.HeaderContentType))
-}
diff --git a/echotest/context_test.go b/echotest/context_test.go
deleted file mode 100644
index 66815e4b0..000000000
--- a/echotest/context_test.go
+++ /dev/null
@@ -1,157 +0,0 @@
-package echotest
-
-import (
- "net/http"
- "net/url"
- "strings"
- "testing"
-
- "github.com/labstack/echo/v5"
- "github.com/stretchr/testify/assert"
-)
-
-func TestServeWithHandler(t *testing.T) {
- handler := func(c *echo.Context) error {
- return c.String(http.StatusOK, c.QueryParam("key"))
- }
- testConf := ContextConfig{
- QueryValues: url.Values{"key": []string{"value"}},
- }
-
- resp := testConf.ServeWithHandler(t, handler)
-
- assert.Equal(t, http.StatusOK, resp.Code)
- assert.Equal(t, "value", resp.Body.String())
-}
-
-func TestServeWithHandler_error(t *testing.T) {
- handler := func(c *echo.Context) error {
- return echo.NewHTTPError(http.StatusBadRequest, "something went wrong")
- }
- testConf := ContextConfig{
- QueryValues: url.Values{"key": []string{"value"}},
- }
-
- customErrHandler := echo.DefaultHTTPErrorHandler(true)
-
- resp := testConf.ServeWithHandler(t, handler, customErrHandler)
-
- assert.Equal(t, http.StatusBadRequest, resp.Code)
- assert.Equal(t, `{"message":"something went wrong"}`+"\n", resp.Body.String())
-}
-
-func TestToContext_QueryValues(t *testing.T) {
- testConf := ContextConfig{
- QueryValues: url.Values{"t": []string{"2006-01-02"}},
- }
- c := testConf.ToContext(t)
-
- v, err := echo.QueryParam[string](c, "t")
-
- assert.NoError(t, err)
- assert.Equal(t, "2006-01-02", v)
-}
-
-func TestToContext_Headers(t *testing.T) {
- testConf := ContextConfig{
- Headers: http.Header{echo.HeaderXRequestID: []string{"ABC"}},
- }
- c := testConf.ToContext(t)
-
- id := c.Request().Header.Get(echo.HeaderXRequestID)
-
- assert.Equal(t, "ABC", id)
-}
-
-func TestToContext_PathValues(t *testing.T) {
- testConf := ContextConfig{
- PathValues: echo.PathValues{{
- Name: "key",
- Value: "value",
- }},
- }
- c := testConf.ToContext(t)
-
- key := c.Param("key")
-
- assert.Equal(t, "value", key)
-}
-
-func TestToContext_RouteInfo(t *testing.T) {
- testConf := ContextConfig{
- RouteInfo: &echo.RouteInfo{
- Name: "my_route",
- Method: http.MethodGet,
- Path: "/:id",
- Parameters: []string{"id"},
- },
- }
- c := testConf.ToContext(t)
-
- ri := c.RouteInfo()
-
- assert.Equal(t, echo.RouteInfo{
- Name: "my_route",
- Method: http.MethodGet,
- Path: "/:id",
- Parameters: []string{"id"},
- }, ri)
-}
-
-func TestToContext_FormValues(t *testing.T) {
- testConf := ContextConfig{
- FormValues: url.Values{"key": []string{"value"}},
- }
- c := testConf.ToContext(t)
-
- assert.Equal(t, "value", c.FormValue("key"))
- assert.Equal(t, http.MethodPost, c.Request().Method)
- assert.Equal(t, echo.MIMEApplicationForm, c.Request().Header.Get(echo.HeaderContentType))
-}
-
-func TestToContext_MultipartForm(t *testing.T) {
- testConf := ContextConfig{
- MultipartForm: &MultipartForm{
- Fields: map[string]string{
- "key": "value",
- },
- Files: []MultipartFormFile{
- {
- Fieldname: "file",
- Filename: "test.json",
- Content: LoadBytes(t, "testdata/test.json"),
- },
- },
- },
- }
- c := testConf.ToContext(t)
-
- assert.Equal(t, "value", c.FormValue("key"))
- assert.Equal(t, http.MethodPost, c.Request().Method)
- assert.Equal(t, true, strings.HasPrefix(c.Request().Header.Get(echo.HeaderContentType), "multipart/form-data; boundary="))
-
- fv, err := c.FormFile("file")
- if err != nil {
- t.Fatal(err)
- }
- assert.Equal(t, "test.json", fv.Filename)
- assert.Equal(t, int64(23), fv.Size)
-}
-
-func TestToContext_JSONBody(t *testing.T) {
- testConf := ContextConfig{
- JSONBody: LoadBytes(t, "testdata/test.json"),
- }
- c := testConf.ToContext(t)
-
- payload := struct {
- Field string `json:"field"`
- }{}
- if err := c.Bind(&payload); err != nil {
- t.Fatal(err)
- }
-
- assert.Equal(t, "value", payload.Field)
- assert.Equal(t, http.MethodPost, c.Request().Method)
- assert.Equal(t, echo.MIMEApplicationJSON, c.Request().Header.Get(echo.HeaderContentType))
-}
diff --git a/echotest/reader.go b/echotest/reader.go
deleted file mode 100644
index 0caceca02..000000000
--- a/echotest/reader.go
+++ /dev/null
@@ -1,46 +0,0 @@
-// SPDX-License-Identifier: MIT
-// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
-
-package echotest
-
-import (
- "os"
- "path/filepath"
- "runtime"
- "testing"
-)
-
-type loadBytesOpts func([]byte) []byte
-
-// TrimNewlineEnd instructs LoadBytes to remove `\n` from the end of loaded file.
-func TrimNewlineEnd(bytes []byte) []byte {
- bLen := len(bytes)
- if bLen > 1 && bytes[bLen-1] == '\n' {
- bytes = bytes[:bLen-1]
- }
- return bytes
-}
-
-// LoadBytes is helper to load file contents relative to current (where test file is) package
-// directory.
-func LoadBytes(t *testing.T, name string, opts ...loadBytesOpts) []byte {
- bytes := loadBytes(t, name, 2)
-
- for _, f := range opts {
- bytes = f(bytes)
- }
-
- return bytes
-}
-
-func loadBytes(t *testing.T, name string, callDepth int) []byte {
- _, b, _, _ := runtime.Caller(callDepth)
- basepath := filepath.Dir(b)
-
- path := filepath.Join(basepath, name) // relative path
- bytes, err := os.ReadFile(path)
- if err != nil {
- t.Fatal(err)
- }
- return bytes[:]
-}
diff --git a/echotest/reader_external_test.go b/echotest/reader_external_test.go
deleted file mode 100644
index 43fd57416..000000000
--- a/echotest/reader_external_test.go
+++ /dev/null
@@ -1,25 +0,0 @@
-package echotest_test
-
-import (
- "strings"
- "testing"
-
- "github.com/labstack/echo/v5/echotest"
- "github.com/stretchr/testify/assert"
-)
-
-const testJSONContent = `{
- "field": "value"
-}`
-
-func TestLoadBytesOK(t *testing.T) {
- data := echotest.LoadBytes(t, "testdata/test.json")
- assert.Equal(t, []byte(testJSONContent+"\n"), data)
-}
-
-func TestLoadBytes_custom(t *testing.T) {
- data := echotest.LoadBytes(t, "testdata/test.json", func(bytes []byte) []byte {
- return []byte(strings.ToUpper(string(bytes)))
- })
- assert.Equal(t, []byte(strings.ToUpper(testJSONContent)+"\n"), data)
-}
diff --git a/echotest/reader_test.go b/echotest/reader_test.go
deleted file mode 100644
index 23b3c2dd2..000000000
--- a/echotest/reader_test.go
+++ /dev/null
@@ -1,21 +0,0 @@
-package echotest
-
-import (
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-const testJSONContent = `{
- "field": "value"
-}`
-
-func TestLoadBytesOK(t *testing.T) {
- data := LoadBytes(t, "testdata/test.json")
- assert.Equal(t, []byte(testJSONContent+"\n"), data)
-}
-
-func TestLoadBytesOK_TrimNewlineEnd(t *testing.T) {
- data := LoadBytes(t, "testdata/test.json", TrimNewlineEnd)
- assert.Equal(t, []byte(testJSONContent), data)
-}
diff --git a/echotest/testdata/test.json b/echotest/testdata/test.json
deleted file mode 100644
index 94ae65f17..000000000
--- a/echotest/testdata/test.json
+++ /dev/null
@@ -1,3 +0,0 @@
-{
- "field": "value"
-}
diff --git a/go.mod b/go.mod
index a2480a285..55990b60b 100644
--- a/go.mod
+++ b/go.mod
@@ -1,16 +1,23 @@
-module github.com/labstack/echo/v5
+module github.com/labstack/echo/v4
go 1.25.0
require (
+ github.com/labstack/gommon v0.5.0
github.com/stretchr/testify v1.11.1
- golang.org/x/net v0.49.0
- golang.org/x/time v0.14.0
+ github.com/valyala/fasttemplate v1.2.2
+ golang.org/x/crypto v0.50.0
+ golang.org/x/net v0.53.0
+ golang.org/x/time v0.15.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/mattn/go-colorable v0.1.14 // indirect
+ github.com/mattn/go-isatty v0.0.22 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
- golang.org/x/text v0.33.0 // indirect
+ github.com/valyala/bytebufferpool v1.0.0 // indirect
+ golang.org/x/sys v0.43.0 // indirect
+ golang.org/x/text v0.36.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/go.sum b/go.sum
index f1e80fc13..77f53a71d 100644
--- a/go.sum
+++ b/go.sum
@@ -1,15 +1,29 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/labstack/gommon v0.5.0 h1:6VSQ2NOzsnEJ5W6+84E0RbcaDDmgB6NIAzWCczTEe6c=
+github.com/labstack/gommon v0.5.0/go.mod h1:Rzlg7HHy1maLfzBYGg9NZcVuz1sA68HHhLjhcEllYE0=
+github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
+github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
+github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4=
+github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
-golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
-golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
-golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
-golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
-golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
-golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
+github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
+github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
+github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo=
+github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
+golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
+golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
+golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
+golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
+golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
+golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
+golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
+golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
+golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
+golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
diff --git a/group.go b/group.go
index 8092bc904..cb37b123f 100644
--- a/group.go
+++ b/group.go
@@ -4,7 +4,6 @@
package echo
import (
- "io/fs"
"net/http"
)
@@ -12,167 +11,119 @@ import (
// routes that share a common middleware or functionality that should be separate
// from the parent echo instance while still inheriting from it.
type Group struct {
- echo *Echo
+ common
+ host string
prefix string
+ echo *Echo
middleware []MiddlewareFunc
}
// Use implements `Echo#Use()` for sub-routes within the Group.
-// Group middlewares are not executed on request when there is no matching route found.
func (g *Group) Use(middleware ...MiddlewareFunc) {
g.middleware = append(g.middleware, middleware...)
+ if len(g.middleware) == 0 {
+ return
+ }
+ // group level middlewares are different from Echo `Pre` and `Use` middlewares (those are global). Group level middlewares
+ // are only executed if they are added to the Router with route.
+ // So we register catch all route (404 is a safe way to emulate route match) for this group and now during routing the
+ // Router would find route to match our request path and therefore guarantee the middleware(s) will get executed.
+ g.RouteNotFound("", NotFoundHandler)
+ g.RouteNotFound("/*", NotFoundHandler)
}
-// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error.
-func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group.
+func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.Add(http.MethodConnect, path, h, m...)
}
-// DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error.
-func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// DELETE implements `Echo#DELETE()` for sub-routes within the Group.
+func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.Add(http.MethodDelete, path, h, m...)
}
-// GET implements `Echo#GET()` for sub-routes within the Group. Panics on error.
-func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// GET implements `Echo#GET()` for sub-routes within the Group.
+func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.Add(http.MethodGet, path, h, m...)
}
-// HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error.
-func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// HEAD implements `Echo#HEAD()` for sub-routes within the Group.
+func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.Add(http.MethodHead, path, h, m...)
}
-// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error.
-func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group.
+func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.Add(http.MethodOptions, path, h, m...)
}
-// PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error.
-func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// PATCH implements `Echo#PATCH()` for sub-routes within the Group.
+func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.Add(http.MethodPatch, path, h, m...)
}
-// POST implements `Echo#POST()` for sub-routes within the Group. Panics on error.
-func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// POST implements `Echo#POST()` for sub-routes within the Group.
+func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.Add(http.MethodPost, path, h, m...)
}
-// PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error.
-func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// PUT implements `Echo#PUT()` for sub-routes within the Group.
+func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.Add(http.MethodPut, path, h, m...)
}
-// TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error.
-func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// TRACE implements `Echo#TRACE()` for sub-routes within the Group.
+func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.Add(http.MethodTrace, path, h, m...)
}
-// Any implements `Echo#Any()` for sub-routes within the Group. Panics on error.
-func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
- return g.Add(RouteAny, path, handler, middleware...)
-}
-
-// Match implements `Echo#Match()` for sub-routes within the Group. Panics on error.
-func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
- errs := make([]error, 0)
- ris := make(Routes, 0)
- for _, m := range methods {
- ri, err := g.AddRoute(Route{
- Method: m,
- Path: path,
- Handler: handler,
- Middlewares: middleware,
- })
- if err != nil {
- errs = append(errs, err)
- continue
- }
- ris = append(ris, ri)
+// Any implements `Echo#Any()` for sub-routes within the Group.
+func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
+ routes := make([]*Route, len(methods))
+ for i, m := range methods {
+ routes[i] = g.Add(m, path, handler, middleware...)
}
- if len(errs) > 0 {
- panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
+ return routes
+}
+
+// Match implements `Echo#Match()` for sub-routes within the Group.
+func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
+ routes := make([]*Route, len(methods))
+ for i, m := range methods {
+ routes[i] = g.Add(m, path, handler, middleware...)
}
- return ris
+ return routes
}
// Group creates a new sub-group with prefix and optional sub-group-level middleware.
-// Important! Group middlewares are only executed in case there was exact route match and not
-// for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add
-// a catch-all route `/*` for the group which handler returns always 404
func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) {
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
m = append(m, g.middleware...)
m = append(m, middleware...)
sg = g.echo.Group(g.prefix+prefix, m...)
+ sg.host = g.host
return
}
-// Static implements `Echo#Static()` for sub-routes within the Group.
-func (g *Group) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo {
- subFs := MustSubFS(g.echo.Filesystem, fsRoot)
- return g.StaticFS(pathPrefix, subFs, middleware...)
-}
-
-// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group.
-//
-// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
-// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
-// including `assets/images` as their prefix.
-func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo {
- return g.Add(
- http.MethodGet,
- pathPrefix+"*",
- StaticDirectoryHandler(filesystem, false),
- middleware...,
- )
-}
-
-// FileFS implements `Echo#FileFS()` for sub-routes within the Group.
-//
-// Avoid using the leading `/` slash as most of the Go standard library fs.FS implementations require relative paths for
-// file operations.
-func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo {
- return g.GET(path, StaticFileHandler(file, filesystem), m...)
-}
-
-// File implements `Echo#File()` for sub-routes within the Group. Panics on error.
-//
-// Avoid using the leading `/` slash as most of the Go standard library fs.FS implementations require relative paths for
-// file operations.
-func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo {
- handler := func(c *Context) error {
- return c.File(file)
- }
- return g.Add(http.MethodGet, path, handler, middleware...)
+// File implements `Echo#File()` for sub-routes within the Group.
+func (g *Group) File(path, file string) {
+ g.file(path, file, g.GET)
}
// RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group.
//
-// Example: `g.RouteNotFound("/*", func(c *echo.Context) error { return c.NoContent(http.StatusNotFound) })`
-func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+// Example: `g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })`
+func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
return g.Add(RouteNotFound, path, h, m...)
}
-// Add implements `Echo#Add()` for sub-routes within the Group. Panics on error.
-func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
- ri, err := g.AddRoute(Route{
- Method: method,
- Path: path,
- Handler: handler,
- Middlewares: middleware,
- })
- if err != nil {
- panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
- }
- return ri
-}
-
-// AddRoute registers a new Routable with Router
-func (g *Group) AddRoute(route Route) (RouteInfo, error) {
- // Combine middleware into a new slice to avoid accidentally passing the same slice for
+// Add implements `Echo#Add()` for sub-routes within the Group.
+func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
+ // Combine into a new slice to avoid accidentally passing the same slice for
// multiple routes, which would lead to later add() calls overwriting the
// middleware from earlier calls.
- groupRoute := route.WithPrefix(g.prefix, append([]MiddlewareFunc{}, g.middleware...))
- return g.echo.add(groupRoute)
+ m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
+ m = append(m, g.middleware...)
+ m = append(m, middleware...)
+ return g.echo.add(g.host, method, g.prefix+path, handler, m...)
}
diff --git a/group_fs.go b/group_fs.go
new file mode 100644
index 000000000..c1b7ec2d3
--- /dev/null
+++ b/group_fs.go
@@ -0,0 +1,33 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "io/fs"
+ "net/http"
+)
+
+// Static implements `Echo#Static()` for sub-routes within the Group.
+func (g *Group) Static(pathPrefix, fsRoot string) {
+ subFs := MustSubFS(g.echo.Filesystem, fsRoot)
+ g.StaticFS(pathPrefix, subFs)
+}
+
+// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group.
+//
+// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
+// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
+// including `assets/images` as their prefix.
+func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) {
+ g.Add(
+ http.MethodGet,
+ pathPrefix+"*",
+ StaticDirectoryHandler(filesystem, false),
+ )
+}
+
+// FileFS implements `Echo#FileFS()` for sub-routes within the Group.
+func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route {
+ return g.GET(path, StaticFileHandler(file, filesystem), m...)
+}
diff --git a/group_fs_test.go b/group_fs_test.go
new file mode 100644
index 000000000..caa200940
--- /dev/null
+++ b/group_fs_test.go
@@ -0,0 +1,103 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "github.com/stretchr/testify/assert"
+ "io/fs"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "testing"
+)
+
+func TestGroup_FileFS(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenPath string
+ whenFile string
+ whenFS fs.FS
+ givenURL string
+ expectCode int
+ expectStartsWith []byte
+ }{
+ {
+ name: "ok",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "walle.png",
+ givenURL: "/assets/walle",
+ expectCode: http.StatusOK,
+ expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ },
+ {
+ name: "nok, requesting invalid path",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "walle.png",
+ givenURL: "/assets/walle.png",
+ expectCode: http.StatusNotFound,
+ expectStartsWith: []byte(`{"message":"Not Found"}`),
+ },
+ {
+ name: "nok, serving not existent file from filesystem",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "not-existent.png",
+ givenURL: "/assets/walle",
+ expectCode: http.StatusNotFound,
+ expectStartsWith: []byte(`{"message":"Not Found"}`),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ g := e.Group("/assets")
+ g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS)
+
+ req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectCode, rec.Code)
+
+ body := rec.Body.Bytes()
+ if len(body) > len(tc.expectStartsWith) {
+ body = body[:len(tc.expectStartsWith)]
+ }
+ assert.Equal(t, tc.expectStartsWith, body)
+ })
+ }
+}
+
+func TestGroup_StaticPanic(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenRoot string
+ }{
+ {
+ name: "panics for ../",
+ givenRoot: "../images",
+ },
+ {
+ name: "panics for /",
+ givenRoot: "/images",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ e.Filesystem = os.DirFS("./")
+
+ g := e.Group("/assets")
+
+ assert.Panics(t, func() {
+ g.Static("/images", tc.givenRoot)
+ })
+ })
+ }
+}
diff --git a/group_method_handling_test.go b/group_method_handling_test.go
deleted file mode 100644
index 4a8cf0979..000000000
--- a/group_method_handling_test.go
+++ /dev/null
@@ -1,115 +0,0 @@
-// SPDX-License-Identifier: MIT
-// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
-
-package echo
-
-import (
- "net/http"
- "net/http/httptest"
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-// These tests lock in v5's method-handling semantics for routes registered through
-// a Group. v5 resolves method mismatches (405) and OPTIONS at the router level and
-// does NOT register any implicit per-group catch-all route.
-//
-// They double as a regression gate. Registering a group-level catch-all β whether
-// manually via g.RouteNotFound("/*", ...) or automatically (as proposed in #2996 to
-// fix CORS-on-group preflight) β makes that catch-all match every method, which masks
-// both 405 and v5's automatic OPTIONS response as 404 β demonstrated directly by
-// TestGroupRoute_catchAllMasksMethodHandling below. If that masking becomes the
-// default (e.g. #2996 lands), the first two tests below fail.
-
-// A method mismatch on an existing group route must return 405 with the allowed
-// methods, not be masked to 404.
-func TestGroupRoute_methodMismatchReturns405(t *testing.T) {
- e := New()
- g := e.Group("/api")
- g.GET("/users", func(c *Context) error { return c.String(http.StatusOK, "users") })
-
- req := httptest.NewRequest(http.MethodPost, "/api/users", nil)
- rec := httptest.NewRecorder()
- e.ServeHTTP(rec, req)
-
- assert.Equal(t, http.StatusMethodNotAllowed, rec.Code,
- "POST to a GET-only group route must be 405, not masked to 404")
- assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow),
- "405 response must advertise the allowed methods")
-}
-
-// OPTIONS on an existing group route is answered automatically by Echo (204 +
-// Allow). This is the behavior CORS preflight relies on, so it must not be masked.
-func TestGroupRoute_automaticOPTIONS(t *testing.T) {
- e := New()
- g := e.Group("/api")
- g.GET("/users", func(c *Context) error { return c.String(http.StatusOK, "users") })
-
- req := httptest.NewRequest(http.MethodOptions, "/api/users", nil)
- rec := httptest.NewRecorder()
- e.ServeHTTP(rec, req)
-
- assert.Equal(t, http.StatusNoContent, rec.Code,
- "OPTIONS on a registered group route must be auto-answered (204), not masked to 404")
- assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow),
- "automatic OPTIONS response must advertise the allowed methods")
-}
-
-// A matched concrete route resolves to its own handler; only a genuinely unmatched
-// path under the prefix is a 404.
-func TestGroupRoute_concreteRoutesResolve(t *testing.T) {
- e := New()
- g := e.Group("/api")
- g.GET("/users", func(c *Context) error { return c.String(http.StatusOK, "users") })
-
- status, body := request(http.MethodGet, "/api/users", e)
- assert.Equal(t, http.StatusOK, status)
- assert.Equal(t, "users", body)
-
- status, _ = request(http.MethodGet, "/api/nope", e)
- assert.Equal(t, http.StatusNotFound, status)
-}
-
-// A group prefix must not affect routing of routes registered outside the group.
-func TestGroup_doesNotAffectRootRoutes(t *testing.T) {
- e := New()
- e.GET("/health", func(c *Context) error { return c.String(http.StatusOK, "root") })
- g := e.Group("/api")
- g.GET("/users", func(c *Context) error { return c.String(http.StatusOK, "users") })
-
- status, body := request(http.MethodGet, "/health", e)
- assert.Equal(t, http.StatusOK, status)
- assert.Equal(t, "root", body)
-}
-
-// Characterization of the regression the 405/OPTIONS tests above guard against:
-// registering a group-wide catch-all (the manual equivalent of #2996's auto-
-// registration) makes it match every method, so method mismatches and the automatic
-// OPTIONS response are masked as 404 even though the concrete route still resolves.
-// If a future change teaches the catch-all to preserve method semantics, update this.
-func TestGroupRoute_catchAllMasksMethodHandling(t *testing.T) {
- e := New()
- g := e.Group("/api")
- g.GET("/users", func(c *Context) error { return c.String(http.StatusOK, "users") })
- g.RouteNotFound("/*", func(c *Context) error { return c.NoContent(http.StatusNotFound) })
-
- // The concrete route still resolves.
- status, body := request(http.MethodGet, "/api/users", e)
- assert.Equal(t, http.StatusOK, status)
- assert.Equal(t, "users", body)
-
- // But the catch-all masks the method mismatch (would be 405) ...
- post := httptest.NewRequest(http.MethodPost, "/api/users", nil)
- postRec := httptest.NewRecorder()
- e.ServeHTTP(postRec, post)
- assert.Equal(t, http.StatusNotFound, postRec.Code,
- "a group-wide catch-all masks the 405 method-mismatch as 404")
-
- // ... and the automatic OPTIONS response (would be 204).
- opts := httptest.NewRequest(http.MethodOptions, "/api/users", nil)
- optsRec := httptest.NewRecorder()
- e.ServeHTTP(optsRec, opts)
- assert.Equal(t, http.StatusNotFound, optsRec.Code,
- "a group-wide catch-all masks the automatic OPTIONS (204) response as 404")
-}
diff --git a/group_test.go b/group_test.go
index 7078b6497..78c2ed485 100644
--- a/group_test.go
+++ b/group_test.go
@@ -4,70 +4,31 @@
package echo
import (
- "io/fs"
"net/http"
"net/http/httptest"
"os"
- "strings"
"testing"
"github.com/stretchr/testify/assert"
)
-func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) {
- e := New()
-
- called := false
- mw := func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
- called = true
- return c.NoContent(http.StatusTeapot)
- }
- }
- // even though group has middleware it will not be executed when there are no routes under that group
- _ = e.Group("/group", mw)
-
- status, body := request(http.MethodGet, "/group/nope", e)
- assert.Equal(t, http.StatusNotFound, status)
- assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
-
- assert.False(t, called)
-}
-
-func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) {
- e := New()
-
- called := false
- mw := func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
- called = true
- return c.NoContent(http.StatusTeapot)
- }
- }
- // even though group has middleware and routes when we have no match on some route the middlewares for that
- // group will not be executed
- g := e.Group("/group", mw)
- g.GET("/yes", handlerFunc)
-
- status, body := request(http.MethodGet, "/group/nope", e)
- assert.Equal(t, http.StatusNotFound, status)
- assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
-
- assert.False(t, called)
-}
-
-func TestGroup_multiLevelGroup(t *testing.T) {
- e := New()
-
- api := e.Group("/api")
- users := api.Group("/users")
- users.GET("/activate", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
-
- status, body := request(http.MethodGet, "/api/users/activate", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, `OK`, body)
+// TODO: Fix me
+func TestGroup(t *testing.T) {
+ g := New().Group("/group")
+ h := func(Context) error { return nil }
+ g.CONNECT("/", h)
+ g.DELETE("/", h)
+ g.GET("/", h)
+ g.HEAD("/", h)
+ g.OPTIONS("/", h)
+ g.PATCH("/", h)
+ g.POST("/", h)
+ g.PUT("/", h)
+ g.TRACE("/", h)
+ g.Any("/", h)
+ g.Match([]string{http.MethodGet, http.MethodPost}, "/", h)
+ g.Static("/static", "/tmp")
+ g.File("/walle", "_fixture/images//walle.png")
}
func TestGroupFile(t *testing.T) {
@@ -87,29 +48,29 @@ func TestGroupRouteMiddleware(t *testing.T) {
// Ensure middleware slices are not re-used
e := New()
g := e.Group("/group")
- h := func(*Context) error { return nil }
+ h := func(Context) error { return nil }
m1 := func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
+ return func(c Context) error {
return next(c)
}
}
m2 := func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
+ return func(c Context) error {
return next(c)
}
}
m3 := func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
+ return func(c Context) error {
return next(c)
}
}
m4 := func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
+ return func(c Context) error {
return c.NoContent(404)
}
}
m5 := func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
+ return func(c Context) error {
return c.NoContent(405)
}
}
@@ -128,17 +89,17 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
e := New()
g := e.Group("/group")
m1 := func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
+ return func(c Context) error {
return next(c)
}
}
m2 := func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
- return c.String(http.StatusOK, c.RouteInfo().Path)
+ return func(c Context) error {
+ return c.String(http.StatusOK, c.Path())
}
}
- h := func(c *Context) error {
- return c.String(http.StatusOK, c.RouteInfo().Path)
+ h := func(c Context) error {
+ return c.String(http.StatusOK, c.Path())
}
g.Use(m1)
g.GET("/help", h, m2)
@@ -162,155 +123,11 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
}
-func TestGroup_CONNECT(t *testing.T) {
- e := New()
-
- users := e.Group("/users")
- ri := users.CONNECT("/activate", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
-
- assert.Equal(t, http.MethodConnect, ri.Method)
- assert.Equal(t, "/users/activate", ri.Path)
- assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodConnect, "/users/activate", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, `OK`, body)
-}
-
-func TestGroup_DELETE(t *testing.T) {
- e := New()
-
- users := e.Group("/users")
- ri := users.DELETE("/activate", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
-
- assert.Equal(t, http.MethodDelete, ri.Method)
- assert.Equal(t, "/users/activate", ri.Path)
- assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodDelete, "/users/activate", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, `OK`, body)
-}
-
-func TestGroup_HEAD(t *testing.T) {
- e := New()
-
- users := e.Group("/users")
- ri := users.HEAD("/activate", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
-
- assert.Equal(t, http.MethodHead, ri.Method)
- assert.Equal(t, "/users/activate", ri.Path)
- assert.Equal(t, http.MethodHead+":/users/activate", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodHead, "/users/activate", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, `OK`, body)
-}
-
-func TestGroup_OPTIONS(t *testing.T) {
- e := New()
-
- users := e.Group("/users")
- ri := users.OPTIONS("/activate", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
-
- assert.Equal(t, http.MethodOptions, ri.Method)
- assert.Equal(t, "/users/activate", ri.Path)
- assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodOptions, "/users/activate", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, `OK`, body)
-}
-
-func TestGroup_PATCH(t *testing.T) {
- e := New()
-
- users := e.Group("/users")
- ri := users.PATCH("/activate", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
-
- assert.Equal(t, http.MethodPatch, ri.Method)
- assert.Equal(t, "/users/activate", ri.Path)
- assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodPatch, "/users/activate", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, `OK`, body)
-}
-
-func TestGroup_POST(t *testing.T) {
- e := New()
-
- users := e.Group("/users")
- ri := users.POST("/activate", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
-
- assert.Equal(t, http.MethodPost, ri.Method)
- assert.Equal(t, "/users/activate", ri.Path)
- assert.Equal(t, http.MethodPost+":/users/activate", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodPost, "/users/activate", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, `OK`, body)
-}
-
-func TestGroup_PUT(t *testing.T) {
- e := New()
-
- users := e.Group("/users")
- ri := users.PUT("/activate", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
-
- assert.Equal(t, http.MethodPut, ri.Method)
- assert.Equal(t, "/users/activate", ri.Path)
- assert.Equal(t, http.MethodPut+":/users/activate", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodPut, "/users/activate", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, `OK`, body)
-}
-
-func TestGroup_TRACE(t *testing.T) {
- e := New()
-
- users := e.Group("/users")
- ri := users.TRACE("/activate", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
-
- assert.Equal(t, http.MethodTrace, ri.Method)
- assert.Equal(t, "/users/activate", ri.Path)
- assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodTrace, "/users/activate", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, `OK`, body)
-}
-
func TestGroup_RouteNotFound(t *testing.T) {
var testCases = []struct {
- expectRoute any
name string
whenURL string
+ expectRoute any
expectCode int
}{
{
@@ -344,10 +161,10 @@ func TestGroup_RouteNotFound(t *testing.T) {
e := New()
g := e.Group("/group")
- okHandler := func(c *Context) error {
+ okHandler := func(c Context) error {
return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
}
- notFoundHandler := func(c *Context) error {
+ notFoundHandler := func(c Context) error {
return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path())
}
@@ -371,416 +188,44 @@ func TestGroup_RouteNotFound(t *testing.T) {
}
}
-func TestGroup_Any(t *testing.T) {
- e := New()
-
- users := e.Group("/users")
- ri := users.Any("/activate", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK from ANY")
- })
-
- assert.Equal(t, RouteAny, ri.Method)
- assert.Equal(t, "/users/activate", ri.Path)
- assert.Equal(t, RouteAny+":/users/activate", ri.Name)
- assert.Nil(t, ri.Parameters)
-
- status, body := request(http.MethodTrace, "/users/activate", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, `OK from ANY`, body)
-}
-
-func TestGroup_Match(t *testing.T) {
- e := New()
-
- myMethods := []string{http.MethodGet, http.MethodPost}
- users := e.Group("/users")
- ris := users.Match(myMethods, "/activate", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
- assert.Len(t, ris, 2)
-
- for _, m := range myMethods {
- status, body := request(m, "/users/activate", e)
- assert.Equal(t, http.StatusTeapot, status)
- assert.Equal(t, `OK`, body)
- }
-}
-
-func TestGroup_MatchWithErrors(t *testing.T) {
- e := New()
-
- users := e.Group("/users")
- users.GET("/activate", func(c *Context) error {
- return c.String(http.StatusOK, "OK")
- })
- myMethods := []string{http.MethodGet, http.MethodPost}
-
- errs := func() (errs []error) {
- defer func() {
- if r := recover(); r != nil {
- if tmpErr, ok := r.([]error); ok {
- errs = tmpErr
- return
- }
- panic(r)
- }
- }()
-
- users.Match(myMethods, "/activate", func(c *Context) error {
- return c.String(http.StatusTeapot, "OK")
- })
- return nil
- }()
- assert.Len(t, errs, 1)
- assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed")
-
- for _, m := range myMethods {
- status, body := request(m, "/users/activate", e)
-
- expect := http.StatusTeapot
- if m == http.MethodGet {
- expect = http.StatusOK
- }
- assert.Equal(t, expect, status)
- assert.Equal(t, `OK`, body)
- }
-}
-
-func TestGroup_Static(t *testing.T) {
- e := New()
-
- g := e.Group("/books")
- ri := g.Static("/download", "_fixture")
- assert.Equal(t, http.MethodGet, ri.Method)
- assert.Equal(t, "/books/download*", ri.Path)
- assert.Equal(t, "GET:/books/download*", ri.Name)
- assert.Equal(t, []string{"*"}, ri.Parameters)
-
- req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil)
- rec := httptest.NewRecorder()
- e.ServeHTTP(rec, req)
-
- assert.Equal(t, http.StatusOK, rec.Code)
- body := rec.Body.String()
- assert.True(t, strings.HasPrefix(body, ""))
-}
-
-func TestGroup_StaticMultiTest(t *testing.T) {
- var testCases = []struct {
- name string
- givenPrefix string
- givenRoot string
- whenURL string
- expectHeaderLocation string
- expectBodyStartsWith string
- expectBodyNotContains string
- expectStatus int
- }{
- {
- name: "ok",
- givenPrefix: "/images",
- givenRoot: "_fixture/images",
- whenURL: "/test/images/walle.png",
- expectStatus: http.StatusOK,
- expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
- },
- {
- name: "ok, without prefix",
- givenPrefix: "",
- givenRoot: "_fixture/images",
- whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png`
- expectStatus: http.StatusOK,
- expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
- },
- {
- name: "nok, without prefix does not serve dir index",
- givenPrefix: "",
- givenRoot: "_fixture/images",
- whenURL: "/test/", // `/test` + `*` creates route `/test*`
- expectStatus: http.StatusNotFound,
- expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
- },
- {
- name: "No file",
- givenPrefix: "/images",
- givenRoot: "_fixture/scripts",
- whenURL: "/test/images/bolt.png",
- expectStatus: http.StatusNotFound,
- expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
- },
- {
- name: "Directory",
- givenPrefix: "/images",
- givenRoot: "_fixture/images",
- whenURL: "/test/images/",
- expectStatus: http.StatusNotFound,
- expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
- },
- {
- name: "Directory Redirect",
- givenPrefix: "/",
- givenRoot: "_fixture",
- whenURL: "/test/folder",
- expectStatus: http.StatusMovedPermanently,
- expectHeaderLocation: "/test/folder/",
- expectBodyStartsWith: "",
- },
- {
- name: "Directory Redirect with non-root path",
- givenPrefix: "/static",
- givenRoot: "_fixture",
- whenURL: "/test/static",
- expectStatus: http.StatusMovedPermanently,
- expectHeaderLocation: "/test/static/",
- expectBodyStartsWith: "",
- },
- {
- name: "Prefixed directory 404 (request URL without slash)",
- givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
- givenRoot: "_fixture",
- whenURL: "/test/folder", // no trailing slash
- expectStatus: http.StatusNotFound,
- expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
- },
- {
- name: "Prefixed directory redirect (without slash redirect to slash)",
- givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
- givenRoot: "_fixture",
- whenURL: "/test/folder", // no trailing slash
- expectStatus: http.StatusMovedPermanently,
- expectHeaderLocation: "/test/folder/",
- expectBodyStartsWith: "",
- },
- {
- name: "Directory with index.html",
- givenPrefix: "/",
- givenRoot: "_fixture",
- whenURL: "/test/",
- expectStatus: http.StatusOK,
- expectBodyStartsWith: "",
- },
- {
- name: "Prefixed directory with index.html (prefix ending with slash)",
- givenPrefix: "/assets/",
- givenRoot: "_fixture",
- whenURL: "/test/assets/",
- expectStatus: http.StatusOK,
- expectBodyStartsWith: "",
- },
- {
- name: "Prefixed directory with index.html (prefix ending without slash)",
- givenPrefix: "/assets",
- givenRoot: "_fixture",
- whenURL: "/test/assets/",
- expectStatus: http.StatusOK,
- expectBodyStartsWith: "",
- },
- {
- name: "Sub-directory with index.html",
- givenPrefix: "/",
- givenRoot: "_fixture",
- whenURL: "/test/folder/",
- expectStatus: http.StatusOK,
- expectBodyStartsWith: "",
- },
- {
- name: "nok, URL encoded path traversal (single encoding, slash - unix separator)",
- givenRoot: "_fixture/dist/public",
- whenURL: "/%2e%2e%2fprivate.txt",
- expectStatus: http.StatusNotFound,
- expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
- expectBodyNotContains: `private file`,
- },
- {
- name: "nok, URL encoded path traversal (single encoding, backslash - windows separator)",
- givenRoot: "_fixture/dist/public",
- whenURL: "/%2e%2e%5cprivate.txt",
- expectStatus: http.StatusNotFound,
- expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
- expectBodyNotContains: `private file`,
- },
- {
- name: "do not allow directory traversal (backslash - windows separator)",
- givenPrefix: "/",
- givenRoot: "_fixture/",
- whenURL: `/test/..\\middleware/basic_auth.go`,
- expectStatus: http.StatusNotFound,
- expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
- },
- {
- name: "do not allow directory traversal (slash - unix separator)",
- givenPrefix: "/",
- givenRoot: "_fixture/",
- whenURL: `/test/../middleware/basic_auth.go`,
- expectStatus: http.StatusNotFound,
- expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := New()
-
- g := e.Group("/test")
- g.Static(tc.givenPrefix, tc.givenRoot)
-
- req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
- rec := httptest.NewRecorder()
- e.ServeHTTP(rec, req)
-
- assert.Equal(t, tc.expectStatus, rec.Code)
- body := rec.Body.String()
- if tc.expectBodyStartsWith != "" {
- assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
- } else {
- assert.Equal(t, "", body)
- }
- if tc.expectBodyNotContains != "" {
- assert.NotContains(t, body, tc.expectBodyNotContains)
- }
-
- if tc.expectHeaderLocation != "" {
- assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
- } else {
- _, ok := rec.Result().Header["Location"]
- assert.False(t, ok)
- }
- })
- }
-}
-
-func TestGroup_FileFS(t *testing.T) {
- var testCases = []struct {
- whenFS fs.FS
- name string
- whenPath string
- whenFile string
- givenURL string
- expectStartsWith []byte
- expectCode int
- }{
- {
- name: "ok",
- whenPath: "/walle",
- whenFS: os.DirFS("_fixture/images"),
- whenFile: "walle.png",
- givenURL: "/assets/walle",
- expectCode: http.StatusOK,
- expectStartsWith: []byte{0x89, 0x50, 0x4e},
- },
- {
- name: "nok, requesting invalid path",
- whenPath: "/walle",
- whenFS: os.DirFS("_fixture/images"),
- whenFile: "walle.png",
- givenURL: "/assets/walle.png",
- expectCode: http.StatusNotFound,
- expectStartsWith: []byte(`{"message":"Not Found"}`),
- },
- {
- name: "nok, serving not existent file from filesystem",
- whenPath: "/walle",
- whenFS: os.DirFS("_fixture/images"),
- whenFile: "not-existent.png",
- givenURL: "/assets/walle",
- expectCode: http.StatusNotFound,
- expectStartsWith: []byte(`{"message":"Not Found"}`),
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := New()
- g := e.Group("/assets")
- g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS)
-
- req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
- rec := httptest.NewRecorder()
-
- e.ServeHTTP(rec, req)
-
- assert.Equal(t, tc.expectCode, rec.Code)
-
- body := rec.Body.Bytes()
- if len(body) > len(tc.expectStartsWith) {
- body = body[:len(tc.expectStartsWith)]
- }
- assert.Equal(t, tc.expectStartsWith, body)
- })
- }
-}
-
-func TestGroup_StaticPanic(t *testing.T) {
- var testCases = []struct {
- name string
- givenRoot string
- }{
- {
- name: "panics for ../",
- givenRoot: "../images",
- },
- {
- name: "panics for /",
- givenRoot: "/images",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := New()
- e.Filesystem = os.DirFS("./")
-
- g := e.Group("/assets")
-
- assert.Panics(t, func() {
- g.Static("/images", tc.givenRoot)
- })
- })
- }
-}
-
func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) {
var testCases = []struct {
- expectBody any
- name string
- whenURL string
- expectCode int
- givenCustom404 bool
- expectMiddlewareCalled bool
+ name string
+ givenCustom404 bool
+ whenURL string
+ expectBody any
+ expectCode int
}{
{
- name: "ok, custom 404 handler is called with middleware",
- givenCustom404: true,
- whenURL: "/group/test3",
- expectBody: "404 GET /group/*",
- expectCode: http.StatusNotFound,
- expectMiddlewareCalled: true, // because RouteNotFound is added after middleware is added
+ name: "ok, custom 404 handler is called with middleware",
+ givenCustom404: true,
+ whenURL: "/group/test3",
+ expectBody: "GET /group/*",
+ expectCode: http.StatusNotFound,
},
{
- name: "ok, default group 404 handler is not called with middleware",
- givenCustom404: false,
- whenURL: "/group/test3",
- expectBody: "404 GET /*",
- expectCode: http.StatusNotFound,
- expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added
+ name: "ok, default group 404 handler is called with middleware",
+ givenCustom404: false,
+ whenURL: "/group/test3",
+ expectBody: "{\"message\":\"Not Found\"}\n",
+ expectCode: http.StatusNotFound,
},
{
- name: "ok, (no slash) default group 404 handler is called with middleware",
- givenCustom404: false,
- whenURL: "/group",
- expectBody: "404 GET /*",
- expectCode: http.StatusNotFound,
- expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added
+ name: "ok, (no slash) default group 404 handler is called with middleware",
+ givenCustom404: false,
+ whenURL: "/group",
+ expectBody: "{\"message\":\"Not Found\"}\n",
+ expectCode: http.StatusNotFound,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- okHandler := func(c *Context) error {
+ okHandler := func(c Context) error {
return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
}
- notFoundHandler := func(c *Context) error {
- return c.String(http.StatusNotFound, "404 "+c.Request().Method+" "+c.Path())
+ notFoundHandler := func(c Context) error {
+ return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path())
}
e := New()
@@ -792,7 +237,7 @@ func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) {
middlewareCalled := false
g.Use(func(next HandlerFunc) HandlerFunc {
- return func(c *Context) error {
+ return func(c Context) error {
middlewareCalled = true
return next(c)
}
@@ -806,7 +251,7 @@ func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) {
e.ServeHTTP(rec, req)
- assert.Equal(t, tc.expectMiddlewareCalled, middlewareCalled)
+ assert.True(t, middlewareCalled)
assert.Equal(t, tc.expectCode, rec.Code)
assert.Equal(t, tc.expectBody, rec.Body.String())
})
diff --git a/httperror.go b/httperror.go
deleted file mode 100644
index 8cb10c8ef..000000000
--- a/httperror.go
+++ /dev/null
@@ -1,162 +0,0 @@
-// SPDX-License-Identifier: MIT
-// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
-
-package echo
-
-import (
- "errors"
- "fmt"
- "net/http"
-)
-
-// The following errors can produce HTTP status code by implementing HTTPStatusCoder interface
-var (
- ErrBadRequest = &httpError{http.StatusBadRequest} // 400
- ErrUnauthorized = &httpError{http.StatusUnauthorized} // 401
- ErrForbidden = &httpError{http.StatusForbidden} // 403
- ErrNotFound = &httpError{http.StatusNotFound} // 404
- ErrMethodNotAllowed = &httpError{http.StatusMethodNotAllowed} // 405
- ErrRequestTimeout = &httpError{http.StatusRequestTimeout} // 408
- ErrStatusRequestEntityTooLarge = &httpError{http.StatusRequestEntityTooLarge} // 413
- ErrUnsupportedMediaType = &httpError{http.StatusUnsupportedMediaType} // 415
- ErrTooManyRequests = &httpError{http.StatusTooManyRequests} // 429
- ErrInternalServerError = &httpError{http.StatusInternalServerError} // 500
- ErrBadGateway = &httpError{http.StatusBadGateway} // 502
- ErrServiceUnavailable = &httpError{http.StatusServiceUnavailable} // 503
-)
-
-// The following errors fall into 500 (InternalServerError) category
-var (
- ErrValidatorNotRegistered = errors.New("validator not registered")
- ErrRendererNotRegistered = errors.New("renderer not registered")
- ErrInvalidRedirectCode = errors.New("invalid redirect status code")
- ErrCookieNotFound = errors.New("cookie not found")
- ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
- ErrInvalidListenerNetwork = errors.New("invalid listener network")
-)
-
-// HTTPStatusCoder is an interface that errors can implement to produce status code for HTTP response
-type HTTPStatusCoder interface {
- StatusCode() int
-}
-
-// StatusCode returns status code from err if it implements HTTPStatusCoder interface.
-// If err does not implement the interface, it returns 0.
-func StatusCode(err error) int {
- var sc HTTPStatusCoder
- if errors.As(err, &sc) {
- return sc.StatusCode()
- }
- return 0
-}
-
-// ResolveResponseStatus returns the Response and HTTP status code that should be (or has been) sent for rw,
-// given an optional error.
-//
-// This function is useful for middleware and handlers that need to figure out the HTTP status
-// code to return based on the error that occurred or what was set in the response.
-//
-// Precedence rules:
-// 1. If the response has already been committed, the committed status wins (err is ignored).
-// 2. Otherwise, start with 200 OK (net/http default if WriteHeader is never called).
-// 3. If the response has a non-zero suggested status, use it.
-// 4. If err != nil, it overrides the suggested status:
-// - StatusCode(err) if non-zero
-// - otherwise 500 Internal Server Error.
-func ResolveResponseStatus(rw http.ResponseWriter, err error) (resp *Response, status int) {
- resp, _ = UnwrapResponse(rw)
-
- // once committed (sent to the client), the wire status is fixed; err cannot change it.
- if resp != nil && resp.Committed {
- if resp.Status == 0 {
- // unlikely path, but fall back to net/http implicit default if handler never calls WriteHeader
- return resp, http.StatusOK
- }
- return resp, resp.Status
- }
-
- // net/http implicit default if handler never calls WriteHeader.
- status = http.StatusOK
-
- // suggested status written from middleware/handlers, if present.
- if resp != nil && resp.Status != 0 {
- status = resp.Status
- }
-
- // error overrides suggested status (matches typical Echo error-handler semantics).
- if err != nil {
- if s := StatusCode(err); s != 0 {
- status = s
- } else {
- status = http.StatusInternalServerError
- }
- }
-
- return resp, status
-}
-
-// NewHTTPError creates a new instance of HTTPError
-func NewHTTPError(code int, message string) *HTTPError {
- return &HTTPError{
- Code: code,
- Message: message,
- }
-}
-
-// HTTPError represents an error that occurred while handling a request.
-type HTTPError struct {
- // Code is status code for HTTP response
- Code int `json:"-"`
- Message string `json:"message"`
- err error
-}
-
-// StatusCode returns status code for HTTP response
-func (he *HTTPError) StatusCode() int {
- return he.Code
-}
-
-// Error makes it compatible with the ` error ` interface.
-func (he *HTTPError) Error() string {
- msg := he.Message
- if msg == "" {
- msg = http.StatusText(he.Code)
- }
- if he.err == nil {
- return fmt.Sprintf("code=%d, message=%v", he.Code, msg)
- }
- return fmt.Sprintf("code=%d, message=%v, err=%v", he.Code, msg, he.err.Error())
-}
-
-// Wrap returns a new HTTPError with given errors wrapped inside
-func (he HTTPError) Wrap(err error) error {
- return &HTTPError{
- Code: he.Code,
- Message: he.Message,
- err: err,
- }
-}
-
-func (he *HTTPError) Unwrap() error {
- return he.err
-}
-
-type httpError struct {
- code int
-}
-
-func (he httpError) StatusCode() int {
- return he.code
-}
-
-func (he httpError) Error() string {
- return http.StatusText(he.code) // does not include status code
-}
-
-func (he httpError) Wrap(err error) error {
- return &HTTPError{
- Code: he.code,
- Message: http.StatusText(he.code),
- err: err,
- }
-}
diff --git a/httperror_external_test.go b/httperror_external_test.go
deleted file mode 100644
index 91acdca25..000000000
--- a/httperror_external_test.go
+++ /dev/null
@@ -1,52 +0,0 @@
-// SPDX-License-Identifier: MIT
-// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
-
-// run tests as external package to get real feel for API
-package echo_test
-
-import (
- "encoding/json"
- "fmt"
- "github.com/labstack/echo/v5"
- "net/http"
- "net/http/httptest"
-)
-
-func ExampleDefaultHTTPErrorHandler() {
- e := echo.New()
- e.GET("/api/endpoint", func(c *echo.Context) error {
- return &apiError{
- Code: http.StatusBadRequest,
- Body: map[string]any{"message": "custom error"},
- }
- })
-
- req := httptest.NewRequest(http.MethodGet, "/api/endpoint?err=1", nil)
- resp := httptest.NewRecorder()
-
- e.ServeHTTP(resp, req)
-
- fmt.Printf("%d %s", resp.Code, resp.Body.String())
-
- // Output: 400 {"error":{"message":"custom error"}}
-}
-
-type apiError struct {
- Code int
- Body any
-}
-
-func (e *apiError) StatusCode() int {
- return e.Code
-}
-
-func (e *apiError) MarshalJSON() ([]byte, error) {
- type body struct {
- Error any `json:"error"`
- }
- return json.Marshal(body{Error: e.Body})
-}
-
-func (e *apiError) Error() string {
- return http.StatusText(e.Code)
-}
diff --git a/httperror_test.go b/httperror_test.go
deleted file mode 100644
index 778a186ce..000000000
--- a/httperror_test.go
+++ /dev/null
@@ -1,186 +0,0 @@
-// SPDX-License-Identifier: MIT
-// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
-
-package echo
-
-import (
- "errors"
- "fmt"
- "net/http"
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-func TestHTTPError_StatusCode(t *testing.T) {
- var err error = &HTTPError{Code: http.StatusBadRequest, Message: "my error message"}
-
- code := 0
- var sc HTTPStatusCoder
- if errors.As(err, &sc) {
- code = sc.StatusCode()
- }
- assert.Equal(t, http.StatusBadRequest, code)
-}
-
-func TestHTTPError_Error(t *testing.T) {
- var testCases = []struct {
- name string
- error error
- expect string
- }{
- {
- name: "ok, without message",
- error: &HTTPError{Code: http.StatusBadRequest},
- expect: "code=400, message=Bad Request",
- },
- {
- name: "ok, with message",
- error: &HTTPError{Code: http.StatusBadRequest, Message: "my error message"},
- expect: "code=400, message=my error message",
- },
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- assert.Equal(t, tc.expect, tc.error.Error())
- })
- }
-}
-
-func TestHTTPError_WrapUnwrap(t *testing.T) {
- err := &HTTPError{Code: http.StatusBadRequest, Message: "bad"}
- wrapped := err.Wrap(errors.New("my_error")).(*HTTPError)
-
- err.Code = http.StatusOK
- err.Message = "changed"
-
- assert.Equal(t, http.StatusBadRequest, wrapped.Code)
- assert.Equal(t, "bad", wrapped.Message)
-
- assert.Equal(t, errors.New("my_error"), wrapped.Unwrap())
- assert.Equal(t, "code=400, message=bad, err=my_error", wrapped.Error())
-}
-
-func TestNewHTTPError(t *testing.T) {
- err := NewHTTPError(http.StatusBadRequest, "bad")
- err2 := &HTTPError{Code: http.StatusBadRequest, Message: "bad"}
-
- assert.Equal(t, err2, err)
-}
-
-func TestStatusCode(t *testing.T) {
- var testCases = []struct {
- name string
- err error
- expect int
- }{
- {
- name: "ok, HTTPError",
- err: &HTTPError{Code: http.StatusNotFound},
- expect: http.StatusNotFound,
- },
- {
- name: "ok, sentinel error",
- err: ErrNotFound,
- expect: http.StatusNotFound,
- },
- {
- name: "ok, wrapped HTTPError",
- err: fmt.Errorf("wrapped: %w", &HTTPError{Code: http.StatusTeapot}),
- expect: http.StatusTeapot,
- },
- {
- name: "nok, normal error",
- err: errors.New("error"),
- expect: 0,
- },
- {
- name: "nok, nil",
- err: nil,
- expect: 0,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- assert.Equal(t, tc.expect, StatusCode(tc.err))
- })
- }
-}
-
-func TestResolveResponseStatus(t *testing.T) {
- someErr := errors.New("some error")
-
- var testCases = []struct {
- name string
- whenResp http.ResponseWriter
- whenErr error
- expectStatus int
- expectResp bool
- }{
- {
- name: "nil resp, nil err -> 200",
- whenResp: nil,
- whenErr: nil,
- expectStatus: http.StatusOK,
- expectResp: false,
- },
- {
- name: "resp suggested status used when no error",
- whenResp: &Response{Status: http.StatusCreated},
- whenErr: nil,
- expectStatus: http.StatusCreated,
- expectResp: true,
- },
- {
- name: "error overrides suggested status with StatusCode(err)",
- whenResp: &Response{Status: http.StatusAccepted},
- whenErr: ErrBadRequest,
- expectStatus: http.StatusBadRequest,
- expectResp: true,
- },
- {
- name: "error overrides suggested status with 500 when StatusCode(err)==0",
- whenResp: &Response{Status: http.StatusAccepted},
- whenErr: ErrInternalServerError,
- expectStatus: http.StatusInternalServerError,
- expectResp: true,
- },
- {
- name: "nil resp, error -> 500 when StatusCode(err)==0",
- whenResp: nil,
- whenErr: someErr,
- expectStatus: http.StatusInternalServerError,
- expectResp: false,
- },
- {
- name: "committed response wins over error",
- whenResp: &Response{Committed: true, Status: http.StatusNoContent},
- whenErr: someErr,
- expectStatus: http.StatusNoContent,
- expectResp: true,
- },
- {
- name: "committed response with status 0 falls back to 200 (defensive)",
- whenResp: &Response{Committed: true, Status: 0},
- whenErr: someErr,
- expectStatus: http.StatusOK,
- expectResp: true,
- },
- {
- name: "resp with status 0 and no error -> 200",
- whenResp: &Response{Status: 0},
- whenErr: nil,
- expectStatus: http.StatusOK,
- expectResp: true,
- },
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- resp, status := ResolveResponseStatus(tc.whenResp, tc.whenErr)
-
- assert.Equal(t, tc.expectResp, resp != nil)
- assert.Equal(t, tc.expectStatus, status)
- })
- }
-}
diff --git a/internal/pathutil/pathutil.go b/internal/pathutil/pathutil.go
index e9171752e..9934fa31c 100644
--- a/internal/pathutil/pathutil.go
+++ b/internal/pathutil/pathutil.go
@@ -11,9 +11,9 @@ package pathutil
// Backslash is included as defense-in-depth against Windows-style separators even
// though fs.FS itself only uses forward slashes.
//
-// Such sequences let an attacker smuggle a separator past the router, which by
-// default matches on the raw encoded path, so they must be rejected before
-// unescaping when resolving static files.
+// Such sequences let an attacker smuggle a separator past the router, which
+// matches on the raw encoded path, so they must be rejected before unescaping
+// when resolving static files.
func HasEncodedPathSeparator(s string) bool {
for i := 0; i+2 < len(s); i++ {
if s[i] != '%' {
diff --git a/ip.go b/ip.go
index c864e0689..dce51f55d 100644
--- a/ip.go
+++ b/ip.go
@@ -202,8 +202,8 @@ func (c *ipChecker) trust(ip net.IP) bool {
// See https://echo.labstack.com/guide/ip-address for more details.
type IPExtractor func(*http.Request) string
-// ExtractIPDirect extracts an IP address using an actual IP address.
-// Use this if your server faces to internet directly (i.e.: uses no proxy).
+// ExtractIPDirect extracts IP address using actual IP address.
+// Use this if your server faces to internet directory (i.e.: uses no proxy).
func ExtractIPDirect() IPExtractor {
return extractIP
}
@@ -219,24 +219,30 @@ func extractIP(req *http.Request) string {
return host
}
-// ExtractIPFromRealIPHeader extracts IP address using `x-real-ip` header.
+// ExtractIPFromRealIPHeader extracts IP address using x-real-ip header.
// Use this if you put proxy which uses this header.
func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor {
checker := newIPChecker(options)
return func(req *http.Request) string {
+ directIP := extractIP(req)
realIP := req.Header.Get(HeaderXRealIP)
- if realIP != "" {
+ if realIP == "" {
+ return directIP
+ }
+
+ if checker.trust(net.ParseIP(directIP)) {
realIP = strings.TrimPrefix(realIP, "[")
realIP = strings.TrimSuffix(realIP, "]")
- if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) {
+ if rIP := net.ParseIP(realIP); rIP != nil {
return realIP
}
}
- return extractIP(req)
+
+ return directIP
}
}
-// ExtractIPFromXFFHeader extracts IP address using `x-forwarded-for` header.
+// ExtractIPFromXFFHeader extracts IP address using x-forwarded-for header.
// Use this if you put proxy which uses this header.
// This returns nearest untrustable IP. If all IPs are trustable, returns furthest one (i.e.: XFF[0]).
func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor {
@@ -265,45 +271,3 @@ func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor {
return strings.TrimSpace(ips[0])
}
}
-
-// LegacyIPExtractor returns an IPExtractor that derives the client IP address
-// from common proxy headers, falling back to the request's remote address.
-//
-// Resolution order:
-// 1. X-Forwarded-For: returns the first IP in the comma-separated list.
-// If multiple values are present, only the left-most (original client)
-// is used. Surrounding brackets (for IPv6) are stripped.
-// 2. X-Real-IP: used if X-Forwarded-For is absent. Surrounding brackets
-// (for IPv6) are stripped.
-// 3. req.RemoteAddr: used as a fallback; the host portion is extracted
-// via net.SplitHostPort.
-//
-// Notes:
-// - No validation is performed on header values.
-// - This function trusts headers as-is and is therefore not safe against
-// spoofing unless the application is behind a trusted proxy that is
-// configured to strip/replace/modify headers correctly.
-//
-// Use ExtractIPFromXFFHeader or ExtractIPFromRealIPHeader instead of LegacyIPExtractor.
-func LegacyIPExtractor() IPExtractor {
- return legacyIPExtractor
-}
-
-func legacyIPExtractor(req *http.Request) string {
- if ip := req.Header.Get(HeaderXForwardedFor); ip != "" {
- i := strings.IndexAny(ip, ",")
- if i > 0 {
- ip = strings.TrimSpace(ip[:i])
- }
- ip = strings.TrimPrefix(ip, "[")
- ip = strings.TrimSuffix(ip, "]")
- return ip
- }
- if ip := req.Header.Get(HeaderXRealIP); ip != "" {
- ip = strings.TrimPrefix(ip, "[")
- ip = strings.TrimSuffix(ip, "]")
- return ip
- }
- ra, _, _ := net.SplitHostPort(req.RemoteAddr)
- return ra
-}
diff --git a/ip_test.go b/ip_test.go
index b20368616..e850b78cb 100644
--- a/ip_test.go
+++ b/ip_test.go
@@ -22,8 +22,8 @@ func mustParseCIDR(s string) *net.IPNet {
func TestIPChecker_TrustOption(t *testing.T) {
var testCases = []struct {
name string
- whenIP string
givenOptions []TrustOption
+ whenIP string
expect bool
}{
{
@@ -490,14 +490,14 @@ func TestExtractIPDirect(t *testing.T) {
}
func TestExtractIPFromRealIPHeader(t *testing.T) {
- _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24")
+ _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.0/24")
_, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64")
var testCases = []struct {
- whenRequest http.Request
name string
- expectIP string
givenTrustOptions []TrustOption
+ whenRequest http.Request
+ expectIP string
}{
{
name: "request has no headers, extracts IP from request remote addr",
@@ -518,36 +518,42 @@ func TestExtractIPFromRealIPHeader(t *testing.T) {
},
{
name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr",
+ givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
+ TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
+ },
whenRequest: http.Request{
Header: http.Header{
- HeaderXRealIP: []string{"203.0.113.199"}, // <-- this is untrusted
+ HeaderXRealIP: []string{"203.0.113.199"},
},
- RemoteAddr: "203.0.113.1:8080",
+ RemoteAddr: "8.8.8.8:8080", // <-- this is untrusted
},
- expectIP: "203.0.113.1",
+ expectIP: "8.8.8.8",
},
{
name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr",
+ givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
+ TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
+ },
whenRequest: http.Request{
Header: http.Header{
- HeaderXRealIP: []string{"[2001:db8::113:199]"}, // <-- this is untrusted
+ HeaderXRealIP: []string{"[bc01:1010::9090:1888]"},
},
- RemoteAddr: "[2001:db8::113:1]:8080",
+ RemoteAddr: "[fe64:aa10::1]:8080", // <-- this is untrusted
},
- expectIP: "2001:db8::113:1",
+ expectIP: "fe64:aa10::1",
},
{
name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
- TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
+ TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.0/24"
},
whenRequest: http.Request{
Header: http.Header{
- HeaderXRealIP: []string{"203.0.113.199"},
+ HeaderXRealIP: []string{"8.8.8.8"},
},
RemoteAddr: "203.0.113.1:8080",
},
- expectIP: "203.0.113.199",
+ expectIP: "8.8.8.8",
},
{
name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
@@ -556,11 +562,11 @@ func TestExtractIPFromRealIPHeader(t *testing.T) {
},
whenRequest: http.Request{
Header: http.Header{
- HeaderXRealIP: []string{"[2001:db8::113:199]"},
+ HeaderXRealIP: []string{"[fe64:db8::113:199]"},
},
RemoteAddr: "[2001:db8::113:1]:8080",
},
- expectIP: "2001:db8::113:199",
+ expectIP: "fe64:db8::113:199",
},
{
name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
@@ -569,12 +575,12 @@ func TestExtractIPFromRealIPHeader(t *testing.T) {
},
whenRequest: http.Request{
Header: http.Header{
- HeaderXRealIP: []string{"203.0.113.199"},
- HeaderXForwardedFor: []string{"203.0.113.198, 203.0.113.197"}, // <-- should not affect anything
+ HeaderXRealIP: []string{"8.8.8.8"},
+ HeaderXForwardedFor: []string{"1.1.1.1 ,8.8.8.8"}, // <-- should not affect anything
},
RemoteAddr: "203.0.113.1:8080",
},
- expectIP: "203.0.113.199",
+ expectIP: "8.8.8.8",
},
{
name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
@@ -583,12 +589,12 @@ func TestExtractIPFromRealIPHeader(t *testing.T) {
},
whenRequest: http.Request{
Header: http.Header{
- HeaderXRealIP: []string{"[2001:db8::113:199]"},
- HeaderXForwardedFor: []string{"[2001:db8::113:198], [2001:db8::113:197]"}, // <-- should not affect anything
+ HeaderXRealIP: []string{"[fe64:db8::113:199]"},
+ HeaderXForwardedFor: []string{"[feab:cde9::113:198], [fe64:db8::113:199]"}, // <-- should not affect anything
},
RemoteAddr: "[2001:db8::113:1]:8080",
},
- expectIP: "2001:db8::113:199",
+ expectIP: "fe64:db8::113:199",
},
}
@@ -605,10 +611,10 @@ func TestExtractIPFromXFFHeader(t *testing.T) {
_, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64")
var testCases = []struct {
- whenRequest http.Request
name string
- expectIP string
givenTrustOptions []TrustOption
+ whenRequest http.Request
+ expectIP string
}{
{
name: "request has no headers, extracts IP from request remote addr",
@@ -714,75 +720,3 @@ func TestExtractIPFromXFFHeader(t *testing.T) {
})
}
}
-
-func TestLegacyIPExtractor(t *testing.T) {
- var testCases = []struct {
- name string
- whenReq *http.Request
- expect string
- expectedError string
- }{
- {
- name: "extract first ip from X-Forwarded-For",
- whenReq: &http.Request{Header: http.Header{"X-Forwarded-For": []string{"203.0.113.10, 198.51.100.7"}}},
- expect: "203.0.113.10",
- },
- {
- name: "extract single ip from X-Forwarded-For",
- whenReq: &http.Request{Header: http.Header{"X-Forwarded-For": []string{"203.0.113.10"}}},
- expect: "203.0.113.10",
- },
- {
- name: "trim brackets from ipv6 in X-Forwarded-For when multiple values",
- whenReq: &http.Request{Header: http.Header{"X-Forwarded-For": []string{"[2001:db8::1], 198.51.100.7"}}},
- expect: "2001:db8::1",
- },
- {
- name: "prefer X-Forwarded-For over X-Real-Ip",
- whenReq: &http.Request{
- Header: http.Header{
- "X-Forwarded-For": []string{"203.0.113.10"},
- "X-Real-Ip": []string{"198.51.100.7"},
- },
- },
- expect: "203.0.113.10",
- },
- {
- name: "extract from X-Real-Ip",
- whenReq: &http.Request{Header: http.Header{"X-Real-Ip": []string{"[2001:db8::1]"}}},
- expect: "2001:db8::1",
- },
- {
- name: "extract plain ipv4 from X-Real-Ip",
- whenReq: &http.Request{Header: http.Header{"X-Real-Ip": []string{"203.0.113.10"}}},
- expect: "203.0.113.10",
- },
- {
- name: "fallback to RemoteAddr host",
- whenReq: &http.Request{RemoteAddr: "203.0.113.10:12345"},
- expect: "203.0.113.10",
- },
- {
- name: "fallback to RemoteAddr ipv6 host",
- whenReq: &http.Request{RemoteAddr: "[2001:db8::1]:12345"},
- expect: "2001:db8::1",
- },
- {
- name: "returns empty string when RemoteAddr is invalid and no headers exist",
- whenReq: &http.Request{RemoteAddr: "not-a-host-port"},
- expect: "",
- },
- {
- name: "trim brackets from single ipv6 in X-Forwarded-For",
- whenReq: &http.Request{Header: http.Header{"X-Forwarded-For": []string{"[2001:db8::1]"}}},
- expect: "2001:db8::1",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- ip := LegacyIPExtractor()(tc.whenReq)
- assert.Equal(t, tc.expect, ip)
- })
- }
-}
diff --git a/json.go b/json.go
index a969ccb8c..589cda55f 100644
--- a/json.go
+++ b/json.go
@@ -5,6 +5,8 @@ package echo
import (
"encoding/json"
+ "fmt"
+ "net/http"
)
// DefaultJSONSerializer implements JSON encoding using encoding/json.
@@ -12,18 +14,21 @@ type DefaultJSONSerializer struct{}
// Serialize converts an interface into a json and writes it to the response.
// You can optionally use the indent parameter to produce pretty JSONs.
-func (d DefaultJSONSerializer) Serialize(c *Context, target any, indent string) error {
+func (d DefaultJSONSerializer) Serialize(c Context, i any, indent string) error {
enc := json.NewEncoder(c.Response())
if indent != "" {
enc.SetIndent("", indent)
}
- return enc.Encode(target)
+ return enc.Encode(i)
}
// Deserialize reads a JSON from a request body and converts it into an interface.
-func (d DefaultJSONSerializer) Deserialize(c *Context, target any) error {
- if err := json.NewDecoder(c.Request().Body).Decode(target); err != nil {
- return ErrBadRequest.Wrap(err)
+func (d DefaultJSONSerializer) Deserialize(c Context, i any) error {
+ err := json.NewDecoder(c.Request().Body).Decode(i)
+ if ute, ok := err.(*json.UnmarshalTypeError); ok {
+ return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err)
+ } else if se, ok := err.(*json.SyntaxError); ok {
+ return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err)
}
- return nil
+ return err
}
diff --git a/json_test.go b/json_test.go
index 1804b3e82..0b15ed1a1 100644
--- a/json_test.go
+++ b/json_test.go
@@ -17,7 +17,7 @@ func TestDefaultJSONCodec_Encode(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", nil)
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
// Echo
assert.Equal(t, e, c.Echo())
@@ -34,15 +34,15 @@ func TestDefaultJSONCodec_Encode(t *testing.T) {
enc := new(DefaultJSONSerializer)
- err := enc.Serialize(c, user{ID: 1, Name: "Jon Snow"}, "")
+ err := enc.Serialize(c, user{1, "Jon Snow"}, "")
if assert.NoError(t, err) {
assert.Equal(t, userJSON+"\n", rec.Body.String())
}
req = httptest.NewRequest(http.MethodPost, "/", nil)
rec = httptest.NewRecorder()
- c = e.NewContext(req, rec)
- err = enc.Serialize(c, user{ID: 1, Name: "Jon Snow"}, " ")
+ c = e.NewContext(req, rec).(*context)
+ err = enc.Serialize(c, user{1, "Jon Snow"}, " ")
if assert.NoError(t, err) {
assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
}
@@ -54,7 +54,7 @@ func TestDefaultJSONCodec_Decode(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
+ c := e.NewContext(req, rec).(*context)
// Echo
assert.Equal(t, e, c.Echo())
@@ -80,10 +80,10 @@ func TestDefaultJSONCodec_Decode(t *testing.T) {
var userUnmarshalSyntaxError = user{}
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(invalidContent))
rec = httptest.NewRecorder()
- c = e.NewContext(req, rec)
+ c = e.NewContext(req, rec).(*context)
err = enc.Deserialize(c, &userUnmarshalSyntaxError)
assert.IsType(t, &HTTPError{}, err)
- assert.EqualError(t, err, "code=400, message=Bad Request, err=invalid character 'i' looking for beginning of value")
+ assert.EqualError(t, err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value")
var userUnmarshalTypeError = struct {
ID string `json:"id"`
@@ -92,9 +92,9 @@ func TestDefaultJSONCodec_Decode(t *testing.T) {
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec = httptest.NewRecorder()
- c = e.NewContext(req, rec)
+ c = e.NewContext(req, rec).(*context)
err = enc.Deserialize(c, &userUnmarshalTypeError)
assert.IsType(t, &HTTPError{}, err)
- assert.EqualError(t, err, "code=400, message=Bad Request, err=json: cannot unmarshal number into Go struct field .id of type string")
+ assert.EqualError(t, err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string")
}
diff --git a/log.go b/log.go
new file mode 100644
index 000000000..42164c325
--- /dev/null
+++ b/log.go
@@ -0,0 +1,41 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "github.com/labstack/gommon/log"
+ "io"
+)
+
+// Logger defines the logging interface.
+type Logger interface {
+ Output() io.Writer
+ SetOutput(w io.Writer)
+ Prefix() string
+ SetPrefix(p string)
+ Level() log.Lvl
+ SetLevel(v log.Lvl)
+ SetHeader(h string)
+ Print(i ...any)
+ Printf(format string, args ...any)
+ Printj(j log.JSON)
+ Debug(i ...any)
+ Debugf(format string, args ...any)
+ Debugj(j log.JSON)
+ Info(i ...any)
+ Infof(format string, args ...any)
+ Infoj(j log.JSON)
+ Warn(i ...any)
+ Warnf(format string, args ...any)
+ Warnj(j log.JSON)
+ Error(i ...any)
+ Errorf(format string, args ...any)
+ Errorj(j log.JSON)
+ Fatal(i ...any)
+ Fatalj(j log.JSON)
+ Fatalf(format string, args ...any)
+ Panic(i ...any)
+ Panicj(j log.JSON)
+ Panicf(format string, args ...any)
+}
diff --git a/middleware/DEVELOPMENT.md b/middleware/DEVELOPMENT.md
deleted file mode 100644
index 77cb226dd..000000000
--- a/middleware/DEVELOPMENT.md
+++ /dev/null
@@ -1,11 +0,0 @@
-# Development Guidelines for middlewares
-
-## Best practices:
-
-* Do not use `panic` in middleware creator functions in case of invalid configuration.
-* In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead
- because previous middlewares up in call chain could have logic for dealing with returned errors.
-* Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they
- want to create middleware with panics or with returning errors on configuration errors.
-* When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same.
-
diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go
index 8a9500a93..4a46098e3 100644
--- a/middleware/basic_auth.go
+++ b/middleware/basic_auth.go
@@ -4,153 +4,105 @@
package middleware
import (
- "bytes"
- "cmp"
"encoding/base64"
- "errors"
+ "net/http"
"strconv"
"strings"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
)
-// BasicAuthConfig defines the config for BasicAuthWithConfig middleware.
-//
-// SECURITY: The Validator function is responsible for securely comparing credentials.
-// See BasicAuthValidator documentation for guidance on preventing timing attacks.
+// BasicAuthConfig defines the config for BasicAuth middleware.
type BasicAuthConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
- // Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers
- // this function would be called once for each header until first valid result is returned
+ // Validator is a function to validate BasicAuth credentials.
// Required.
Validator BasicAuthValidator
- // Realm is a string to define realm attribute of BasicAuthWithConfig.
+ // Realm is a string to define realm attribute of BasicAuth.
// Default value "Restricted".
Realm string
-
- // AllowedCheckLimit set how many headers are allowed to be checked. This is useful
- // environments like corporate test environments with application proxies restricting
- // access to environment with their own auth scheme.
- // Defaults to 1.
- AllowedCheckLimit uint
}
-// BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials.
-//
-// SECURITY WARNING: To prevent timing attacks that could allow attackers to enumerate
-// valid usernames or passwords, validator implementations MUST use constant-time
-// comparison for credential checking. Use crypto/subtle.ConstantTimeCompare instead
-// of standard string equality (==) or switch statements.
-//
-// Example of SECURE implementation:
-//
-// import "crypto/subtle"
-//
-// validator := func(c *echo.Context, username, password string) (bool, error) {
-// // Fetch expected credentials from database/config
-// expectedUser := "admin"
-// expectedPass := "secretpassword"
-//
-// // Use constant-time comparison to prevent timing attacks
-// userMatch := subtle.ConstantTimeCompare([]byte(username), []byte(expectedUser)) == 1
-// passMatch := subtle.ConstantTimeCompare([]byte(password), []byte(expectedPass)) == 1
-//
-// if userMatch && passMatch {
-// return true, nil
-// }
-// return false, nil
-// }
-//
-// Example of INSECURE implementation (DO NOT USE):
-//
-// // VULNERABLE TO TIMING ATTACKS - DO NOT USE
-// validator := func(c *echo.Context, username, password string) (bool, error) {
-// if username == "admin" && password == "secret" { // Timing leak!
-// return true, nil
-// }
-// return false, nil
-// }
-type BasicAuthValidator func(c *echo.Context, user string, password string) (bool, error)
+// BasicAuthValidator defines a function to validate BasicAuth credentials.
+// The function should return a boolean indicating whether the credentials are valid,
+// and an error if any error occurs during the validation process.
+type BasicAuthValidator func(string, string, echo.Context) (bool, error)
const (
basic = "basic"
defaultRealm = "Restricted"
)
+// DefaultBasicAuthConfig is the default BasicAuth middleware config.
+var DefaultBasicAuthConfig = BasicAuthConfig{
+ Skipper: DefaultSkipper,
+ Realm: defaultRealm,
+}
+
// BasicAuth returns an BasicAuth middleware.
//
// For valid credentials it calls the next handler.
// For missing or invalid credentials, it sends "401 - Unauthorized" response.
func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
- return BasicAuthWithConfig(BasicAuthConfig{Validator: fn})
+ c := DefaultBasicAuthConfig
+ c.Validator = fn
+ return BasicAuthWithConfig(c)
}
-// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config.
+// BasicAuthWithConfig returns an BasicAuth middleware with config.
+// See `BasicAuth()`.
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
- return toMiddlewareOrPanic(config)
-}
-
-// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration
-func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ // Defaults
if config.Validator == nil {
- return nil, errors.New("echo basic-auth middleware requires a validator function")
+ panic("echo: basic-auth middleware requires a validator function")
}
if config.Skipper == nil {
- config.Skipper = DefaultSkipper
+ config.Skipper = DefaultBasicAuthConfig.Skipper
}
- realm := defaultRealm
- if config.Realm != "" {
- realm = config.Realm
+ if config.Realm == "" {
+ config.Realm = defaultRealm
}
- realm = strconv.Quote(realm)
- limit := cmp.Or(config.AllowedCheckLimit, 1)
+
+ // Pre-compute the quoted realm for WWW-Authenticate header (RFC 7617)
+ quotedRealm := strconv.Quote(config.Realm)
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) error {
+ return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
- var lastError error
+ auth := c.Request().Header.Get(echo.HeaderAuthorization)
l := len(basic)
- i := uint(0)
- for _, auth := range c.Request().Header[echo.HeaderAuthorization] {
- if i >= limit {
- break
- }
- if len(auth) <= l+1 || !strings.EqualFold(auth[:l], basic) {
- continue
- }
- i++
+ if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) {
// Invalid base64 shouldn't be treated as error
// instead should be treated as invalid client input
- b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
- if errDecode != nil {
- lastError = echo.ErrBadRequest.Wrap(errDecode)
- continue
+ b, err := base64.StdEncoding.DecodeString(auth[l+1:])
+ if err != nil {
+ return echo.NewHTTPError(http.StatusBadRequest).SetInternal(err)
}
- before, after, ok := bytes.Cut(b, []byte{':'})
+
+ cred := string(b)
+ user, pass, ok := strings.Cut(cred, ":")
if ok {
- valid, errValidate := config.Validator(c, string(before), string(after))
- if errValidate != nil {
- lastError = errValidate
+ // Verify credentials
+ valid, err := config.Validator(user, pass, c)
+ if err != nil {
+ return err
} else if valid {
return next(c)
}
}
}
- if lastError != nil {
- return lastError
- }
-
// Need to return `401` for browsers to pop-up login box.
- c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm)
+ // Realm is case-insensitive, so we can use "basic" directly. See RFC 7617.
+ c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+quotedRealm)
return echo.ErrUnauthorized
}
- }, nil
+ }
}
diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go
index 42386354f..2d3192615 100644
--- a/middleware/basic_auth_test.go
+++ b/middleware/basic_auth_test.go
@@ -4,7 +4,6 @@
package middleware
import (
- "crypto/subtle"
"encoding/base64"
"errors"
"net/http"
@@ -12,177 +11,116 @@ import (
"strings"
"testing"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func TestBasicAuth(t *testing.T) {
- validatorFunc := func(c *echo.Context, u, p string) (bool, error) {
- // Use constant-time comparison to prevent timing attacks
- userMatch := subtle.ConstantTimeCompare([]byte(u), []byte("joe")) == 1
- passMatch := subtle.ConstantTimeCompare([]byte(p), []byte("secret")) == 1
+ e := echo.New()
- if userMatch && passMatch {
+ mockValidator := func(u, p string, c echo.Context) (bool, error) {
+ if u == "joe" && p == "secret" {
return true, nil
}
-
- // Special case for testing error handling
- if u == "error" {
- return false, errors.New(p)
- }
-
return false, nil
}
- defaultConfig := BasicAuthConfig{Validator: validatorFunc}
- var testCases = []struct {
- name string
- givenConfig BasicAuthConfig
- whenAuth []string
- expectHeader string
- expectErr string
+ tests := []struct {
+ name string
+ authHeader string
+ expectedCode int
+ expectedAuth string
+ skipperResult bool
+ expectedErr bool
+ expectedErrMsg string
}{
{
- name: "ok",
- givenConfig: defaultConfig,
- whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
- },
- {
- name: "ok, multiple",
- givenConfig: BasicAuthConfig{Validator: validatorFunc, AllowedCheckLimit: 2},
- whenAuth: []string{
- "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")),
- basic + " NOT_BASE64",
- basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
- },
+ name: "Valid credentials",
+ authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
+ expectedCode: http.StatusOK,
},
{
- name: "nok, multiple, valid out of limit",
- givenConfig: BasicAuthConfig{Validator: validatorFunc, AllowedCheckLimit: 1},
- whenAuth: []string{
- "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")),
- basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid_password")),
- // limit only check first and should not check auth below
- basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
- },
- expectHeader: basic + ` realm="Restricted"`,
- expectErr: "Unauthorized",
+ name: "Case-insensitive header scheme",
+ authHeader: strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
+ expectedCode: http.StatusOK,
},
{
- name: "nok, invalid Authorization header",
- givenConfig: defaultConfig,
- whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
- expectHeader: basic + ` realm="Restricted"`,
- expectErr: "Unauthorized",
+ name: "Invalid credentials",
+ authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")),
+ expectedCode: http.StatusUnauthorized,
+ expectedAuth: basic + ` realm="someRealm"`,
+ expectedErr: true,
+ expectedErrMsg: "Unauthorized",
},
{
- name: "nok, not base64 Authorization header",
- givenConfig: defaultConfig,
- whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
- expectErr: "code=400, message=Bad Request, err=illegal base64 data at input byte 3",
+ name: "Invalid base64 string",
+ authHeader: basic + " invalidString",
+ expectedCode: http.StatusBadRequest,
+ expectedErr: true,
+ expectedErrMsg: "Bad Request",
},
{
- name: "nok, missing Authorization header",
- givenConfig: defaultConfig,
- expectHeader: basic + ` realm="Restricted"`,
- expectErr: "Unauthorized",
+ name: "Missing Authorization header",
+ expectedCode: http.StatusUnauthorized,
+ expectedErr: true,
+ expectedErrMsg: "Unauthorized",
},
{
- name: "ok, realm",
- givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
- whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
+ name: "Invalid Authorization header",
+ authHeader: base64.StdEncoding.EncodeToString([]byte("invalid")),
+ expectedCode: http.StatusUnauthorized,
+ expectedErr: true,
+ expectedErrMsg: "Unauthorized",
},
{
- name: "ok, realm, case-insensitive header scheme",
- givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
- whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
- },
- {
- name: "nok, realm, invalid Authorization header",
- givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
- whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
- expectHeader: basic + ` realm="someRealm"`,
- expectErr: "Unauthorized",
- },
- {
- name: "nok, validator func returns an error",
- givenConfig: defaultConfig,
- whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))},
- expectErr: "my_error",
- },
- {
- name: "ok, skipped",
- givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c *echo.Context) bool {
- return true
- }},
- whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
+ name: "Skipped Request",
+ authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")),
+ expectedCode: http.StatusOK,
+ skipperResult: true,
},
}
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := echo.New()
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)
- config := tc.givenConfig
-
- mw, err := config.ToMiddleware()
- assert.NoError(t, err)
+ if tt.authHeader != "" {
+ req.Header.Set(echo.HeaderAuthorization, tt.authHeader)
+ }
- h := mw(func(c *echo.Context) error {
- return c.String(http.StatusTeapot, "test")
+ h := BasicAuthWithConfig(BasicAuthConfig{
+ Validator: mockValidator,
+ Realm: "someRealm",
+ Skipper: func(c echo.Context) bool {
+ return tt.skipperResult
+ },
+ })(func(c echo.Context) error {
+ return c.String(http.StatusOK, "test")
})
- if len(tc.whenAuth) != 0 {
- for _, a := range tc.whenAuth {
- req.Header.Add(echo.HeaderAuthorization, a)
- }
- }
- err = h(c)
+ err := h(c)
- if tc.expectErr != "" {
- assert.Equal(t, http.StatusOK, res.Code)
- assert.EqualError(t, err, tc.expectErr)
+ if tt.expectedErr {
+ var he *echo.HTTPError
+ errors.As(err, &he)
+ assert.Equal(t, tt.expectedCode, he.Code)
+ if tt.expectedAuth != "" {
+ assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate))
+ }
} else {
- assert.Equal(t, http.StatusTeapot, res.Code)
assert.NoError(t, err)
- }
- if tc.expectHeader != "" {
- assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate))
+ assert.Equal(t, tt.expectedCode, res.Code)
}
})
}
}
-func TestBasicAuth_panic(t *testing.T) {
- assert.Panics(t, func() {
- mw := BasicAuth(nil)
- assert.NotNil(t, mw)
- })
-
- mw := BasicAuth(func(c *echo.Context, user string, password string) (bool, error) {
- return true, nil
- })
- assert.NotNil(t, mw)
-}
-
-func TestBasicAuthWithConfig_panic(t *testing.T) {
- assert.Panics(t, func() {
- mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil})
- assert.NotNil(t, mw)
- })
-
- mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c *echo.Context, user string, password string) (bool, error) {
- return true, nil
- }})
- assert.NotNil(t, mw)
-}
-
func TestBasicAuthRealm(t *testing.T) {
e := echo.New()
- mockValidator := func(c *echo.Context, u, p string) (bool, error) {
+ mockValidator := func(u, p string, c echo.Context) (bool, error) {
return false, nil // Always fail to trigger WWW-Authenticate header
}
@@ -227,13 +165,15 @@ func TestBasicAuthRealm(t *testing.T) {
h := BasicAuthWithConfig(BasicAuthConfig{
Validator: mockValidator,
Realm: tt.realm,
- })(func(c *echo.Context) error {
+ })(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
err := h(c)
- assert.Equal(t, echo.ErrUnauthorized, err)
+ var he *echo.HTTPError
+ errors.As(err, &he)
+ assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate))
})
}
diff --git a/middleware/body_dump.go b/middleware/body_dump.go
index 0443a67ab..add778d67 100644
--- a/middleware/body_dump.go
+++ b/middleware/body_dump.go
@@ -10,9 +10,8 @@ import (
"io"
"net"
"net/http"
- "sync"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
)
// BodyDumpConfig defines the config for BodyDump middleware.
@@ -20,127 +19,78 @@ type BodyDumpConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
- // Handler receives request, response payloads and handler error if there are any.
+ // Handler receives request and response payload.
// Required.
Handler BodyDumpHandler
-
- // MaxRequestBytes limits how much of the request body to dump.
- // If the request body exceeds this limit, only the first MaxRequestBytes
- // are dumped. The handler callback receives truncated data.
- // Default: 5 * MB (5,242,880 bytes)
- // Set to -1 to disable limits (not recommended in production).
- MaxRequestBytes int64
-
- // MaxResponseBytes limits how much of the response body to dump.
- // If the response body exceeds this limit, only the first MaxResponseBytes
- // are dumped. The handler callback receives truncated data.
- // Default: 5 * MB (5,242,880 bytes)
- // Set to -1 to disable limits (not recommended in production).
- MaxResponseBytes int64
}
// BodyDumpHandler receives the request and response payload.
-type BodyDumpHandler func(c *echo.Context, reqBody []byte, resBody []byte, err error)
+type BodyDumpHandler func(echo.Context, []byte, []byte)
type bodyDumpResponseWriter struct {
io.Writer
http.ResponseWriter
}
+// DefaultBodyDumpConfig is the default BodyDump middleware config.
+var DefaultBodyDumpConfig = BodyDumpConfig{
+ Skipper: DefaultSkipper,
+}
+
// BodyDump returns a BodyDump middleware.
//
// BodyDump middleware captures the request and response payload and calls the
// registered handler.
-//
-// SECURITY: By default, this limits dumped bodies to 5MB to prevent memory exhaustion
-// attacks. To customize limits, use BodyDumpWithConfig. To disable limits (not recommended
-// in production), explicitly set MaxRequestBytes and MaxResponseBytes to -1.
func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc {
- return BodyDumpWithConfig(BodyDumpConfig{Handler: handler})
+ c := DefaultBodyDumpConfig
+ c.Handler = handler
+ return BodyDumpWithConfig(c)
}
// BodyDumpWithConfig returns a BodyDump middleware with config.
// See: `BodyDump()`.
-//
-// SECURITY: If MaxRequestBytes and MaxResponseBytes are not set (zero values), they default
-// to 5MB each to prevent DoS attacks via large payloads. Set them explicitly to -1 to disable
-// limits if needed for your use case.
func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
- return toMiddlewareOrPanic(config)
-}
-
-// ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration
-func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ // Defaults
if config.Handler == nil {
- return nil, errors.New("echo body-dump middleware requires a handler function")
+ panic("echo: body-dump middleware requires a handler function")
}
if config.Skipper == nil {
- config.Skipper = DefaultSkipper
- }
- if config.MaxRequestBytes == 0 {
- config.MaxRequestBytes = 5 * MB
- }
- if config.MaxResponseBytes == 0 {
- config.MaxResponseBytes = 5 * MB
+ config.Skipper = DefaultBodyDumpConfig.Skipper
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) error {
+ return func(c echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
- reqBuf := bodyDumpBufferPool.Get().(*bytes.Buffer)
- reqBuf.Reset()
- defer bodyDumpBufferPool.Put(reqBuf)
-
- var bodyReader io.Reader = c.Request().Body
- if config.MaxRequestBytes > 0 {
- bodyReader = io.LimitReader(c.Request().Body, config.MaxRequestBytes)
- }
- _, readErr := io.Copy(reqBuf, bodyReader)
- if readErr != nil && readErr != io.EOF {
- return readErr
- }
- if config.MaxRequestBytes > 0 {
- // Drain any remaining body data to prevent connection issues
- _, _ = io.Copy(io.Discard, c.Request().Body)
- _ = c.Request().Body.Close()
- }
-
- reqBody := make([]byte, reqBuf.Len())
- copy(reqBody, reqBuf.Bytes())
- c.Request().Body = io.NopCloser(bytes.NewReader(reqBody))
-
- // response part
- resBuf := bodyDumpBufferPool.Get().(*bytes.Buffer)
- resBuf.Reset()
- defer bodyDumpBufferPool.Put(resBuf)
-
- var respWriter io.Writer
- if config.MaxResponseBytes > 0 {
- respWriter = &limitedWriter{
- response: c.Response(),
- dumpBuf: resBuf,
- limit: config.MaxResponseBytes,
+ // Request
+ reqBody := []byte{}
+ if c.Request().Body != nil {
+ var readErr error
+ reqBody, readErr = io.ReadAll(c.Request().Body)
+ if readErr != nil {
+ return readErr
}
- } else {
- respWriter = io.MultiWriter(c.Response(), resBuf)
}
- writer := &bodyDumpResponseWriter{
- Writer: respWriter,
- ResponseWriter: c.Response(),
- }
- c.SetResponse(writer)
+ c.Request().Body = io.NopCloser(bytes.NewBuffer(reqBody)) // Reset
+
+ // Response
+ resBody := new(bytes.Buffer)
+ mw := io.MultiWriter(c.Response().Writer, resBody)
+ writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer}
+ c.Response().Writer = writer
- err := next(c)
+ if err = next(c); err != nil {
+ c.Error(err)
+ }
// Callback
- config.Handler(c, reqBody, resBuf.Bytes(), err)
+ config.Handler(c, reqBody, resBody.Bytes())
- return err
+ return
}
- }, nil
+ }
}
func (w *bodyDumpResponseWriter) WriteHeader(code int) {
@@ -165,34 +115,3 @@ func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}
-
-var bodyDumpBufferPool = sync.Pool{
- New: func() any {
- return new(bytes.Buffer)
- },
-}
-
-type limitedWriter struct {
- response http.ResponseWriter
- dumpBuf *bytes.Buffer
- dumped int64
- limit int64
-}
-
-func (w *limitedWriter) Write(b []byte) (n int, err error) {
- // Always write full data to actual response (don't truncate client response)
- n, err = w.response.Write(b)
- if err != nil {
- return n, err
- }
-
- // Write to dump buffer only up to limit
- if w.dumped < w.limit {
- remaining := w.limit - w.dumped
- toDump := min(int64(n), remaining)
- w.dumpBuf.Write(b[:toDump])
- w.dumped += toDump
- }
-
- return n, nil
-}
diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go
index e5f64541a..7a7dee3d9 100644
--- a/middleware/body_dump_test.go
+++ b/middleware/body_dump_test.go
@@ -11,7 +11,7 @@ import (
"strings"
"testing"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
@@ -21,7 +21,7 @@ func TestBodyDump(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- h := func(c *echo.Context) error {
+ h := func(c echo.Context) error {
body, err := io.ReadAll(c.Request().Body)
if err != nil {
return err
@@ -31,11 +31,10 @@ func TestBodyDump(t *testing.T) {
requestBody := ""
responseBody := ""
- mw, err := BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody)
responseBody = string(resBody)
- }}.ToMiddleware()
- assert.NoError(t, err)
+ })
if assert.NoError(t, mw(h)(c)) {
assert.Equal(t, requestBody, hw)
@@ -44,76 +43,51 @@ func TestBodyDump(t *testing.T) {
assert.Equal(t, hw, rec.Body.String())
}
-}
-
-func TestBodyDump_skipper(t *testing.T) {
- e := echo.New()
-
- isCalled := false
- mw, err := BodyDumpConfig{
- Skipper: func(c *echo.Context) bool {
- return true
+ // Must set default skipper
+ BodyDumpWithConfig(BodyDumpConfig{
+ Skipper: nil,
+ Handler: func(c echo.Context, reqBody, resBody []byte) {
+ requestBody = string(reqBody)
+ responseBody = string(resBody)
},
- Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
- isCalled = true
- },
- }.ToMiddleware()
- assert.NoError(t, err)
-
- req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}"))
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
- h := func(c *echo.Context) error {
- return errors.New("some error")
- }
-
- err = mw(h)(c)
- assert.EqualError(t, err, "some error")
- assert.False(t, isCalled)
+ })
}
-func TestBodyDump_fails(t *testing.T) {
+func TestBodyDumpFails(t *testing.T) {
e := echo.New()
hw := "Hello, World!"
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- h := func(c *echo.Context) error {
+ h := func(c echo.Context) error {
return errors.New("some error")
}
- mw, err := BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}}.ToMiddleware()
- assert.NoError(t, err)
+ mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {})
- err = mw(h)(c)
- assert.EqualError(t, err, "some error")
- assert.Equal(t, http.StatusOK, rec.Code)
-
-}
+ if !assert.Error(t, mw(h)(c)) {
+ t.FailNow()
+ }
-func TestBodyDumpWithConfig_panic(t *testing.T) {
assert.Panics(t, func() {
- mw := BodyDumpWithConfig(BodyDumpConfig{
+ mw = BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil,
Handler: nil,
})
- assert.NotNil(t, mw)
})
assert.NotPanics(t, func() {
- mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}})
- assert.NotNil(t, mw)
- })
-}
-
-func TestBodyDump_panic(t *testing.T) {
- assert.Panics(t, func() {
- mw := BodyDump(nil)
- assert.NotNil(t, mw)
- })
+ mw = BodyDumpWithConfig(BodyDumpConfig{
+ Skipper: func(c echo.Context) bool {
+ return true
+ },
+ Handler: func(c echo.Context, reqBody, resBody []byte) {
+ },
+ })
- assert.NotPanics(t, func() {
- BodyDump(func(c *echo.Context, reqBody, resBody []byte, err error) {})
+ if !assert.Error(t, mw(h)(c)) {
+ t.FailNow()
+ }
})
}
@@ -121,6 +95,7 @@ func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) {
bdrw := bodyDumpResponseWriter{
ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush
}
+
assert.PanicsWithError(t, "response writer flushing is not supported", func() {
bdrw.Flush()
})
@@ -131,6 +106,7 @@ func TestBodyDumpResponseWriter_CanFlush(t *testing.T) {
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu,
}
+
bdrw.Flush()
assert.Equal(t, 1, trwu.unwrapCalled)
}
@@ -140,6 +116,7 @@ func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) {
bdrw := bodyDumpResponseWriter{
ResponseWriter: trwu,
}
+
result := bdrw.Unwrap()
assert.Equal(t, trwu, result)
}
@@ -149,6 +126,7 @@ func TestBodyDumpResponseWriter_CanHijack(t *testing.T) {
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}
+
_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "can hijack")
}
@@ -158,6 +136,7 @@ func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) {
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}
+
_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "feature not supported")
}
@@ -176,14 +155,14 @@ func TestBodyDump_ReadError(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- h := func(c *echo.Context) error {
+ h := func(c echo.Context) error {
// This handler should not be reached if body read fails
body, _ := io.ReadAll(c.Request().Body)
return c.String(http.StatusOK, string(body))
}
requestBodyReceived := ""
- mw := BodyDump(func(c *echo.Context, reqBody, resBody []byte, err error) {
+ mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {
requestBodyReceived = string(reqBody)
})
@@ -223,359 +202,3 @@ func (f *failingReadCloser) Read(p []byte) (n int, err error) {
func (f *failingReadCloser) Close() error {
return nil
}
-
-func TestBodyDump_RequestWithinLimit(t *testing.T) {
- e := echo.New()
- requestData := "Hello, World!"
- req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData))
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h := func(c *echo.Context) error {
- body, _ := io.ReadAll(c.Request().Body)
- return c.String(http.StatusOK, string(body))
- }
-
- requestBodyDumped := ""
- mw, err := BodyDumpConfig{
- Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
- requestBodyDumped = string(reqBody)
- },
- MaxRequestBytes: 1 * MB, // 1MB limit
- MaxResponseBytes: 1 * MB,
- }.ToMiddleware()
- assert.NoError(t, err)
-
- err = mw(h)(c)
- assert.NoError(t, err)
- assert.Equal(t, requestData, requestBodyDumped, "Small request should be fully dumped")
- assert.Equal(t, requestData, rec.Body.String(), "Handler should receive full request")
-}
-
-func TestBodyDump_RequestExceedsLimit(t *testing.T) {
- e := echo.New()
- // Create 2KB of data but limit to 1KB
- largeData := strings.Repeat("A", 2*1024)
- req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeData))
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h := func(c *echo.Context) error {
- body, _ := io.ReadAll(c.Request().Body)
- return c.String(http.StatusOK, string(body))
- }
-
- requestBodyDumped := ""
- limit := int64(1024) // 1KB limit
- mw, err := BodyDumpConfig{
- Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
- requestBodyDumped = string(reqBody)
- },
- MaxRequestBytes: limit,
- MaxResponseBytes: 1 * MB,
- }.ToMiddleware()
- assert.NoError(t, err)
-
- err = mw(h)(c)
- assert.NoError(t, err)
- assert.Equal(t, int(limit), len(requestBodyDumped), "Dumped request should be truncated to limit")
- assert.Equal(t, strings.Repeat("A", 1024), requestBodyDumped, "Dumped data should match first N bytes")
- // Handler should receive truncated data (what was dumped)
- assert.Equal(t, strings.Repeat("A", 1024), rec.Body.String())
-}
-
-func TestBodyDump_RequestAtExactLimit(t *testing.T) {
- e := echo.New()
- exactData := strings.Repeat("B", 1024)
- req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(exactData))
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h := func(c *echo.Context) error {
- body, _ := io.ReadAll(c.Request().Body)
- return c.String(http.StatusOK, string(body))
- }
-
- requestBodyDumped := ""
- limit := int64(1024)
- mw, err := BodyDumpConfig{
- Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
- requestBodyDumped = string(reqBody)
- },
- MaxRequestBytes: limit,
- MaxResponseBytes: 1 * MB,
- }.ToMiddleware()
- assert.NoError(t, err)
-
- err = mw(h)(c)
- assert.NoError(t, err)
- assert.Equal(t, int(limit), len(requestBodyDumped), "Exact limit should dump full data")
- assert.Equal(t, exactData, requestBodyDumped)
-}
-
-func TestBodyDump_ResponseWithinLimit(t *testing.T) {
- e := echo.New()
- responseData := "Response data"
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h := func(c *echo.Context) error {
- return c.String(http.StatusOK, responseData)
- }
-
- responseBodyDumped := ""
- mw, err := BodyDumpConfig{
- Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
- responseBodyDumped = string(resBody)
- },
- MaxRequestBytes: 1 * MB,
- MaxResponseBytes: 1 * MB,
- }.ToMiddleware()
- assert.NoError(t, err)
-
- err = mw(h)(c)
- assert.NoError(t, err)
- assert.Equal(t, responseData, responseBodyDumped, "Small response should be fully dumped")
- assert.Equal(t, responseData, rec.Body.String(), "Client should receive full response")
-}
-
-func TestBodyDump_ResponseExceedsLimit(t *testing.T) {
- e := echo.New()
- largeResponse := strings.Repeat("X", 2*1024) // 2KB
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h := func(c *echo.Context) error {
- return c.String(http.StatusOK, largeResponse)
- }
-
- responseBodyDumped := ""
- limit := int64(1024) // 1KB limit
- mw, err := BodyDumpConfig{
- Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
- responseBodyDumped = string(resBody)
- },
- MaxRequestBytes: 1 * MB,
- MaxResponseBytes: limit,
- }.ToMiddleware()
- assert.NoError(t, err)
-
- err = mw(h)(c)
- assert.NoError(t, err)
- // Dump should be truncated
- assert.Equal(t, int(limit), len(responseBodyDumped), "Dumped response should be truncated to limit")
- assert.Equal(t, strings.Repeat("X", 1024), responseBodyDumped)
- // Client should still receive full response!
- assert.Equal(t, largeResponse, rec.Body.String(), "Client must receive full response despite dump truncation")
-}
-
-func TestBodyDump_ClientGetsFullResponse(t *testing.T) {
- e := echo.New()
- // This is critical - even when dump is limited, client gets everything
- largeResponse := strings.Repeat("DATA", 500) // 2KB
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h := func(c *echo.Context) error {
- // Write response in chunks to test incremental writes
- for range 4 {
- c.Response().Write([]byte(strings.Repeat("DATA", 125)))
- }
- return nil
- }
-
- responseBodyDumped := ""
- mw, err := BodyDumpConfig{
- Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
- responseBodyDumped = string(resBody)
- },
- MaxRequestBytes: 1 * MB,
- MaxResponseBytes: 512, // Very small limit
- }.ToMiddleware()
- assert.NoError(t, err)
-
- err = mw(h)(c)
- assert.NoError(t, err)
- assert.Equal(t, 512, len(responseBodyDumped), "Dump should be limited")
- assert.Equal(t, largeResponse, rec.Body.String(), "Client must get full response")
-}
-
-func TestBodyDump_BothLimitsSimultaneous(t *testing.T) {
- e := echo.New()
- largeRequest := strings.Repeat("Q", 2*1024)
- largeResponse := strings.Repeat("R", 2*1024)
- req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeRequest))
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h := func(c *echo.Context) error {
- io.ReadAll(c.Request().Body) // Consume request
- return c.String(http.StatusOK, largeResponse)
- }
-
- requestBodyDumped := ""
- responseBodyDumped := ""
- limit := int64(1024)
- mw, err := BodyDumpConfig{
- Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
- requestBodyDumped = string(reqBody)
- responseBodyDumped = string(resBody)
- },
- MaxRequestBytes: limit,
- MaxResponseBytes: limit,
- }.ToMiddleware()
- assert.NoError(t, err)
-
- err = mw(h)(c)
- assert.NoError(t, err)
- assert.Equal(t, int(limit), len(requestBodyDumped), "Request dump should be limited")
- assert.Equal(t, int(limit), len(responseBodyDumped), "Response dump should be limited")
-}
-
-func TestBodyDump_DefaultConfig(t *testing.T) {
- e := echo.New()
- smallData := "test"
- req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(smallData))
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h := func(c *echo.Context) error {
- body, _ := io.ReadAll(c.Request().Body)
- return c.String(http.StatusOK, string(body))
- }
-
- requestBodyDumped := ""
- // Use default config which should have 1MB limits
- config := BodyDumpConfig{}
- config.Handler = func(c *echo.Context, reqBody, resBody []byte, err error) {
- requestBodyDumped = string(reqBody)
- }
- mw, err := config.ToMiddleware()
- assert.NoError(t, err)
-
- err = mw(h)(c)
- assert.NoError(t, err)
- assert.Equal(t, smallData, requestBodyDumped)
-}
-
-func TestBodyDump_LargeRequestDosPrevention(t *testing.T) {
- e := echo.New()
- // Simulate a very large request (10MB) that could cause OOM
- largeSize := 10 * 1024 * 1024 // 10MB
- largeData := strings.Repeat("M", largeSize)
- req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeData))
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h := func(c *echo.Context) error {
- body, _ := io.ReadAll(c.Request().Body)
- return c.String(http.StatusOK, string(body))
- }
-
- requestBodyDumped := ""
- limit := int64(1 * MB) // Only dump 1MB max
- mw, err := BodyDumpConfig{
- Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
- requestBodyDumped = string(reqBody)
- },
- MaxRequestBytes: limit,
- MaxResponseBytes: 1 * MB,
- }.ToMiddleware()
- assert.NoError(t, err)
-
- err = mw(h)(c)
- assert.NoError(t, err)
- // Verify only 1MB was dumped, not 10MB
- assert.Equal(t, int(limit), len(requestBodyDumped), "Should only dump up to limit, preventing DoS")
- assert.Less(t, len(requestBodyDumped), largeSize, "Dump should be much smaller than full request")
-}
-
-func TestBodyDump_LargeResponseDosPrevention(t *testing.T) {
- e := echo.New()
- // Simulate a very large response (10MB)
- largeSize := 10 * 1024 * 1024 // 10MB
- largeResponse := strings.Repeat("R", largeSize)
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h := func(c *echo.Context) error {
- return c.String(http.StatusOK, largeResponse)
- }
-
- responseBodyDumped := ""
- limit := int64(1 * MB) // Only dump 1MB max
- mw, err := BodyDumpConfig{
- Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
- responseBodyDumped = string(resBody)
- },
- MaxRequestBytes: 1 * MB,
- MaxResponseBytes: limit,
- }.ToMiddleware()
- assert.NoError(t, err)
-
- err = mw(h)(c)
- assert.NoError(t, err)
- // Verify only 1MB was dumped, not 10MB
- assert.Equal(t, int(limit), len(responseBodyDumped), "Should only dump up to limit, preventing DoS")
- assert.Less(t, len(responseBodyDumped), largeSize, "Dump should be much smaller than full response")
- // Client still gets full response
- assert.Equal(t, largeSize, rec.Body.Len(), "Client must receive full response")
-}
-
-func BenchmarkBodyDump_WithLimit(b *testing.B) {
- e := echo.New()
- requestData := strings.Repeat("data", 256) // 1KB
- responseData := strings.Repeat("resp", 256) // 1KB
-
- h := func(c *echo.Context) error {
- io.ReadAll(c.Request().Body)
- return c.String(http.StatusOK, responseData)
- }
-
- mw, _ := BodyDumpConfig{
- Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
- // Simulate logging
- _ = len(reqBody) + len(resBody)
- },
- MaxRequestBytes: 1 * MB,
- MaxResponseBytes: 1 * MB,
- }.ToMiddleware()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData))
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
- mw(h)(c)
- }
-}
-
-func BenchmarkBodyDump_BufferPooling(b *testing.B) {
- e := echo.New()
- requestData := strings.Repeat("x", 1024)
- responseData := "response"
-
- h := func(c *echo.Context) error {
- io.ReadAll(c.Request().Body)
- return c.String(http.StatusOK, responseData)
- }
-
- mw, _ := BodyDumpConfig{
- Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {},
- MaxRequestBytes: 1 * MB,
- MaxResponseBytes: 1 * MB,
- }.ToMiddleware()
-
- b.ReportAllocs()
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData))
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
- mw(h)(c)
- }
-}
diff --git a/middleware/body_limit.go b/middleware/body_limit.go
index 4f1963e18..1450e9293 100644
--- a/middleware/body_limit.go
+++ b/middleware/body_limit.go
@@ -4,20 +4,24 @@
package middleware
import (
+ "fmt"
"io"
"net/http"
"sync"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
+ "github.com/labstack/gommon/bytes"
)
-// BodyLimitConfig defines the config for BodyLimitWithConfig middleware.
+// BodyLimitConfig defines the config for BodyLimit middleware.
type BodyLimitConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
- // LimitBytes is maximum allowed size in bytes for a request body
- LimitBytes int64
+ // Maximum allowed size for a request body, it can be specified
+ // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P.
+ Limit string `yaml:"limit"`
+ limit int64
}
type limitedReader struct {
@@ -26,43 +30,50 @@ type limitedReader struct {
read int64
}
+// DefaultBodyLimitConfig is the default BodyLimit middleware config.
+var DefaultBodyLimitConfig = BodyLimitConfig{
+ Skipper: DefaultSkipper,
+}
+
// BodyLimit returns a BodyLimit middleware.
//
-// BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it
-// sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request
+// BodyLimit middleware sets the maximum allowed size for a request body, if the
+// size exceeds the configured limit, it sends "413 - Request Entity Too Large"
+// response. The BodyLimit is determined based on both `Content-Length` request
// header and actual content read, which makes it super secure.
-func BodyLimit(limitBytes int64) echo.MiddlewareFunc {
- return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes})
+// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M,
+// G, T or P.
+func BodyLimit(limit string) echo.MiddlewareFunc {
+ c := DefaultBodyLimitConfig
+ c.Limit = limit
+ return BodyLimitWithConfig(c)
}
-// BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for
-// a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response.
-// The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which
-// makes it super secure.
+// BodyLimitWithConfig returns a BodyLimit middleware with config.
+// See: `BodyLimit()`.
func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
- return toMiddlewareOrPanic(config)
-}
-
-// ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration
-func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ // Defaults
if config.Skipper == nil {
- config.Skipper = DefaultSkipper
+ config.Skipper = DefaultBodyLimitConfig.Skipper
}
- pool := sync.Pool{
- New: func() any {
- return &limitedReader{BodyLimitConfig: config}
- },
+
+ limit, err := bytes.Parse(config.Limit)
+ if err != nil {
+ panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit))
}
+ config.limit = limit
+ pool := limitedReaderPool(config)
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) error {
+ return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
+
req := c.Request()
// Based on content length
- if req.ContentLength > config.LimitBytes {
+ if req.ContentLength > config.limit {
return echo.ErrStatusRequestEntityTooLarge
}
@@ -77,13 +88,13 @@ func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
return next(c)
}
- }, nil
+ }
}
func (r *limitedReader) Read(b []byte) (n int, err error) {
n, err = r.reader.Read(b)
r.read += int64(n)
- if r.read > r.LimitBytes {
+ if r.read > r.limit {
return n, echo.ErrStatusRequestEntityTooLarge
}
return
@@ -97,3 +108,11 @@ func (r *limitedReader) Reset(reader io.ReadCloser) {
r.reader = reader
r.read = 0
}
+
+func limitedReaderPool(c BodyLimitConfig) sync.Pool {
+ return sync.Pool{
+ New: func() any {
+ return &limitedReader{BodyLimitConfig: c}
+ },
+ }
+}
diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go
index 68d904da8..d14c2b649 100644
--- a/middleware/body_limit_test.go
+++ b/middleware/body_limit_test.go
@@ -10,17 +10,17 @@ import (
"net/http/httptest"
"testing"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
-func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
+func TestBodyLimit(t *testing.T) {
e := echo.New()
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- h := func(c *echo.Context) error {
+ h := func(c echo.Context) error {
body, err := io.ReadAll(c.Request().Body)
if err != nil {
return err
@@ -29,76 +29,41 @@ func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
}
// Based on content length (within limit)
- mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
- assert.NoError(t, err)
-
- err = mw(h)(c)
- if assert.NoError(t, err) {
+ if assert.NoError(t, BodyLimit("2M")(h)(c)) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
- // Based on content read (overlimit)
- mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
- assert.NoError(t, err)
- he := mw(h)(c).(echo.HTTPStatusCoder)
- assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
+ // Based on content length (overlimit)
+ he := BodyLimit("2B")(h)(c).(*echo.HTTPError)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// Based on content read (within limit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
req.ContentLength = -1
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
-
- mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
- assert.NoError(t, err)
- err = mw(h)(c)
- assert.NoError(t, err)
- assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, "Hello, World!", rec.Body.String())
+ if assert.NoError(t, BodyLimit("2M")(h)(c)) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "Hello, World!", rec.Body.String())
+ }
// Based on content read (overlimit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
req.ContentLength = -1
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
- mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
- assert.NoError(t, err)
- he = mw(h)(c).(echo.HTTPStatusCoder)
- assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
-}
-
-func TestBodyLimitAfterDecompressUsesDecodedSize(t *testing.T) {
- e := echo.New()
- body := "ok"
- gz, err := gzipString(body)
- assert.NoError(t, err)
- assert.Greater(t, len(gz), len(body))
-
- req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
- req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- err = Decompress()(BodyLimit(int64(len(body)))(func(c *echo.Context) error {
- body, readErr := io.ReadAll(c.Request().Body)
- if readErr != nil {
- return readErr
- }
- return c.String(http.StatusOK, string(body))
- }))(c)
-
- assert.NoError(t, err)
- assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, body, rec.Body.String())
+ he = BodyLimit("2B")(h)(c).(*echo.HTTPError)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
}
func TestBodyLimitReader(t *testing.T) {
hw := []byte("Hello, World!")
config := BodyLimitConfig{
- Skipper: DefaultSkipper,
- LimitBytes: 2,
+ Skipper: DefaultSkipper,
+ Limit: "2B",
+ limit: 2,
}
reader := &limitedReader{
BodyLimitConfig: config,
@@ -107,8 +72,8 @@ func TestBodyLimitReader(t *testing.T) {
// read all should return ErrStatusRequestEntityTooLarge
_, err := io.ReadAll(reader)
- he := err.(echo.HTTPStatusCoder)
- assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
+ he := err.(*echo.HTTPError)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// reset reader and read two bytes must succeed
bt := make([]byte, 2)
@@ -118,74 +83,91 @@ func TestBodyLimitReader(t *testing.T) {
assert.Equal(t, nil, err)
}
-func TestBodyLimit_skipper(t *testing.T) {
+func TestBodyLimitWithConfig_Skipper(t *testing.T) {
e := echo.New()
- h := func(c *echo.Context) error {
+ h := func(c echo.Context) error {
body, err := io.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
- mw, err := BodyLimitConfig{
- Skipper: func(c *echo.Context) bool {
+ mw := BodyLimitWithConfig(BodyLimitConfig{
+ Skipper: func(c echo.Context) bool {
return true
},
- LimitBytes: 2,
- }.ToMiddleware()
- assert.NoError(t, err)
+ Limit: "2B", // if not skipped this limit would make request to fail limit check
+ })
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- err = mw(h)(c)
+ err := mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
func TestBodyLimitWithConfig(t *testing.T) {
- e := echo.New()
- hw := []byte("Hello, World!")
- req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
- h := func(c *echo.Context) error {
- body, err := io.ReadAll(c.Request().Body)
- if err != nil {
- return err
- }
- return c.String(http.StatusOK, string(body))
+ var testCases = []struct {
+ name string
+ givenLimit string
+ whenBody []byte
+ expectBody []byte
+ expectError string
+ }{
+ {
+ name: "ok, body is less than limit",
+ givenLimit: "10B",
+ whenBody: []byte("123456789"),
+ expectBody: []byte("123456789"),
+ expectError: "",
+ },
+ {
+ name: "nok, body is more than limit",
+ givenLimit: "9B",
+ whenBody: []byte("1234567890"),
+ expectBody: []byte(nil),
+ expectError: "code=413, message=Request Entity Too Large",
+ },
}
- mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB})
-
- err := mw(h)(c)
- assert.NoError(t, err)
- assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, hw, rec.Body.Bytes())
-}
-
-func TestBodyLimit(t *testing.T) {
- e := echo.New()
- hw := []byte("Hello, World!")
- req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
- h := func(c *echo.Context) error {
- body, err := io.ReadAll(c.Request().Body)
- if err != nil {
- return err
- }
- return c.String(http.StatusOK, string(body))
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+ h := func(c echo.Context) error {
+ body, err := io.ReadAll(c.Request().Body)
+ if err != nil {
+ return err
+ }
+ return c.String(http.StatusOK, string(body))
+ }
+ mw := BodyLimitWithConfig(BodyLimitConfig{
+ Limit: tc.givenLimit,
+ })
+
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(tc.whenBody))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := mw(h)(c)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ // not testing status as middlewares return error instead of committing it and OK cases are anyway 200
+ assert.Equal(t, tc.expectBody, rec.Body.Bytes())
+ })
}
+}
- mw := BodyLimit(2 * MB)
-
- err := mw(h)(c)
- assert.NoError(t, err)
- assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, hw, rec.Body.Bytes())
+func TestBodyLimit_panicOnInvalidLimit(t *testing.T) {
+ assert.PanicsWithError(
+ t,
+ "echo: invalid body-limit=",
+ func() { BodyLimit("") },
+ )
}
diff --git a/middleware/compress.go b/middleware/compress.go
index 7754d5db8..4a2497b49 100644
--- a/middleware/compress.go
+++ b/middleware/compress.go
@@ -7,18 +7,13 @@ import (
"bufio"
"bytes"
"compress/gzip"
- "errors"
"io"
"net"
"net/http"
"strings"
"sync"
- "github.com/labstack/echo/v5"
-)
-
-const (
- gzipScheme = "gzip"
+ "github.com/labstack/echo/v4"
)
// GzipConfig defines the config for Gzip middleware.
@@ -28,7 +23,7 @@ type GzipConfig struct {
// Gzip compression level.
// Optional. Default value -1.
- Level int
+ Level int `yaml:"level"`
// Length threshold before gzip compression is applied.
// Optional. Default value 0.
@@ -55,36 +50,42 @@ type gzipResponseWriter struct {
code int
}
-// Gzip returns a middleware which compresses HTTP response using gzip compression scheme.
-func Gzip() echo.MiddlewareFunc {
- return GzipWithConfig(GzipConfig{})
+const (
+ gzipScheme = "gzip"
+)
+
+// DefaultGzipConfig is the default Gzip middleware config.
+var DefaultGzipConfig = GzipConfig{
+ Skipper: DefaultSkipper,
+ Level: -1,
+ MinLength: 0,
}
-// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme.
-func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
- return toMiddlewareOrPanic(config)
+// Gzip returns a middleware which compresses HTTP response using gzip compression
+// scheme.
+func Gzip() echo.MiddlewareFunc {
+ return GzipWithConfig(DefaultGzipConfig)
}
-// ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration
-func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+// GzipWithConfig return Gzip middleware with config.
+// See: `Gzip()`.
+func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
+ // Defaults
if config.Skipper == nil {
- config.Skipper = DefaultSkipper
- }
- if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression
- return nil, errors.New("invalid gzip level")
+ config.Skipper = DefaultGzipConfig.Skipper
}
if config.Level == 0 {
- config.Level = -1
+ config.Level = DefaultGzipConfig.Level
}
if config.MinLength < 0 {
- config.MinLength = 0
+ config.MinLength = DefaultGzipConfig.MinLength
}
pool := gzipCompressPool(config)
bpool := bufferPool()
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) error {
+ return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
@@ -97,18 +98,13 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, "invalid pool object")
}
- rw := res
+ rw := res.Writer
w.Reset(rw)
+
buf := bpool.Get().(*bytes.Buffer)
buf.Reset()
- grw := &gzipResponseWriter{
- Writer: w,
- ResponseWriter: rw,
- minLength: config.MinLength,
- buffer: buf,
- }
- c.SetResponse(grw)
+ grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf}
defer func() {
// There are different reasons for cases when we have not yet written response to the client and now need to do so.
// a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now.
@@ -123,25 +119,26 @@ func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// We have to reset response to it's pristine state when
// nothing is written to body or error is returned.
// See issue #424, #407.
- c.SetResponse(rw)
+ res.Writer = rw
w.Reset(io.Discard)
} else if !grw.minLengthExceeded {
// Write uncompressed response
- c.SetResponse(rw)
+ res.Writer = rw
if grw.wroteHeader {
grw.ResponseWriter.WriteHeader(grw.code)
}
- _, _ = grw.buffer.WriteTo(rw)
+ grw.buffer.WriteTo(rw)
w.Reset(io.Discard)
}
- _ = w.Close()
+ w.Close()
bpool.Put(buf)
pool.Put(w)
}()
+ res.Writer = grw
}
return next(c)
}
- }, nil
+ }
}
func (w *gzipResponseWriter) WriteHeader(code int) {
@@ -189,7 +186,7 @@ func (w *gzipResponseWriter) Flush() {
w.ResponseWriter.WriteHeader(w.code)
}
- _, _ = w.Writer.Write(w.buffer.Bytes())
+ w.Writer.Write(w.buffer.Bytes())
}
if gw, ok := w.Writer.(*gzip.Writer); ok {
@@ -198,14 +195,14 @@ func (w *gzipResponseWriter) Flush() {
_ = http.NewResponseController(w.ResponseWriter).Flush()
}
-func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
- return http.NewResponseController(w.ResponseWriter).Hijack()
-}
-
func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}
+func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ return http.NewResponseController(w.ResponseWriter).Hijack()
+}
+
func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
if p, ok := w.ResponseWriter.(http.Pusher); ok {
return p.Push(target, opts)
diff --git a/middleware/compress_test.go b/middleware/compress_test.go
index 084ffc9c7..c9083ee28 100644
--- a/middleware/compress_test.go
+++ b/middleware/compress_test.go
@@ -11,216 +11,91 @@ import (
"net/http/httptest"
"os"
"testing"
- "time"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
-func TestGzip_NoAcceptEncodingHeader(t *testing.T) {
- // Skip if no Accept-Encoding header
- h := Gzip()(func(c *echo.Context) error {
- c.Response().Write([]byte("test")) // For Content-Type sniffing
- return nil
- })
-
+func TestGzip(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- err := h(c)
- assert.NoError(t, err)
-
- assert.Equal(t, "test", rec.Body.String())
-}
-
-func TestMustGzipWithConfig_panics(t *testing.T) {
- assert.Panics(t, func() {
- GzipWithConfig(GzipConfig{Level: 999})
- })
-}
-
-func TestGzip_AcceptEncodingHeader(t *testing.T) {
- h := Gzip()(func(c *echo.Context) error {
+ // Skip if no Accept-Encoding header
+ h := Gzip()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
+ h(c)
- e := echo.New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
-
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- err := h(c)
- assert.NoError(t, err)
+ assert.Equal(t, "test", rec.Body.String())
+ // Gzip
+ req = httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+ h(c)
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
-
r, err := gzip.NewReader(rec.Body)
- assert.NoError(t, err)
- buf := new(bytes.Buffer)
- defer r.Close()
- buf.ReadFrom(r)
- assert.Equal(t, "test", buf.String())
-}
+ if assert.NoError(t, err) {
+ buf := new(bytes.Buffer)
+ defer r.Close()
+ buf.ReadFrom(r)
+ assert.Equal(t, "test", buf.String())
+ }
-func TestGzip_chunked(t *testing.T) {
- e := echo.New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
+ chunkBuf := make([]byte, 5)
+
+ // Gzip chunked
+ req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
+ rec = httptest.NewRecorder()
- chunkChan := make(chan struct{})
- waitChan := make(chan struct{})
- h := Gzip()(func(c *echo.Context) error {
- rc := http.NewResponseController(c.Response())
+ c = e.NewContext(req, rec)
+ Gzip()(func(c echo.Context) error {
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Transfer-Encoding", "chunked")
// Write and flush the first part of the data
- c.Response().Write([]byte("first\n"))
- rc.Flush()
+ c.Response().Write([]byte("test\n"))
+ c.Response().Flush()
+
+ // Read the first part of the data
+ assert.True(t, rec.Flushed)
+ assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
+ r.Reset(rec.Body)
- chunkChan <- struct{}{}
- <-waitChan
+ _, err = io.ReadFull(r, chunkBuf)
+ assert.NoError(t, err)
+ assert.Equal(t, "test\n", string(chunkBuf))
// Write and flush the second part of the data
- c.Response().Write([]byte("second\n"))
- rc.Flush()
+ c.Response().Write([]byte("test\n"))
+ c.Response().Flush()
- chunkChan <- struct{}{}
- <-waitChan
+ _, err = io.ReadFull(r, chunkBuf)
+ assert.NoError(t, err)
+ assert.Equal(t, "test\n", string(chunkBuf))
// Write the final part of the data and return
- c.Response().Write([]byte("third"))
-
- chunkChan <- struct{}{}
+ c.Response().Write([]byte("test"))
return nil
- })
-
- go func() {
- err := h(c)
- chunkChan <- struct{}{}
- assert.NoError(t, err)
- }()
+ })(c)
- <-chunkChan // wait for first write
- waitChan <- struct{}{}
-
- <-chunkChan // wait for second write
- waitChan <- struct{}{}
-
- <-chunkChan // wait for final write in handler
- <-chunkChan // wait for return from handler
- time.Sleep(5 * time.Millisecond) // to have time for flushing
-
- assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
-
- r, err := gzip.NewReader(rec.Body)
- assert.NoError(t, err)
buf := new(bytes.Buffer)
+ defer r.Close()
buf.ReadFrom(r)
- assert.Equal(t, "first\nsecond\nthird", buf.String())
-}
-
-func TestGzip_NoContent(t *testing.T) {
- e := echo.New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
- h := Gzip()(func(c *echo.Context) error {
- return c.NoContent(http.StatusNoContent)
- })
- if assert.NoError(t, h(c)) {
- assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
- assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
- assert.Equal(t, 0, len(rec.Body.Bytes()))
- }
-}
-
-func TestGzip_Empty(t *testing.T) {
- e := echo.New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
- h := Gzip()(func(c *echo.Context) error {
- return c.String(http.StatusOK, "")
- })
- if assert.NoError(t, h(c)) {
- assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
- assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get(echo.HeaderContentType))
- r, err := gzip.NewReader(rec.Body)
- if assert.NoError(t, err) {
- var buf bytes.Buffer
- buf.ReadFrom(r)
- assert.Equal(t, "", buf.String())
- }
- }
-}
-
-func TestGzip_ErrorReturned(t *testing.T) {
- e := echo.New()
- e.Use(Gzip())
- e.GET("/", func(c *echo.Context) error {
- return echo.ErrNotFound
- })
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
- rec := httptest.NewRecorder()
- e.ServeHTTP(rec, req)
- assert.Equal(t, http.StatusNotFound, rec.Code)
- assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
-}
-
-func TestGzipWithConfig_invalidLevel(t *testing.T) {
- mw, err := GzipConfig{Level: 12}.ToMiddleware()
- assert.EqualError(t, err, "invalid gzip level")
- assert.Nil(t, mw)
-}
-
-// Issue #806
-func TestGzipWithStatic(t *testing.T) {
- e := echo.New()
- e.Filesystem = os.DirFS("../")
-
- e.Use(Gzip())
- e.Static("/test", "_fixture/images")
- req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil)
- req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
- rec := httptest.NewRecorder()
-
- e.ServeHTTP(rec, req)
-
- assert.Equal(t, http.StatusOK, rec.Code)
- // Data is written out in chunks when Content-Length == "", so only
- // validate the content length if it's not set.
- if cl := rec.Header().Get("Content-Length"); cl != "" {
- assert.Equal(t, cl, rec.Body.Len())
- }
- r, err := gzip.NewReader(rec.Body)
- if assert.NoError(t, err) {
- defer r.Close()
- want, err := os.ReadFile("../_fixture/images/walle.png")
- if assert.NoError(t, err) {
- buf := new(bytes.Buffer)
- buf.ReadFrom(r)
- assert.Equal(t, want, buf.Bytes())
- }
- }
+ assert.Equal(t, "test", buf.String())
}
func TestGzipWithMinLength(t *testing.T) {
e := echo.New()
// Minimal response length
e.Use(GzipWithConfig(GzipConfig{MinLength: 10}))
- e.GET("/", func(c *echo.Context) error {
+ e.GET("/", func(c echo.Context) error {
c.Response().Write([]byte("foobarfoobar"))
return nil
})
@@ -243,7 +118,7 @@ func TestGzipWithMinLengthTooShort(t *testing.T) {
e := echo.New()
// Minimal response length
e.Use(GzipWithConfig(GzipConfig{MinLength: 10}))
- e.GET("/", func(c *echo.Context) error {
+ e.GET("/", func(c echo.Context) error {
c.Response().Write([]byte("test"))
return nil
})
@@ -259,7 +134,7 @@ func TestGzipWithResponseWithoutBody(t *testing.T) {
e := echo.New()
e.Use(Gzip())
- e.GET("/", func(c *echo.Context) error {
+ e.GET("/", func(c echo.Context) error {
return c.Redirect(http.StatusMovedPermanently, "http://localhost")
})
@@ -286,14 +161,13 @@ func TestGzipWithMinLengthChunked(t *testing.T) {
var r *gzip.Reader = nil
c := e.NewContext(req, rec)
- next := func(c *echo.Context) error {
- rc := http.NewResponseController(c.Response())
+ GzipWithConfig(GzipConfig{MinLength: 10})(func(c echo.Context) error {
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Transfer-Encoding", "chunked")
// Write and flush the first part of the data
c.Response().Write([]byte("test\n"))
- rc.Flush()
+ c.Response().Flush()
// Read the first part of the data
assert.True(t, rec.Flushed)
@@ -309,7 +183,7 @@ func TestGzipWithMinLengthChunked(t *testing.T) {
// Write and flush the second part of the data
c.Response().Write([]byte("test\n"))
- rc.Flush()
+ c.Response().Flush()
_, err = io.ReadFull(r, chunkBuf)
assert.NoError(t, err)
@@ -318,10 +192,8 @@ func TestGzipWithMinLengthChunked(t *testing.T) {
// Write the final part of the data and return
c.Response().Write([]byte("test"))
return nil
- }
- err := GzipWithConfig(GzipConfig{MinLength: 10})(next)(c)
+ })(c)
- assert.NoError(t, err)
assert.NotNil(t, r)
buf := new(bytes.Buffer)
@@ -338,7 +210,7 @@ func TestGzipWithMinLengthNoContent(t *testing.T) {
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c *echo.Context) error {
+ h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c echo.Context) error {
return c.NoContent(http.StatusNoContent)
})
if assert.NoError(t, h(c)) {
@@ -348,11 +220,106 @@ func TestGzipWithMinLengthNoContent(t *testing.T) {
}
}
+func TestGzipNoContent(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := Gzip()(func(c echo.Context) error {
+ return c.NoContent(http.StatusNoContent)
+ })
+ if assert.NoError(t, h(c)) {
+ assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
+ assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
+ assert.Equal(t, 0, len(rec.Body.Bytes()))
+ }
+}
+
+func TestGzipEmpty(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := Gzip()(func(c echo.Context) error {
+ return c.String(http.StatusOK, "")
+ })
+ if assert.NoError(t, h(c)) {
+ assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
+ assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get(echo.HeaderContentType))
+ r, err := gzip.NewReader(rec.Body)
+ if assert.NoError(t, err) {
+ var buf bytes.Buffer
+ buf.ReadFrom(r)
+ assert.Equal(t, "", buf.String())
+ }
+ }
+}
+
+func TestGzipErrorReturned(t *testing.T) {
+ e := echo.New()
+ e.Use(Gzip())
+ e.GET("/", func(c echo.Context) error {
+ return echo.ErrNotFound
+ })
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusNotFound, rec.Code)
+ assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
+}
+
+func TestGzipErrorReturnedInvalidConfig(t *testing.T) {
+ e := echo.New()
+ // Invalid level
+ e.Use(GzipWithConfig(GzipConfig{Level: 12}))
+ e.GET("/", func(c echo.Context) error {
+ c.Response().Write([]byte("test"))
+ return nil
+ })
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusInternalServerError, rec.Code)
+ assert.Contains(t, rec.Body.String(), `{"message":"invalid pool object"}`)
+}
+
+// Issue #806
+func TestGzipWithStatic(t *testing.T) {
+ e := echo.New()
+ e.Use(Gzip())
+ e.Static("/test", "../_fixture/images")
+ req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ // Data is written out in chunks when Content-Length == "", so only
+ // validate the content length if it's not set.
+ if cl := rec.Header().Get("Content-Length"); cl != "" {
+ assert.Equal(t, cl, rec.Body.Len())
+ }
+ r, err := gzip.NewReader(rec.Body)
+ if assert.NoError(t, err) {
+ defer r.Close()
+ want, err := os.ReadFile("../_fixture/images/walle.png")
+ if assert.NoError(t, err) {
+ buf := new(bytes.Buffer)
+ buf.ReadFrom(r)
+ assert.Equal(t, want, buf.Bytes())
+ }
+ }
+}
+
func TestGzipResponseWriter_CanUnwrap(t *testing.T) {
trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := gzipResponseWriter{
ResponseWriter: trwu,
}
+
result := bdrw.Unwrap()
assert.Equal(t, trwu, result)
}
@@ -362,6 +329,7 @@ func TestGzipResponseWriter_CanHijack(t *testing.T) {
bdrw := gzipResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}
+
_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "can hijack")
}
@@ -371,6 +339,7 @@ func TestGzipResponseWriter_CanNotHijack(t *testing.T) {
bdrw := gzipResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}
+
_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "feature not supported")
}
@@ -381,7 +350,7 @@ func BenchmarkGzip(b *testing.B) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
- h := Gzip()(func(c *echo.Context) error {
+ h := Gzip()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
diff --git a/middleware/context_timeout.go b/middleware/context_timeout.go
index 68465199a..5d9ae9755 100644
--- a/middleware/context_timeout.go
+++ b/middleware/context_timeout.go
@@ -8,18 +8,51 @@ import (
"errors"
"time"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
)
+// ContextTimeout Middleware
+//
+// ContextTimeout provides request timeout functionality using Go's context mechanism.
+// It is the recommended replacement for the deprecated Timeout middleware.
+//
+//
+// Basic Usage:
+//
+// e.Use(middleware.ContextTimeout(30 * time.Second))
+//
+// With Configuration:
+//
+// e.Use(middleware.ContextTimeoutWithConfig(middleware.ContextTimeoutConfig{
+// Timeout: 30 * time.Second,
+// Skipper: middleware.DefaultSkipper,
+// }))
+//
+// Handler Example:
+//
+// e.GET("/task", func(c echo.Context) error {
+// ctx := c.Request().Context()
+//
+// result, err := performTaskWithContext(ctx)
+// if err != nil {
+// if errors.Is(err, context.DeadlineExceeded) {
+// return echo.NewHTTPError(http.StatusServiceUnavailable, "timeout")
+// }
+// return err
+// }
+//
+// return c.JSON(http.StatusOK, result)
+// })
+
// ContextTimeoutConfig defines the config for ContextTimeout middleware.
type ContextTimeoutConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
- // ErrorHandler is a function when error arises in middeware execution.
- ErrorHandler func(c *echo.Context, err error) error
+ // ErrorHandler is a function when error arises in middleware execution.
+ ErrorHandler func(err error, c echo.Context) error
- // Timeout configures a timeout for the middleware
+ // Timeout configures a timeout for the middleware, defaults to 0 for no timeout
Timeout time.Duration
}
@@ -31,7 +64,11 @@ func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc {
// ContextTimeoutWithConfig returns a Timeout middleware with config.
func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc {
- return toMiddlewareOrPanic(config)
+ mw, err := config.ToMiddleware()
+ if err != nil {
+ panic(err)
+ }
+ return mw
}
// ToMiddleware converts Config to middleware.
@@ -43,16 +80,16 @@ func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
config.Skipper = DefaultSkipper
}
if config.ErrorHandler == nil {
- config.ErrorHandler = func(c *echo.Context, err error) error {
+ config.ErrorHandler = func(err error, c echo.Context) error {
if err != nil && errors.Is(err, context.DeadlineExceeded) {
- return echo.ErrServiceUnavailable.Wrap(err)
+ return echo.ErrServiceUnavailable.WithInternal(err)
}
return err
}
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) error {
+ return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
@@ -63,7 +100,7 @@ func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
c.SetRequest(c.Request().WithContext(timeoutContext))
if err := next(c); err != nil {
- return config.ErrorHandler(c, err)
+ return config.ErrorHandler(err, c)
}
return nil
}
diff --git a/middleware/context_timeout_test.go b/middleware/context_timeout_test.go
index c7ba76beb..e69bcd268 100644
--- a/middleware/context_timeout_test.go
+++ b/middleware/context_timeout_test.go
@@ -6,7 +6,6 @@ package middleware
import (
"context"
"errors"
- "github.com/labstack/echo/v5"
"net/http"
"net/http/httptest"
"net/url"
@@ -14,13 +13,14 @@ import (
"testing"
"time"
+ "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func TestContextTimeoutSkipper(t *testing.T) {
t.Parallel()
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
- Skipper: func(context *echo.Context) bool {
+ Skipper: func(context echo.Context) bool {
return true
},
Timeout: 10 * time.Millisecond,
@@ -32,7 +32,7 @@ func TestContextTimeoutSkipper(t *testing.T) {
e := echo.New()
c := e.NewContext(req, rec)
- err := m(func(c *echo.Context) error {
+ err := m(func(c echo.Context) error {
if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil {
return err
}
@@ -65,7 +65,7 @@ func TestContextTimeoutErrorOutInHandler(t *testing.T) {
c := e.NewContext(req, rec)
rec.Code = 1 // we want to be sure that even 200 will not be sent
- err := m(func(c *echo.Context) error {
+ err := m(func(c echo.Context) error {
// this error must not be written to the client response. Middlewares upstream of timeout middleware must be able
// to handle returned error and this can be done only then handler has not yet committed (written status code)
// the response.
@@ -91,7 +91,7 @@ func TestContextTimeoutSuccessfulRequest(t *testing.T) {
e := echo.New()
c := e.NewContext(req, rec)
- err := m(func(c *echo.Context) error {
+ err := m(func(c echo.Context) error {
return c.JSON(http.StatusCreated, map[string]string{"data": "ok"})
})(c)
@@ -115,7 +115,7 @@ func TestContextTimeoutTestRequestClone(t *testing.T) {
e := echo.New()
c := e.NewContext(req, rec)
- err := m(func(c *echo.Context) error {
+ err := m(func(c echo.Context) error {
// Cookie test
cookie, err := c.Request().Cookie("cookie")
if assert.NoError(t, err) {
@@ -150,24 +150,23 @@ func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) {
e := echo.New()
c := e.NewContext(req, rec)
- err := m(func(c *echo.Context) error {
+ err := m(func(c echo.Context) error {
if err := sleepWithContext(c.Request().Context(), time.Duration(80*time.Millisecond)); err != nil {
return err
}
return c.String(http.StatusOK, "Hello, World!")
})(c)
+ assert.IsType(t, &echo.HTTPError{}, err)
assert.Error(t, err)
- if assert.IsType(t, &echo.HTTPError{}, err) {
- assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
- assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message)
- }
+ assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
+ assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message)
}
func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) {
t.Parallel()
- timeoutErrorHandler := func(c *echo.Context, err error) error {
+ timeoutErrorHandler := func(err error, c echo.Context) error {
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return &echo.HTTPError{
@@ -192,7 +191,7 @@ func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) {
e := echo.New()
c := e.NewContext(req, rec)
- err := m(func(c *echo.Context) error {
+ err := m(func(c echo.Context) error {
// NOTE: Very short periods are not reliable for tests due to Go routine scheduling and the unpredictable order
// for 1) request and 2) time goroutine. For most OS this works as expected, but MacOS seems most flaky.
diff --git a/middleware/cors.go b/middleware/cors.go
index 96ed16985..a1f445321 100644
--- a/middleware/cors.go
+++ b/middleware/cors.go
@@ -4,13 +4,12 @@
package middleware
import (
- "errors"
- "fmt"
"net/http"
+ "regexp"
"strconv"
"strings"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
)
// CORSConfig defines the config for CORS middleware.
@@ -20,41 +19,29 @@ type CORSConfig struct {
// AllowOrigins determines the value of the Access-Control-Allow-Origin
// response header. This header defines a list of origins that may access the
- // resource.
- //
- // Origin consist of following parts: `scheme + "://" + host + optional ":" + port`
- // Wildcard can be used, but has to be set explicitly []string{"*"}
- // Example: `https://example.com`, `http://example.com:8080`, `*`
+ // resource. The wildcard characters '*' and '?' are supported and are
+ // converted to regex fragments '.*' and '.' accordingly.
//
// Security: use extreme caution when handling the origin, and carefully
// validate any logic. Remember that attackers may register hostile domain names.
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
- // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
//
- // Mandatory.
- AllowOrigins []string
-
- // UnsafeAllowOriginFunc is an optional custom function to validate the origin. It takes the
- // origin as an argument and returns
- // - string, allowed origin
- // - bool, true if allowed or false otherwise.
- // - error, if an error is returned, it is returned immediately by the handler.
- // If this option is set, AllowOrigins is ignored.
+ // Optional. Default value []string{"*"}.
+ //
+ // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
+ AllowOrigins []string `yaml:"allow_origins"`
+
+ // AllowOriginFunc is a custom function to validate the origin. It takes the
+ // origin as an argument and returns true if allowed or false otherwise. If
+ // an error is returned, it is returned by the handler. If this option is
+ // set, AllowOrigins is ignored.
//
// Security: use extreme caution when handling the origin, and carefully
- // validate any logic. Remember that attackers may register hostile (sub)domain names.
+ // validate any logic. Remember that attackers may register hostile domain names.
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
//
- // Sub-domain checks example:
- // UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) {
- // if strings.HasSuffix(origin, ".example.com") {
- // return origin, true, nil
- // }
- // return "", false, nil
- // },
- //
// Optional.
- UnsafeAllowOriginFunc func(c *echo.Context, origin string) (allowedOrigin string, allowed bool, err error)
+ AllowOriginFunc func(origin string) (bool, error) `yaml:"-"`
// AllowMethods determines the value of the Access-Control-Allow-Methods
// response header. This header specified the list of methods allowed when
@@ -66,16 +53,16 @@ type CORSConfig struct {
// from `Allow` header that echo.Router set into context.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
- AllowMethods []string
+ AllowMethods []string `yaml:"allow_methods"`
// AllowHeaders determines the value of the Access-Control-Allow-Headers
// response header. This header is used in response to a preflight request to
// indicate which HTTP headers can be used when making the actual request.
//
- // Optional. Defaults to empty list. No domains allowed for CORS.
+ // Optional. Default value []string{}.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
- AllowHeaders []string
+ AllowHeaders []string `yaml:"allow_headers"`
// AllowCredentials determines the value of the
// Access-Control-Allow-Credentials response header. This header indicates
@@ -92,7 +79,16 @@ type CORSConfig struct {
// https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
- AllowCredentials bool
+ AllowCredentials bool `yaml:"allow_credentials"`
+
+ // UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials
+ // flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header.
+ //
+ // This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties)
+ // attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject.
+ //
+ // Optional. Default value is false.
+ UnsafeWildcardOriginWithAllowCredentials bool `yaml:"unsafe_wildcard_origin_with_allow_credentials"`
// ExposeHeaders determines the value of Access-Control-Expose-Headers, which
// defines a list of headers that clients are allowed to access.
@@ -100,7 +96,7 @@ type CORSConfig struct {
// Optional. Default value []string{}, in which case the header is not set.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header
- ExposeHeaders []string
+ ExposeHeaders []string `yaml:"expose_headers"`
// MaxAge determines the value of the Access-Control-Max-Age response header.
// This header indicates how long (in seconds) the results of a preflight
@@ -110,16 +106,19 @@ type CORSConfig struct {
// Optional. Default value 0 - meaning header is not sent.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
- MaxAge int
+ MaxAge int `yaml:"max_age"`
+}
+
+// DefaultCORSConfig is the default CORS middleware config.
+var DefaultCORSConfig = CORSConfig{
+ Skipper: DefaultSkipper,
+ AllowOrigins: []string{"*"},
+ AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
}
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
// See also [MDN: Cross-Origin Resource Sharing (CORS)].
//
-// Origin consist of following parts: `scheme + "://" + host + optional ":" + port`
-// Wildcard `*` can be used, but has to be set explicitly.
-// Example: `https://example.com`, `http://example.com:8080`, `*`
-//
// Security: Poorly configured CORS can compromise security because it allows
// relaxation of the browser's Same-Origin policy. See [Exploiting CORS
// misconfigurations for Bitcoins and bounties] and [Portswigger: Cross-origin
@@ -128,29 +127,45 @@ type CORSConfig struct {
// [MDN: Cross-Origin Resource Sharing (CORS)]: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
// [Exploiting CORS misconfigurations for Bitcoins and bounties]: https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
// [Portswigger: Cross-origin resource sharing (CORS)]: https://portswigger.net/web-security/cors
-func CORS(allowOrigins ...string) echo.MiddlewareFunc {
- c := CORSConfig{
- AllowOrigins: allowOrigins,
- }
- return CORSWithConfig(c)
+func CORS() echo.MiddlewareFunc {
+ return CORSWithConfig(DefaultCORSConfig)
}
-// CORSWithConfig returns a CORS middleware with config or panics on invalid configuration.
+// CORSWithConfig returns a CORS middleware with config.
// See: [CORS].
func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
- return toMiddlewareOrPanic(config)
-}
-
-// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration
-func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
- config.Skipper = DefaultSkipper
+ config.Skipper = DefaultCORSConfig.Skipper
+ }
+ if len(config.AllowOrigins) == 0 {
+ config.AllowOrigins = DefaultCORSConfig.AllowOrigins
}
hasCustomAllowMethods := true
if len(config.AllowMethods) == 0 {
hasCustomAllowMethods = false
- config.AllowMethods = []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}
+ config.AllowMethods = DefaultCORSConfig.AllowMethods
+ }
+
+ allowOriginPatterns := make([]*regexp.Regexp, 0, len(config.AllowOrigins))
+ for _, origin := range config.AllowOrigins {
+ if origin == "*" {
+ continue // "*" is handled differently and does not need regexp
+ }
+ pattern := regexp.QuoteMeta(origin)
+ pattern = strings.ReplaceAll(pattern, "\\*", ".*")
+ pattern = strings.ReplaceAll(pattern, "\\?", ".")
+ pattern = "^" + pattern + "$"
+
+ re, err := regexp.Compile(pattern)
+ if err != nil {
+ // this is to preserve previous behaviour - invalid patterns were just ignored.
+ // If we would turn this to panic, users with invalid patterns
+ // would have applications crashing in production due unrecovered panic.
+ // TODO: this should be turned to error/panic in `v5`
+ continue
+ }
+ allowOriginPatterns = append(allowOriginPatterns, re)
}
allowMethods := strings.Join(config.AllowMethods, ",")
@@ -162,29 +177,8 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
maxAge = strconv.Itoa(config.MaxAge)
}
- allowOriginFunc := config.UnsafeAllowOriginFunc
- if config.UnsafeAllowOriginFunc == nil {
- if len(config.AllowOrigins) == 0 {
- return nil, errors.New("at least one AllowOrigins is required or UnsafeAllowOriginFunc must be provided")
- }
- allowOriginFunc = config.defaultAllowOriginFunc
- for _, origin := range config.AllowOrigins {
- if origin == "*" {
- if config.AllowCredentials {
- return nil, fmt.Errorf("* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc")
- }
- allowOriginFunc = config.starAllowOriginFunc
- break
- }
- if err := validateOrigin(origin, "allow origin"); err != nil {
- return nil, err
- }
- }
- config.AllowOrigins = append([]string(nil), config.AllowOrigins...)
- }
-
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) error {
+ return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
@@ -192,6 +186,7 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
req := c.Request()
res := c.Response()
origin := req.Header.Get(echo.HeaderOrigin)
+ allowOrigin := ""
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
@@ -216,51 +211,76 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
if origin == "" {
- if preflight { // req.Method=OPTIONS
- return c.NoContent(http.StatusNoContent)
+ if !preflight {
+ return next(c)
}
- return next(c) // let non-browser calls through
+ return c.NoContent(http.StatusNoContent)
}
- allowedOrigin, allowed, err := allowOriginFunc(c, origin)
- if err != nil {
- return err
- }
- if !allowed {
- // Origin existed and was NOT allowed
- if preflight {
- // From: https://github.com/labstack/echo/issues/2767
- // If the request's origin isn't allowed by the CORS configuration,
- // the middleware should simply omit the relevant CORS headers from the response
- // and let the browser fail the CORS check (if any).
- return c.NoContent(http.StatusNoContent)
+ if config.AllowOriginFunc != nil {
+ allowed, err := config.AllowOriginFunc(origin)
+ if err != nil {
+ return err
+ }
+ if allowed {
+ allowOrigin = origin
+ }
+ } else {
+ // Check allowed origins
+ for _, o := range config.AllowOrigins {
+ if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials {
+ allowOrigin = origin
+ break
+ }
+ if o == "*" || o == origin {
+ allowOrigin = o
+ break
+ }
+ if matchSubdomain(origin, o) {
+ allowOrigin = origin
+ break
+ }
+ }
+
+ checkPatterns := false
+ if allowOrigin == "" {
+ // to avoid regex cost by invalid (long) domains (253 is domain name max limit)
+ if len(origin) <= (253+3+5) && strings.Contains(origin, "://") {
+ checkPatterns = true
+ }
+ }
+ if checkPatterns {
+ for _, re := range allowOriginPatterns {
+ if match := re.MatchString(origin); match {
+ allowOrigin = origin
+ break
+ }
+ }
}
- // From: https://github.com/labstack/echo/issues/2767
- // no CORS middleware should block non-preflight requests;
- // such requests should be let through. One reason is that not all requests that
- // carry an Origin header participate in the CORS protocol.
- return next(c)
}
- // Origin existed and was allowed
+ // Origin not allowed
+ if allowOrigin == "" {
+ if !preflight {
+ return next(c)
+ }
+ return c.NoContent(http.StatusNoContent)
+ }
- res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
+ res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
}
- // Simple request will be let though
+ // Simple request
if !preflight {
if exposeHeaders != "" {
res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
}
return next(c)
}
- // Below code is for Preflight (OPTIONS) request
- //
- // Preflight will end with c.NoContent(http.StatusNoContent) as we do not know if
- // at the end of handler chain is actual OPTIONS route or 404/405 route which
- // response code will confuse browsers
+
+ // Preflight request
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
@@ -283,18 +303,5 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
}
return c.NoContent(http.StatusNoContent)
}
- }, nil
-}
-
-func (config CORSConfig) starAllowOriginFunc(c *echo.Context, origin string) (string, bool, error) {
- return "*", true, nil
-}
-
-func (config CORSConfig) defaultAllowOriginFunc(c *echo.Context, origin string) (string, bool, error) {
- for _, allowedOrigin := range config.AllowOrigins {
- if strings.EqualFold(allowedOrigin, origin) {
- return allowedOrigin, true, nil
- }
}
- return "", false, nil
}
diff --git a/middleware/cors_test.go b/middleware/cors_test.go
index 5de4ca063..5461e9362 100644
--- a/middleware/cors_test.go
+++ b/middleware/cors_test.go
@@ -4,87 +4,72 @@
package middleware
import (
- "cmp"
"errors"
"net/http"
"net/http/httptest"
- "strings"
"testing"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func TestCORS(t *testing.T) {
- e := echo.New()
- req := httptest.NewRequest(http.MethodOptions, "/", nil) // Preflight request
- req.Header.Set(echo.HeaderOrigin, "http://example.com")
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- mw := CORS("*")
- handler := mw(func(c *echo.Context) error {
- return nil
- })
-
- err := handler(c)
- assert.NoError(t, err)
- assert.Equal(t, http.StatusNoContent, rec.Code)
- assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
-}
-
-func TestCORSConfig(t *testing.T) {
var testCases = []struct {
name string
- givenConfig *CORSConfig
+ givenMW echo.MiddlewareFunc
whenMethod string
whenHeaders map[string]string
expectHeaders map[string]string
notExpectHeaders map[string]string
- expectErr string
}{
{
- name: "ok, wildcard origin",
- givenConfig: &CORSConfig{
- AllowOrigins: []string{"*"},
- },
+ name: "ok, wildcard origin",
whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"},
expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "*"},
},
{
- name: "ok, wildcard AllowedOrigin with no Origin header in request",
- givenConfig: &CORSConfig{
- AllowOrigins: []string{"*"},
- },
+ name: "ok, wildcard AllowedOrigin with no Origin header in request",
notExpectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: ""},
},
+ {
+ name: "ok, invalid pattern is ignored",
+ givenMW: CORSWithConfig(CORSConfig{
+ AllowOrigins: []string{
+ "\xff", // Invalid UTF-8 makes regexp.Compile to error
+ "*.example.com",
+ },
+ }),
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{echo.HeaderOrigin: "http://aaa.example.com"},
+ expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://aaa.example.com"},
+ },
{
name: "ok, specific AllowOrigins and AllowCredentials",
- givenConfig: &CORSConfig{
- AllowOrigins: []string{"http://localhost", "http://localhost:8080"},
+ givenMW: CORSWithConfig(CORSConfig{
+ AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 3600,
- },
- whenHeaders: map[string]string{echo.HeaderOrigin: "http://localhost"},
+ }),
+ whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"},
expectHeaders: map[string]string{
- echo.HeaderAccessControlAllowOrigin: "http://localhost",
+ echo.HeaderAccessControlAllowOrigin: "localhost",
echo.HeaderAccessControlAllowCredentials: "true",
},
},
{
name: "ok, preflight request with matching origin for `AllowOrigins`",
- givenConfig: &CORSConfig{
- AllowOrigins: []string{"http://localhost"},
+ givenMW: CORSWithConfig(CORSConfig{
+ AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 3600,
- },
+ }),
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
- echo.HeaderOrigin: "http://localhost",
+ echo.HeaderOrigin: "localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
},
expectHeaders: map[string]string{
- echo.HeaderAccessControlAllowOrigin: "http://localhost",
+ echo.HeaderAccessControlAllowOrigin: "localhost",
echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
echo.HeaderAccessControlAllowCredentials: "true",
echo.HeaderAccessControlMaxAge: "3600",
@@ -92,14 +77,14 @@ func TestCORSConfig(t *testing.T) {
},
{
name: "ok, preflight request when `Access-Control-Max-Age` is set",
- givenConfig: &CORSConfig{
- AllowOrigins: []string{"http://localhost"},
+ givenMW: CORSWithConfig(CORSConfig{
+ AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 1,
- },
+ }),
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
- echo.HeaderOrigin: "http://localhost",
+ echo.HeaderOrigin: "localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
},
expectHeaders: map[string]string{
@@ -108,14 +93,14 @@ func TestCORSConfig(t *testing.T) {
},
{
name: "ok, preflight request when `Access-Control-Max-Age` is set to 0 - not to cache response",
- givenConfig: &CORSConfig{
- AllowOrigins: []string{"http://localhost"},
+ givenMW: CORSWithConfig(CORSConfig{
+ AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: -1, // forces `Access-Control-Max-Age: 0`
- },
+ }),
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
- echo.HeaderOrigin: "http://localhost",
+ echo.HeaderOrigin: "localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
},
expectHeaders: map[string]string{
@@ -124,16 +109,16 @@ func TestCORSConfig(t *testing.T) {
},
{
name: "ok, CORS check are skipped",
- givenConfig: &CORSConfig{
- AllowOrigins: []string{"http://localhost"},
+ givenMW: CORSWithConfig(CORSConfig{
+ AllowOrigins: []string{"localhost"},
AllowCredentials: true,
- Skipper: func(c *echo.Context) bool {
+ Skipper: func(c echo.Context) bool {
return true
},
- },
+ }),
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
- echo.HeaderOrigin: "http://localhost",
+ echo.HeaderOrigin: "localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
},
notExpectHeaders: map[string]string{
@@ -144,33 +129,31 @@ func TestCORSConfig(t *testing.T) {
},
},
{
- name: "nok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` true",
- givenConfig: &CORSConfig{
+ name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` true",
+ givenMW: CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"},
AllowCredentials: true,
MaxAge: 3600,
- },
+ }),
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
echo.HeaderOrigin: "localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
},
- expectErr: `* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc`,
- },
- {
- name: "nok, preflight request with invalid `AllowOrigins` value",
- givenConfig: &CORSConfig{
- AllowOrigins: []string{"http://server", "missing-scheme"},
+ expectHeaders: map[string]string{
+ echo.HeaderAccessControlAllowOrigin: "*", // Note: browsers will ignore and complain about responses having `*`
+ echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
+ echo.HeaderAccessControlAllowCredentials: "true",
+ echo.HeaderAccessControlMaxAge: "3600",
},
- expectErr: `allow origin is missing scheme or host: missing-scheme`,
},
{
name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` false",
- givenConfig: &CORSConfig{
+ givenMW: CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"},
AllowCredentials: false, // important for this testcase
MaxAge: 3600,
- },
+ }),
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
echo.HeaderOrigin: "localhost",
@@ -187,23 +170,29 @@ func TestCORSConfig(t *testing.T) {
},
{
name: "ok, INSECURE preflight request with wildcard `AllowOrigins` and `AllowCredentials` true",
- givenConfig: &CORSConfig{
- AllowOrigins: []string{"*"},
- AllowCredentials: true,
- MaxAge: 3600,
- },
+ givenMW: CORSWithConfig(CORSConfig{
+ AllowOrigins: []string{"*"},
+ AllowCredentials: true,
+ UnsafeWildcardOriginWithAllowCredentials: true, // important for this testcase
+ MaxAge: 3600,
+ }),
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
echo.HeaderOrigin: "localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
},
- expectErr: `* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc`,
+ expectHeaders: map[string]string{
+ echo.HeaderAccessControlAllowOrigin: "localhost", // This could end up as cross-origin attack
+ echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
+ echo.HeaderAccessControlAllowCredentials: "true",
+ echo.HeaderAccessControlMaxAge: "3600",
+ },
},
{
name: "ok, preflight request with Access-Control-Request-Headers",
- givenConfig: &CORSConfig{
+ givenMW: CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"},
- },
+ }),
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
echo.HeaderOrigin: "localhost",
@@ -218,28 +207,18 @@ func TestCORSConfig(t *testing.T) {
},
{
name: "ok, preflight request with `AllowOrigins` which allow all subdomains aaa with *",
- givenConfig: &CORSConfig{
- UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (allowedOrigin string, allowed bool, err error) {
- if strings.HasSuffix(origin, ".example.com") {
- allowed = true
- }
- return origin, allowed, nil
- },
- },
+ givenMW: CORSWithConfig(CORSConfig{
+ AllowOrigins: []string{"http://*.example.com"},
+ }),
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{echo.HeaderOrigin: "http://aaa.example.com"},
expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://aaa.example.com"},
},
{
name: "ok, preflight request with `AllowOrigins` which allow all subdomains bbb with *",
- givenConfig: &CORSConfig{
- UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) {
- if strings.HasSuffix(origin, ".example.com") {
- return origin, true, nil
- }
- return "", false, nil
- },
- },
+ givenMW: CORSWithConfig(CORSConfig{
+ AllowOrigins: []string{"http://*.example.com"},
+ }),
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{echo.HeaderOrigin: "http://bbb.example.com"},
expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://bbb.example.com"},
@@ -249,26 +228,18 @@ func TestCORSConfig(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
- var mw echo.MiddlewareFunc
- var err error
- if tc.givenConfig != nil {
- mw, err = tc.givenConfig.ToMiddleware()
- } else {
- mw, err = CORSConfig{}.ToMiddleware()
- }
- if err != nil {
- if tc.expectErr != "" {
- assert.EqualError(t, err, tc.expectErr)
- return
- }
- t.Fatal(err)
+ mw := CORS()
+ if tc.givenMW != nil {
+ mw = tc.givenMW
}
-
- h := mw(func(c *echo.Context) error {
+ h := mw(func(c echo.Context) error {
return nil
})
- method := cmp.Or(tc.whenMethod, http.MethodGet)
+ method := http.MethodGet
+ if tc.whenMethod != "" {
+ method = tc.whenMethod
+ }
req := httptest.NewRequest(method, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
@@ -276,7 +247,7 @@ func TestCORSConfig(t *testing.T) {
req.Header.Set(k, v)
}
- err = h(c)
+ err := h(c)
assert.NoError(t, err)
header := rec.Header()
@@ -330,7 +301,98 @@ func Test_allowOriginScheme(t *testing.T) {
cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.pattern},
})
- h := cors(func(c *echo.Context) error { return echo.ErrNotFound })
+ h := cors(echo.NotFoundHandler)
+ h(c)
+
+ if tt.expected {
+ assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ } else {
+ assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
+ }
+ }
+}
+
+func Test_allowOriginSubdomain(t *testing.T) {
+ tests := []struct {
+ domain, pattern string
+ expected bool
+ }{
+ {
+ domain: "http://aaa.example.com",
+ pattern: "http://*.example.com",
+ expected: true,
+ },
+ {
+ domain: "http://bbb.aaa.example.com",
+ pattern: "http://*.example.com",
+ expected: true,
+ },
+ {
+ domain: "http://bbb.aaa.example.com",
+ pattern: "http://*.aaa.example.com",
+ expected: true,
+ },
+ {
+ domain: "http://aaa.example.com:8080",
+ pattern: "http://*.example.com:8080",
+ expected: true,
+ },
+
+ {
+ domain: "http://fuga.hoge.com",
+ pattern: "http://*.example.com",
+ expected: false,
+ },
+ {
+ domain: "http://ccc.bbb.example.com",
+ pattern: "http://*.aaa.example.com",
+ expected: false,
+ },
+ {
+ domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
+ .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
+ .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
+ .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`,
+ pattern: "http://*.example.com",
+ expected: false,
+ },
+ {
+ domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`,
+ pattern: "http://*.example.com",
+ expected: false,
+ },
+ {
+ domain: "http://ccc.bbb.example.com",
+ pattern: "http://example.com",
+ expected: false,
+ },
+ {
+ domain: "https://prod-preview--aaa.bbb.com",
+ pattern: "https://*--aaa.bbb.com",
+ expected: true,
+ },
+ {
+ domain: "http://ccc.bbb.example.com",
+ pattern: "http://*.example.com",
+ expected: true,
+ },
+ {
+ domain: "http://ccc.bbb.example.com",
+ pattern: "http://foo.[a-z]*.example.com",
+ expected: false,
+ },
+ }
+
+ e := echo.New()
+ for _, tt := range tests {
+ req := httptest.NewRequest(http.MethodOptions, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ req.Header.Set(echo.HeaderOrigin, tt.domain)
+ cors := CORSWithConfig(CORSConfig{
+ AllowOrigins: []string{tt.pattern},
+ })
+ h := cors(echo.NotFoundHandler)
h(c)
if tt.expected {
@@ -343,53 +405,50 @@ func Test_allowOriginScheme(t *testing.T) {
func TestCORSWithConfig_AllowMethods(t *testing.T) {
var testCases = []struct {
- name string
- givenAllowOrigins []string
- givenAllowMethods []string
- whenAllowContextKey string
- whenOrigin string
+ name string
+ allowOrigins []string
+ allowContextKey string
+
+ whenOrigin string
+ whenAllowMethods []string
+
expectAllow string
expectAccessControlAllowMethods string
}{
{
- name: "custom AllowMethods, preflight, no origin, sets only allow header from context key",
- givenAllowOrigins: []string{"*"},
- givenAllowMethods: []string{http.MethodGet, http.MethodHead},
- whenAllowContextKey: "OPTIONS, GET",
- whenOrigin: "",
- expectAllow: "OPTIONS, GET",
+ name: "custom AllowMethods, preflight, no origin, sets only allow header from context key",
+ allowContextKey: "OPTIONS, GET",
+ whenAllowMethods: []string{http.MethodGet, http.MethodHead},
+ whenOrigin: "",
+ expectAllow: "OPTIONS, GET",
},
{
- name: "default AllowMethods, preflight, no origin, no allow header in context key and in response",
- givenAllowOrigins: []string{"*"},
- givenAllowMethods: nil,
- whenAllowContextKey: "",
- whenOrigin: "",
- expectAllow: "",
+ name: "default AllowMethods, preflight, no origin, no allow header in context key and in response",
+ allowContextKey: "",
+ whenAllowMethods: nil,
+ whenOrigin: "",
+ expectAllow: "",
},
{
name: "custom AllowMethods, preflight, existing origin, sets both headers different values",
- givenAllowOrigins: []string{"*"},
- givenAllowMethods: []string{http.MethodGet, http.MethodHead},
- whenAllowContextKey: "OPTIONS, GET",
+ allowContextKey: "OPTIONS, GET",
+ whenAllowMethods: []string{http.MethodGet, http.MethodHead},
whenOrigin: "http://google.com",
expectAllow: "OPTIONS, GET",
expectAccessControlAllowMethods: "GET,HEAD",
},
{
name: "default AllowMethods, preflight, existing origin, sets both headers",
- givenAllowOrigins: []string{"*"},
- givenAllowMethods: nil,
- whenAllowContextKey: "OPTIONS, GET",
+ allowContextKey: "OPTIONS, GET",
+ whenAllowMethods: nil,
whenOrigin: "http://google.com",
expectAllow: "OPTIONS, GET",
expectAccessControlAllowMethods: "OPTIONS, GET",
},
{
name: "default AllowMethods, preflight, existing origin, no allows, sets only CORS allow methods",
- givenAllowOrigins: []string{"*"},
- givenAllowMethods: nil,
- whenAllowContextKey: "",
+ allowContextKey: "",
+ whenAllowMethods: nil,
whenOrigin: "http://google.com",
expectAllow: "",
expectAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
@@ -399,13 +458,13 @@ func TestCORSWithConfig_AllowMethods(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
- e.GET("/test", func(c *echo.Context) error {
+ e.GET("/test", func(c echo.Context) error {
return c.String(http.StatusOK, "OK")
})
cors := CORSWithConfig(CORSConfig{
- AllowOrigins: tc.givenAllowOrigins,
- AllowMethods: tc.givenAllowMethods,
+ AllowOrigins: tc.allowOrigins,
+ AllowMethods: tc.whenAllowMethods,
})
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
@@ -413,13 +472,11 @@ func TestCORSWithConfig_AllowMethods(t *testing.T) {
c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, tc.whenOrigin)
- if tc.whenAllowContextKey != "" {
- c.Set(echo.ContextKeyHeaderAllow, tc.whenAllowContextKey)
+ if tc.allowContextKey != "" {
+ c.Set(echo.ContextKeyHeaderAllow, tc.allowContextKey)
}
- h := cors(func(c *echo.Context) error {
- return c.String(http.StatusOK, "OK")
- })
+ h := cors(echo.NotFoundHandler)
h(c)
assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow))
@@ -535,10 +592,10 @@ func TestCorsHeaders(t *testing.T) {
//MaxAge: 3600,
}))
- e.GET("/", func(c *echo.Context) error {
+ e.GET("/", func(c echo.Context) error {
return c.String(http.StatusOK, "OK")
})
- e.POST("/", func(c *echo.Context) error {
+ e.POST("/", func(c echo.Context) error {
return c.String(http.StatusCreated, "OK")
})
@@ -582,17 +639,17 @@ func TestCorsHeaders(t *testing.T) {
}
func Test_allowOriginFunc(t *testing.T) {
- returnTrue := func(c *echo.Context, origin string) (string, bool, error) {
- return origin, true, nil
+ returnTrue := func(origin string) (bool, error) {
+ return true, nil
}
- returnFalse := func(c *echo.Context, origin string) (string, bool, error) {
- return origin, false, nil
+ returnFalse := func(origin string) (bool, error) {
+ return false, nil
}
- returnError := func(c *echo.Context, origin string) (string, bool, error) {
- return origin, true, errors.New("this is a test error")
+ returnError := func(origin string) (bool, error) {
+ return true, errors.New("this is a test error")
}
- allowOriginFuncs := []func(c *echo.Context, origin string) (string, bool, error){
+ allowOriginFuncs := []func(origin string) (bool, error){
returnTrue,
returnFalse,
returnError,
@@ -606,21 +663,21 @@ func Test_allowOriginFunc(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, origin)
- cors, err := CORSConfig{UnsafeAllowOriginFunc: allowOriginFunc}.ToMiddleware()
- assert.NoError(t, err)
-
- h := cors(func(c *echo.Context) error { return echo.ErrNotFound })
- err = h(c)
+ cors := CORSWithConfig(CORSConfig{
+ AllowOriginFunc: allowOriginFunc,
+ })
+ h := cors(echo.NotFoundHandler)
+ err := h(c)
- allowedOrigin, allowed, expectedErr := allowOriginFunc(c, origin)
+ expected, expectedErr := allowOriginFunc(origin)
if expectedErr != nil {
assert.Equal(t, expectedErr, err)
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
continue
}
- if allowed {
- assert.Equal(t, allowedOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ if expected {
+ assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
}
diff --git a/middleware/csrf.go b/middleware/csrf.go
index e3616516f..1a35da63c 100644
--- a/middleware/csrf.go
+++ b/middleware/csrf.go
@@ -10,7 +10,7 @@ import (
"strings"
"time"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
)
// CSRFUsingSecFetchSite is a context key for CSRF middleware what is set when the client browser is using Sec-Fetch-Site
@@ -26,24 +26,25 @@ const CSRFUsingSecFetchSite = "_echo_csrf_using_sec_fetch_site_"
type CSRFConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
- // TrustedOrigins permits any request with `Sec-Fetch-Site` header whose `Origin` header
- // exactly matches a configured origin.
- // Values should be formatted as Origin header "scheme://host[:port]".
+
+ // TrustedOrigin permits any request with `Sec-Fetch-Site` header whose `Origin` header
+ // exactly matches the specified value.
+ // Values should be formated as Origin header "scheme://host[:port]".
//
// See [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin
// See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers
TrustedOrigins []string
- // AllowSecFetchSiteFunc allows custom behaviour for `Sec-Fetch-Site` requests that are about to
- // fail with CSRF error, to be allowed or replaced with custom error.
+ // AllowSecFetchSameSite allows custom behaviour for `Sec-Fetch-Site` requests that are about to
+ // fail with CRSF error, to be allowed or replaced with custom error.
// This function applies to `Sec-Fetch-Site` values:
// - `same-site` same registrable domain (subdomain and/or different port)
// - `cross-site` request originates from different site
// See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers
- AllowSecFetchSiteFunc func(c *echo.Context) (bool, error)
+ AllowSecFetchSiteFunc func(c echo.Context) (bool, error)
// TokenLength is the length of the generated token.
- TokenLength uint8
+ TokenLength uint8 `yaml:"token_length"`
// Optional. Default value 32.
// TokenLookup is a string in the form of ":" or ":,:" that is used
@@ -57,48 +58,49 @@ type CSRFConfig struct {
// - "header:X-CSRF-Token,query:csrf"
TokenLookup string `yaml:"token_lookup"`
- // Generator defines a function to generate token.
- // Optional. Defaults tp randomString(TokenLength).
- Generator func() string
-
// Context key to store generated CSRF token into context.
// Optional. Default value "csrf".
- ContextKey string
+ ContextKey string `yaml:"context_key"`
// Name of the CSRF cookie. This cookie will store CSRF token.
// Optional. Default value "csrf".
- CookieName string
+ CookieName string `yaml:"cookie_name"`
// Domain of the CSRF cookie.
// Optional. Default value none.
- CookieDomain string
+ CookieDomain string `yaml:"cookie_domain"`
// Path of the CSRF cookie.
// Optional. Default value none.
- CookiePath string
+ CookiePath string `yaml:"cookie_path"`
// Max age (in seconds) of the CSRF cookie.
// Optional. Default value 86400 (24hr).
- CookieMaxAge int
+ CookieMaxAge int `yaml:"cookie_max_age"`
// Indicates if CSRF cookie is secure.
// Optional. Default value false.
- CookieSecure bool
+ CookieSecure bool `yaml:"cookie_secure"`
// Indicates if CSRF cookie is HTTP only.
// Optional. Default value false.
- CookieHTTPOnly bool
+ CookieHTTPOnly bool `yaml:"cookie_http_only"`
// Indicates SameSite mode of the CSRF cookie.
// Optional. Default value SameSiteDefaultMode.
- CookieSameSite http.SameSite
+ CookieSameSite http.SameSite `yaml:"cookie_same_site"`
// ErrorHandler defines a function which is executed for returning custom errors.
- ErrorHandler func(c *echo.Context, err error) error
+ ErrorHandler CSRFErrorHandler
+
+ generator func(length uint8) string
}
+// CSRFErrorHandler is a function which is executed for creating custom errors.
+type CSRFErrorHandler func(err error, c echo.Context) error
+
// ErrCSRFInvalid is returned when CSRF check fails
-var ErrCSRFInvalid = &echo.HTTPError{Code: http.StatusForbidden, Message: "invalid csrf token"}
+var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
// DefaultCSRFConfig is the default CSRF middleware config.
var DefaultCSRFConfig = CSRFConfig{
@@ -114,26 +116,25 @@ var DefaultCSRFConfig = CSRFConfig{
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
func CSRF() echo.MiddlewareFunc {
- return CSRFWithConfig(DefaultCSRFConfig)
+ c := DefaultCSRFConfig
+ return CSRFWithConfig(c)
}
-// CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration.
+// CSRFWithConfig returns a CSRF middleware with config.
+// See `CSRF()`.
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration
func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
- // Defaults
if config.Skipper == nil {
config.Skipper = DefaultCSRFConfig.Skipper
}
if config.TokenLength == 0 {
config.TokenLength = DefaultCSRFConfig.TokenLength
}
- if config.Generator == nil {
- config.Generator = createRandomStringGenerator(config.TokenLength)
- }
+
if config.TokenLookup == "" {
config.TokenLookup = DefaultCSRFConfig.TokenLookup
}
@@ -150,19 +151,23 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
config.CookieSecure = true
}
if len(config.TrustedOrigins) > 0 {
- if err := validateOrigins(config.TrustedOrigins, "trusted origin"); err != nil {
- return nil, err
+ if vErr := validateOrigins(config.TrustedOrigins, "trusted origin"); vErr != nil {
+ return nil, vErr
}
config.TrustedOrigins = append([]string(nil), config.TrustedOrigins...)
}
+ tokenGenerator := randomString
+ if config.generator != nil {
+ tokenGenerator = config.generator
+ }
- extractors, cErr := createExtractors(config.TokenLookup, 1)
+ extractors, cErr := CreateExtractors(config.TokenLookup)
if cErr != nil {
return nil, cErr
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) error {
+ return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
@@ -180,7 +185,7 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
token := ""
if k, err := c.Cookie(config.CookieName); err != nil {
- token = config.Generator() // Generate token
+ token = tokenGenerator(config.TokenLength)
} else {
token = k.Value // Reuse token
}
@@ -193,7 +198,7 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
var lastTokenErr error
outer:
for _, extractor := range extractors {
- clientTokens, _, err := extractor(c)
+ clientTokens, err := extractor(c)
if err != nil {
lastExtractorErr = err
continue
@@ -212,11 +217,22 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if lastTokenErr != nil {
finalErr = lastTokenErr
} else if lastExtractorErr != nil {
- finalErr = echo.ErrBadRequest.Wrap(lastExtractorErr)
+ // ugly part to preserve backwards compatible errors. someone could rely on them
+ if lastExtractorErr == errQueryExtractorValueMissing {
+ lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string")
+ } else if lastExtractorErr == errFormExtractorValueMissing {
+ lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter")
+ } else if lastExtractorErr == errHeaderExtractorValueMissing {
+ lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header")
+ } else {
+ lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error())
+ }
+ finalErr = lastExtractorErr
}
+
if finalErr != nil {
if config.ErrorHandler != nil {
- return config.ErrorHandler(c, finalErr)
+ return config.ErrorHandler(finalErr, c)
}
return finalErr
}
@@ -257,7 +273,7 @@ func validateCSRFToken(token, clientToken string) bool {
var safeMethods = []string{http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace}
-func (config CSRFConfig) checkSecFetchSiteRequest(c *echo.Context) (bool, error) {
+func (config CSRFConfig) checkSecFetchSiteRequest(c echo.Context) (bool, error) {
// https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers
// Sec-Fetch-Site values are:
// - `same-origin` exact origin match - allow always
@@ -295,13 +311,13 @@ func (config CSRFConfig) checkSecFetchSiteRequest(c *echo.Context) (bool, error)
}
// we are here when request is state-changing and `cross-site` or `same-site`
- // Note: if you want to allow `same-site` use config.TrustedOrigins or `config.AllowSecFetchSiteFunc`
+ // Note: if you want to block `same-site` use config.TrustedOrigins or `config.AllowSecFetchSiteFunc`
if config.AllowSecFetchSiteFunc != nil {
return config.AllowSecFetchSiteFunc(c)
}
if secFetchSite == "same-site" {
- return false, echo.NewHTTPError(http.StatusForbidden, "same-site request blocked by CSRF")
+ return false, nil // fall back to legacy token
}
return false, echo.NewHTTPError(http.StatusForbidden, "cross-site request blocked by CSRF")
}
diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go
index a13fdc82c..0b3210f07 100644
--- a/middleware/csrf_test.go
+++ b/middleware/csrf_test.go
@@ -11,7 +11,7 @@ import (
"strings"
"testing"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
@@ -57,7 +57,6 @@ func TestCSRF_tokenExtractors(t *testing.T) {
givenFormTokens: map[string][]string{
"csrf": {"invalid", "token"},
},
- expectError: "code=403, message=invalid csrf token",
},
{
name: "nok, invalid token from POST form",
@@ -75,7 +74,7 @@ func TestCSRF_tokenExtractors(t *testing.T) {
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenFormTokens: map[string][]string{},
- expectError: "code=400, message=Bad Request, err=missing value in the form",
+ expectError: "code=400, message=missing csrf token in the form parameter",
},
{
name: "ok, token from POST header",
@@ -94,7 +93,6 @@ func TestCSRF_tokenExtractors(t *testing.T) {
givenHeaderTokens: map[string][]string{
echo.HeaderXCSRFToken: {"invalid", "token"},
},
- expectError: "code=403, message=invalid csrf token",
},
{
name: "nok, invalid token from POST header",
@@ -112,7 +110,7 @@ func TestCSRF_tokenExtractors(t *testing.T) {
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenHeaderTokens: map[string][]string{},
- expectError: "code=400, message=Bad Request, err=missing value in request header",
+ expectError: "code=400, message=missing csrf token in request header",
},
{
name: "ok, token from PUT query param",
@@ -131,7 +129,6 @@ func TestCSRF_tokenExtractors(t *testing.T) {
givenQueryTokens: map[string][]string{
"csrf": {"invalid", "token"},
},
- expectError: "code=403, message=invalid csrf token",
},
{
name: "nok, invalid token from PUT query form",
@@ -149,7 +146,7 @@ func TestCSRF_tokenExtractors(t *testing.T) {
givenCSRFCookie: "token",
givenMethod: http.MethodPut,
givenQueryTokens: map[string][]string{},
- expectError: "code=400, message=Bad Request, err=missing value in the query string",
+ expectError: "code=400, message=missing csrf token in the query string",
},
{
name: "nok, invalid TokenLookup",
@@ -213,7 +210,7 @@ func TestCSRF_tokenExtractors(t *testing.T) {
assert.NoError(t, err)
}
- h := csrf(func(c *echo.Context) error {
+ h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
@@ -261,7 +258,7 @@ func TestCSRFWithConfig(t *testing.T) {
name: "nok, POST without token",
whenMethod: http.MethodPost,
expectEmptyBody: true,
- expectErr: `code=400, message=Bad Request, err=missing value in request header`,
+ expectErr: `code=400, message=missing csrf token in request header`,
},
{
name: "nok, POST empty token",
@@ -328,12 +325,11 @@ func TestCSRFWithConfig(t *testing.T) {
if tc.givenConfig != nil {
config = *tc.givenConfig
}
- if config.Generator == nil {
- config.Generator = func() string {
+ if config.generator == nil {
+ config.generator = func(_ uint8) string {
return "TESTTOKEN"
}
}
-
mw, err := config.ToMiddleware()
if tc.expectMWError != "" {
assert.EqualError(t, err, tc.expectMWError)
@@ -341,7 +337,7 @@ func TestCSRFWithConfig(t *testing.T) {
}
assert.NoError(t, err)
- h := mw(func(c *echo.Context) error {
+ h := mw(func(c echo.Context) error {
cToken := c.Get(cmp.Or(config.ContextKey, DefaultCSRFConfig.ContextKey))
assert.Equal(t, tc.expectTokenInContext, cToken)
return c.String(http.StatusOK, "test")
@@ -373,7 +369,7 @@ func TestCSRF(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRF()
- h := csrf(func(c *echo.Context) error {
+ h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
@@ -393,7 +389,7 @@ func TestCSRFSetSameSiteMode(t *testing.T) {
CookieSameSite: http.SameSiteStrictMode,
})
- h := csrf(func(c *echo.Context) error {
+ h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
@@ -410,7 +406,7 @@ func TestCSRFWithoutSameSiteMode(t *testing.T) {
csrf := CSRFWithConfig(CSRFConfig{})
- h := csrf(func(c *echo.Context) error {
+ h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
@@ -429,7 +425,7 @@ func TestCSRFWithSameSiteDefaultMode(t *testing.T) {
CookieSameSite: http.SameSiteDefaultMode,
})
- h := csrf(func(c *echo.Context) error {
+ h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
@@ -449,7 +445,7 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) {
}.ToMiddleware()
assert.NoError(t, err)
- h := csrf(func(c *echo.Context) error {
+ h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
@@ -485,12 +481,12 @@ func TestCSRFConfig_skipper(t *testing.T) {
c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{
- Skipper: func(c *echo.Context) bool {
+ Skipper: func(c echo.Context) bool {
return tc.whenSkip
},
})
- h := csrf(func(c *echo.Context) error {
+ h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
@@ -504,13 +500,13 @@ func TestCSRFConfig_skipper(t *testing.T) {
func TestCSRFErrorHandling(t *testing.T) {
cfg := CSRFConfig{
- ErrorHandler: func(c *echo.Context, err error) error {
+ ErrorHandler: func(err error, c echo.Context) error {
return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed")
},
}
e := echo.New()
- e.POST("/", func(c *echo.Context) error {
+ e.POST("/", func(c echo.Context) error {
return c.String(http.StatusNotImplemented, "should not end up here")
})
@@ -583,7 +579,6 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
whenMethod: http.MethodPost,
whenSecFetchSite: "same-site",
expectAllow: false,
- expectErr: `code=403, message=same-site request blocked by CSRF`,
},
{
name: "ok, unsafe POST + same-origin passes",
@@ -641,7 +636,6 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
whenMethod: http.MethodPut,
whenSecFetchSite: "same-site",
expectAllow: false,
- expectErr: `code=403, message=same-site request blocked by CSRF`,
},
{
name: "nok, unsafe DELETE + cross-site is blocked",
@@ -657,7 +651,6 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
whenMethod: http.MethodDelete,
whenSecFetchSite: "same-site",
expectAllow: false,
- expectErr: `code=403, message=same-site request blocked by CSRF`,
},
{
name: "nok, unsafe PATCH + cross-site is blocked",
@@ -770,7 +763,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
{
name: "ok, unsafe POST + same-site + custom func allows",
givenConfig: CSRFConfig{
- AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) {
return true, nil
},
},
@@ -781,7 +774,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
{
name: "ok, unsafe POST + cross-site + custom func allows",
givenConfig: CSRFConfig{
- AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) {
return true, nil
},
},
@@ -792,7 +785,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
{
name: "nok, unsafe POST + same-site + custom func returns custom error",
givenConfig: CSRFConfig{
- AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) {
return false, echo.NewHTTPError(http.StatusTeapot, "custom error from func")
},
},
@@ -804,7 +797,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
{
name: "nok, unsafe POST + cross-site + custom func returns false with nil error",
givenConfig: CSRFConfig{
- AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) {
return false, nil
},
},
@@ -825,7 +818,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
name: "ok, unsafe POST + cross-site + trusted origin takes precedence over custom func",
givenConfig: CSRFConfig{
TrustedOrigins: []string{"https://trusted.example.com"},
- AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) {
return false, echo.NewHTTPError(http.StatusTeapot, "should not be called")
},
},
@@ -838,7 +831,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
name: "nok, unsafe POST + cross-site + trusted origin not matched, custom func blocks",
givenConfig: CSRFConfig{
TrustedOrigins: []string{"https://trusted.example.com"},
- AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) {
return false, echo.NewHTTPError(http.StatusTeapot, "custom block")
},
},
@@ -860,7 +853,8 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
}
res := httptest.NewRecorder()
- c := echo.NewContext(req, res)
+ e := echo.New()
+ c := e.NewContext(req, res)
allow, err := tc.givenConfig.checkSecFetchSiteRequest(c)
diff --git a/middleware/decompress.go b/middleware/decompress.go
index 501ee6c5b..8c418efd7 100644
--- a/middleware/decompress.go
+++ b/middleware/decompress.go
@@ -9,7 +9,7 @@ import (
"net/http"
"sync"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
)
// DecompressConfig defines the config for Decompress middleware.
@@ -19,13 +19,6 @@ type DecompressConfig struct {
// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
GzipDecompressPool Decompressor
-
- // MaxDecompressedSize limits the maximum size of decompressed request body in bytes.
- // If the decompressed body exceeds this limit, the middleware returns HTTP 413 error.
- // This prevents zip bomb attacks where small compressed payloads decompress to huge sizes.
- // Default: 100 * MB (104,857,600 bytes)
- // Set to -1 to disable limits (not recommended in production).
- MaxDecompressedSize int64
}
// GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
@@ -36,6 +29,12 @@ type Decompressor interface {
gzipDecompressPool() sync.Pool
}
+// DefaultDecompressConfig defines the config for decompress middleware
+var DefaultDecompressConfig = DecompressConfig{
+ Skipper: DefaultSkipper,
+ GzipDecompressPool: &DefaultGzipDecompressPool{},
+}
+
// DefaultGzipDecompressPool is the default implementation of Decompressor interface
type DefaultGzipDecompressPool struct {
}
@@ -45,39 +44,24 @@ func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
}
// Decompress decompresses request body based if content encoding type is set to "gzip" with default config
-//
-// SECURITY: By default, this limits decompressed data to 100MB to prevent zip bomb attacks.
-// To customize the limit, use DecompressWithConfig. To disable limits (not recommended in production),
-// set MaxDecompressedSize to -1.
func Decompress() echo.MiddlewareFunc {
- return DecompressWithConfig(DecompressConfig{})
+ return DecompressWithConfig(DefaultDecompressConfig)
}
-// DecompressWithConfig returns a decompress middleware with config or panics on invalid configuration.
-//
-// SECURITY: If MaxDecompressedSize is not set (zero value), it defaults to 100MB to prevent
-// DoS attacks via zip bombs. Set to -1 to explicitly disable limits if needed for your use case.
+// DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
- return toMiddlewareOrPanic(config)
-}
-
-// ToMiddleware converts DecompressConfig to middleware or returns an error for invalid configuration
-func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ // Defaults
if config.Skipper == nil {
- config.Skipper = DefaultSkipper
+ config.Skipper = DefaultGzipConfig.Skipper
}
if config.GzipDecompressPool == nil {
- config.GzipDecompressPool = &DefaultGzipDecompressPool{}
- }
- // Apply secure default for decompression limit
- if config.MaxDecompressedSize == 0 {
- config.MaxDecompressedSize = 100 * MB
+ config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
pool := config.GzipDecompressPool.gzipDecompressPool()
- return func(c *echo.Context) error {
+ return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
@@ -89,10 +73,7 @@ func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
i := pool.Get()
gr, ok := i.(*gzip.Reader)
if !ok || gr == nil {
- if err, isErr := i.(error); isErr {
- return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
- }
- return echo.NewHTTPError(http.StatusInternalServerError, "unexpected type from gzip decompression pool")
+ return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
}
defer pool.Put(gr)
@@ -109,48 +90,9 @@ func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// only Close gzip reader if it was set to a proper gzip source otherwise it will panic on close.
defer gr.Close()
- // Apply decompression size limit to prevent zip bombs
- if config.MaxDecompressedSize > 0 {
- c.Request().Body = &limitedGzipReader{
- Reader: gr,
- remaining: config.MaxDecompressedSize,
- limit: config.MaxDecompressedSize,
- }
- } else {
- // -1 means explicitly unlimited (not recommended)
- c.Request().Body = gr
- }
- c.Request().ContentLength = -1
+ c.Request().Body = gr
return next(c)
}
- }, nil
-}
-
-// limitedGzipReader wraps a gzip reader with size limiting to prevent zip bombs
-type limitedGzipReader struct {
- *gzip.Reader
- remaining int64
- limit int64
-}
-
-func (r *limitedGzipReader) Read(p []byte) (n int, err error) {
- if r.remaining <= 0 {
- // Limit exceeded - return 413 error
- return 0, echo.ErrStatusRequestEntityTooLarge
- }
-
- // Limit the read to remaining bytes
- if int64(len(p)) > r.remaining {
- p = p[:r.remaining]
}
-
- n, err = r.Reader.Read(p)
- r.remaining -= int64(n)
-
- return n, err
-}
-
-func (r *limitedGzipReader) Close() error {
- return r.Reader.Close()
}
diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go
index 8dc3057ba..52506ce8e 100644
--- a/middleware/decompress_test.go
+++ b/middleware/decompress_test.go
@@ -14,91 +14,61 @@ import (
"sync"
"testing"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func TestDecompress(t *testing.T) {
e := echo.New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
- h := Decompress()(func(c *echo.Context) error {
+ // Skip if no Content-Encoding header
+ h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
+ h(c)
+
+ assert.Equal(t, "test", rec.Body.String())
- // Decompress request body
+ // Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
- req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
+ req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- err := h(c)
- assert.NoError(t, err)
-
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+ h(c)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := io.ReadAll(req.Body)
assert.NoError(t, err)
assert.Equal(t, body, string(b))
}
-func TestDecompress_skippedIfNoHeader(t *testing.T) {
+func TestDecompressDefaultConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- // Skip if no Content-Encoding header
- h := Decompress()(func(c *echo.Context) error {
+ h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
+ h(c)
- err := h(c)
- assert.NoError(t, err)
assert.Equal(t, "test", rec.Body.String())
-}
-
-func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) {
- e := echo.New()
- req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h, err := DecompressConfig{}.ToMiddleware()
- assert.NoError(t, err)
-
- err = h(func(c *echo.Context) error {
- c.Response().Write([]byte("test")) // For Content-Type sniffing
- return nil
- })(c)
- assert.NoError(t, err)
-
- assert.Equal(t, "test", rec.Body.String())
-
-}
-
-func TestDecompressWithConfig_DefaultConfig(t *testing.T) {
- e := echo.New()
-
- h := Decompress()(func(c *echo.Context) error {
- c.Response().Write([]byte("test")) // For Content-Type sniffing
- return nil
- })
-
// Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
- req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
+ req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- err := h(c)
- assert.NoError(t, err)
-
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+ h(c)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := io.ReadAll(req.Body)
assert.NoError(t, err)
@@ -113,9 +83,7 @@ func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
e.NewContext(req, rec)
-
e.ServeHTTP(rec, req)
-
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := io.ReadAll(req.Body)
assert.NoError(t, err)
@@ -129,13 +97,10 @@ func TestDecompressNoContent(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- h := Decompress()(func(c *echo.Context) error {
+ h := Decompress()(func(c echo.Context) error {
return c.NoContent(http.StatusNoContent)
})
-
- err := h(c)
-
- if assert.NoError(t, err) {
+ if assert.NoError(t, h(c)) {
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
assert.Equal(t, 0, len(rec.Body.Bytes()))
@@ -145,15 +110,13 @@ func TestDecompressNoContent(t *testing.T) {
func TestDecompressErrorReturned(t *testing.T) {
e := echo.New()
e.Use(Decompress())
- e.GET("/", func(c *echo.Context) error {
+ e.GET("/", func(c echo.Context) error {
return echo.ErrNotFound
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
-
e.ServeHTTP(rec, req)
-
assert.Equal(t, http.StatusNotFound, rec.Code)
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
}
@@ -161,7 +124,7 @@ func TestDecompressErrorReturned(t *testing.T) {
func TestDecompressSkipper(t *testing.T) {
e := echo.New()
e.Use(DecompressWithConfig(DecompressConfig{
- Skipper: func(c *echo.Context) bool {
+ Skipper: func(c echo.Context) bool {
return c.Request().URL.Path == "/skip"
},
}))
@@ -170,9 +133,7 @@ func TestDecompressSkipper(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
-
e.ServeHTTP(rec, req)
-
assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSON)
reqBody, err := io.ReadAll(c.Request().Body)
assert.NoError(t, err)
@@ -201,9 +162,7 @@ func TestDecompressPoolError(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
-
e.ServeHTTP(rec, req)
-
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
reqBody, err := io.ReadAll(c.Request().Body)
assert.NoError(t, err)
@@ -215,8 +174,10 @@ func BenchmarkDecompress(b *testing.B) {
e := echo.New()
body := `{"name": "echo"}`
gz, _ := gzipString(body)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
- h := Decompress()(func(c *echo.Context) error {
+ h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte(body)) // For Content-Type sniffing
return nil
})
@@ -226,8 +187,6 @@ func BenchmarkDecompress(b *testing.B) {
for i := 0; i < b.N; i++ {
// Decompress
- req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
- req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h(c)
@@ -249,260 +208,3 @@ func gzipString(body string) ([]byte, error) {
return buf.Bytes(), nil
}
-
-func TestDecompress_WithinLimit(t *testing.T) {
- e := echo.New()
- body := strings.Repeat("test data ", 100) // Small payload ~1KB
- gz, _ := gzipString(body)
-
- req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
- req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h, err := DecompressConfig{MaxDecompressedSize: 100 * MB}.ToMiddleware()
- assert.NoError(t, err)
-
- err = h(func(c *echo.Context) error {
- b, _ := io.ReadAll(c.Request().Body)
- return c.String(http.StatusOK, string(b))
- })(c)
-
- assert.NoError(t, err)
- assert.Equal(t, body, rec.Body.String())
-}
-
-func TestDecompress_ExceedsLimit(t *testing.T) {
- e := echo.New()
- // Create 2KB of data but limit to 1KB
- largeBody := strings.Repeat("A", 2*1024)
- gz, _ := gzipString(largeBody)
-
- req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
- req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit
- assert.NoError(t, err)
-
- err = h(func(c *echo.Context) error {
- _, readErr := io.ReadAll(c.Request().Body)
- return readErr
- })(c)
-
- // Should return 413 error
- assert.Error(t, err)
- he, ok := err.(echo.HTTPStatusCoder)
- assert.True(t, ok)
- assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
-}
-
-func TestDecompress_AtExactLimit(t *testing.T) {
- e := echo.New()
- exactBody := strings.Repeat("B", 1024) // Exactly 1KB
- gz, _ := gzipString(exactBody)
-
- req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
- req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware()
- assert.NoError(t, err)
-
- err = h(func(c *echo.Context) error {
- b, _ := io.ReadAll(c.Request().Body)
- return c.String(http.StatusOK, string(b))
- })(c)
-
- assert.NoError(t, err)
- assert.Equal(t, exactBody, rec.Body.String())
-}
-
-func TestDecompress_ZipBomb(t *testing.T) {
- e := echo.New()
- // Create highly compressed data that expands to 2MB
- // but limit is 1MB
- largeBody := bytes.Repeat([]byte("A"), 2*1024*1024) // 2MB
- var buf bytes.Buffer
- gzWriter := gzip.NewWriter(&buf)
- gzWriter.Write(largeBody)
- gzWriter.Close()
-
- req := httptest.NewRequest(http.MethodPost, "/", &buf)
- req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h, err := DecompressConfig{MaxDecompressedSize: 1 * MB}.ToMiddleware()
- assert.NoError(t, err)
-
- err = h(func(c *echo.Context) error {
- _, readErr := io.ReadAll(c.Request().Body)
- return readErr
- })(c)
-
- // Should return 413 error
- assert.Error(t, err)
- he, ok := err.(echo.HTTPStatusCoder)
- assert.True(t, ok)
- assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
-}
-
-func TestDecompress_UnlimitedExplicit(t *testing.T) {
- e := echo.New()
- largeBody := strings.Repeat("X", 10*1024) // 10KB
- gz, _ := gzipString(largeBody)
-
- req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
- req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h, err := DecompressConfig{MaxDecompressedSize: -1}.ToMiddleware() // Unlimited
- assert.NoError(t, err)
-
- err = h(func(c *echo.Context) error {
- b, _ := io.ReadAll(c.Request().Body)
- return c.String(http.StatusOK, string(b))
- })(c)
-
- assert.NoError(t, err)
- assert.Equal(t, largeBody, rec.Body.String())
-}
-
-func TestDecompress_DefaultLimit(t *testing.T) {
- e := echo.New()
- smallBody := "test"
- gz, _ := gzipString(smallBody)
-
- req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
- req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- // Use zero value which should apply 100MB default
- h, err := DecompressConfig{}.ToMiddleware()
- assert.NoError(t, err)
-
- err = h(func(c *echo.Context) error {
- b, _ := io.ReadAll(c.Request().Body)
- return c.String(http.StatusOK, string(b))
- })(c)
-
- assert.NoError(t, err)
- assert.Equal(t, smallBody, rec.Body.String())
-}
-
-func TestDecompress_SmallCustomLimit(t *testing.T) {
- e := echo.New()
- body := strings.Repeat("D", 512) // 512 bytes
- gz, _ := gzipString(body)
-
- req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
- req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit
- assert.NoError(t, err)
-
- err = h(func(c *echo.Context) error {
- b, _ := io.ReadAll(c.Request().Body)
- return c.String(http.StatusOK, string(b))
- })(c)
-
- assert.NoError(t, err)
- assert.Equal(t, body, rec.Body.String())
-}
-
-func TestDecompress_MultipleReads(t *testing.T) {
- e := echo.New()
- // Test that limit is enforced across multiple Read() calls
- largeBody := strings.Repeat("M", 2*1024) // 2KB
- gz, _ := gzipString(largeBody)
-
- req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
- req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit
- assert.NoError(t, err)
-
- err = h(func(c *echo.Context) error {
- // Read in small chunks
- buf := make([]byte, 256)
- total := 0
- for {
- n, readErr := c.Request().Body.Read(buf)
- total += n
- if readErr != nil {
- if readErr == io.EOF {
- return nil
- }
- return readErr
- }
- }
- })(c)
-
- // Should return 413 error from cumulative reads
- assert.Error(t, err)
- he, ok := err.(echo.HTTPStatusCoder)
- assert.True(t, ok)
- assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
-}
-
-func TestDecompress_LargePayloadDosPrevention(t *testing.T) {
- e := echo.New()
- // Simulate a DoS attack with highly compressed large payload
- largeSize := 10 * 1024 * 1024 // 10MB decompressed
- largeBody := bytes.Repeat([]byte("Z"), largeSize)
- var buf bytes.Buffer
- gzWriter := gzip.NewWriter(&buf)
- gzWriter.Write(largeBody)
- gzWriter.Close()
-
- req := httptest.NewRequest(http.MethodPost, "/", &buf)
- req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h, err := DecompressConfig{MaxDecompressedSize: 1 * MB}.ToMiddleware()
- assert.NoError(t, err)
-
- err = h(func(c *echo.Context) error {
- _, readErr := io.ReadAll(c.Request().Body)
- return readErr
- })(c)
-
- // Should prevent DoS by returning 413
- assert.Error(t, err)
- he, ok := err.(echo.HTTPStatusCoder)
- assert.True(t, ok)
- assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
-}
-
-func BenchmarkDecompress_WithLimit(b *testing.B) {
- e := echo.New()
- body := strings.Repeat("benchmark data ", 1000) // ~15KB
- gz, _ := gzipString(body)
-
- h, _ := DecompressConfig{MaxDecompressedSize: 100 * MB}.ToMiddleware()
-
- b.ReportAllocs()
- b.ResetTimer()
-
- for i := 0; i < b.N; i++ {
- req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
- req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- h(func(c *echo.Context) error {
- io.ReadAll(c.Request().Body)
- return nil
- })(c)
- }
-}
diff --git a/middleware/extractor.go b/middleware/extractor.go
index f800a49e9..3f2741407 100644
--- a/middleware/extractor.go
+++ b/middleware/extractor.go
@@ -4,11 +4,11 @@
package middleware
import (
+ "errors"
"fmt"
+ "github.com/labstack/echo/v4"
"net/textproto"
"strings"
-
- "github.com/labstack/echo/v5"
)
const (
@@ -17,44 +17,18 @@ const (
extractorLimit = 20
)
-// ExtractorSource is type to indicate source for extracted value
-type ExtractorSource string
-
-const (
- // ExtractorSourceHeader means value was extracted from request header
- ExtractorSourceHeader ExtractorSource = "header"
- // ExtractorSourceQuery means value was extracted from request query parameters
- ExtractorSourceQuery ExtractorSource = "query"
- // ExtractorSourcePathParam means value was extracted from route path parameters
- ExtractorSourcePathParam ExtractorSource = "param"
- // ExtractorSourceCookie means value was extracted from request cookies
- ExtractorSourceCookie ExtractorSource = "cookie"
- // ExtractorSourceForm means value was extracted from request form values
- ExtractorSourceForm ExtractorSource = "form"
-)
-
-// ValueExtractorError is error type when middleware extractor is unable to extract value from lookups
-type ValueExtractorError struct {
- message string
-}
-
-// Error returns errors text
-func (e *ValueExtractorError) Error() string {
- return e.message
-}
-
-var errHeaderExtractorValueMissing = &ValueExtractorError{message: "missing value in request header"}
-var errHeaderExtractorValueInvalid = &ValueExtractorError{message: "invalid value in request header"}
-var errQueryExtractorValueMissing = &ValueExtractorError{message: "missing value in the query string"}
-var errParamExtractorValueMissing = &ValueExtractorError{message: "missing value in path params"}
-var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing value in cookies"}
-var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"}
+var errHeaderExtractorValueMissing = errors.New("missing value in request header")
+var errHeaderExtractorValueInvalid = errors.New("invalid value in request header")
+var errQueryExtractorValueMissing = errors.New("missing value in the query string")
+var errParamExtractorValueMissing = errors.New("missing value in path params")
+var errCookieExtractorValueMissing = errors.New("missing value in cookies")
+var errFormExtractorValueMissing = errors.New("missing value in the form")
// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context.
-type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error)
+type ValuesExtractor func(c echo.Context) ([]string, error)
// CreateExtractors creates ValuesExtractors from given lookups.
-// lookups is a string in the form of ":" or ":,:" that is used
+// Lookups is a string in the form of ":" or ":,:" that is used
// to extract key from the request.
// Possible values:
// - "header:" or "header::"
@@ -69,25 +43,17 @@ type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error)
//
// Multiple sources example:
// - "header:Authorization,header:X-Api-Key"
-//
-// limit sets the maximum amount how many lookups can be returned.
-func CreateExtractors(lookups string, limit uint) ([]ValuesExtractor, error) {
- return createExtractors(lookups, limit)
+func CreateExtractors(lookups string) ([]ValuesExtractor, error) {
+ return createExtractors(lookups, "")
}
-func createExtractors(lookups string, limit uint) ([]ValuesExtractor, error) {
+func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, error) {
if lookups == "" {
return nil, nil
}
- if limit == 0 {
- limit = 1
- } else if limit > extractorLimit {
- limit = extractorLimit
- }
-
- sources := strings.SplitSeq(lookups, ",")
+ sources := strings.Split(lookups, ",")
var extractors = make([]ValuesExtractor, 0)
- for source := range sources {
+ for _, source := range sources {
parts := strings.Split(source, ":")
if len(parts) < 2 {
return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source)
@@ -95,19 +61,28 @@ func createExtractors(lookups string, limit uint) ([]ValuesExtractor, error) {
switch parts[0] {
case "query":
- extractors = append(extractors, valuesFromQuery(parts[1], limit))
+ extractors = append(extractors, valuesFromQuery(parts[1]))
case "param":
- extractors = append(extractors, valuesFromParam(parts[1], limit))
+ extractors = append(extractors, valuesFromParam(parts[1]))
case "cookie":
- extractors = append(extractors, valuesFromCookie(parts[1], limit))
+ extractors = append(extractors, valuesFromCookie(parts[1]))
case "form":
- extractors = append(extractors, valuesFromForm(parts[1], limit))
+ extractors = append(extractors, valuesFromForm(parts[1]))
case "header":
prefix := ""
if len(parts) > 2 {
prefix = parts[2]
+ } else if authScheme != "" && parts[1] == echo.HeaderAuthorization {
+ // backwards compatibility for JWT and KeyAuth:
+ // * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer " etc
+ // * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that
+ // behaviour for default values and Authorization header.
+ prefix = authScheme
+ if !strings.HasSuffix(prefix, " ") {
+ prefix += " "
+ }
}
- extractors = append(extractors, valuesFromHeader(parts[1], prefix, limit))
+ extractors = append(extractors, valuesFromHeader(parts[1], prefix))
}
}
return extractors, nil
@@ -119,32 +94,28 @@ func createExtractors(lookups string, limit uint) ([]ValuesExtractor, error) {
// note the space at the end. In case of basic authentication `Authorization: Basic ` prefix we want to remove
// is `Basic `. In case of JWT tokens `Authorization: Bearer ` prefix is `Bearer `.
// If prefix is left empty the whole value is returned.
-func valuesFromHeader(header string, valuePrefix string, limit uint) ValuesExtractor {
+func valuesFromHeader(header string, valuePrefix string) ValuesExtractor {
prefixLen := len(valuePrefix)
// standard library parses http.Request header keys in canonical form but we may provide something else so fix this
header = textproto.CanonicalMIMEHeaderKey(header)
- if limit == 0 {
- limit = 1
- }
- return func(c *echo.Context) ([]string, ExtractorSource, error) {
+ return func(c echo.Context) ([]string, error) {
values := c.Request().Header.Values(header)
if len(values) == 0 {
- return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing
+ return nil, errHeaderExtractorValueMissing
}
- i := uint(0)
result := make([]string, 0)
- for _, value := range values {
+ for i, value := range values {
if prefixLen == 0 {
result = append(result, value)
- i++
- if i >= limit {
+ if i >= extractorLimit-1 {
break
}
- } else if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) {
+ continue
+ }
+ if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) {
result = append(result, value[prefixLen:])
- i++
- if i >= limit {
+ if i >= extractorLimit-1 {
break
}
}
@@ -152,102 +123,85 @@ func valuesFromHeader(header string, valuePrefix string, limit uint) ValuesExtra
if len(result) == 0 {
if prefixLen > 0 {
- return nil, ExtractorSourceHeader, errHeaderExtractorValueInvalid
+ return nil, errHeaderExtractorValueInvalid
}
- return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing
+ return nil, errHeaderExtractorValueMissing
}
- return result, ExtractorSourceHeader, nil
+ return result, nil
}
}
// valuesFromQuery returns a function that extracts values from the query string.
-func valuesFromQuery(param string, limit uint) ValuesExtractor {
- if limit == 0 {
- limit = 1
- }
- return func(c *echo.Context) ([]string, ExtractorSource, error) {
+func valuesFromQuery(param string) ValuesExtractor {
+ return func(c echo.Context) ([]string, error) {
result := c.QueryParams()[param]
if len(result) == 0 {
- return nil, ExtractorSourceQuery, errQueryExtractorValueMissing
- } else if len(result) > int(limit)-1 {
- result = result[:limit]
+ return nil, errQueryExtractorValueMissing
+ } else if len(result) > extractorLimit-1 {
+ result = result[:extractorLimit]
}
- return result, ExtractorSourceQuery, nil
+ return result, nil
}
}
// valuesFromParam returns a function that extracts values from the url param string.
-func valuesFromParam(param string, limit uint) ValuesExtractor {
- if limit == 0 {
- limit = 1
- }
- return func(c *echo.Context) ([]string, ExtractorSource, error) {
+func valuesFromParam(param string) ValuesExtractor {
+ return func(c echo.Context) ([]string, error) {
result := make([]string, 0)
- i := uint(0)
- for _, p := range c.PathValues() {
- if param != p.Name {
- continue
- }
- result = append(result, p.Value)
- i++
- if i >= limit {
- break
+ paramVales := c.ParamValues()
+ for i, p := range c.ParamNames() {
+ if param == p {
+ result = append(result, paramVales[i])
+ if i >= extractorLimit-1 {
+ break
+ }
}
}
if len(result) == 0 {
- return nil, ExtractorSourcePathParam, errParamExtractorValueMissing
+ return nil, errParamExtractorValueMissing
}
- return result, ExtractorSourcePathParam, nil
+ return result, nil
}
}
// valuesFromCookie returns a function that extracts values from the named cookie.
-func valuesFromCookie(name string, limit uint) ValuesExtractor {
- if limit == 0 {
- limit = 1
- }
- return func(c *echo.Context) ([]string, ExtractorSource, error) {
+func valuesFromCookie(name string) ValuesExtractor {
+ return func(c echo.Context) ([]string, error) {
cookies := c.Cookies()
if len(cookies) == 0 {
- return nil, ExtractorSourceCookie, errCookieExtractorValueMissing
+ return nil, errCookieExtractorValueMissing
}
- i := uint(0)
result := make([]string, 0)
- for _, cookie := range cookies {
- if name != cookie.Name {
- continue
- }
- result = append(result, cookie.Value)
- i++
- if i >= limit {
- break
+ for i, cookie := range cookies {
+ if name == cookie.Name {
+ result = append(result, cookie.Value)
+ if i >= extractorLimit-1 {
+ break
+ }
}
}
if len(result) == 0 {
- return nil, ExtractorSourceCookie, errCookieExtractorValueMissing
+ return nil, errCookieExtractorValueMissing
}
- return result, ExtractorSourceCookie, nil
+ return result, nil
}
}
// valuesFromForm returns a function that extracts values from the form field.
-func valuesFromForm(name string, limit uint) ValuesExtractor {
- if limit == 0 {
- limit = 1
- }
- return func(c *echo.Context) ([]string, ExtractorSource, error) {
+func valuesFromForm(name string) ValuesExtractor {
+ return func(c echo.Context) ([]string, error) {
if c.Request().Form == nil {
- _, _ = c.MultipartForm() // we want to trigger c.request.ParseMultipartForm(c.formParseMaxMemory)
+ _ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does
}
values := c.Request().Form[name]
if len(values) == 0 {
- return nil, ExtractorSourceForm, errFormExtractorValueMissing
+ return nil, errFormExtractorValueMissing
}
- if len(values) > int(limit)-1 {
- values = values[:limit]
+ if len(values) > extractorLimit-1 {
+ values = values[:extractorLimit]
}
result := append([]string{}, values...)
- return result, ExtractorSourceForm, nil
+ return result, nil
}
}
diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go
index 04cc7b829..42cbcfeab 100644
--- a/middleware/extractor_test.go
+++ b/middleware/extractor_test.go
@@ -6,26 +6,39 @@ package middleware
import (
"bytes"
"fmt"
+ "github.com/labstack/echo/v4"
+ "github.com/stretchr/testify/assert"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
-
- "github.com/labstack/echo/v5"
- "github.com/stretchr/testify/assert"
)
+type pathParam struct {
+ name string
+ value string
+}
+
+func setPathParams(c echo.Context, params []pathParam) {
+ names := make([]string, 0, len(params))
+ values := make([]string, 0, len(params))
+ for _, pp := range params {
+ names = append(names, pp.name)
+ values = append(values, pp.value)
+ }
+ c.SetParamNames(names...)
+ c.SetParamValues(values...)
+}
+
func TestCreateExtractors(t *testing.T) {
var testCases = []struct {
name string
givenRequest func() *http.Request
- givenPathValues echo.PathValues
- whenLookups string
- whenLimit uint
+ givenPathParams []pathParam
+ whenLoopups string
expectValues []string
- expectSource ExtractorSource
expectCreateError string
expectError string
}{
@@ -36,9 +49,8 @@ func TestCreateExtractors(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bearer token")
return req
},
- whenLookups: "header:Authorization:Bearer ",
+ whenLoopups: "header:Authorization:Bearer ",
expectValues: []string{"token"},
- expectSource: ExtractorSourceHeader,
},
{
name: "ok, form",
@@ -50,9 +62,8 @@ func TestCreateExtractors(t *testing.T) {
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
},
- whenLookups: "form:name",
+ whenLoopups: "form:name",
expectValues: []string{"Jon Snow"},
- expectSource: ExtractorSourceForm,
},
{
name: "ok, cookie",
@@ -61,18 +72,16 @@ func TestCreateExtractors(t *testing.T) {
req.Header.Set(echo.HeaderCookie, "_csrf=token")
return req
},
- whenLookups: "cookie:_csrf",
+ whenLoopups: "cookie:_csrf",
expectValues: []string{"token"},
- expectSource: ExtractorSourceCookie,
},
{
name: "ok, param",
- givenPathValues: echo.PathValues{
- {Name: "id", Value: "123"},
+ givenPathParams: []pathParam{
+ {name: "id", value: "123"},
},
- whenLookups: "param:id",
+ whenLoopups: "param:id",
expectValues: []string{"123"},
- expectSource: ExtractorSourcePathParam,
},
{
name: "ok, query",
@@ -80,13 +89,12 @@ func TestCreateExtractors(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/?id=999", nil)
return req
},
- whenLookups: "query:id",
+ whenLoopups: "query:id",
expectValues: []string{"999"},
- expectSource: ExtractorSourceQuery,
},
{
name: "nok, invalid lookup",
- whenLookups: "query",
+ whenLoopups: "query",
expectCreateError: "extractor source for lookup could not be split into needed parts: query",
},
}
@@ -101,11 +109,11 @@ func TestCreateExtractors(t *testing.T) {
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- if tc.givenPathValues != nil {
- c.SetPathValues(tc.givenPathValues)
+ if tc.givenPathParams != nil {
+ setPathParams(c, tc.givenPathParams)
}
- extractors, err := CreateExtractors(tc.whenLookups, tc.whenLimit)
+ extractors, err := CreateExtractors(tc.whenLoopups)
if tc.expectCreateError != "" {
assert.EqualError(t, err, tc.expectCreateError)
return
@@ -113,9 +121,8 @@ func TestCreateExtractors(t *testing.T) {
assert.NoError(t, err)
for _, e := range extractors {
- values, source, eErr := e(c)
+ values, eErr := e(c)
assert.Equal(t, tc.expectValues, values)
- assert.Equal(t, tc.expectSource, source)
if tc.expectError != "" {
assert.EqualError(t, eErr, tc.expectError)
return
@@ -136,7 +143,6 @@ func TestValuesFromHeader(t *testing.T) {
givenRequest func(req *http.Request)
whenName string
whenValuePrefix string
- whenLimit uint
expectValues []string
expectError string
}{
@@ -162,7 +168,6 @@ func TestValuesFromHeader(t *testing.T) {
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
- whenLimit: 2,
expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"},
},
{
@@ -208,7 +213,6 @@ func TestValuesFromHeader(t *testing.T) {
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
- whenLimit: extractorLimit,
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
@@ -223,7 +227,6 @@ func TestValuesFromHeader(t *testing.T) {
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "",
- whenLimit: extractorLimit,
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
@@ -242,11 +245,10 @@ func TestValuesFromHeader(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix, tc.whenLimit)
+ extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix)
- values, source, err := extractor(c)
+ values, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
- assert.Equal(t, ExtractorSourceHeader, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
@@ -261,7 +263,6 @@ func TestValuesFromQuery(t *testing.T) {
name string
givenQueryPart string
whenName string
- whenLimit uint
expectValues []string
expectError string
}{
@@ -275,7 +276,6 @@ func TestValuesFromQuery(t *testing.T) {
name: "ok, multiple value",
givenQueryPart: "?id=123&id=456&name=test",
whenName: "id",
- whenLimit: 2,
expectValues: []string{"123", "456"},
},
{
@@ -290,8 +290,7 @@ func TestValuesFromQuery(t *testing.T) {
"&id=1&id=2&id=3&id=4&id=5&id=6&id=7&id=8&id=9&id=10" +
"&id=11&id=12&id=13&id=14&id=15&id=16&id=17&id=18&id=19&id=20" +
"&id=21&id=22&id=23&id=24&id=25",
- whenName: "id",
- whenLimit: extractorLimit,
+ whenName: "id",
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
@@ -307,11 +306,10 @@ func TestValuesFromQuery(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- extractor := valuesFromQuery(tc.whenName, tc.whenLimit)
+ extractor := valuesFromQuery(tc.whenName)
- values, source, err := extractor(c)
+ values, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
- assert.Equal(t, ExtractorSourceQuery, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
@@ -322,56 +320,53 @@ func TestValuesFromQuery(t *testing.T) {
}
func TestValuesFromParam(t *testing.T) {
- examplePathValues := echo.PathValues{
- {Name: "id", Value: "123"},
- {Name: "gid", Value: "456"},
- {Name: "gid", Value: "789"},
+ examplePathParams := []pathParam{
+ {name: "id", value: "123"},
+ {name: "gid", value: "456"},
+ {name: "gid", value: "789"},
}
- examplePathValues20 := make(echo.PathValues, 0)
+ examplePathParams20 := make([]pathParam, 0)
for i := 1; i < 25; i++ {
- examplePathValues20 = append(examplePathValues20, echo.PathValue{Name: "id", Value: fmt.Sprintf("%v", i)})
+ examplePathParams20 = append(examplePathParams20, pathParam{name: "id", value: fmt.Sprintf("%v", i)})
}
var testCases = []struct {
name string
- givenPathValues echo.PathValues
+ givenPathParams []pathParam
whenName string
- whenLimit uint
expectValues []string
expectError string
}{
{
name: "ok, single value",
- givenPathValues: examplePathValues,
+ givenPathParams: examplePathParams,
whenName: "id",
expectValues: []string{"123"},
},
{
name: "ok, multiple value",
- givenPathValues: examplePathValues,
+ givenPathParams: examplePathParams,
whenName: "gid",
- whenLimit: 2,
expectValues: []string{"456", "789"},
},
{
name: "nok, no values",
- givenPathValues: nil,
+ givenPathParams: nil,
whenName: "nope",
expectValues: nil,
expectError: errParamExtractorValueMissing.Error(),
},
{
name: "nok, no matching value",
- givenPathValues: examplePathValues,
+ givenPathParams: examplePathParams,
whenName: "nope",
expectValues: nil,
expectError: errParamExtractorValueMissing.Error(),
},
{
name: "ok, cut values over extractorLimit",
- givenPathValues: examplePathValues20,
+ givenPathParams: examplePathParams20,
whenName: "id",
- whenLimit: extractorLimit,
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
@@ -386,15 +381,14 @@ func TestValuesFromParam(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- if tc.givenPathValues != nil {
- c.SetPathValues(tc.givenPathValues)
+ if tc.givenPathParams != nil {
+ setPathParams(c, tc.givenPathParams)
}
- extractor := valuesFromParam(tc.whenName, tc.whenLimit)
+ extractor := valuesFromParam(tc.whenName)
- values, source, err := extractor(c)
+ values, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
- assert.Equal(t, ExtractorSourcePathParam, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
@@ -413,7 +407,6 @@ func TestValuesFromCookie(t *testing.T) {
name string
givenRequest func(req *http.Request)
whenName string
- whenLimit uint
expectValues []string
expectError string
}{
@@ -430,7 +423,6 @@ func TestValuesFromCookie(t *testing.T) {
req.Header.Add(echo.HeaderCookie, "_csrf=token2")
},
whenName: "_csrf",
- whenLimit: 2,
expectValues: []string{"token", "token2"},
},
{
@@ -454,8 +446,7 @@ func TestValuesFromCookie(t *testing.T) {
req.Header.Add(echo.HeaderCookie, fmt.Sprintf("_csrf=%v", i))
}
},
- whenName: "_csrf",
- whenLimit: extractorLimit,
+ whenName: "_csrf",
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
@@ -474,11 +465,10 @@ func TestValuesFromCookie(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- extractor := valuesFromCookie(tc.whenName, tc.whenLimit)
+ extractor := valuesFromCookie(tc.whenName)
- values, source, err := extractor(c)
+ values, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
- assert.Equal(t, ExtractorSourceCookie, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
@@ -537,7 +527,6 @@ func TestValuesFromForm(t *testing.T) {
name string
givenRequest *http.Request
whenName string
- whenLimit uint
expectValues []string
expectError string
}{
@@ -553,7 +542,6 @@ func TestValuesFromForm(t *testing.T) {
v.Add("emails[]", "snow@labstack.com")
}),
whenName: "emails[]",
- whenLimit: 2,
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
@@ -562,7 +550,6 @@ func TestValuesFromForm(t *testing.T) {
w.WriteField("emails[]", "snow@labstack.com")
}),
whenName: "emails[]",
- whenLimit: 2,
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
@@ -577,7 +564,6 @@ func TestValuesFromForm(t *testing.T) {
v.Add("emails[]", "snow@labstack.com")
}),
whenName: "emails[]",
- whenLimit: 2,
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
@@ -593,8 +579,7 @@ func TestValuesFromForm(t *testing.T) {
v.Add("id[]", fmt.Sprintf("%v", i))
}
}),
- whenName: "id[]",
- whenLimit: extractorLimit,
+ whenName: "id[]",
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
@@ -610,11 +595,10 @@ func TestValuesFromForm(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- extractor := valuesFromForm(tc.whenName, tc.whenLimit)
+ extractor := valuesFromForm(tc.whenName)
- values, source, err := extractor(c)
+ values, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
- assert.Equal(t, ExtractorSourceForm, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
diff --git a/middleware/key_auth.go b/middleware/key_auth.go
index 2dcb98039..79bee207c 100644
--- a/middleware/key_auth.go
+++ b/middleware/key_auth.go
@@ -4,25 +4,19 @@
package middleware
import (
- "cmp"
"errors"
- "fmt"
+ "github.com/labstack/echo/v4"
"net/http"
-
- "github.com/labstack/echo/v5"
)
// KeyAuthConfig defines the config for KeyAuth middleware.
-//
-// SECURITY: The Validator function is responsible for securely comparing API keys.
-// See KeyAuthValidator documentation for guidance on preventing timing attacks.
type KeyAuthConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// KeyLookup is a string in the form of ":" or ":,:" that is used
// to extract key from the request.
- // Optional. Default value "header:Authorization:Bearer ".
+ // Optional. Default value "header:Authorization".
// Possible values:
// - "header:" or "header::"
// `` is argument value to cut/trim prefix of the extracted value. This is useful if header
@@ -36,22 +30,16 @@ type KeyAuthConfig struct {
// - "header:Authorization,header:X-Api-Key"
KeyLookup string
- // AllowedCheckLimit set how many KeyLookup values are allowed to be checked. This is
- // useful environments like corporate test environments with application proxies restricting
- // access to environment with their own auth scheme.
- AllowedCheckLimit uint
+ // AuthScheme to be used in the Authorization header.
+ // Optional. Default value "Bearer".
+ AuthScheme string
// Validator is a function to validate key.
// Required.
Validator KeyAuthValidator
- // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator
- // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key.
+ // ErrorHandler defines a function which is executed for an invalid key.
// It may be used to define a custom error.
- //
- // Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler.
- // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
- // In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain.
ErrorHandler KeyAuthErrorHandler
// ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to
@@ -63,55 +51,31 @@ type KeyAuthConfig struct {
}
// KeyAuthValidator defines a function to validate KeyAuth credentials.
-//
-// SECURITY WARNING: To prevent timing attacks that could allow attackers to enumerate
-// valid API keys, validator implementations MUST use constant-time comparison.
-// Use crypto/subtle.ConstantTimeCompare instead of standard string equality (==)
-// or switch statements.
-//
-// Example of SECURE implementation:
-//
-// import "crypto/subtle"
-//
-// validator := func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
-// // Fetch valid keys from database/config
-// validKeys := []string{"key1", "key2", "key3"}
-//
-// for _, validKey := range validKeys {
-// // Use constant-time comparison to prevent timing attacks
-// if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 {
-// return true, nil
-// }
-// }
-// return false, nil
-// }
-//
-// Example of INSECURE implementation (DO NOT USE):
-//
-// // VULNERABLE TO TIMING ATTACKS - DO NOT USE
-// validator := func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
-// switch key { // Timing leak!
-// case "valid-key":
-// return true, nil
-// default:
-// return false, nil
-// }
-// }
-type KeyAuthValidator func(c *echo.Context, key string, source ExtractorSource) (bool, error)
+type KeyAuthValidator func(auth string, c echo.Context) (bool, error)
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
-type KeyAuthErrorHandler func(c *echo.Context, err error) error
-
-// ErrKeyMissing denotes an error raised when key value could not be extracted from request
-var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key")
+type KeyAuthErrorHandler func(err error, c echo.Context) error
-// ErrInvalidKey denotes an error raised when key value is invalid by validator
-var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key")
+// ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups
+type ErrKeyAuthMissing struct {
+ Err error
+}
// DefaultKeyAuthConfig is the default KeyAuth middleware config.
var DefaultKeyAuthConfig = KeyAuthConfig{
- Skipper: DefaultSkipper,
- KeyLookup: "header:" + echo.HeaderAuthorization + ":Bearer ",
+ Skipper: DefaultSkipper,
+ KeyLookup: "header:" + echo.HeaderAuthorization,
+ AuthScheme: "Bearer",
+}
+
+// Error returns errors text
+func (e *ErrKeyAuthMissing) Error() string {
+ return e.Err.Error()
+}
+
+// Unwrap unwraps error
+func (e *ErrKeyAuthMissing) Unwrap() error {
+ return e.Err
}
// KeyAuth returns an KeyAuth middleware.
@@ -125,39 +89,31 @@ func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc {
return KeyAuthWithConfig(c)
}
-// KeyAuthWithConfig returns an KeyAuth middleware or panics if configuration is invalid.
-//
-// For first valid key it calls the next handler.
-// For invalid key, it sends "401 - Unauthorized" response.
-// For missing key, it sends "400 - Bad Request" response.
+// KeyAuthWithConfig returns an KeyAuth middleware with config.
+// See `KeyAuth()`.
func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
- return toMiddlewareOrPanic(config)
-}
-
-// ToMiddleware converts KeyAuthConfig to middleware or returns an error for invalid configuration
-func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ // Defaults
if config.Skipper == nil {
config.Skipper = DefaultKeyAuthConfig.Skipper
}
+ // Defaults
+ if config.AuthScheme == "" {
+ config.AuthScheme = DefaultKeyAuthConfig.AuthScheme
+ }
if config.KeyLookup == "" {
config.KeyLookup = DefaultKeyAuthConfig.KeyLookup
}
if config.Validator == nil {
- return nil, errors.New("echo key-auth middleware requires a validator function")
+ panic("echo: key-auth middleware requires a validator function")
}
- limit := cmp.Or(config.AllowedCheckLimit, 1)
-
- extractors, cErr := createExtractors(config.KeyLookup, limit)
+ extractors, cErr := createExtractors(config.KeyLookup, config.AuthScheme)
if cErr != nil {
- return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", cErr)
- }
- if len(extractors) == 0 {
- return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string")
+ panic(cErr)
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) error {
+ return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
@@ -165,41 +121,59 @@ func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
var lastExtractorErr error
var lastValidatorErr error
for _, extractor := range extractors {
- keys, source, extrErr := extractor(c)
- if extrErr != nil {
- lastExtractorErr = extrErr
+ keys, err := extractor(c)
+ if err != nil {
+ lastExtractorErr = err
continue
}
for _, key := range keys {
- valid, err := config.Validator(c, key, source)
+ valid, err := config.Validator(key, c)
if err != nil {
lastValidatorErr = err
continue
}
- if !valid {
- lastValidatorErr = ErrInvalidKey
- continue
+ if valid {
+ return next(c)
}
- return next(c)
+ lastValidatorErr = errors.New("invalid key")
}
}
- // prioritize validator errors over extracting errors
+ // we are here only when we did not successfully extract and validate any of keys
err := lastValidatorErr
- if err == nil {
- err = lastExtractorErr
+ if err == nil { // prioritize validator errors over extracting errors
+ // ugly part to preserve backwards compatible errors. someone could rely on them
+ if lastExtractorErr == errQueryExtractorValueMissing {
+ err = errors.New("missing key in the query string")
+ } else if lastExtractorErr == errCookieExtractorValueMissing {
+ err = errors.New("missing key in cookies")
+ } else if lastExtractorErr == errFormExtractorValueMissing {
+ err = errors.New("missing key in the form")
+ } else if lastExtractorErr == errHeaderExtractorValueMissing {
+ err = errors.New("missing key in request header")
+ } else if lastExtractorErr == errHeaderExtractorValueInvalid {
+ err = errors.New("invalid key in the request header")
+ } else {
+ err = lastExtractorErr
+ }
+ err = &ErrKeyAuthMissing{Err: err}
}
+
if config.ErrorHandler != nil {
- tmpErr := config.ErrorHandler(c, err)
+ tmpErr := config.ErrorHandler(err, c)
if config.ContinueOnIgnoredError && tmpErr == nil {
return next(c)
}
return tmpErr
}
- if lastValidatorErr == nil {
- return ErrKeyMissing.Wrap(err)
+ if lastValidatorErr != nil { // prioritize validator errors over extracting errors
+ return &echo.HTTPError{
+ Code: http.StatusUnauthorized,
+ Message: "Unauthorized",
+ Internal: lastValidatorErr,
+ }
}
- return echo.ErrUnauthorized.Wrap(err)
+ return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
- }, nil
+ }
}
diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go
index 49a917ed3..447f0bee8 100644
--- a/middleware/key_auth_test.go
+++ b/middleware/key_auth_test.go
@@ -4,34 +4,30 @@
package middleware
import (
- "crypto/subtle"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
-func testKeyValidator(c *echo.Context, key string, source ExtractorSource) (bool, error) {
- // Use constant-time comparison to prevent timing attacks
- if subtle.ConstantTimeCompare([]byte(key), []byte("valid-key")) == 1 {
+func testKeyValidator(key string, c echo.Context) (bool, error) {
+ switch key {
+ case "valid-key":
return true, nil
- }
-
- // Special case for testing error handling
- if key == "error-key" { // Error path doesn't need constant-time
+ case "error-key":
return false, errors.New("some user defined error")
+ default:
+ return false, nil
}
-
- return false, nil
}
func TestKeyAuth(t *testing.T) {
handlerCalled := false
- handler := func(c *echo.Context) error {
+ handler := func(c echo.Context) error {
handlerCalled = true
return c.String(http.StatusOK, "test")
}
@@ -71,7 +67,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
},
whenConfig: func(conf *KeyAuthConfig) {
- conf.Skipper = func(context *echo.Context) bool {
+ conf.Skipper = func(context echo.Context) bool {
return true
}
},
@@ -83,7 +79,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
},
expectHandlerCalled: false,
- expectError: "code=401, message=Unauthorized, err=code=401, message=invalid key",
+ expectError: "code=401, message=Unauthorized, internal=invalid key",
},
{
name: "nok, defaults, invalid scheme in header",
@@ -91,13 +87,24 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bear valid-key")
},
expectHandlerCalled: false,
- expectError: "code=401, message=missing key, err=invalid value in request header",
+ expectError: "code=400, message=invalid key in the request header",
},
{
name: "nok, defaults, missing header",
givenRequest: func(req *http.Request) {},
expectHandlerCalled: false,
- expectError: "code=401, message=missing key, err=missing value in request header",
+ expectError: "code=400, message=missing key in request header",
+ },
+ {
+ name: "ok, custom key lookup from multiple places, query and header",
+ givenRequest: func(req *http.Request) {
+ req.URL.RawQuery = "key=invalid-key"
+ req.Header.Set("API-Key", "valid-key")
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "query:key,header:API-Key"
+ },
+ expectHandlerCalled: true,
},
{
name: "ok, custom key lookup, header",
@@ -117,7 +124,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "header:API-Key"
},
expectHandlerCalled: false,
- expectError: "code=401, message=missing key, err=missing value in request header",
+ expectError: "code=400, message=missing key in request header",
},
{
name: "ok, custom key lookup, query",
@@ -137,7 +144,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "query:key"
},
expectHandlerCalled: false,
- expectError: "code=401, message=missing key, err=missing value in the query string",
+ expectError: "code=400, message=missing key in the query string",
},
{
name: "ok, custom key lookup, form",
@@ -162,7 +169,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "form:key"
},
expectHandlerCalled: false,
- expectError: "code=401, message=missing key, err=missing value in the form",
+ expectError: "code=400, message=missing key in the form",
},
{
name: "ok, custom key lookup, cookie",
@@ -186,18 +193,20 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "cookie:key"
},
expectHandlerCalled: false,
- expectError: "code=401, message=missing key, err=missing value in cookies",
+ expectError: "code=400, message=missing key in cookies",
},
{
name: "nok, custom errorHandler, error from extractor",
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "header:token"
- conf.ErrorHandler = func(c *echo.Context, err error) error {
- return echo.NewHTTPError(http.StatusTeapot, "custom").Wrap(err)
+ conf.ErrorHandler = func(err error, context echo.Context) error {
+ httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
+ httpError.Internal = err
+ return httpError
}
},
expectHandlerCalled: false,
- expectError: "code=418, message=custom, err=missing value in request header",
+ expectError: "code=418, message=custom, internal=missing key in request header",
},
{
name: "nok, custom errorHandler, error from validator",
@@ -205,12 +214,14 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
},
whenConfig: func(conf *KeyAuthConfig) {
- conf.ErrorHandler = func(c *echo.Context, err error) error {
- return echo.NewHTTPError(http.StatusTeapot, "custom").Wrap(err)
+ conf.ErrorHandler = func(err error, context echo.Context) error {
+ httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
+ httpError.Internal = err
+ return httpError
}
},
expectHandlerCalled: false,
- expectError: "code=418, message=custom, err=some user defined error",
+ expectError: "code=418, message=custom, internal=some user defined error",
},
{
name: "nok, defaults, error from validator",
@@ -219,33 +230,14 @@ func TestKeyAuthWithConfig(t *testing.T) {
},
whenConfig: func(conf *KeyAuthConfig) {},
expectHandlerCalled: false,
- expectError: "code=401, message=Unauthorized, err=some user defined error",
- },
- {
- name: "ok, custom validator checks source",
- givenRequest: func(req *http.Request) {
- q := req.URL.Query()
- q.Add("key", "valid-key")
- req.URL.RawQuery = q.Encode()
- },
- whenConfig: func(conf *KeyAuthConfig) {
- conf.KeyLookup = "query:key"
- conf.Validator = func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
- if source == ExtractorSourceQuery {
- return true, nil
- }
- return false, errors.New("invalid source")
- }
-
- },
- expectHandlerCalled: true,
+ expectError: "code=401, message=Unauthorized, internal=some user defined error",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
handlerCalled := false
- handler := func(c *echo.Context) error {
+ handler := func(c echo.Context) error {
handlerCalled = true
return c.String(http.StatusOK, "test")
}
@@ -280,96 +272,108 @@ func TestKeyAuthWithConfig(t *testing.T) {
}
}
-func TestKeyAuthWithConfig_errors(t *testing.T) {
+func TestKeyAuthWithConfig_panicsOnInvalidLookup(t *testing.T) {
+ assert.PanicsWithError(
+ t,
+ "extractor source for lookup could not be split into needed parts: a",
+ func() {
+ handler := func(c echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+ KeyAuthWithConfig(KeyAuthConfig{
+ Validator: testKeyValidator,
+ KeyLookup: "a",
+ })(handler)
+ },
+ )
+}
+
+func TestKeyAuthWithConfig_panicsOnEmptyValidator(t *testing.T) {
+ assert.PanicsWithValue(
+ t,
+ "echo: key-auth middleware requires a validator function",
+ func() {
+ handler := func(c echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+ KeyAuthWithConfig(KeyAuthConfig{
+ Validator: nil,
+ })(handler)
+ },
+ )
+}
+
+func TestKeyAuthWithConfig_ContinueOnIgnoredError(t *testing.T) {
var testCases = []struct {
- name string
- whenConfig KeyAuthConfig
- expectError string
+ name string
+ whenContinueOnIgnoredError bool
+ givenKey string
+ expectStatus int
+ expectBody string
}{
{
- name: "ok, no error",
- whenConfig: KeyAuthConfig{
- Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
- return false, nil
- },
- },
+ name: "no error handler is called",
+ whenContinueOnIgnoredError: true,
+ givenKey: "valid-key",
+ expectStatus: http.StatusTeapot,
+ expectBody: "",
},
{
- name: "ok, missing validator func",
- whenConfig: KeyAuthConfig{
- Validator: nil,
- },
- expectError: "echo key-auth middleware requires a validator function",
+ name: "ContinueOnIgnoredError is false and error handler is called for missing token",
+ whenContinueOnIgnoredError: false,
+ givenKey: "",
+ // empty response with 200. This emulates previous behaviour when error handler swallowed the error
+ expectStatus: http.StatusOK,
+ expectBody: "",
},
{
- name: "ok, extractor source can not be split",
- whenConfig: KeyAuthConfig{
- KeyLookup: "nope",
- Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
- return false, nil
- },
- },
- expectError: "echo key-auth middleware could not create key extractor: extractor source for lookup could not be split into needed parts: nope",
+ name: "error handler is called for missing token",
+ whenContinueOnIgnoredError: true,
+ givenKey: "",
+ expectStatus: http.StatusTeapot,
+ expectBody: "public-auth",
},
{
- name: "ok, no extractors",
- whenConfig: KeyAuthConfig{
- KeyLookup: "nope:nope",
- Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
- return false, nil
- },
- },
- expectError: "echo key-auth middleware could not create extractors from KeyLookup string",
+ name: "error handler is called for invalid token",
+ whenContinueOnIgnoredError: true,
+ givenKey: "x.x.x",
+ expectStatus: http.StatusUnauthorized,
+ expectBody: "{\"message\":\"Unauthorized\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- mw, err := tc.whenConfig.ToMiddleware()
- if tc.expectError != "" {
- assert.Nil(t, mw)
- assert.EqualError(t, err, tc.expectError)
- } else {
- assert.NotNil(t, mw)
- assert.NoError(t, err)
- }
- })
- }
-}
+ e := echo.New()
-func TestMustKeyAuthWithConfig_panic(t *testing.T) {
- assert.Panics(t, func() {
- KeyAuthWithConfig(KeyAuthConfig{})
- })
-}
+ e.GET("/", func(c echo.Context) error {
+ testValue, _ := c.Get("test").(string)
+ return c.String(http.StatusTeapot, testValue)
+ })
-func TestKeyAuth_errorHandlerSwallowsError(t *testing.T) {
- handlerCalled := false
- var authValue string
- handler := func(c *echo.Context) error {
- handlerCalled = true
- authValue = c.Get("auth").(string)
- return c.String(http.StatusOK, "test")
- }
- middlewareChain := KeyAuthWithConfig(KeyAuthConfig{
- Validator: testKeyValidator,
- ErrorHandler: func(c *echo.Context, err error) error {
- // could check error to decide if we can swallow the error
- c.Set("auth", "public")
- return nil
- },
- ContinueOnIgnoredError: true,
- })(handler)
+ e.Use(KeyAuthWithConfig(KeyAuthConfig{
+ Validator: testKeyValidator,
+ ErrorHandler: func(err error, c echo.Context) error {
+ if _, ok := err.(*ErrKeyAuthMissing); ok {
+ c.Set("test", "public-auth")
+ return nil
+ }
+ return echo.ErrUnauthorized
+ },
+ KeyLookup: "header:X-API-Key",
+ ContinueOnIgnoredError: tc.whenContinueOnIgnoredError,
+ }))
- e := echo.New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- // no auth header this time
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ if tc.givenKey != "" {
+ req.Header.Set("X-API-Key", tc.givenKey)
+ }
+ res := httptest.NewRecorder()
- err := middlewareChain(c)
+ e.ServeHTTP(res, req)
- assert.NoError(t, err)
- assert.True(t, handlerCalled)
- assert.Equal(t, "public", authValue)
+ assert.Equal(t, tc.expectStatus, res.Code)
+ assert.Equal(t, tc.expectBody, res.Body.String())
+ })
+ }
}
diff --git a/middleware/logger.go b/middleware/logger.go
new file mode 100644
index 000000000..0766dbf55
--- /dev/null
+++ b/middleware/logger.go
@@ -0,0 +1,420 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "bytes"
+ "io"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/labstack/echo/v4"
+ "github.com/labstack/gommon/color"
+ "github.com/valyala/fasttemplate"
+)
+
+// LoggerConfig defines the config for Logger middleware.
+//
+// # Configuration Examples
+//
+// ## Basic Usage with Default Settings
+//
+// e.Use(middleware.Logger())
+//
+// This uses the default JSON format that logs all common request/response details.
+//
+// ## Custom Simple Format
+//
+// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
+// Format: "${time_rfc3339_nano} ${status} ${method} ${uri} ${latency_human}\n",
+// }))
+//
+// ## JSON Format with Custom Fields
+//
+// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
+// Format: `{"timestamp":"${time_rfc3339_nano}","level":"info","remote_ip":"${remote_ip}",` +
+// `"method":"${method}","uri":"${uri}","status":${status},"latency":"${latency_human}",` +
+// `"user_agent":"${user_agent}","error":"${error}"}` + "\n",
+// }))
+//
+// ## Custom Time Format
+//
+// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
+// Format: "${time_custom} ${method} ${uri} ${status}\n",
+// CustomTimeFormat: "2006-01-02 15:04:05",
+// }))
+//
+// ## Logging Headers and Parameters
+//
+// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
+// Format: `{"time":"${time_rfc3339_nano}","method":"${method}","uri":"${uri}",` +
+// `"status":${status},"auth":"${header:Authorization}","user":"${query:user}",` +
+// `"form_data":"${form:action}","session":"${cookie:session_id}"}` + "\n",
+// }))
+//
+// ## Custom Output (File Logging)
+//
+// file, err := os.OpenFile("app.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
+// if err != nil {
+// log.Fatal(err)
+// }
+// defer file.Close()
+//
+// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
+// Output: file,
+// }))
+//
+// ## Custom Tag Function
+//
+// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
+// Format: `{"time":"${time_rfc3339_nano}","user_id":"${custom}","method":"${method}"}` + "\n",
+// CustomTagFunc: func(c echo.Context, buf *bytes.Buffer) (int, error) {
+// userID := getUserIDFromContext(c) // Your custom logic
+// return buf.WriteString(strconv.Itoa(userID))
+// },
+// }))
+//
+// ## Conditional Logging (Skip Certain Requests)
+//
+// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
+// Skipper: func(c echo.Context) bool {
+// // Skip logging for health check endpoints
+// return c.Request().URL.Path == "/health" || c.Request().URL.Path == "/metrics"
+// },
+// }))
+//
+// ## Integration with External Logging Service
+//
+// logBuffer := &SyncBuffer{} // Thread-safe buffer for external service
+//
+// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
+// Format: `{"timestamp":"${time_rfc3339_nano}","service":"my-api","level":"info",` +
+// `"method":"${method}","uri":"${uri}","status":${status},"latency_ms":${latency},` +
+// `"remote_ip":"${remote_ip}","user_agent":"${user_agent}","error":"${error}"}` + "\n",
+// Output: logBuffer,
+// }))
+//
+// # Available Tags
+//
+// ## Time Tags
+// - time_unix: Unix timestamp (seconds)
+// - time_unix_milli: Unix timestamp (milliseconds)
+// - time_unix_micro: Unix timestamp (microseconds)
+// - time_unix_nano: Unix timestamp (nanoseconds)
+// - time_rfc3339: RFC3339 format (2006-01-02T15:04:05Z07:00)
+// - time_rfc3339_nano: RFC3339 with nanoseconds
+// - time_custom: Uses CustomTimeFormat field
+//
+// ## Request Information
+// - id: Request ID from X-Request-ID header
+// - remote_ip: Client IP address (respects proxy headers)
+// - uri: Full request URI with query parameters
+// - host: Host header value
+// - method: HTTP method (GET, POST, etc.)
+// - path: URL path without query parameters
+// - route: Echo route pattern (e.g., /users/:id)
+// - protocol: HTTP protocol version
+// - referer: Referer header value
+// - user_agent: User-Agent header value
+//
+// ## Response Information
+// - status: HTTP status code
+// - error: Error message if request failed
+// - latency: Request processing time in nanoseconds
+// - latency_human: Human-readable processing time
+// - bytes_in: Request body size in bytes
+// - bytes_out: Response body size in bytes
+//
+// ## Dynamic Tags
+// - header:: Value of specific header (e.g., header:Authorization)
+// - query:: Value of specific query parameter (e.g., query:user_id)
+// - form:: Value of specific form field (e.g., form:username)
+// - cookie:: Value of specific cookie (e.g., cookie:session_id)
+// - custom: Output from CustomTagFunc
+//
+// # Troubleshooting
+//
+// ## Common Issues
+//
+// 1. **Missing logs**: Check if Skipper function is filtering out requests
+// 2. **Invalid JSON**: Ensure CustomTagFunc outputs valid JSON content
+// 3. **Performance issues**: Consider using a buffered writer for high-traffic applications
+// 4. **File permission errors**: Ensure write permissions when logging to files
+//
+// ## Performance Tips
+//
+// - Use time_unix formats for better performance than time_rfc3339
+// - Minimize the number of dynamic tags (header:, query:, form:, cookie:)
+// - Use Skipper to exclude high-frequency, low-value requests (health checks, etc.)
+// - Consider async logging for very high-traffic applications
+type LoggerConfig struct {
+ // Skipper defines a function to skip middleware.
+ // Use this to exclude certain requests from logging (e.g., health checks).
+ //
+ // Example:
+ // Skipper: func(c echo.Context) bool {
+ // return c.Request().URL.Path == "/health"
+ // },
+ Skipper Skipper
+
+ // Format defines the logging format using template tags.
+ // Tags are enclosed in ${} and replaced with actual values.
+ // See the detailed tag documentation above for all available options.
+ //
+ // Default: JSON format with common fields
+ // Example: "${time_rfc3339_nano} ${status} ${method} ${uri} ${latency_human}\n"
+ Format string `yaml:"format"`
+
+ // CustomTimeFormat specifies the time format used by ${time_custom} tag.
+ // Uses Go's reference time: Mon Jan 2 15:04:05 MST 2006
+ //
+ // Default: "2006-01-02 15:04:05.00000"
+ // Example: "2006-01-02 15:04:05" or "15:04:05.000"
+ CustomTimeFormat string `yaml:"custom_time_format"`
+
+ // CustomTagFunc is called when ${custom} tag is encountered.
+ // Use this to add application-specific information to logs.
+ // The function should write valid content for your log format.
+ //
+ // Example:
+ // CustomTagFunc: func(c echo.Context, buf *bytes.Buffer) (int, error) {
+ // userID := getUserFromContext(c)
+ // return buf.WriteString(`"user_id":"` + userID + `"`)
+ // },
+ CustomTagFunc func(c echo.Context, buf *bytes.Buffer) (int, error)
+
+ // Output specifies where logs are written.
+ // Can be any io.Writer: files, buffers, network connections, etc.
+ //
+ // Default: os.Stdout
+ // Example: Custom file, syslog, or external logging service
+ Output io.Writer
+
+ template *fasttemplate.Template
+ colorer *color.Color
+ pool *sync.Pool
+ timeNow func() time.Time
+}
+
+// DefaultLoggerConfig is the default Logger middleware config.
+var DefaultLoggerConfig = LoggerConfig{
+ Skipper: DefaultSkipper,
+ Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` +
+ `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` +
+ `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` +
+ `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n",
+ CustomTimeFormat: "2006-01-02 15:04:05.00000",
+ colorer: color.New(),
+ timeNow: time.Now,
+}
+
+// Logger returns a middleware that logs HTTP requests using the default configuration.
+//
+// The default format logs requests as JSON with the following fields:
+// - time: RFC3339 nano timestamp
+// - id: Request ID from X-Request-ID header
+// - remote_ip: Client IP address
+// - host: Host header
+// - method: HTTP method
+// - uri: Request URI
+// - user_agent: User-Agent header
+// - status: HTTP status code
+// - error: Error message (if any)
+// - latency: Processing time in nanoseconds
+// - latency_human: Human-readable processing time
+// - bytes_in: Request body size
+// - bytes_out: Response body size
+//
+// Example output:
+//
+// {"time":"2023-01-15T10:30:45.123456789Z","id":"","remote_ip":"127.0.0.1",
+// "host":"localhost:8080","method":"GET","uri":"/users/123","user_agent":"curl/7.81.0",
+// "status":200,"error":"","latency":1234567,"latency_human":"1.234567ms",
+// "bytes_in":0,"bytes_out":42}
+//
+// For custom configurations, use LoggerWithConfig instead.
+//
+// Deprecated: please use middleware.RequestLogger or middleware.RequestLoggerWithConfig instead.
+func Logger() echo.MiddlewareFunc {
+ return LoggerWithConfig(DefaultLoggerConfig)
+}
+
+// LoggerWithConfig returns a Logger middleware with custom configuration.
+//
+// This function allows you to customize all aspects of request logging including:
+// - Log format and fields
+// - Output destination
+// - Time formatting
+// - Custom tags and logic
+// - Request filtering
+//
+// See LoggerConfig documentation for detailed configuration examples and options.
+//
+// Example:
+//
+// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
+// Format: "${time_rfc3339} ${status} ${method} ${uri} ${latency_human}\n",
+// Output: customLogWriter,
+// Skipper: func(c echo.Context) bool {
+// return c.Request().URL.Path == "/health"
+// },
+// }))
+//
+// Deprecated: please use middleware.RequestLoggerWithConfig instead.
+func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
+ // Defaults
+ if config.Skipper == nil {
+ config.Skipper = DefaultLoggerConfig.Skipper
+ }
+ if config.Format == "" {
+ config.Format = DefaultLoggerConfig.Format
+ }
+ writeString := func(buf *bytes.Buffer, in string) (int, error) { return buf.WriteString(in) }
+ if config.Format[0] == '{' { // format looks like JSON, so we need to escape invalid characters
+ writeString = writeJSONSafeString
+ }
+
+ if config.Output == nil {
+ config.Output = DefaultLoggerConfig.Output
+ }
+ timeNow := DefaultLoggerConfig.timeNow
+ if config.timeNow != nil {
+ timeNow = config.timeNow
+ }
+
+ config.template = fasttemplate.New(config.Format, "${", "}")
+ config.colorer = color.New()
+ config.colorer.SetOutput(config.Output)
+ config.pool = &sync.Pool{
+ New: func() any {
+ return bytes.NewBuffer(make([]byte, 256))
+ },
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) (err error) {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
+ req := c.Request()
+ res := c.Response()
+ start := time.Now()
+ if err = next(c); err != nil {
+ c.Error(err)
+ }
+ stop := time.Now()
+ buf := config.pool.Get().(*bytes.Buffer)
+ buf.Reset()
+ defer config.pool.Put(buf)
+
+ if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) {
+ switch tag {
+ case "custom":
+ if config.CustomTagFunc == nil {
+ return 0, nil
+ }
+ return config.CustomTagFunc(c, buf)
+ case "time_unix":
+ return buf.WriteString(strconv.FormatInt(timeNow().Unix(), 10))
+ case "time_unix_milli":
+ return buf.WriteString(strconv.FormatInt(timeNow().UnixMilli(), 10))
+ case "time_unix_micro":
+ return buf.WriteString(strconv.FormatInt(timeNow().UnixMicro(), 10))
+ case "time_unix_nano":
+ return buf.WriteString(strconv.FormatInt(timeNow().UnixNano(), 10))
+ case "time_rfc3339":
+ return buf.WriteString(timeNow().Format(time.RFC3339))
+ case "time_rfc3339_nano":
+ return buf.WriteString(timeNow().Format(time.RFC3339Nano))
+ case "time_custom":
+ return buf.WriteString(timeNow().Format(config.CustomTimeFormat))
+ case "id":
+ id := req.Header.Get(echo.HeaderXRequestID)
+ if id == "" {
+ id = res.Header().Get(echo.HeaderXRequestID)
+ }
+ return writeString(buf, id)
+ case "remote_ip":
+ return writeString(buf, c.RealIP())
+ case "host":
+ return writeString(buf, req.Host)
+ case "uri":
+ return writeString(buf, req.RequestURI)
+ case "method":
+ return writeString(buf, req.Method)
+ case "path":
+ p := req.URL.Path
+ if p == "" {
+ p = "/"
+ }
+ return writeString(buf, p)
+ case "route":
+ return writeString(buf, c.Path())
+ case "protocol":
+ return writeString(buf, req.Proto)
+ case "referer":
+ return writeString(buf, req.Referer())
+ case "user_agent":
+ return writeString(buf, req.UserAgent())
+ case "status":
+ n := res.Status
+ s := config.colorer.Green(n)
+ switch {
+ case n >= 500:
+ s = config.colorer.Red(n)
+ case n >= 400:
+ s = config.colorer.Yellow(n)
+ case n >= 300:
+ s = config.colorer.Cyan(n)
+ }
+ return buf.WriteString(s)
+ case "error":
+ if err != nil {
+ return writeJSONSafeString(buf, err.Error())
+ }
+ case "latency":
+ l := stop.Sub(start)
+ return buf.WriteString(strconv.FormatInt(int64(l), 10))
+ case "latency_human":
+ return buf.WriteString(stop.Sub(start).String())
+ case "bytes_in":
+ cl := req.Header.Get(echo.HeaderContentLength)
+ if cl == "" {
+ cl = "0"
+ }
+ return writeString(buf, cl)
+ case "bytes_out":
+ return buf.WriteString(strconv.FormatInt(res.Size, 10))
+ default:
+ switch {
+ case strings.HasPrefix(tag, "header:"):
+ return writeString(buf, c.Request().Header.Get(tag[7:]))
+ case strings.HasPrefix(tag, "query:"):
+ return writeString(buf, c.QueryParam(tag[6:]))
+ case strings.HasPrefix(tag, "form:"):
+ return writeString(buf, c.FormValue(tag[5:]))
+ case strings.HasPrefix(tag, "cookie:"):
+ cookie, err := c.Cookie(tag[7:])
+ if err == nil {
+ return buf.Write([]byte(cookie.Value))
+ }
+ }
+ }
+ return 0, nil
+ }); err != nil {
+ return
+ }
+
+ if config.Output == nil {
+ _, err = c.Logger().Output().Write(buf.Bytes())
+ return
+ }
+ _, err = config.Output.Write(buf.Bytes())
+ return
+ }
+ }
+}
diff --git a/middleware/logger_strings.go b/middleware/logger_strings.go
new file mode 100644
index 000000000..8476cb046
--- /dev/null
+++ b/middleware/logger_strings.go
@@ -0,0 +1,242 @@
+// SPDX-License-Identifier: BSD-3-Clause
+// SPDX-FileCopyrightText: Copyright 2010 The Go Authors
+//
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+//
+//
+// Go LICENSE https://raw.githubusercontent.com/golang/go/36bca3166e18db52687a4d91ead3f98ffe6d00b8/LICENSE
+/**
+Copyright 2009 The Go Authors.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google LLC nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+
+package middleware
+
+import (
+ "bytes"
+ "unicode/utf8"
+)
+
+// This function is modified copy from Go standard library encoding/json/encode.go `appendString` function
+// Source: https://github.com/golang/go/blob/36bca3166e18db52687a4d91ead3f98ffe6d00b8/src/encoding/json/encode.go#L999
+func writeJSONSafeString(buf *bytes.Buffer, src string) (int, error) {
+ const hex = "0123456789abcdef"
+
+ written := 0
+ start := 0
+ for i := 0; i < len(src); {
+ if b := src[i]; b < utf8.RuneSelf {
+ if safeSet[b] {
+ i++
+ continue
+ }
+
+ n, err := buf.Write([]byte(src[start:i]))
+ written += n
+ if err != nil {
+ return written, err
+ }
+ switch b {
+ case '\\', '"':
+ n, err := buf.Write([]byte{'\\', b})
+ written += n
+ if err != nil {
+ return written, err
+ }
+ case '\b':
+ n, err := buf.Write([]byte{'\\', 'b'})
+ written += n
+ if err != nil {
+ return n, err
+ }
+ case '\f':
+ n, err := buf.Write([]byte{'\\', 'f'})
+ written += n
+ if err != nil {
+ return written, err
+ }
+ case '\n':
+ n, err := buf.Write([]byte{'\\', 'n'})
+ written += n
+ if err != nil {
+ return written, err
+ }
+ case '\r':
+ n, err := buf.Write([]byte{'\\', 'r'})
+ written += n
+ if err != nil {
+ return written, err
+ }
+ case '\t':
+ n, err := buf.Write([]byte{'\\', 't'})
+ written += n
+ if err != nil {
+ return written, err
+ }
+ default:
+ // This encodes bytes < 0x20 except for \b, \f, \n, \r and \t.
+ n, err := buf.Write([]byte{'\\', 'u', '0', '0', hex[b>>4], hex[b&0xF]})
+ written += n
+ if err != nil {
+ return written, err
+ }
+ }
+ i++
+ start = i
+ continue
+ }
+ srcN := min(len(src)-i, utf8.UTFMax)
+ c, size := utf8.DecodeRuneInString(src[i : i+srcN])
+ if c == utf8.RuneError && size == 1 {
+ n, err := buf.Write([]byte(src[start:i]))
+ written += n
+ if err != nil {
+ return written, err
+ }
+ n, err = buf.Write([]byte(`\ufffd`))
+ written += n
+ if err != nil {
+ return written, err
+ }
+ i += size
+ start = i
+ continue
+ }
+ i += size
+ }
+ n, err := buf.Write([]byte(src[start:]))
+ written += n
+ return written, err
+}
+
+// safeSet holds the value true if the ASCII character with the given array
+// position can be represented inside a JSON string without any further
+// escaping.
+//
+// All values are true except for the ASCII control characters (0-31), the
+// double quote ("), and the backslash character ("\").
+var safeSet = [utf8.RuneSelf]bool{
+ ' ': true,
+ '!': true,
+ '"': false,
+ '#': true,
+ '$': true,
+ '%': true,
+ '&': true,
+ '\'': true,
+ '(': true,
+ ')': true,
+ '*': true,
+ '+': true,
+ ',': true,
+ '-': true,
+ '.': true,
+ '/': true,
+ '0': true,
+ '1': true,
+ '2': true,
+ '3': true,
+ '4': true,
+ '5': true,
+ '6': true,
+ '7': true,
+ '8': true,
+ '9': true,
+ ':': true,
+ ';': true,
+ '<': true,
+ '=': true,
+ '>': true,
+ '?': true,
+ '@': true,
+ 'A': true,
+ 'B': true,
+ 'C': true,
+ 'D': true,
+ 'E': true,
+ 'F': true,
+ 'G': true,
+ 'H': true,
+ 'I': true,
+ 'J': true,
+ 'K': true,
+ 'L': true,
+ 'M': true,
+ 'N': true,
+ 'O': true,
+ 'P': true,
+ 'Q': true,
+ 'R': true,
+ 'S': true,
+ 'T': true,
+ 'U': true,
+ 'V': true,
+ 'W': true,
+ 'X': true,
+ 'Y': true,
+ 'Z': true,
+ '[': true,
+ '\\': false,
+ ']': true,
+ '^': true,
+ '_': true,
+ '`': true,
+ 'a': true,
+ 'b': true,
+ 'c': true,
+ 'd': true,
+ 'e': true,
+ 'f': true,
+ 'g': true,
+ 'h': true,
+ 'i': true,
+ 'j': true,
+ 'k': true,
+ 'l': true,
+ 'm': true,
+ 'n': true,
+ 'o': true,
+ 'p': true,
+ 'q': true,
+ 'r': true,
+ 's': true,
+ 't': true,
+ 'u': true,
+ 'v': true,
+ 'w': true,
+ 'x': true,
+ 'y': true,
+ 'z': true,
+ '{': true,
+ '|': true,
+ '}': true,
+ '~': true,
+ '\u007f': true,
+}
diff --git a/middleware/logger_strings_test.go b/middleware/logger_strings_test.go
new file mode 100644
index 000000000..3d66404c5
--- /dev/null
+++ b/middleware/logger_strings_test.go
@@ -0,0 +1,288 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestWriteJSONSafeString(t *testing.T) {
+ testCases := []struct {
+ name string
+ whenInput string
+ expect string
+ expectN int
+ }{
+ // Basic cases
+ {
+ name: "empty string",
+ whenInput: "",
+ expect: "",
+ expectN: 0,
+ },
+ {
+ name: "simple ASCII without special chars",
+ whenInput: "hello",
+ expect: "hello",
+ expectN: 5,
+ },
+ {
+ name: "single character",
+ whenInput: "a",
+ expect: "a",
+ expectN: 1,
+ },
+ {
+ name: "alphanumeric",
+ whenInput: "Hello123World",
+ expect: "Hello123World",
+ expectN: 13,
+ },
+
+ // Special character escaping
+ {
+ name: "backslash",
+ whenInput: `path\to\file`,
+ expect: `path\\to\\file`,
+ expectN: 14,
+ },
+ {
+ name: "double quote",
+ whenInput: `say "hello"`,
+ expect: `say \"hello\"`,
+ expectN: 13,
+ },
+ {
+ name: "backslash and quote combined",
+ whenInput: `a\b"c`,
+ expect: `a\\b\"c`,
+ expectN: 7,
+ },
+ {
+ name: "single backslash",
+ whenInput: `\`,
+ expect: `\\`,
+ expectN: 2,
+ },
+ {
+ name: "single quote",
+ whenInput: `"`,
+ expect: `\"`,
+ expectN: 2,
+ },
+
+ // Control character escaping
+ {
+ name: "backspace",
+ whenInput: "hello\bworld",
+ expect: `hello\bworld`,
+ expectN: 12,
+ },
+ {
+ name: "form feed",
+ whenInput: "hello\fworld",
+ expect: `hello\fworld`,
+ expectN: 12,
+ },
+ {
+ name: "newline",
+ whenInput: "hello\nworld",
+ expect: `hello\nworld`,
+ expectN: 12,
+ },
+ {
+ name: "carriage return",
+ whenInput: "hello\rworld",
+ expect: `hello\rworld`,
+ expectN: 12,
+ },
+ {
+ name: "tab",
+ whenInput: "hello\tworld",
+ expect: `hello\tworld`,
+ expectN: 12,
+ },
+ {
+ name: "multiple newlines",
+ whenInput: "line1\nline2\nline3",
+ expect: `line1\nline2\nline3`,
+ expectN: 19,
+ },
+
+ // Low control characters (< 0x20)
+ {
+ name: "null byte",
+ whenInput: "hello\x00world",
+ expect: `hello\u0000world`,
+ expectN: 16,
+ },
+ {
+ name: "control character 0x01",
+ whenInput: "test\x01value",
+ expect: `test\u0001value`,
+ expectN: 15,
+ },
+ {
+ name: "control character 0x0e",
+ whenInput: "test\x0evalue",
+ expect: `test\u000evalue`,
+ expectN: 15,
+ },
+ {
+ name: "control character 0x1f",
+ whenInput: "test\x1fvalue",
+ expect: `test\u001fvalue`,
+ expectN: 15,
+ },
+ {
+ name: "multiple control characters",
+ whenInput: "\x00\x01\x02",
+ expect: `\u0000\u0001\u0002`,
+ expectN: 18,
+ },
+
+ // UTF-8 handling
+ {
+ name: "valid UTF-8 Chinese",
+ whenInput: "hello δΈη",
+ expect: "hello δΈη",
+ expectN: 12,
+ },
+ {
+ name: "valid UTF-8 emoji",
+ whenInput: "party π time",
+ expect: "party π time",
+ expectN: 15,
+ },
+ {
+ name: "mixed ASCII and UTF-8",
+ whenInput: "HelloδΈη123",
+ expect: "HelloδΈη123",
+ expectN: 14,
+ },
+ {
+ name: "UTF-8 with special chars",
+ whenInput: "δΈη\n\"test\"",
+ expect: `δΈη\n\"test\"`,
+ expectN: 16,
+ },
+
+ // Invalid UTF-8
+ {
+ name: "invalid UTF-8 sequence",
+ whenInput: "hello\xff\xfeworld",
+ expect: `hello\ufffd\ufffdworld`,
+ expectN: 22,
+ },
+ {
+ name: "incomplete UTF-8 sequence",
+ whenInput: "test\xc3value",
+ expect: `test\ufffdvalue`,
+ expectN: 15,
+ },
+
+ // Complex mixed cases
+ {
+ name: "all common escapes",
+ whenInput: "tab\there\nquote\"backslash\\",
+ expect: `tab\there\nquote\"backslash\\`,
+ expectN: 29,
+ },
+ {
+ name: "mixed controls and UTF-8",
+ whenInput: "hello\tδΈη\ntest\"",
+ expect: `hello\tδΈη\ntest\"`,
+ expectN: 21,
+ },
+ {
+ name: "all control characters",
+ whenInput: "\b\f\n\r\t",
+ expect: `\b\f\n\r\t`,
+ expectN: 10,
+ },
+ {
+ name: "control and low ASCII",
+ whenInput: "a\nb\x00c",
+ expect: `a\nb\u0000c`,
+ expectN: 11,
+ },
+
+ // Edge cases
+ {
+ name: "starts with special char",
+ whenInput: "\\start",
+ expect: `\\start`,
+ expectN: 7,
+ },
+ {
+ name: "ends with special char",
+ whenInput: "end\"",
+ expect: `end\"`,
+ expectN: 5,
+ },
+ {
+ name: "consecutive special chars",
+ whenInput: "\\\\\"\"",
+ expect: `\\\\\"\"`,
+ expectN: 8,
+ },
+ {
+ name: "only special characters",
+ whenInput: "\"\\\n\t",
+ expect: `\"\\\n\t`,
+ expectN: 8,
+ },
+ {
+ name: "spaces and punctuation",
+ whenInput: "Hello, World! How are you?",
+ expect: "Hello, World! How are you?",
+ expectN: 26,
+ },
+ {
+ name: "JSON-like string",
+ whenInput: "{\"key\":\"value\"}",
+ expect: `{\"key\":\"value\"}`,
+ expectN: 19,
+ },
+ }
+
+ for _, tt := range testCases {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := &bytes.Buffer{}
+ n, err := writeJSONSafeString(buf, tt.whenInput)
+
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expect, buf.String())
+ assert.Equal(t, tt.expectN, n)
+ })
+ }
+}
+
+func BenchmarkWriteJSONSafeString(b *testing.B) {
+ testCases := []struct {
+ name string
+ input string
+ }{
+ {"simple", "hello world"},
+ {"with escapes", "tab\there\nquote\"backslash\\"},
+ {"utf8", "hello δΈη π"},
+ {"mixed", "Hello\tδΈη\ntest\"value\\path"},
+ {"long simple", "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789"},
+ {"long complex", "line1\nline2\tline3\"quote\\slash\x00nullδΈηπ"},
+ }
+
+ for _, tc := range testCases {
+ b.Run(tc.name, func(b *testing.B) {
+ buf := &bytes.Buffer{}
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ buf.Reset()
+ writeJSONSafeString(buf, tc.input)
+ }
+ })
+ }
+}
diff --git a/middleware/logger_test.go b/middleware/logger_test.go
new file mode 100644
index 000000000..e4b783db5
--- /dev/null
+++ b/middleware/logger_test.go
@@ -0,0 +1,540 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "bytes"
+ "cmp"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "regexp"
+ "strings"
+ "testing"
+ "time"
+ "unsafe"
+
+ "github.com/labstack/echo/v4"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestLoggerDefaultMW(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenHeader map[string]string
+ whenStatusCode int
+ whenResponse string
+ whenError error
+ expect string
+ }{
+ {
+ name: "ok, status 200",
+ whenStatusCode: http.StatusOK,
+ whenResponse: "test",
+ expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":200,"error":"","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":4}` + "\n",
+ },
+ {
+ name: "ok, status 300",
+ whenStatusCode: http.StatusTemporaryRedirect,
+ whenResponse: "test",
+ expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":307,"error":"","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":4}` + "\n",
+ },
+ {
+ name: "ok, handler error = status 500",
+ whenError: errors.New("error"),
+ expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":500,"error":"error","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":36}` + "\n",
+ },
+ {
+ name: "error with invalid UTF-8 sequences",
+ whenError: errors.New("invalid data: \xFF\xFE"),
+ expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":500,"error":"invalid data: \ufffd\ufffd","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":36}` + "\n",
+ },
+ {
+ name: "error with JSON special characters (quotes and backslashes)",
+ whenError: errors.New(`error with "quotes" and \backslash`),
+ expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":500,"error":"error with \"quotes\" and \\backslash","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":36}` + "\n",
+ },
+ {
+ name: "error with control characters (newlines and tabs)",
+ whenError: errors.New("error\nwith\nnewlines\tand\ttabs"),
+ expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":500,"error":"error\nwith\nnewlines\tand\ttabs","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":36}` + "\n",
+ },
+ {
+ name: "ok, remote_ip from X-Real-Ip header",
+ whenHeader: map[string]string{echo.HeaderXRealIP: "127.0.0.1"},
+ whenStatusCode: http.StatusOK,
+ whenResponse: "test",
+ expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"127.0.0.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":200,"error":"","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":4}` + "\n",
+ },
+ {
+ name: "ok, remote_ip from X-Forwarded-For header",
+ whenHeader: map[string]string{echo.HeaderXForwardedFor: "127.0.0.1"},
+ whenStatusCode: http.StatusOK,
+ whenResponse: "test",
+ expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"127.0.0.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":200,"error":"","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":4}` + "\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ if len(tc.whenHeader) > 0 {
+ for k, v := range tc.whenHeader {
+ req.Header.Add(k, v)
+ }
+ }
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ DefaultLoggerConfig.timeNow = func() time.Time { return time.Unix(1588037200, 0).UTC() }
+ h := Logger()(func(c echo.Context) error {
+ if tc.whenError != nil {
+ return tc.whenError
+ }
+ return c.String(tc.whenStatusCode, tc.whenResponse)
+ })
+ buf := new(bytes.Buffer)
+ e.Logger.SetOutput(buf)
+
+ err := h(c)
+ assert.NoError(t, err)
+
+ result := buf.String()
+ // handle everchanging latency numbers
+ result = regexp.MustCompile(`"latency":\d+,`).ReplaceAllString(result, `"latency":1,`)
+ result = regexp.MustCompile(`"latency_human":"[^"]+"`).ReplaceAllString(result, `"latency_human":"1Β΅s"`)
+
+ assert.Equal(t, tc.expect, result)
+ })
+ }
+}
+
+func TestLoggerWithLoggerConfig(t *testing.T) {
+ // to handle everchanging latency numbers
+ jsonLatency := map[string]*regexp.Regexp{
+ `"latency":1,`: regexp.MustCompile(`"latency":\d+,`),
+ `"latency_human":"1Β΅s"`: regexp.MustCompile(`"latency_human":"[^"]+"`),
+ }
+
+ form := make(url.Values)
+ form.Set("csrf", "token")
+ form.Add("multiple", "1")
+ form.Add("multiple", "2")
+
+ var testCases = []struct {
+ name string
+ givenConfig LoggerConfig
+ whenURI string
+ whenMethod string
+ whenHost string
+ whenPath string
+ whenRoute string
+ whenProto string
+ whenRequestURI string
+ whenHeader map[string]string
+ whenFormValues url.Values
+ whenStatusCode int
+ whenResponse string
+ whenError error
+ whenReplacers map[string]*regexp.Regexp
+ expect string
+ }{
+ {
+ name: "ok, skipper",
+ givenConfig: LoggerConfig{
+ Skipper: func(c echo.Context) bool { return true },
+ },
+ expect: ``,
+ },
+ { // this is an example how format that does not seem to be JSON is not currently escaped
+ name: "ok, NON json string is not escaped: method",
+ givenConfig: LoggerConfig{Format: `method:"${method}"`},
+ whenMethod: `","method":":D"`,
+ expect: `method:"","method":":D""`,
+ },
+ {
+ name: "ok, json string escape: method",
+ givenConfig: LoggerConfig{Format: `{"method":"${method}"}`},
+ whenMethod: `","method":":D"`,
+ expect: `{"method":"\",\"method\":\":D\""}`,
+ },
+ {
+ name: "ok, json string escape: id",
+ givenConfig: LoggerConfig{Format: `{"id":"${id}"}`},
+ whenHeader: map[string]string{echo.HeaderXRequestID: `\"127.0.0.1\"`},
+ expect: `{"id":"\\\"127.0.0.1\\\""}`,
+ },
+ {
+ name: "ok, json string escape: remote_ip",
+ givenConfig: LoggerConfig{Format: `{"remote_ip":"${remote_ip}"}`},
+ whenHeader: map[string]string{echo.HeaderXForwardedFor: `\"127.0.0.1\"`},
+ expect: `{"remote_ip":"\\\"127.0.0.1\\\""}`,
+ },
+ {
+ name: "ok, json string escape: host",
+ givenConfig: LoggerConfig{Format: `{"host":"${host}"}`},
+ whenHost: `\"127.0.0.1\"`,
+ expect: `{"host":"\\\"127.0.0.1\\\""}`,
+ },
+ {
+ name: "ok, json string escape: path",
+ givenConfig: LoggerConfig{Format: `{"path":"${path}"}`},
+ whenPath: `\","` + "\n",
+ expect: `{"path":"\\\",\"\n"}`,
+ },
+ {
+ name: "ok, json string escape: route",
+ givenConfig: LoggerConfig{Format: `{"route":"${route}"}`},
+ whenRoute: `\","` + "\n",
+ expect: `{"route":"\\\",\"\n"}`,
+ },
+ {
+ name: "ok, json string escape: proto",
+ givenConfig: LoggerConfig{Format: `{"protocol":"${protocol}"}`},
+ whenProto: `\","` + "\n",
+ expect: `{"protocol":"\\\",\"\n"}`,
+ },
+ {
+ name: "ok, json string escape: referer",
+ givenConfig: LoggerConfig{Format: `{"referer":"${referer}"}`},
+ whenHeader: map[string]string{"Referer": `\","` + "\n"},
+ expect: `{"referer":"\\\",\"\n"}`,
+ },
+ {
+ name: "ok, json string escape: user_agent",
+ givenConfig: LoggerConfig{Format: `{"user_agent":"${user_agent}"}`},
+ whenHeader: map[string]string{"User-Agent": `\","` + "\n"},
+ expect: `{"user_agent":"\\\",\"\n"}`,
+ },
+ {
+ name: "ok, json string escape: bytes_in",
+ givenConfig: LoggerConfig{Format: `{"bytes_in":"${bytes_in}"}`},
+ whenHeader: map[string]string{echo.HeaderContentLength: `\","` + "\n"},
+ expect: `{"bytes_in":"\\\",\"\n"}`,
+ },
+ {
+ name: "ok, json string escape: query param",
+ givenConfig: LoggerConfig{Format: `{"query":"${query:test}"}`},
+ whenURI: `/?test=1","`,
+ expect: `{"query":"1\",\""}`,
+ },
+ {
+ name: "ok, json string escape: header",
+ givenConfig: LoggerConfig{Format: `{"header":"${header:referer}"}`},
+ whenHeader: map[string]string{"referer": `\","` + "\n"},
+ expect: `{"header":"\\\",\"\n"}`,
+ },
+ {
+ name: "ok, json string escape: form",
+ givenConfig: LoggerConfig{Format: `{"csrf":"${form:csrf}"}`},
+ whenMethod: http.MethodPost,
+ whenFormValues: url.Values{"csrf": {`token","`}},
+ expect: `{"csrf":"token\",\""}`,
+ },
+ {
+ name: "nok, json string escape: cookie - will not accept invalid chars",
+ // net/cookie.go: validCookieValueByte function allows these byte in cookie value
+ // only `0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\'`
+ givenConfig: LoggerConfig{Format: `{"cookie":"${cookie:session}"}`},
+ whenHeader: map[string]string{"Cookie": `_ga=GA1.2.000000000.0000000000; session=test\n`},
+ expect: `{"cookie":""}`,
+ },
+ {
+ name: "ok, format time_unix",
+ givenConfig: LoggerConfig{Format: `${time_unix}`},
+ whenStatusCode: http.StatusOK,
+ whenResponse: "test",
+ expect: `1588037200`,
+ },
+ {
+ name: "ok, format time_unix_milli",
+ givenConfig: LoggerConfig{Format: `${time_unix_milli}`},
+ whenStatusCode: http.StatusOK,
+ whenResponse: "test",
+ expect: `1588037200000`,
+ },
+ {
+ name: "ok, format time_unix_micro",
+ givenConfig: LoggerConfig{Format: `${time_unix_micro}`},
+ whenStatusCode: http.StatusOK,
+ whenResponse: "test",
+ expect: `1588037200000000`,
+ },
+ {
+ name: "ok, format time_unix_nano",
+ givenConfig: LoggerConfig{Format: `${time_unix_nano}`},
+ whenStatusCode: http.StatusOK,
+ whenResponse: "test",
+ expect: `1588037200000000000`,
+ },
+ {
+ name: "ok, format time_rfc3339",
+ givenConfig: LoggerConfig{Format: `${time_rfc3339}`},
+ whenStatusCode: http.StatusOK,
+ whenResponse: "test",
+ expect: `2020-04-28T01:26:40Z`,
+ },
+ {
+ name: "ok, status 200",
+ whenStatusCode: http.StatusOK,
+ whenResponse: "test",
+ whenReplacers: jsonLatency,
+ expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":200,"error":"","latency":1,"latency_human":"1Β΅s","bytes_in":0,"bytes_out":4}` + "\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ req := httptest.NewRequest(http.MethodGet, cmp.Or(tc.whenURI, "/"), nil)
+ if tc.whenFormValues != nil {
+ req = httptest.NewRequest(http.MethodGet, cmp.Or(tc.whenURI, "/"), strings.NewReader(tc.whenFormValues.Encode()))
+ req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
+ }
+
+ for k, v := range tc.whenHeader {
+ req.Header.Add(k, v)
+ }
+ if tc.whenHost != "" {
+ req.Host = tc.whenHost
+ }
+ if tc.whenMethod != "" {
+ req.Method = tc.whenMethod
+ }
+ if tc.whenProto != "" {
+ req.Proto = tc.whenProto
+ }
+ if tc.whenRequestURI != "" {
+ req.RequestURI = tc.whenRequestURI
+ }
+ if tc.whenPath != "" {
+ req.URL.Path = tc.whenPath
+ }
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ if tc.whenFormValues != nil {
+ c.FormValue("to trigger form parsing")
+ }
+ if tc.whenRoute != "" {
+ c.SetPath(tc.whenRoute)
+ }
+
+ config := tc.givenConfig
+ if config.timeNow == nil {
+ config.timeNow = func() time.Time { return time.Unix(1588037200, 0).UTC() }
+ }
+ buf := new(bytes.Buffer)
+ if config.Output == nil {
+ e.Logger.SetOutput(buf)
+ }
+
+ h := LoggerWithConfig(config)(func(c echo.Context) error {
+ if tc.whenError != nil {
+ return tc.whenError
+ }
+ return c.String(cmp.Or(tc.whenStatusCode, http.StatusOK), cmp.Or(tc.whenResponse, "test"))
+ })
+
+ err := h(c)
+ assert.NoError(t, err)
+
+ result := buf.String()
+
+ for replaceTo, replacer := range tc.whenReplacers {
+ result = replacer.ReplaceAllString(result, replaceTo)
+ }
+
+ assert.Equal(t, tc.expect, result)
+ })
+ }
+}
+
+func TestLoggerTemplate(t *testing.T) {
+ buf := new(bytes.Buffer)
+
+ e := echo.New()
+ e.Use(LoggerWithConfig(LoggerConfig{
+ Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` +
+ `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` +
+ `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "route":"${route}", "referer":"${referer}",` +
+ `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}", "protocol":"${protocol}"` +
+ `"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n",
+ Output: buf,
+ }))
+
+ e.GET("/users/:id", func(c echo.Context) error {
+ return c.String(http.StatusOK, "Header Logged")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/users/1?username=apagano-param&password=secret", nil)
+ req.RequestURI = "/"
+ req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
+ req.Header.Add("Referer", "google.com")
+ req.Header.Add("User-Agent", "echo-tests-agent")
+ req.Header.Add("X-Custom-Header", "AAA-CUSTOM-VALUE")
+ req.Header.Add("X-Request-ID", "6ba7b810-9dad-11d1-80b4-00c04fd430c8")
+ req.Header.Add("Cookie", "_ga=GA1.2.000000000.0000000000; session=ac08034cd216a647fc2eb62f2bcf7b810")
+ req.Form = url.Values{
+ "username": []string{"apagano-form"},
+ "password": []string{"secret-form"},
+ }
+
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+
+ cases := map[string]bool{
+ "apagano-param": true,
+ "apagano-form": true,
+ "AAA-CUSTOM-VALUE": true,
+ "BBB-CUSTOM-VALUE": false,
+ "secret-form": false,
+ "hexvalue": false,
+ "GET": true,
+ "127.0.0.1": true,
+ "\"path\":\"/users/1\"": true,
+ "\"route\":\"/users/:id\"": true,
+ "\"uri\":\"/\"": true,
+ "\"status\":200": true,
+ "\"bytes_in\":0": true,
+ "google.com": true,
+ "echo-tests-agent": true,
+ "6ba7b810-9dad-11d1-80b4-00c04fd430c8": true,
+ "ac08034cd216a647fc2eb62f2bcf7b810": true,
+ }
+
+ for token, present := range cases {
+ assert.True(t, strings.Contains(buf.String(), token) == present, "Case: "+token)
+ }
+}
+
+func TestLoggerCustomTimestamp(t *testing.T) {
+ buf := new(bytes.Buffer)
+ customTimeFormat := "2006-01-02 15:04:05.00000"
+ e := echo.New()
+ e.Use(LoggerWithConfig(LoggerConfig{
+ Format: `{"time":"${time_custom}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` +
+ `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` +
+ `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` +
+ `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}",` +
+ `"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n",
+ CustomTimeFormat: customTimeFormat,
+ Output: buf,
+ }))
+
+ e.GET("/", func(c echo.Context) error {
+ return c.String(http.StatusOK, "custom time stamp test")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+
+ var objs map[string]*json.RawMessage
+ if err := json.Unmarshal(buf.Bytes(), &objs); err != nil {
+ panic(err)
+ }
+ loggedTime := *(*string)(unsafe.Pointer(objs["time"]))
+ _, err := time.Parse(customTimeFormat, loggedTime)
+ assert.Error(t, err)
+}
+
+func TestLoggerCustomTagFunc(t *testing.T) {
+ e := echo.New()
+ buf := new(bytes.Buffer)
+ e.Use(LoggerWithConfig(LoggerConfig{
+ Format: `{"method":"${method}",${custom}}` + "\n",
+ CustomTagFunc: func(c echo.Context, buf *bytes.Buffer) (int, error) {
+ return buf.WriteString(`"tag":"my-value"`)
+ },
+ Output: buf,
+ }))
+
+ e.GET("/", func(c echo.Context) error {
+ return c.String(http.StatusOK, "custom time stamp test")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, `{"method":"GET","tag":"my-value"}`+"\n", buf.String())
+}
+
+func BenchmarkLoggerWithConfig_withoutMapFields(b *testing.B) {
+ e := echo.New()
+
+ buf := new(bytes.Buffer)
+ mw := LoggerWithConfig(LoggerConfig{
+ Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` +
+ `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` +
+ `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` +
+ `"bytes_out":${bytes_out}, "protocol":"${protocol}"}` + "\n",
+ Output: buf,
+ })(func(c echo.Context) error {
+ c.Request().Header.Set(echo.HeaderXRequestID, "123")
+ c.FormValue("to force parse form")
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ f := make(url.Values)
+ f.Set("csrf", "token")
+ f.Add("multiple", "1")
+ f.Add("multiple", "2")
+ req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", strings.NewReader(f.Encode()))
+ req.Header.Set("Referer", "https://echo.labstack.com/")
+ req.Header.Set("User-Agent", "curl/7.68.0")
+ req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ mw(c)
+ buf.Reset()
+ }
+}
+
+func BenchmarkLoggerWithConfig_withMapFields(b *testing.B) {
+ e := echo.New()
+
+ buf := new(bytes.Buffer)
+ mw := LoggerWithConfig(LoggerConfig{
+ Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` +
+ `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` +
+ `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` +
+ `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}", "protocol":"${protocol}"` +
+ `"us":"${query:username}", "cf":"${form:csrf}", "Referer2":"${header:Referer}"}` + "\n",
+ Output: buf,
+ })(func(c echo.Context) error {
+ c.Request().Header.Set(echo.HeaderXRequestID, "123")
+ c.FormValue("to force parse form")
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ f := make(url.Values)
+ f.Set("csrf", "token")
+ f.Add("multiple", "1")
+ f.Add("multiple", "2")
+ req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", strings.NewReader(f.Encode()))
+ req.Header.Set("Referer", "https://echo.labstack.com/")
+ req.Header.Set("User-Agent", "curl/7.68.0")
+ req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ mw(c)
+ buf.Reset()
+ }
+}
diff --git a/middleware/method_override.go b/middleware/method_override.go
index 25ec1f935..3991e1029 100644
--- a/middleware/method_override.go
+++ b/middleware/method_override.go
@@ -6,7 +6,7 @@ package middleware
import (
"net/http"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
)
// MethodOverrideConfig defines the config for MethodOverride middleware.
@@ -20,7 +20,7 @@ type MethodOverrideConfig struct {
}
// MethodOverrideGetter is a function that gets overridden method from the request
-type MethodOverrideGetter func(c *echo.Context) string
+type MethodOverrideGetter func(echo.Context) string
// DefaultMethodOverrideConfig is the default MethodOverride middleware config.
var DefaultMethodOverrideConfig = MethodOverrideConfig{
@@ -37,13 +37,9 @@ func MethodOverride() echo.MiddlewareFunc {
return MethodOverrideWithConfig(DefaultMethodOverrideConfig)
}
-// MethodOverrideWithConfig returns a Method Override middleware with config or panics on invalid configuration.
+// MethodOverrideWithConfig returns a MethodOverride middleware with config.
+// See: `MethodOverride()`.
func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
- return toMiddlewareOrPanic(config)
-}
-
-// ToMiddleware converts MethodOverrideConfig to middleware or returns an error for invalid configuration
-func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultMethodOverrideConfig.Skipper
@@ -53,7 +49,7 @@ func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) error {
+ return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
@@ -67,13 +63,13 @@ func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
}
return next(c)
}
- }, nil
+ }
}
// MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from
// the request header.
func MethodFromHeader(header string) MethodOverrideGetter {
- return func(c *echo.Context) string {
+ return func(c echo.Context) string {
return c.Request().Header.Get(header)
}
}
@@ -81,7 +77,7 @@ func MethodFromHeader(header string) MethodOverrideGetter {
// MethodFromForm is a `MethodOverrideGetter` that gets overridden method from the
// form parameter.
func MethodFromForm(param string) MethodOverrideGetter {
- return func(c *echo.Context) string {
+ return func(c echo.Context) string {
return c.FormValue(param)
}
}
@@ -89,7 +85,7 @@ func MethodFromForm(param string) MethodOverrideGetter {
// MethodFromQuery is a `MethodOverrideGetter` that gets overridden method from
// the query parameter.
func MethodFromQuery(param string) MethodOverrideGetter {
- return func(c *echo.Context) string {
+ return func(c echo.Context) string {
return c.QueryParam(param)
}
}
diff --git a/middleware/method_override_test.go b/middleware/method_override_test.go
index 525ad10ba..0000d1d80 100644
--- a/middleware/method_override_test.go
+++ b/middleware/method_override_test.go
@@ -9,14 +9,14 @@ import (
"net/http/httptest"
"testing"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func TestMethodOverride(t *testing.T) {
e := echo.New()
m := MethodOverride()
- h := func(c *echo.Context) error {
+ h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
@@ -25,68 +25,28 @@ func TestMethodOverride(t *testing.T) {
rec := httptest.NewRecorder()
req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
c := e.NewContext(req, rec)
-
- err := m(h)(c)
- assert.NoError(t, err)
-
+ m(h)(c)
assert.Equal(t, http.MethodDelete, req.Method)
-}
-
-func TestMethodOverride_formParam(t *testing.T) {
- e := echo.New()
- h := func(c *echo.Context) error {
- return c.String(http.StatusOK, "test")
- }
-
// Override with form parameter
- m, err := MethodOverrideConfig{Getter: MethodFromForm("_method")}.ToMiddleware()
- assert.NoError(t, err)
- req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete)))
- rec := httptest.NewRecorder()
+ m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")})
+ req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete)))
+ rec = httptest.NewRecorder()
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
- c := e.NewContext(req, rec)
-
- err = m(h)(c)
- assert.NoError(t, err)
-
+ c = e.NewContext(req, rec)
+ m(h)(c)
assert.Equal(t, http.MethodDelete, req.Method)
-}
-
-func TestMethodOverride_queryParam(t *testing.T) {
- e := echo.New()
- h := func(c *echo.Context) error {
- return c.String(http.StatusOK, "test")
- }
// Override with query parameter
- m, err := MethodOverrideConfig{Getter: MethodFromQuery("_method")}.ToMiddleware()
- assert.NoError(t, err)
- req := httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- err = m(h)(c)
- assert.NoError(t, err)
-
+ m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")})
+ req = httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil)
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+ m(h)(c)
assert.Equal(t, http.MethodDelete, req.Method)
-}
-
-func TestMethodOverride_ignoreGet(t *testing.T) {
- e := echo.New()
- m := MethodOverride()
- h := func(c *echo.Context) error {
- return c.String(http.StatusOK, "test")
- }
// Ignore `GET`
- req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- err := m(h)(c)
- assert.NoError(t, err)
-
assert.Equal(t, http.MethodGet, req.Method)
}
diff --git a/middleware/middleware.go b/middleware/middleware.go
index 4562d03b5..164e52b4c 100644
--- a/middleware/middleware.go
+++ b/middleware/middleware.go
@@ -9,14 +9,15 @@ import (
"strconv"
"strings"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
)
-// Skipper defines a function to skip middleware. Returning true skips processing the middleware.
-type Skipper func(c *echo.Context) bool
+// Skipper defines a function to skip middleware. Returning true skips processing
+// the middleware.
+type Skipper func(c echo.Context) bool
// BeforeFunc defines a function which is executed just before the middleware.
-type BeforeFunc func(c *echo.Context)
+type BeforeFunc func(c echo.Context)
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
groups := pattern.FindAllStringSubmatch(input, -1)
@@ -53,7 +54,7 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error
return nil
}
- // Depending how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path.
+ // Depending on how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path.
// We only want to use path part for rewriting and therefore trim prefix if it exists
rawURI := req.RequestURI
if rawURI != "" && rawURI[0] != '/' {
@@ -84,11 +85,13 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error
}
// DefaultSkipper returns false which processes the middleware.
-func DefaultSkipper(c *echo.Context) bool {
+func DefaultSkipper(echo.Context) bool {
return false
}
-func toMiddlewareOrPanic(config echo.MiddlewareConfigurator) echo.MiddlewareFunc {
+func toMiddlewareOrPanic(config interface {
+ ToMiddleware() (echo.MiddlewareFunc, error)
+}) echo.MiddlewareFunc {
mw, err := config.ToMiddleware()
if err != nil {
panic(err)
diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go
index 28407ed5c..7f3dc3866 100644
--- a/middleware/middleware_test.go
+++ b/middleware/middleware_test.go
@@ -102,9 +102,11 @@ type testResponseWriterNoFlushHijack struct {
func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) {
}
+
func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) {
return 0, nil
}
+
func (w *testResponseWriterNoFlushHijack) Header() http.Header {
return nil
}
@@ -116,12 +118,15 @@ type testResponseWriterUnwrapper struct {
func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) {
}
+
func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) {
return 0, nil
}
+
func (w *testResponseWriterUnwrapper) Header() http.Header {
return nil
}
+
func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter {
w.unwrapCalled++
return w.rw
diff --git a/middleware/proxy.go b/middleware/proxy.go
index 497aefea4..f26870077 100644
--- a/middleware/proxy.go
+++ b/middleware/proxy.go
@@ -6,10 +6,8 @@ package middleware
import (
"context"
"crypto/tls"
- "errors"
"fmt"
"io"
- "maps"
"math/rand"
"net"
"net/http"
@@ -20,7 +18,7 @@ import (
"sync"
"time"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
)
// TODO: Handle TLS proxy
@@ -43,14 +41,14 @@ type ProxyConfig struct {
// of previous retries is less than RetryCount. If the function returns true, the
// request will be retried. The provided error indicates the reason for the request
// failure. When the ProxyTarget is unavailable, the error will be an instance of
- // echo.HTTPError with a code of http.StatusBadGateway. In all other cases, the error
+ // echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error
// will indicate an internal error in the Proxy middleware. When a RetryFilter is not
// specified, all requests that fail with http.StatusBadGateway will be retried. A custom
// RetryFilter can be provided to only retry specific requests. Note that RetryFilter is
// only called when the request to the target fails, or an internal error in the Proxy
// middleware has occurred. Successful requests that return a non-200 response code cannot
// be retried.
- RetryFilter func(c *echo.Context, e error) bool
+ RetryFilter func(c echo.Context, e error) bool
// ErrorHandler defines a function which can be used to return custom errors from
// the Proxy middleware. ErrorHandler is only invoked when there has been
@@ -59,7 +57,7 @@ type ProxyConfig struct {
// when a ProxyTarget returns a non-200 response. In these cases, the response
// is already written so errors cannot be modified. ErrorHandler is only
// invoked after all retry attempts have been exhausted.
- ErrorHandler func(c *echo.Context, err error) error
+ ErrorHandler func(c echo.Context, err error) error
// Rewrite defines URL path rewrite rules. The values captured in asterisk can be
// retrieved by index e.g. $1, $2 and so on.
@@ -93,14 +91,20 @@ type ProxyConfig struct {
type ProxyTarget struct {
Name string
URL *url.URL
- Meta map[string]any
+ Meta echo.Map
}
// ProxyBalancer defines an interface to implement a load balancing technique.
type ProxyBalancer interface {
- AddTarget(target *ProxyTarget) bool
- RemoveTarget(targetName string) bool
- Next(c *echo.Context) (*ProxyTarget, error)
+ AddTarget(*ProxyTarget) bool
+ RemoveTarget(string) bool
+ Next(echo.Context) *ProxyTarget
+}
+
+// TargetProvider defines an interface that gives the opportunity for balancer
+// to return custom errors when selecting target.
+type TargetProvider interface {
+ NextTarget(echo.Context) (*ProxyTarget, error)
}
type commonBalancer struct {
@@ -127,7 +131,7 @@ var DefaultProxyConfig = ProxyConfig{
ContextKey: "target",
}
-func proxyRaw(c *echo.Context, t *ProxyTarget, config ProxyConfig) http.Handler {
+func proxyRaw(t *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler {
var dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
if transport, ok := config.Transport.(*http.Transport); ok {
if transport.TLSClientConfig != nil {
@@ -143,13 +147,12 @@ func proxyRaw(c *echo.Context, t *ProxyTarget, config ProxyConfig) http.Handler
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- in, _, err := http.NewResponseController(w).Hijack()
+ in, _, err := c.Response().Hijack()
if err != nil {
c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL))
return
}
defer in.Close()
-
out, err := dialFunc(c.Request().Context(), "tcp", t.URL.Host)
if err != nil {
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
@@ -189,9 +192,7 @@ func proxyRaw(c *echo.Context, t *ProxyTarget, config ProxyConfig) http.Handler
func NewRandomBalancer(targets []*ProxyTarget) ProxyBalancer {
b := randomBalancer{}
b.targets = targets
- // G404 (CWE-338): Use of weak random number generator (math/rand or math/rand/v2 instead of crypto/rand)
- // this random is used to select next target. I can not think of reason this must be cryptographically safe. If you can - please open PR.
- b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) // #nosec G404
+ b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
return &b
}
@@ -235,15 +236,15 @@ func (b *commonBalancer) RemoveTarget(name string) bool {
// Next randomly returns an upstream target.
//
// Note: `nil` is returned in case upstream target list is empty.
-func (b *randomBalancer) Next(c *echo.Context) (*ProxyTarget, error) {
+func (b *randomBalancer) Next(c echo.Context) *ProxyTarget {
b.mutex.Lock()
defer b.mutex.Unlock()
if len(b.targets) == 0 {
- return nil, nil
+ return nil
} else if len(b.targets) == 1 {
- return b.targets[0], nil
+ return b.targets[0]
}
- return b.targets[b.random.Intn(len(b.targets))], nil
+ return b.targets[b.random.Intn(len(b.targets))]
}
// Next returns an upstream target using round-robin technique. In the case
@@ -254,13 +255,13 @@ func (b *randomBalancer) Next(c *echo.Context) (*ProxyTarget, error) {
// return the original failed target.
//
// Note: `nil` is returned in case upstream target list is empty.
-func (b *roundRobinBalancer) Next(c *echo.Context) (*ProxyTarget, error) {
+func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget {
b.mutex.Lock()
defer b.mutex.Unlock()
if len(b.targets) == 0 {
- return nil, nil
+ return nil
} else if len(b.targets) == 1 {
- return b.targets[0], nil
+ return b.targets[0]
}
var i int
@@ -282,8 +283,9 @@ func (b *roundRobinBalancer) Next(c *echo.Context) (*ProxyTarget, error) {
i = b.i
b.i++
}
+
c.Set(lastIdxKey, i)
- return b.targets[i], nil
+ return b.targets[i]
}
// Proxy returns a Proxy middleware.
@@ -295,26 +297,18 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc {
return ProxyWithConfig(c)
}
-// ProxyWithConfig returns a Proxy middleware or panics if configuration is invalid.
-//
-// Proxy middleware forwards the request to upstream server using a configured load balancing technique.
+// ProxyWithConfig returns a Proxy middleware with config.
+// See: `Proxy()`
func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
- return toMiddlewareOrPanic(config)
-}
-
-// ToMiddleware converts ProxyConfig to middleware or returns an error for invalid configuration
-func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Balancer == nil {
+ panic("echo: proxy middleware requires balancer")
+ }
+ // Defaults
if config.Skipper == nil {
config.Skipper = DefaultProxyConfig.Skipper
}
- if config.ContextKey == "" {
- config.ContextKey = DefaultProxyConfig.ContextKey
- }
- if config.Balancer == nil {
- return nil, errors.New("echo proxy middleware requires balancer")
- }
if config.RetryFilter == nil {
- config.RetryFilter = func(c *echo.Context, e error) bool {
+ config.RetryFilter = func(c echo.Context, e error) bool {
if httpErr, ok := e.(*echo.HTTPError); ok {
return httpErr.Code == http.StatusBadGateway
}
@@ -322,20 +316,23 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
}
}
if config.ErrorHandler == nil {
- config.ErrorHandler = func(c *echo.Context, err error) error {
+ config.ErrorHandler = func(c echo.Context, err error) error {
return err
}
}
-
if config.Rewrite != nil {
if config.RegexRewrite == nil {
config.RegexRewrite = make(map[*regexp.Regexp]string)
}
- maps.Copy(config.RegexRewrite, rewriteRulesRegex(config.Rewrite))
+ for k, v := range rewriteRulesRegex(config.Rewrite) {
+ config.RegexRewrite[k] = v
+ }
}
+ provider, isTargetProvider := config.Balancer.(TargetProvider)
+
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) (err error) {
+ return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
@@ -355,24 +352,21 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if req.Header.Get(echo.HeaderXForwardedProto) == "" {
req.Header.Set(echo.HeaderXForwardedProto, c.Scheme())
}
- if c.IsWebSocket() { // For HTTP, this is set by Go HTTP reverse proxy.
- // Append, not set, to preserve the incoming chain from upstream proxies.
- prior := req.Header[echo.HeaderXForwardedFor]
- if len(prior) > 0 {
- req.Header.Set(echo.HeaderXForwardedFor, strings.Join(prior, ", ")+", "+c.RealIP())
- } else {
- req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
- }
+ if c.IsWebSocket() && req.Header.Get(echo.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy.
+ req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
}
retries := config.RetryCount
for {
- tgt, err := config.Balancer.Next(c)
- if err != nil {
- return config.ErrorHandler(c, err)
- }
- if tgt == nil || tgt.URL == nil {
- return config.ErrorHandler(c, echo.NewHTTPError(http.StatusBadGateway, "no proxy target available"))
+ var tgt *ProxyTarget
+ var err error
+ if isTargetProvider {
+ tgt, err = provider.NextTarget(c)
+ if err != nil {
+ return config.ErrorHandler(c, err)
+ }
+ } else {
+ tgt = config.Balancer.Next(c)
}
c.Set(config.ContextKey, tgt)
@@ -391,9 +385,9 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Proxy
switch {
case c.IsWebSocket():
- proxyRaw(c, tgt, config).ServeHTTP(res, req)
+ proxyRaw(tgt, c, config).ServeHTTP(res, req)
default: // even SSE requests
- proxyHTTP(c, tgt, config).ServeHTTP(res, req)
+ proxyHTTP(tgt, c, config).ServeHTTP(res, req)
}
err, hasError := c.Get("_error").(error)
@@ -409,7 +403,7 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
retries--
}
}
- }, nil
+ }
}
// StatusCodeContextCanceled is a custom HTTP status code for situations
@@ -419,7 +413,7 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// 499 too instead of the more problematic 5xx, which does not allow to detect this situation
const StatusCodeContextCanceled = 499
-func proxyHTTP(c *echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handler {
+func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler {
proxy := httputil.NewSingleHostReverseProxy(tgt.URL)
proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) {
desc := tgt.URL.String()
@@ -429,17 +423,15 @@ func proxyHTTP(c *echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handl
// If the client canceled the request (usually by closing the connection), we can report a
// client error (4xx) instead of a server error (5xx) to correctly identify the situation.
// The Go standard library (at of late 2020) wraps the exported, standard
- // context. Canceled error with unexported garbage value requiring a substring check, see
+ // context.Canceled error with unexported garbage value requiring a substring check, see
// https://github.com/golang/go/blob/6965b01ea248cabb70c3749fd218b36089a21efb/src/net/net.go#L416-L430
- // From Caddy https://github.com/caddyserver/caddy/blob/afa778ae05503f563af0d1015cdf7e5e78b1eeec/modules/caddyhttp/reverseproxy/reverseproxy.go#L1352
- if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "operation was canceled") {
- httpError := echo.NewHTTPError(StatusCodeContextCanceled, "client closed connection").Wrap(err)
+ if err == context.Canceled || strings.Contains(err.Error(), "operation was canceled") {
+ httpError := echo.NewHTTPError(StatusCodeContextCanceled, fmt.Sprintf("client closed connection: %v", err))
+ httpError.Internal = err
c.Set("_error", httpError)
} else {
- httpError := echo.NewHTTPError(
- http.StatusBadGateway,
- "remote server unreachable, could not proxy request",
- ).Wrap(fmt.Errorf("server: %s, err: %w", desc, err))
+ httpError := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err))
+ httpError.Internal = err
c.Set("_error", httpError)
}
}
diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go
index 5053f7945..dbf07648b 100644
--- a/middleware/proxy_test.go
+++ b/middleware/proxy_test.go
@@ -15,12 +15,11 @@ import (
"net/http/httptest"
"net/url"
"regexp"
- "strings"
"sync"
"testing"
"time"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"golang.org/x/net/websocket"
)
@@ -38,7 +37,6 @@ func TestProxy(t *testing.T) {
}))
defer t2.Close()
url2, _ := url.Parse(t2.URL)
-
targets := []*ProxyTarget{
{
Name: "target 1",
@@ -62,7 +60,7 @@ func TestProxy(t *testing.T) {
// Random
e := echo.New()
- e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
+ e.Use(Proxy(rb))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
@@ -84,7 +82,7 @@ func TestProxy(t *testing.T) {
// Round-robin
rrb := NewRoundRobinBalancer(targets)
e = echo.New()
- e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
+ e.Use(Proxy(rrb))
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
@@ -114,24 +112,68 @@ func TestProxy(t *testing.T) {
// ProxyTarget is set in context
contextObserver := func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) (err error) {
+ return func(c echo.Context) (err error) {
next(c)
assert.Contains(t, targets, c.Get("target"), "target is not set in context")
return nil
}
}
+ rrb1 := NewRoundRobinBalancer(targets)
e = echo.New()
e.Use(contextObserver)
- e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRoundRobinBalancer(targets)}))
+ e.Use(Proxy(rrb1))
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
}
-func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) {
- assert.Panics(t, func() {
- ProxyWithConfig(ProxyConfig{Balancer: nil})
- })
+type testProvider struct {
+ commonBalancer
+ target *ProxyTarget
+ err error
+}
+
+func (p *testProvider) Next(c echo.Context) *ProxyTarget {
+ return &ProxyTarget{}
+}
+
+func (p *testProvider) NextTarget(c echo.Context) (*ProxyTarget, error) {
+ return p.target, p.err
+}
+
+func TestTargetProvider(t *testing.T) {
+ t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, "target 1")
+ }))
+ defer t1.Close()
+ url1, _ := url.Parse(t1.URL)
+
+ e := echo.New()
+ tp := &testProvider{}
+ tp.target = &ProxyTarget{Name: "target 1", URL: url1}
+ e.Use(Proxy(tp))
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ e.ServeHTTP(rec, req)
+ body := rec.Body.String()
+ assert.Equal(t, "target 1", body)
+}
+
+func TestFailNextTarget(t *testing.T) {
+ url1, err := url.Parse("http://dummy:8080")
+ assert.Nil(t, err)
+
+ e := echo.New()
+ tp := &testProvider{}
+ tp.target = &ProxyTarget{Name: "target 1", URL: url1}
+ tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target")
+
+ e.Use(Proxy(tp))
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ e.ServeHTTP(rec, req)
+ body := rec.Body.String()
+ assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body)
}
func TestProxyRealIPHeader(t *testing.T) {
@@ -141,7 +183,7 @@ func TestProxyRealIPHeader(t *testing.T) {
url, _ := url.Parse(upstream.URL)
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}})
e := echo.New()
- e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
+ e.Use(Proxy(rrb))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
@@ -346,7 +388,7 @@ func TestProxyError(t *testing.T) {
// Random
e := echo.New()
- e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
+ e.Use(Proxy(rb))
req := httptest.NewRequest(http.MethodGet, "/", nil)
// Remote unreachable
@@ -357,202 +399,8 @@ func TestProxyError(t *testing.T) {
assert.Equal(t, http.StatusBadGateway, rec.Code)
}
-func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) {
- var timeoutStop sync.WaitGroup
- timeoutStop.Add(1)
- HTTPTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- timeoutStop.Wait() // wait until we have canceled the request
- w.WriteHeader(http.StatusOK)
- }))
- defer HTTPTarget.Close()
- targetURL, _ := url.Parse(HTTPTarget.URL)
- target := &ProxyTarget{
- Name: "target",
- URL: targetURL,
- }
- rb := NewRandomBalancer(nil)
- assert.True(t, rb.AddTarget(target))
- e := echo.New()
- e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- ctx, cancel := context.WithCancel(req.Context())
- req = req.WithContext(ctx)
- go func() {
- time.Sleep(10 * time.Millisecond)
- cancel()
- }()
- e.ServeHTTP(rec, req)
- timeoutStop.Done()
- assert.Equal(t, 499, rec.Code)
-}
-
-type testProvider struct {
- commonBalancer
- target *ProxyTarget
- err error
-}
-
-func (p *testProvider) Next(c *echo.Context) (*ProxyTarget, error) {
- return p.target, p.err
-}
-
-func TestTargetProvider(t *testing.T) {
- t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- fmt.Fprint(w, "target 1")
- }))
- defer t1.Close()
- url1, _ := url.Parse(t1.URL)
-
- e := echo.New()
- tp := &testProvider{}
- tp.target = &ProxyTarget{Name: "target 1", URL: url1}
- e.Use(Proxy(tp))
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- e.ServeHTTP(rec, req)
- body := rec.Body.String()
- assert.Equal(t, "target 1", body)
-}
-
-func TestFailNextTarget(t *testing.T) {
- url1, err := url.Parse("http://dummy:8080")
- assert.Nil(t, err)
-
- e := echo.New()
- tp := &testProvider{}
- tp.target = &ProxyTarget{Name: "target 1", URL: url1}
- tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target")
-
- e.Use(Proxy(tp))
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- e.ServeHTTP(rec, req)
- body := rec.Body.String()
- assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body)
-}
-
-func TestRandomBalancerWithNoTargets(t *testing.T) {
- e := echo.New()
- req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- // Assert balancer with empty targets does return `nil` on `Next()`
- rb := NewRandomBalancer(nil)
- target, err := rb.Next(c)
- assert.Nil(t, target)
- assert.NoError(t, err)
-}
-
-func TestRoundRobinBalancerWithNoTargets(t *testing.T) {
- // Assert balancer with empty targets does return `nil` on `Next()`
- rrb := NewRoundRobinBalancer([]*ProxyTarget{})
- e := echo.New()
- req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- target, err := rrb.Next(c)
- assert.Nil(t, target)
- assert.NoError(t, err)
-}
-
-func TestProxyWithNoTargetReturnsBadGateway(t *testing.T) {
- targetURL, _ := url.Parse("http://127.0.0.1:8080")
- target := &ProxyTarget{Name: "target", URL: targetURL}
- emptyAfterRemove := NewRoundRobinBalancer([]*ProxyTarget{target})
- assert.True(t, emptyAfterRemove.RemoveTarget("target"))
-
- testCases := []struct {
- name string
- balancer ProxyBalancer
- }{
- {
- name: "random balancer with nil targets",
- balancer: NewRandomBalancer(nil),
- },
- {
- name: "round-robin balancer with nil targets",
- balancer: NewRoundRobinBalancer(nil),
- },
- {
- name: "round-robin balancer after removing last target",
- balancer: emptyAfterRemove,
- },
- {
- name: "custom balancer with nil target",
- balancer: &customBalancer{},
- },
- {
- name: "custom balancer with nil target URL",
- balancer: &customBalancer{target: &ProxyTarget{Name: "target"}},
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := echo.New()
- errorHandlerCalled := false
- e.Use(ProxyWithConfig(ProxyConfig{
- Balancer: tc.balancer,
- ErrorHandler: func(c *echo.Context, err error) error {
- errorHandlerCalled = true
- httpErr, ok := err.(*echo.HTTPError)
- assert.True(t, ok, "expected http error to be passed to handler")
- assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler")
- return err
- },
- }))
-
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
-
- assert.NotPanics(t, func() {
- e.ServeHTTP(rec, req)
- })
- assert.True(t, errorHandlerCalled)
- assert.Equal(t, http.StatusBadGateway, rec.Code)
- })
- }
-}
-
-func TestProxyWithNoTargetDoesNotRetry(t *testing.T) {
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- }))
- defer server.Close()
- targetURL, _ := url.Parse(server.URL)
-
- balancer := &sequenceBalancer{
- targets: []*ProxyTarget{
- nil,
- {Name: "target", URL: targetURL},
- },
- }
-
- retryFilterCalled := false
- e := echo.New()
- e.Use(ProxyWithConfig(ProxyConfig{
- Balancer: balancer,
- RetryCount: 1,
- RetryFilter: func(c *echo.Context, err error) bool {
- retryFilterCalled = true
- return true
- },
- }))
-
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
-
- e.ServeHTTP(rec, req)
-
- assert.False(t, retryFilterCalled)
- assert.Equal(t, 1, balancer.calls)
- assert.Equal(t, http.StatusBadGateway, rec.Code)
-}
-
func TestProxyRetries(t *testing.T) {
+
newServer := func(res int) (*url.URL, *httptest.Server) {
server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -583,13 +431,13 @@ func TestProxyRetries(t *testing.T) {
URL: targetURL,
}
- alwaysRetryFilter := func(c *echo.Context, e error) bool { return true }
- neverRetryFilter := func(c *echo.Context, e error) bool { return false }
+ alwaysRetryFilter := func(c echo.Context, e error) bool { return true }
+ neverRetryFilter := func(c echo.Context, e error) bool { return false }
testCases := []struct {
name string
retryCount int
- retryFilters []func(c *echo.Context, e error) bool
+ retryFilters []func(c echo.Context, e error) bool
targets []*ProxyTarget
expectedResponse int
}{
@@ -612,7 +460,7 @@ func TestProxyRetries(t *testing.T) {
{
name: "retry count 1 does retry on handler return true",
retryCount: 1,
- retryFilters: []func(c *echo.Context, e error) bool{
+ retryFilters: []func(c echo.Context, e error) bool{
alwaysRetryFilter,
},
targets: []*ProxyTarget{
@@ -624,7 +472,7 @@ func TestProxyRetries(t *testing.T) {
{
name: "retry count 1 does not retry on handler return false",
retryCount: 1,
- retryFilters: []func(c *echo.Context, e error) bool{
+ retryFilters: []func(c echo.Context, e error) bool{
neverRetryFilter,
},
targets: []*ProxyTarget{
@@ -636,7 +484,7 @@ func TestProxyRetries(t *testing.T) {
{
name: "retry count 2 returns error when no more retries left",
retryCount: 2,
- retryFilters: []func(c *echo.Context, e error) bool{
+ retryFilters: []func(c echo.Context, e error) bool{
alwaysRetryFilter,
alwaysRetryFilter,
},
@@ -651,7 +499,7 @@ func TestProxyRetries(t *testing.T) {
{
name: "retry count 2 returns error when retries left but handler returns false",
retryCount: 3,
- retryFilters: []func(c *echo.Context, e error) bool{
+ retryFilters: []func(c echo.Context, e error) bool{
alwaysRetryFilter,
alwaysRetryFilter,
neverRetryFilter,
@@ -667,7 +515,7 @@ func TestProxyRetries(t *testing.T) {
{
name: "retry count 3 succeeds",
retryCount: 3,
- retryFilters: []func(c *echo.Context, e error) bool{
+ retryFilters: []func(c echo.Context, e error) bool{
alwaysRetryFilter,
alwaysRetryFilter,
alwaysRetryFilter,
@@ -695,7 +543,7 @@ func TestProxyRetries(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
retryFilterCall := 0
- retryFilter := func(c *echo.Context, e error) bool {
+ retryFilter := func(c echo.Context, e error) bool {
if len(tc.retryFilters) == 0 {
assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall))
}
@@ -771,13 +619,15 @@ func TestProxyRetryWithBackendTimeout(t *testing.T) {
))
var wg sync.WaitGroup
- for range 20 {
- wg.Go(func() {
+ for i := 0; i < 20; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, 200, rec.Code)
- })
+ }()
}
wg.Wait()
@@ -808,13 +658,13 @@ func TestProxyErrorHandler(t *testing.T) {
testCases := []struct {
name string
target *ProxyTarget
- errorHandler func(c *echo.Context, e error) error
+ errorHandler func(c echo.Context, e error) error
expectFinalError func(t *testing.T, err error)
}{
{
name: "Error handler not invoked when request success",
target: goodTarget,
- errorHandler: func(c *echo.Context, e error) error {
+ errorHandler: func(c echo.Context, e error) error {
assert.FailNow(t, "error handler should not be invoked")
return e
},
@@ -822,7 +672,7 @@ func TestProxyErrorHandler(t *testing.T) {
{
name: "Error handler invoked when request fails",
target: badTarget,
- errorHandler: func(c *echo.Context, e error) error {
+ errorHandler: func(c echo.Context, e error) error {
httpErr, ok := e.(*echo.HTTPError)
assert.True(t, ok, "expected http error to be passed to handler")
assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler")
@@ -845,11 +695,10 @@ func TestProxyErrorHandler(t *testing.T) {
))
errorHandlerCalled := false
- dheh := echo.DefaultHTTPErrorHandler(false)
- e.HTTPErrorHandler = func(c *echo.Context, err error) {
+ e.HTTPErrorHandler = func(err error, c echo.Context) {
errorHandlerCalled = true
tc.expectFinalError(t, err)
- dheh(c, err)
+ e.DefaultHTTPErrorHandler(err, c)
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
@@ -865,7 +714,47 @@ func TestProxyErrorHandler(t *testing.T) {
}
}
+func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) {
+ var timeoutStop sync.WaitGroup
+ timeoutStop.Add(1)
+ HTTPTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ timeoutStop.Wait() // wait until we have canceled the request
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer HTTPTarget.Close()
+ targetURL, _ := url.Parse(HTTPTarget.URL)
+ target := &ProxyTarget{
+ Name: "target",
+ URL: targetURL,
+ }
+ rb := NewRandomBalancer(nil)
+ assert.True(t, rb.AddTarget(target))
+ e := echo.New()
+ e.Use(Proxy(rb))
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ ctx, cancel := context.WithCancel(req.Context())
+ req = req.WithContext(ctx)
+ go func() {
+ time.Sleep(10 * time.Millisecond)
+ cancel()
+ }()
+ e.ServeHTTP(rec, req)
+ timeoutStop.Done()
+ assert.Equal(t, 499, rec.Code)
+}
+
+// Assert balancer with empty targets does return `nil` on `Next()`
+func TestProxyBalancerWithNoTargets(t *testing.T) {
+ rb := NewRandomBalancer(nil)
+ assert.Nil(t, rb.Next(nil))
+
+ rrb := NewRoundRobinBalancer([]*ProxyTarget{})
+ assert.Nil(t, rrb.Next(nil))
+}
+
type testContextKey string
+
type customBalancer struct {
target *ProxyTarget
}
@@ -873,33 +762,15 @@ type customBalancer struct {
func (b *customBalancer) AddTarget(target *ProxyTarget) bool {
return false
}
+
func (b *customBalancer) RemoveTarget(name string) bool {
return false
}
-func (b *customBalancer) Next(c *echo.Context) (*ProxyTarget, error) {
+func (b *customBalancer) Next(c echo.Context) *ProxyTarget {
ctx := context.WithValue(c.Request().Context(), testContextKey("FROM_BALANCER"), "CUSTOM_BALANCER")
c.SetRequest(c.Request().WithContext(ctx))
- return b.target, nil
-}
-
-type sequenceBalancer struct {
- targets []*ProxyTarget
- calls int
-}
-
-func (b *sequenceBalancer) AddTarget(target *ProxyTarget) bool {
- return false
-}
-
-func (b *sequenceBalancer) RemoveTarget(name string) bool {
- return false
-}
-
-func (b *sequenceBalancer) Next(c *echo.Context) (*ProxyTarget, error) {
- target := b.targets[b.calls]
- b.calls++
- return target, nil
+ return b.target
}
func TestModifyResponseUseContext(t *testing.T) {
@@ -910,6 +781,7 @@ func TestModifyResponseUseContext(t *testing.T) {
}),
)
defer server.Close()
+
targetURL, _ := url.Parse(server.URL)
e := echo.New()
e.Use(ProxyWithConfig(
@@ -930,9 +802,12 @@ func TestModifyResponseUseContext(t *testing.T) {
},
},
))
+
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
+
e.ServeHTTP(rec, req)
+
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "OK", rec.Body.String())
assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER"))
@@ -1165,108 +1040,3 @@ func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, sendMsg, recvMsg)
}
-
-// TestProxyWebSocketXForwardedFor verifies that for WebSocket Upgrade requests,
-// the proxy middleware appends c.RealIP() to any existing X-Forwarded-For chain,
-// mirroring net/http/httputil.(*ProxyRequest).SetXForwarded used by the HTTP path.
-//
-// Regression guard for the previous "set only if empty" behavior, which dropped
-// the proxy's own peer IP from the chain whenever upstream proxies had already
-// added entries.
-func TestProxyWebSocketXForwardedFor(t *testing.T) {
- tests := []struct {
- name string
- incomingXFF []string // nil = no incoming X-Forwarded-For header at all
- wantPrefix string // expected join of entries preceding the appended proxy RealIP
- }{
- {
- name: "no incoming XFF, only proxy RealIP is set",
- incomingXFF: nil,
- wantPrefix: "",
- },
- {
- name: "single-line single-entry XFF is preserved with proxy RealIP appended",
- incomingXFF: []string{"203.0.113.1"},
- wantPrefix: "203.0.113.1",
- },
- {
- name: "single-line comma-separated XFF is preserved with proxy RealIP appended",
- incomingXFF: []string{"203.0.113.1, 10.0.0.5"},
- wantPrefix: "203.0.113.1, 10.0.0.5",
- },
- {
- name: "multi-line XFF (multiple header occurrences) is joined with proxy RealIP appended",
- incomingXFF: []string{"203.0.113.1", "10.0.0.5"},
- wantPrefix: "203.0.113.1, 10.0.0.5",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Buffered so the upstream handler never blocks before the client reads.
- headerCh := make(chan http.Header, 1)
-
- upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- wsHandler := func(conn *websocket.Conn) {
- headerCh <- conn.Request().Header.Clone()
- defer conn.Close()
- var msg string
- if err := websocket.Message.Receive(conn, &msg); err == nil {
- _ = websocket.Message.Send(conn, msg)
- }
- }
- websocket.Server{Handler: wsHandler}.ServeHTTP(w, r)
- }))
- defer upstream.Close()
-
- tgtURL, _ := url.Parse(upstream.URL)
- e := echo.New()
- e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})}))
- proxySrv := httptest.NewServer(e)
- defer proxySrv.Close()
-
- proxyWSURL, _ := url.Parse(proxySrv.URL)
- proxyWSURL.Scheme = "ws"
-
- origin, _ := url.Parse(proxySrv.URL)
- cfg := &websocket.Config{
- Location: proxyWSURL,
- Origin: origin,
- Version: websocket.ProtocolVersionHybi13,
- Header: http.Header{},
- }
- for _, v := range tt.incomingXFF {
- cfg.Header.Add(echo.HeaderXForwardedFor, v)
- }
-
- wsConn, err := websocket.DialConfig(cfg)
- assert.NoError(t, err)
- defer wsConn.Close()
-
- assert.NoError(t, websocket.Message.Send(wsConn, "ping"))
- var got string
- assert.NoError(t, websocket.Message.Receive(wsConn, &got))
-
- // The handler sends to headerCh before echoing, so it arrives before Receive returns.
- captured := <-headerCh
- xff := captured.Get(echo.HeaderXForwardedFor)
-
- // The middleware uses Header.Set, so the upstream sees exactly one
- // X-Forwarded-For header line. Split it back into entries.
- entries := strings.Split(xff, ", ")
- assert.NotEmpty(t, entries, "X-Forwarded-For must be set by the proxy middleware")
-
- // The tail entry is the proxy's c.RealIP(). When the test client dials
- // via httptest.NewServer the proxy sees 127.0.0.1.
- tail := entries[len(entries)-1]
- assert.Equal(t, "127.0.0.1", tail,
- "proxy RealIP must be appended at the tail of X-Forwarded-For")
-
- // The remaining entries must equal the prior chain, preserving order
- // and joining multi-line headers with ", ".
- gotPrefix := strings.Join(entries[:len(entries)-1], ", ")
- assert.Equal(t, tt.wantPrefix, gotPrefix,
- "prior X-Forwarded-For entries must be preserved before the appended RealIP")
- })
- }
-}
diff --git a/middleware/randomstring_bench_test.go b/middleware/randomstring_bench_test.go
deleted file mode 100644
index d37d3dfd4..000000000
--- a/middleware/randomstring_bench_test.go
+++ /dev/null
@@ -1,50 +0,0 @@
-// SPDX-License-Identifier: MIT
-// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
-
-package middleware
-
-import (
- "bufio"
- "io"
- "testing"
-)
-
-// randomStringUnpooled is the previous (pre-pooling) implementation, kept here only to A/B benchmark
-// against the current pooled randomString.
-func randomStringUnpooled(length uint8) string {
- reader := randomReaderPool.Get().(*bufio.Reader)
- defer randomReaderPool.Put(reader)
-
- b := make([]byte, length)
- r := make([]byte, length+(length/4))
- var i uint8 = 0
- for {
- if _, err := io.ReadFull(reader, r); err != nil {
- panic("unexpected error reading from crypto/rand")
- }
- for _, rb := range r {
- if rb > randomStringMaxByte {
- continue
- }
- b[i] = randomStringCharset[rb%randomStringCharsetLen]
- i++
- if i == length {
- return string(b)
- }
- }
- }
-}
-
-func BenchmarkRandomString_Unpooled(b *testing.B) {
- b.ReportAllocs()
- for i := 0; i < b.N; i++ {
- _ = randomStringUnpooled(32)
- }
-}
-
-func BenchmarkRandomString_Pooled(b *testing.B) {
- b.ReportAllocs()
- for i := 0; i < b.N; i++ {
- _ = randomString(32)
- }
-}
diff --git a/middleware/randomstring_concurrent_test.go b/middleware/randomstring_concurrent_test.go
deleted file mode 100644
index c1efc31a5..000000000
--- a/middleware/randomstring_concurrent_test.go
+++ /dev/null
@@ -1,37 +0,0 @@
-// SPDX-License-Identifier: MIT
-// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
-
-package middleware
-
-import (
- "strings"
- "sync"
- "testing"
-)
-
-// TestRandomStringConcurrent guards the pooled scratch buffers in randomString: concurrent callers
-// must not share/alias a buffer and corrupt each other's output. Run with -race.
-func TestRandomStringConcurrent(t *testing.T) {
- const goroutines, iterations = 100, 300
- var wg sync.WaitGroup
- wg.Add(goroutines)
- for g := 0; g < goroutines; g++ {
- go func() {
- defer wg.Done()
- for i := 0; i < iterations; i++ {
- s := randomString(32)
- if len(s) != 32 {
- t.Errorf("expected length 32, got %d (%q)", len(s), s)
- return
- }
- for _, r := range s {
- if !strings.ContainsRune(randomStringCharset, r) {
- t.Errorf("char %q not in charset (%q)", r, s)
- return
- }
- }
- }
- }()
- }
- wg.Wait()
-}
diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go
index 9756daf50..2746a3de1 100644
--- a/middleware/rate_limiter.go
+++ b/middleware/rate_limiter.go
@@ -4,52 +4,37 @@
package middleware
import (
- "errors"
"math"
"net/http"
- "strconv"
"sync"
"time"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
"golang.org/x/time/rate"
)
-// Rate limit response headers set by stores that implement RateLimiterStoreContext.
-const (
- HeaderXRateLimitLimit = "X-RateLimit-Limit"
- HeaderXRateLimitRemaining = "X-RateLimit-Remaining"
-)
-
// RateLimiterStore is the interface to be implemented by custom stores.
type RateLimiterStore interface {
+ // Stores for the rate limiter have to implement the Allow method
Allow(identifier string) (bool, error)
}
-// RateLimiterStoreContext is an optional interface a RateLimiterStore may implement.
-// When the configured store implements it, the rate limiter calls AllowContext
-// (with the request context) instead of Allow, allowing the store to set response
-// headers such as Retry-After or X-RateLimit-* on the allow/deny decision.
-type RateLimiterStoreContext interface {
- AllowContext(c *echo.Context, identifier string) (bool, error)
-}
-
// RateLimiterConfig defines the configuration for the rate limiter
type RateLimiterConfig struct {
Skipper Skipper
BeforeFunc BeforeFunc
- // IdentifierExtractor uses *echo.Context to extract the identifier for a visitor
+ // IdentifierExtractor uses echo.Context to extract the identifier for a visitor
IdentifierExtractor Extractor
// Store defines a store for the rate limiter
Store RateLimiterStore
// ErrorHandler provides a handler to be called when IdentifierExtractor returns an error
- ErrorHandler func(c *echo.Context, err error) error
+ ErrorHandler func(context echo.Context, err error) error
// DenyHandler provides a handler to be called when RateLimiter denies access
- DenyHandler func(c *echo.Context, identifier string, err error) error
+ DenyHandler func(context echo.Context, identifier string, err error) error
}
-// Extractor is used to extract data from *echo.Context
-type Extractor func(c *echo.Context) (string, error)
+// Extractor is used to extract data from echo.Context
+type Extractor func(context echo.Context) (string, error)
// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
@@ -60,15 +45,23 @@ var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while ext
// DefaultRateLimiterConfig defines default values for RateLimiterConfig
var DefaultRateLimiterConfig = RateLimiterConfig{
Skipper: DefaultSkipper,
- IdentifierExtractor: func(ctx *echo.Context) (string, error) {
+ IdentifierExtractor: func(ctx echo.Context) (string, error) {
id := ctx.RealIP()
return id, nil
},
- ErrorHandler: func(c *echo.Context, err error) error {
- return ErrExtractorError.Wrap(err)
+ ErrorHandler: func(context echo.Context, err error) error {
+ return &echo.HTTPError{
+ Code: ErrExtractorError.Code,
+ Message: ErrExtractorError.Message,
+ Internal: err,
+ }
},
- DenyHandler: func(c *echo.Context, identifier string, err error) error {
- return ErrRateLimitExceeded.Wrap(err)
+ DenyHandler: func(context echo.Context, identifier string, err error) error {
+ return &echo.HTTPError{
+ Code: ErrRateLimitExceeded.Code,
+ Message: ErrRateLimitExceeded.Message,
+ Internal: err,
+ }
},
}
@@ -79,7 +72,7 @@ RateLimiter returns a rate limiting middleware
limiterStore := middleware.NewRateLimiterMemoryStore(20)
- e.GET("/rate-limited", func(c *echo.Context) error {
+ e.GET("/rate-limited", func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}, RateLimiter(limiterStore))
*/
@@ -100,28 +93,23 @@ RateLimiterWithConfig returns a rate limiting middleware
Store: middleware.NewRateLimiterMemoryStore(
middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute}
)
- IdentifierExtractor: func(ctx *echo.Context) (string, error) {
+ IdentifierExtractor: func(ctx echo.Context) (string, error) {
id := ctx.RealIP()
return id, nil
},
- ErrorHandler: func(ctx *echo.Context, err error) error {
+ ErrorHandler: func(context echo.Context, err error) error {
return context.JSON(http.StatusTooManyRequests, nil)
},
- DenyHandler: func(ctx *echo.Context, identifier string, err error) error {
+ DenyHandler: func(context echo.Context, identifier string) error {
return context.JSON(http.StatusForbidden, nil)
},
}
- e.GET("/rate-limited", func(c *echo.Context) error {
+ e.GET("/rate-limited", func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}, middleware.RateLimiterWithConfig(config))
*/
func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
- return toMiddlewareOrPanic(config)
-}
-
-// ToMiddleware converts RateLimiterConfig to middleware or returns an error for invalid configuration
-func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultRateLimiterConfig.Skipper
}
@@ -135,10 +123,10 @@ func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
config.DenyHandler = DefaultRateLimiterConfig.DenyHandler
}
if config.Store == nil {
- return nil, errors.New("echo rate limiter store configuration must be provided")
+ panic("Store configuration must be provided")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) error {
+ return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
@@ -148,29 +136,25 @@ func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
identifier, err := config.IdentifierExtractor(c)
if err != nil {
- return config.ErrorHandler(c, err)
+ c.Error(config.ErrorHandler(c, err))
+ return nil
}
- var allow bool
- var allowErr error
- if sc, ok := config.Store.(RateLimiterStoreContext); ok {
- allow, allowErr = sc.AllowContext(c, identifier)
- } else {
- allow, allowErr = config.Store.Allow(identifier)
- }
- if !allow {
- return config.DenyHandler(c, identifier, allowErr)
+ if allow, err := config.Store.Allow(identifier); !allow {
+ c.Error(config.DenyHandler(c, identifier, err))
+ return nil
}
return next(c)
}
- }, nil
+ }
}
// RateLimiterMemoryStore is the built-in store implementation for RateLimiter
type RateLimiterMemoryStore struct {
- visitors map[string]*Visitor
- mutex sync.Mutex
- rate float64 // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit
+ visitors map[string]*Visitor
+ mutex sync.Mutex
+ rate rate.Limit // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
+
burst int
expiresIn time.Duration
lastCleanup time.Time
@@ -191,21 +175,21 @@ for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate
Burst and ExpiresIn will be set to default values.
-Note that if the provided rate is a float number and Burst is zero, Burst will be treated as the rounded up value of the rate.
+Note that if the provided rate is a float number and Burst is zero, Burst will be treated as the rounded down value of the rate.
Example (with 20 requests/sec):
limiterStore := middleware.NewRateLimiterMemoryStore(20)
*/
-func NewRateLimiterMemoryStore(rateLimit float64) (store *RateLimiterMemoryStore) {
+func NewRateLimiterMemoryStore(rate rate.Limit) (store *RateLimiterMemoryStore) {
return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
- Rate: rateLimit,
+ Rate: rate,
})
}
/*
NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore
-with the provided configuration. Rate must be provided. Burst will be set to the rounded up value of
+with the provided configuration. Rate must be provided. Burst will be set to the rounded down value of
the configured rate if not provided or set to 0.
The built-in memory store is usually capable for modest loads. For higher loads other
@@ -242,7 +226,7 @@ func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (s
// RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore
type RateLimiterMemoryStoreConfig struct {
- Rate float64 // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
+ Rate rate.Limit // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
Burst int // Burst is maximum number of requests to pass at the same moment. It additionally allows a number of requests to pass when rate limit is reached.
ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up
}
@@ -254,64 +238,32 @@ var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{
// Allow implements RateLimiterStore.Allow
func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) {
- _, allowed := store.allow(identifier)
- return allowed, nil
-}
-
-// AllowContext implements RateLimiterStoreContext: it makes the allow/deny decision
-// and sets the X-RateLimit-* (and Retry-After when denied) response headers.
-func (store *RateLimiterMemoryStore) AllowContext(c *echo.Context, identifier string) (bool, error) {
- limiter, allowed := store.allow(identifier)
- store.setRateLimitHeaders(c, limiter, allowed)
- return allowed, nil
-}
-
-func (store *RateLimiterMemoryStore) allow(identifier string) (*rate.Limiter, bool) {
store.mutex.Lock()
- defer store.mutex.Unlock()
-
limiter, exists := store.visitors[identifier]
if !exists {
limiter = new(Visitor)
- limiter.Limiter = rate.NewLimiter(rate.Limit(store.rate), store.burst)
+ limiter.Limiter = rate.NewLimiter(store.rate, store.burst)
store.visitors[identifier] = limiter
}
now := store.timeNow()
limiter.lastSeen = now
if now.Sub(store.lastCleanup) > store.expiresIn {
- store.cleanupStaleVisitors(now)
- }
- return limiter.Limiter, limiter.AllowN(now, 1)
-}
-
-func (store *RateLimiterMemoryStore) setRateLimitHeaders(c *echo.Context, limiter *rate.Limiter, allowed bool) {
- header := c.Response().Header()
- header.Set(HeaderXRateLimitLimit, strconv.Itoa(store.burst))
-
- remaining := int(math.Floor(limiter.Tokens()))
- if remaining < 0 {
- remaining = 0
- }
- header.Set(HeaderXRateLimitRemaining, strconv.Itoa(remaining))
-
- if !allowed {
- reservation := limiter.ReserveN(store.timeNow(), 1)
- if delay := reservation.Delay(); delay > 0 {
- header.Set(echo.HeaderRetryAfter, strconv.Itoa(int(math.Ceil(delay.Seconds()))))
- }
- reservation.Cancel()
+ store.cleanupStaleVisitors()
}
+ allowed := limiter.AllowN(now, 1)
+ store.mutex.Unlock()
+ return allowed, nil
}
/*
cleanupStaleVisitors helps manage the size of the visitors map by removing stale records
of users who haven't visited again after the configured expiry time has elapsed
*/
-func (store *RateLimiterMemoryStore) cleanupStaleVisitors(now time.Time) {
+func (store *RateLimiterMemoryStore) cleanupStaleVisitors() {
for id, visitor := range store.visitors {
- if now.Sub(visitor.lastSeen) > store.expiresIn {
+ if store.timeNow().Sub(visitor.lastSeen) > store.expiresIn {
delete(store.visitors, id)
}
}
- store.lastCleanup = now
+ store.lastCleanup = store.timeNow()
}
diff --git a/middleware/rate_limiter_context_test.go b/middleware/rate_limiter_context_test.go
deleted file mode 100644
index 629c01e47..000000000
--- a/middleware/rate_limiter_context_test.go
+++ /dev/null
@@ -1,89 +0,0 @@
-// SPDX-License-Identifier: MIT
-// SPDX-FileCopyrightText: Β© 2015 LabStack LLC and Echo contributors
-
-package middleware
-
-import (
- "net/http"
- "net/http/httptest"
- "strconv"
- "testing"
-
- "github.com/labstack/echo/v5"
- "github.com/stretchr/testify/assert"
-)
-
-// ctxAwareStore implements both Allow and the optional AllowContext. AllowContext
-// gives the store the request context so it can set response headers (e.g.
-// Retry-After / X-RateLimit-*) β see #2961.
-type ctxAwareStore struct {
- allowCalled bool
- ctxAllowCalled bool
- allow bool
-}
-
-func (s *ctxAwareStore) Allow(identifier string) (bool, error) {
- s.allowCalled = true
- return s.allow, nil
-}
-
-func (s *ctxAwareStore) AllowContext(c *echo.Context, identifier string) (bool, error) {
- s.ctxAllowCalled = true
- c.Response().Header().Set("Retry-After", "42")
- return s.allow, nil
-}
-
-// When the store implements AllowContext, the middleware must call it instead of
-// Allow, so the store can set rate-limit headers on the response.
-func TestRateLimiter_storeAllowContextIsPreferred(t *testing.T) {
- e := echo.New()
- store := &ctxAwareStore{allow: true}
- mw := RateLimiterWithConfig(RateLimiterConfig{
- Store: store,
- IdentifierExtractor: func(c *echo.Context) (string, error) { return "id", nil },
- })
- handler := mw(func(c *echo.Context) error { return c.String(http.StatusOK, "ok") })
-
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- assert.NoError(t, handler(c))
- assert.True(t, store.ctxAllowCalled, "AllowContext should be called when implemented")
- assert.False(t, store.allowCalled, "Allow should not be called when AllowContext is implemented")
- assert.Equal(t, "42", rec.Header().Get("Retry-After"), "store should be able to set headers via the context")
-}
-
-// The built-in memory store implements AllowContext, so it sets X-RateLimit-Limit /
-// X-RateLimit-Remaining on every request and Retry-After when the limit is hit (#2961).
-func TestRateLimiterMemoryStore_AllowContextSetsHeaders(t *testing.T) {
- store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
- e := echo.New()
- e.GET("/", func(c *echo.Context) error { return c.String(http.StatusOK, "ok") },
- RateLimiterWithConfig(RateLimiterConfig{
- Store: store,
- IdentifierExtractor: func(c *echo.Context) (string, error) { return "id", nil },
- }))
-
- do := func() *httptest.ResponseRecorder {
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- e.ServeHTTP(rec, req)
- return rec
- }
-
- // Burst of 3: each allowed request advertises the limit and decreasing remaining.
- for i := 0; i < 3; i++ {
- rec := do()
- assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, "3", rec.Header().Get(HeaderXRateLimitLimit))
- assert.Equal(t, strconv.Itoa(2-i), rec.Header().Get(HeaderXRateLimitRemaining))
- assert.Empty(t, rec.Header().Get(echo.HeaderRetryAfter))
- }
-
- // 4th request is denied: 429, remaining 0, and a Retry-After hint.
- rec := do()
- assert.Equal(t, http.StatusTooManyRequests, rec.Code)
- assert.Equal(t, "0", rec.Header().Get(HeaderXRateLimitRemaining))
- assert.NotEmpty(t, rec.Header().Get(echo.HeaderRetryAfter))
-}
diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go
index 267e8d08f..655d4731d 100644
--- a/middleware/rate_limiter_test.go
+++ b/middleware/rate_limiter_test.go
@@ -13,7 +13,7 @@ import (
"testing"
"time"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"golang.org/x/time/rate"
)
@@ -21,25 +21,25 @@ import (
func TestRateLimiter(t *testing.T) {
e := echo.New()
- handler := func(c *echo.Context) error {
+ handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
- mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
+ mw := RateLimiter(inMemoryStore)
testCases := []struct {
- id string
- expectErr string
+ id string
+ code int
}{
- {id: "127.0.0.1"},
- {id: "127.0.0.1"},
- {id: "127.0.0.1"},
- {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
- {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
- {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
- {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {"127.0.0.1", http.StatusOK},
+ {"127.0.0.1", http.StatusOK},
+ {"127.0.0.1", http.StatusOK},
+ {"127.0.0.1", http.StatusTooManyRequests},
+ {"127.0.0.1", http.StatusTooManyRequests},
+ {"127.0.0.1", http.StatusTooManyRequests},
+ {"127.0.0.1", http.StatusTooManyRequests},
}
for _, tc := range testCases {
@@ -49,25 +49,20 @@ func TestRateLimiter(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- err := mw(handler)(c)
- if tc.expectErr != "" {
- assert.EqualError(t, err, tc.expectErr)
- } else {
- assert.NoError(t, err)
- }
- assert.Equal(t, http.StatusOK, rec.Code)
+ _ = mw(handler)(c)
+ assert.Equal(t, tc.code, rec.Code)
}
}
-func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) {
+func TestRateLimiter_panicBehaviour(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
assert.Panics(t, func() {
- RateLimiterWithConfig(RateLimiterConfig{})
+ RateLimiter(nil)
})
assert.NotPanics(t, func() {
- RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
+ RateLimiter(inMemoryStore)
})
}
@@ -76,27 +71,26 @@ func TestRateLimiterWithConfig(t *testing.T) {
e := echo.New()
- handler := func(c *echo.Context) error {
+ handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
- mw, err := RateLimiterConfig{
- IdentifierExtractor: func(c *echo.Context) (string, error) {
+ mw := RateLimiterWithConfig(RateLimiterConfig{
+ IdentifierExtractor: func(c echo.Context) (string, error) {
id := c.Request().Header.Get(echo.HeaderXRealIP)
if id == "" {
return "", errors.New("invalid identifier")
}
return id, nil
},
- DenyHandler: func(ctx *echo.Context, identifier string, err error) error {
+ DenyHandler: func(ctx echo.Context, identifier string, err error) error {
return ctx.JSON(http.StatusForbidden, nil)
},
- ErrorHandler: func(ctx *echo.Context, err error) error {
+ ErrorHandler: func(ctx echo.Context, err error) error {
return ctx.JSON(http.StatusBadRequest, nil)
},
Store: inMemoryStore,
- }.ToMiddleware()
- assert.NoError(t, err)
+ })
testCases := []struct {
id string
@@ -119,9 +113,8 @@ func TestRateLimiterWithConfig(t *testing.T) {
c := e.NewContext(req, rec)
- err := mw(handler)(c)
+ _ = mw(handler)(c)
- assert.NoError(t, err)
assert.Equal(t, tc.code, rec.Code)
}
}
@@ -131,12 +124,12 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
e := echo.New()
- handler := func(c *echo.Context) error {
+ handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
- mw, err := RateLimiterConfig{
- IdentifierExtractor: func(c *echo.Context) (string, error) {
+ mw := RateLimiterWithConfig(RateLimiterConfig{
+ IdentifierExtractor: func(c echo.Context) (string, error) {
id := c.Request().Header.Get(echo.HeaderXRealIP)
if id == "" {
return "", errors.New("invalid identifier")
@@ -144,20 +137,19 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
return id, nil
},
Store: inMemoryStore,
- }.ToMiddleware()
- assert.NoError(t, err)
+ })
testCases := []struct {
- id string
- expectErr string
+ id string
+ code int
}{
- {id: "127.0.0.1"},
- {id: "127.0.0.1"},
- {id: "127.0.0.1"},
- {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
- {expectErr: "code=403, message=error while extracting identifier, err=invalid identifier"},
- {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
- {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {"127.0.0.1", http.StatusOK},
+ {"127.0.0.1", http.StatusOK},
+ {"127.0.0.1", http.StatusOK},
+ {"127.0.0.1", http.StatusTooManyRequests},
+ {"", http.StatusForbidden},
+ {"127.0.0.1", http.StatusTooManyRequests},
+ {"127.0.0.1", http.StatusTooManyRequests},
}
for _, tc := range testCases {
@@ -168,13 +160,9 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
c := e.NewContext(req, rec)
- err := mw(handler)(c)
- if tc.expectErr != "" {
- assert.EqualError(t, err, tc.expectErr)
- } else {
- assert.NoError(t, err)
- }
- assert.Equal(t, http.StatusOK, rec.Code)
+ _ = mw(handler)(c)
+
+ assert.Equal(t, tc.code, rec.Code)
}
}
@@ -184,26 +172,25 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) {
e := echo.New()
- handler := func(c *echo.Context) error {
+ handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
- mw, err := RateLimiterConfig{
+ mw := RateLimiterWithConfig(RateLimiterConfig{
Store: inMemoryStore,
- }.ToMiddleware()
- assert.NoError(t, err)
+ })
testCases := []struct {
- id string
- expectErr string
+ id string
+ code int
}{
- {id: "127.0.0.1"},
- {id: "127.0.0.1"},
- {id: "127.0.0.1"},
- {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
- {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
- {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
- {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {"127.0.0.1", http.StatusOK},
+ {"127.0.0.1", http.StatusOK},
+ {"127.0.0.1", http.StatusOK},
+ {"127.0.0.1", http.StatusTooManyRequests},
+ {"127.0.0.1", http.StatusTooManyRequests},
+ {"127.0.0.1", http.StatusTooManyRequests},
+ {"127.0.0.1", http.StatusTooManyRequests},
}
for _, tc := range testCases {
@@ -214,13 +201,9 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) {
c := e.NewContext(req, rec)
- err := mw(handler)(c)
- if tc.expectErr != "" {
- assert.EqualError(t, err, tc.expectErr)
- } else {
- assert.NoError(t, err)
- }
- assert.Equal(t, http.StatusOK, rec.Code)
+ _ = mw(handler)(c)
+
+ assert.Equal(t, tc.code, rec.Code)
}
}
}
@@ -229,7 +212,7 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) {
e := echo.New()
var beforeFuncRan bool
- handler := func(c *echo.Context) error {
+ handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
var inMemoryStore = NewRateLimiterMemoryStore(5)
@@ -241,23 +224,21 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) {
c := e.NewContext(req, rec)
- mw, err := RateLimiterConfig{
- Skipper: func(c *echo.Context) bool {
+ mw := RateLimiterWithConfig(RateLimiterConfig{
+ Skipper: func(c echo.Context) bool {
return true
},
- BeforeFunc: func(c *echo.Context) {
+ BeforeFunc: func(c echo.Context) {
beforeFuncRan = true
},
Store: inMemoryStore,
- IdentifierExtractor: func(ctx *echo.Context) (string, error) {
+ IdentifierExtractor: func(ctx echo.Context) (string, error) {
return "127.0.0.1", nil
},
- }.ToMiddleware()
- assert.NoError(t, err)
+ })
- err = mw(handler)(c)
+ _ = mw(handler)(c)
- assert.NoError(t, err)
assert.Equal(t, false, beforeFuncRan)
}
@@ -265,7 +246,7 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
e := echo.New()
var beforeFuncRan bool
- handler := func(c *echo.Context) error {
+ handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
var inMemoryStore = NewRateLimiterMemoryStore(5)
@@ -277,19 +258,18 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
c := e.NewContext(req, rec)
- mw, err := RateLimiterConfig{
- Skipper: func(c *echo.Context) bool {
+ mw := RateLimiterWithConfig(RateLimiterConfig{
+ Skipper: func(c echo.Context) bool {
return false
},
- BeforeFunc: func(c *echo.Context) {
+ BeforeFunc: func(c echo.Context) {
beforeFuncRan = true
},
Store: inMemoryStore,
- IdentifierExtractor: func(ctx *echo.Context) (string, error) {
+ IdentifierExtractor: func(ctx echo.Context) (string, error) {
return "127.0.0.1", nil
},
- }.ToMiddleware()
- assert.NoError(t, err)
+ })
_ = mw(handler)(c)
@@ -299,7 +279,7 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
func TestRateLimiterWithConfig_beforeFunc(t *testing.T) {
e := echo.New()
- handler := func(c *echo.Context) error {
+ handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
@@ -313,20 +293,18 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) {
c := e.NewContext(req, rec)
- mw, err := RateLimiterConfig{
- BeforeFunc: func(c *echo.Context) {
+ mw := RateLimiterWithConfig(RateLimiterConfig{
+ BeforeFunc: func(c echo.Context) {
beforeRan = true
},
Store: inMemoryStore,
- IdentifierExtractor: func(ctx *echo.Context) (string, error) {
+ IdentifierExtractor: func(ctx echo.Context) (string, error) {
return "127.0.0.1", nil
},
- }.ToMiddleware()
- assert.NoError(t, err)
+ })
- err = mw(handler)(c)
+ _ = mw(handler)(c)
- assert.NoError(t, err)
assert.Equal(t, true, beforeRan)
}
@@ -394,7 +372,7 @@ func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) {
}
inMemoryStore.Allow("D")
- inMemoryStore.cleanupStaleVisitors(time.Now())
+ inMemoryStore.cleanupStaleVisitors()
var exists bool
@@ -413,7 +391,7 @@ func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) {
func TestNewRateLimiterMemoryStore(t *testing.T) {
testCases := []struct {
- rate float64
+ rate rate.Limit
burst int
expiresIn time.Duration
expectedExpiresIn time.Duration
@@ -461,7 +439,7 @@ func TestRateLimiterMemoryStore_FractionalRateDefaultBurst(t *testing.T) {
func generateAddressList(count int) []string {
addrs := make([]string, count)
- for i := range count {
+ for i := 0; i < count; i++ {
addrs[i] = randomString(15)
}
return addrs
@@ -477,7 +455,7 @@ func run(wg *sync.WaitGroup, store RateLimiterStore, addrs []string, max int, b
func benchmarkStore(store RateLimiterStore, parallel int, max int, b *testing.B) {
addrs := generateAddressList(max)
wg := &sync.WaitGroup{}
- for range parallel {
+ for i := 0; i < parallel; i++ {
wg.Add(1)
go run(wg, store, addrs, max, b)
}
@@ -553,9 +531,11 @@ func TestRateLimiterMemoryStore_ConcurrentAccess(t *testing.T) {
var wg sync.WaitGroup
var allowedCount, deniedCount int32
- for range goroutines {
- wg.Go(func() {
- for range requestsPerGoroutine {
+ for i := 0; i < goroutines; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for j := 0; j < requestsPerGoroutine; j++ {
allowed, err := store.Allow("test-user")
assert.NoError(t, err)
if allowed {
@@ -565,7 +545,7 @@ func TestRateLimiterMemoryStore_ConcurrentAccess(t *testing.T) {
}
time.Sleep(time.Millisecond)
}
- })
+ }()
}
wg.Wait()
@@ -596,11 +576,11 @@ func TestRateLimiterMemoryStore_RaceDetection(t *testing.T) {
var wg sync.WaitGroup
identifiers := []string{"user1", "user2", "user3", "user4", "user5"}
- for i := range goroutines {
+ for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(routineID int) {
defer wg.Done()
- for range requestsPerGoroutine {
+ for j := 0; j < requestsPerGoroutine; j++ {
identifier := identifiers[routineID%len(identifiers)]
_, err := store.Allow(identifier)
assert.NoError(t, err)
diff --git a/middleware/recover.go b/middleware/recover.go
index 01fde5152..e6a5940e4 100644
--- a/middleware/recover.go
+++ b/middleware/recover.go
@@ -8,9 +8,13 @@ import (
"net/http"
"runtime"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
+ "github.com/labstack/gommon/log"
)
+// LogErrorFunc defines a function for custom logging in the middleware.
+type LogErrorFunc func(c echo.Context, err error, stack []byte) error
+
// RecoverConfig defines the config for Recover middleware.
type RecoverConfig struct {
// Skipper defines a function to skip middleware.
@@ -18,24 +22,41 @@ type RecoverConfig struct {
// Size of the stack to be printed.
// Optional. Default value 4KB.
- StackSize int
+ StackSize int `yaml:"stack_size"`
// DisableStackAll disables formatting stack traces of all other goroutines
// into buffer after the trace for the current goroutine.
// Optional. Default value false.
- DisableStackAll bool
+ DisableStackAll bool `yaml:"disable_stack_all"`
// DisablePrintStack disables printing stack trace.
// Optional. Default value as false.
- DisablePrintStack bool
+ DisablePrintStack bool `yaml:"disable_print_stack"`
+
+ // LogLevel is log level to printing stack trace.
+ // Optional. Default value 0 (Print).
+ LogLevel log.Lvl
+
+ // LogErrorFunc defines a function for custom logging in the middleware.
+ // If it's set you don't need to provide LogLevel for config.
+ // If this function returns nil, the centralized HTTPErrorHandler will not be called.
+ LogErrorFunc LogErrorFunc
+
+ // DisableErrorHandler disables the call to centralized HTTPErrorHandler.
+ // The recovered error is then passed back to upstream middleware, instead of swallowing the error.
+ // Optional. Default value false.
+ DisableErrorHandler bool `yaml:"disable_error_handler"`
}
// DefaultRecoverConfig is the default Recover middleware config.
var DefaultRecoverConfig = RecoverConfig{
- Skipper: DefaultSkipper,
- StackSize: 4 << 10, // 4 KB
- DisableStackAll: false,
- DisablePrintStack: false,
+ Skipper: DefaultSkipper,
+ StackSize: 4 << 10, // 4 KB
+ DisableStackAll: false,
+ DisablePrintStack: false,
+ LogLevel: 0,
+ LogErrorFunc: nil,
+ DisableErrorHandler: false,
}
// Recover returns a middleware which recovers from panics anywhere in the chain
@@ -44,13 +65,9 @@ func Recover() echo.MiddlewareFunc {
return RecoverWithConfig(DefaultRecoverConfig)
}
-// RecoverWithConfig returns a Recovery middleware with config or panics on invalid configuration.
+// RecoverWithConfig returns a Recover middleware with config.
+// See: `Recover()`.
func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
- return toMiddlewareOrPanic(config)
-}
-
-// ToMiddleware converts RecoverConfig to middleware or returns an error for invalid configuration
-func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultRecoverConfig.Skipper
@@ -60,7 +77,7 @@ func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) (err error) {
+ return func(c echo.Context) (returnErr error) {
if config.Skipper(c) {
return next(c)
}
@@ -70,34 +87,47 @@ func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if r == http.ErrAbortHandler {
panic(r)
}
- tmpErr, ok := r.(error)
+ err, ok := r.(error)
if !ok {
- tmpErr = fmt.Errorf("%v", r)
+ err = fmt.Errorf("%v", r)
}
+ var stack []byte
+ var length int
+
if !config.DisablePrintStack {
- stack := make([]byte, config.StackSize)
- length := runtime.Stack(stack, !config.DisableStackAll)
- tmpErr = &PanicStackError{Stack: stack[:length], Err: tmpErr}
+ stack = make([]byte, config.StackSize)
+ length = runtime.Stack(stack, !config.DisableStackAll)
+ stack = stack[:length]
+ }
+
+ if config.LogErrorFunc != nil {
+ err = config.LogErrorFunc(c, err, stack)
+ } else if !config.DisablePrintStack {
+ msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length])
+ switch config.LogLevel {
+ case log.DEBUG:
+ c.Logger().Debug(msg)
+ case log.INFO:
+ c.Logger().Info(msg)
+ case log.WARN:
+ c.Logger().Warn(msg)
+ case log.ERROR:
+ c.Logger().Error(msg)
+ case log.OFF:
+ // None.
+ default:
+ c.Logger().Print(msg)
+ }
+ }
+
+ if err != nil && !config.DisableErrorHandler {
+ c.Error(err)
+ } else {
+ returnErr = err
}
- err = tmpErr
}
}()
return next(c)
}
- }, nil
-}
-
-// PanicStackError is an error type that wraps an error along with its stack trace.
-// It is returned when config.DisablePrintStack is set to false.
-type PanicStackError struct {
- Stack []byte
- Err error
-}
-
-func (e *PanicStackError) Error() string {
- return fmt.Sprintf("[PANIC RECOVER] %s %s", e.Err.Error(), e.Stack)
-}
-
-func (e *PanicStackError) Unwrap() error {
- return e.Err
+ }
}
diff --git a/middleware/recover_test.go b/middleware/recover_test.go
index 719e0cc3d..8fa34fa5c 100644
--- a/middleware/recover_test.go
+++ b/middleware/recover_test.go
@@ -6,72 +6,42 @@ package middleware
import (
"bytes"
"errors"
- "log/slog"
+ "fmt"
"net/http"
"net/http/httptest"
"testing"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
+ "github.com/labstack/gommon/log"
"github.com/stretchr/testify/assert"
)
func TestRecover(t *testing.T) {
e := echo.New()
buf := new(bytes.Buffer)
- e.Logger = slog.New(&discardHandler{})
+ e.Logger.SetOutput(buf)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- h := Recover()(func(c *echo.Context) error {
+ h := Recover()(echo.HandlerFunc(func(c echo.Context) error {
panic("test")
- })
+ }))
err := h(c)
- assert.Contains(t, err.Error(), "[PANIC RECOVER] test goroutine")
-
- var pse *PanicStackError
- if errors.As(err, &pse) {
- assert.Contains(t, string(pse.Stack), "middleware/recover.go")
- } else {
- assert.Fail(t, "not of type PanicStackError")
- }
-
- assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
- assert.Contains(t, buf.String(), "") // nothing is logged
-}
-
-func TestRecover_skipper(t *testing.T) {
- e := echo.New()
-
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- config := RecoverConfig{
- Skipper: func(c *echo.Context) bool {
- return true
- },
- }
- h := RecoverWithConfig(config)(func(c *echo.Context) error {
- panic("testPANIC")
- })
-
- var err error
- assert.Panics(t, func() {
- err = h(c)
- })
-
assert.NoError(t, err)
- assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
+ assert.Equal(t, http.StatusInternalServerError, rec.Code)
+ assert.Contains(t, buf.String(), "PANIC RECOVER")
}
func TestRecoverErrAbortHandler(t *testing.T) {
e := echo.New()
+ buf := new(bytes.Buffer)
+ e.Logger.SetOutput(buf)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- h := Recover()(func(c *echo.Context) error {
+ h := Recover()(echo.HandlerFunc(func(c echo.Context) error {
panic(http.ErrAbortHandler)
- })
+ }))
defer func() {
r := recover()
if r == nil {
@@ -85,66 +55,135 @@ func TestRecoverErrAbortHandler(t *testing.T) {
}
}()
- hErr := h(c)
+ h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
- assert.NotContains(t, hErr.Error(), "PANIC RECOVER")
+ assert.NotContains(t, buf.String(), "PANIC RECOVER")
}
-func TestRecoverWithConfig(t *testing.T) {
- var testCases = []struct {
- name string
- givenNoPanic bool
- whenConfig RecoverConfig
- expectErrContain string
- expectErr string
- }{
- {
- name: "ok, default config",
- whenConfig: DefaultRecoverConfig,
- expectErrContain: "[PANIC RECOVER] testPANIC goroutine",
- },
- {
- name: "ok, no panic",
- givenNoPanic: true,
- whenConfig: DefaultRecoverConfig,
- expectErrContain: "",
- },
- {
- name: "ok, DisablePrintStack",
- whenConfig: RecoverConfig{
- DisablePrintStack: true,
- },
- expectErr: "testPANIC",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
+func TestRecoverWithConfig_LogLevel(t *testing.T) {
+ tests := []struct {
+ logLevel log.Lvl
+ levelName string
+ }{{
+ logLevel: log.DEBUG,
+ levelName: "DEBUG",
+ }, {
+ logLevel: log.INFO,
+ levelName: "INFO",
+ }, {
+ logLevel: log.WARN,
+ levelName: "WARN",
+ }, {
+ logLevel: log.ERROR,
+ levelName: "ERROR",
+ }, {
+ logLevel: log.OFF,
+ levelName: "OFF",
+ }}
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.levelName, func(t *testing.T) {
e := echo.New()
+ e.Logger.SetLevel(log.DEBUG)
+
+ buf := new(bytes.Buffer)
+ e.Logger.SetOutput(buf)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- config := tc.whenConfig
- h := RecoverWithConfig(config)(func(c *echo.Context) error {
- if tc.givenNoPanic {
- return nil
- }
- panic("testPANIC")
- })
+ config := DefaultRecoverConfig
+ config.LogLevel = tt.logLevel
+ h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
+ panic("test")
+ }))
- err := h(c)
+ h(c)
- if tc.expectErrContain != "" {
- assert.Contains(t, err.Error(), tc.expectErrContain)
- } else if tc.expectErr != "" {
- assert.Contains(t, err.Error(), tc.expectErr)
+ assert.Equal(t, http.StatusInternalServerError, rec.Code)
+
+ output := buf.String()
+ if tt.logLevel == log.OFF {
+ assert.Empty(t, output)
} else {
- assert.NoError(t, err)
+ assert.Contains(t, output, "PANIC RECOVER")
+ assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName))
}
- assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
})
}
}
+
+func TestRecoverWithConfig_LogErrorFunc(t *testing.T) {
+ e := echo.New()
+ e.Logger.SetLevel(log.DEBUG)
+
+ buf := new(bytes.Buffer)
+ e.Logger.SetOutput(buf)
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ testError := errors.New("test")
+ config := DefaultRecoverConfig
+ config.LogErrorFunc = func(c echo.Context, err error, stack []byte) error {
+ msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack)
+ if errors.Is(err, testError) {
+ c.Logger().Debug(msg)
+ } else {
+ c.Logger().Error(msg)
+ }
+ return err
+ }
+
+ t.Run("first branch case for LogErrorFunc", func(t *testing.T) {
+ buf.Reset()
+ h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
+ panic(testError)
+ }))
+
+ h(c)
+ assert.Equal(t, http.StatusInternalServerError, rec.Code)
+
+ output := buf.String()
+ assert.Contains(t, output, "PANIC RECOVER")
+ assert.Contains(t, output, `"level":"DEBUG"`)
+ })
+
+ t.Run("else branch case for LogErrorFunc", func(t *testing.T) {
+ buf.Reset()
+ h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
+ panic("other")
+ }))
+
+ h(c)
+ assert.Equal(t, http.StatusInternalServerError, rec.Code)
+
+ output := buf.String()
+ assert.Contains(t, output, "PANIC RECOVER")
+ assert.Contains(t, output, `"level":"ERROR"`)
+ })
+}
+
+func TestRecoverWithDisabled_ErrorHandler(t *testing.T) {
+ e := echo.New()
+ buf := new(bytes.Buffer)
+ e.Logger.SetOutput(buf)
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ config := DefaultRecoverConfig
+ config.DisableErrorHandler = true
+ h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
+ panic("test")
+ }))
+ err := h(c)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Contains(t, buf.String(), "PANIC RECOVER")
+ assert.EqualError(t, err, "test")
+}
diff --git a/middleware/redirect.go b/middleware/redirect.go
index bb7045cfe..b772ac131 100644
--- a/middleware/redirect.go
+++ b/middleware/redirect.go
@@ -4,11 +4,10 @@
package middleware
import (
- "errors"
"net/http"
"strings"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
)
// RedirectConfig defines the config for Redirect middleware.
@@ -18,9 +17,7 @@ type RedirectConfig struct {
// Status code to be used when redirecting the request.
// Optional. Default value http.StatusMovedPermanently.
- Code int
-
- redirect redirectLogic
+ Code int `yaml:"code"`
}
// redirectLogic represents a function that given a scheme, host and uri
@@ -30,33 +27,29 @@ type redirectLogic func(scheme, host, uri string) (ok bool, url string)
const www = "www."
-// RedirectHTTPSConfig is the HTTPS Redirect middleware config.
-var RedirectHTTPSConfig = RedirectConfig{redirect: redirectHTTPS}
-
-// RedirectHTTPSWWWConfig is the HTTPS WWW Redirect middleware config.
-var RedirectHTTPSWWWConfig = RedirectConfig{redirect: redirectHTTPSWWW}
-
-// RedirectNonHTTPSWWWConfig is the non HTTPS WWW Redirect middleware config.
-var RedirectNonHTTPSWWWConfig = RedirectConfig{redirect: redirectNonHTTPSWWW}
-
-// RedirectWWWConfig is the WWW Redirect middleware config.
-var RedirectWWWConfig = RedirectConfig{redirect: redirectWWW}
-
-// RedirectNonWWWConfig is the non WWW Redirect middleware config.
-var RedirectNonWWWConfig = RedirectConfig{redirect: redirectNonWWW}
+// DefaultRedirectConfig is the default Redirect middleware config.
+var DefaultRedirectConfig = RedirectConfig{
+ Skipper: DefaultSkipper,
+ Code: http.StatusMovedPermanently,
+}
// HTTPSRedirect redirects http requests to https.
// For example, http://labstack.com will be redirect to https://labstack.com.
//
// Usage `Echo#Pre(HTTPSRedirect())`
func HTTPSRedirect() echo.MiddlewareFunc {
- return HTTPSRedirectWithConfig(RedirectHTTPSConfig)
+ return HTTPSRedirectWithConfig(DefaultRedirectConfig)
}
-// HTTPSRedirectWithConfig returns a HTTPS redirect middleware with config or panics on invalid configuration.
+// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config.
+// See `HTTPSRedirect()`.
func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
- config.redirect = redirectHTTPS
- return toMiddlewareOrPanic(config)
+ return redirect(config, func(scheme, host, uri string) (bool, string) {
+ if scheme != "https" {
+ return true, "https://" + host + uri
+ }
+ return false, ""
+ })
}
// HTTPSWWWRedirect redirects http requests to https www.
@@ -64,13 +57,18 @@ func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
//
// Usage `Echo#Pre(HTTPSWWWRedirect())`
func HTTPSWWWRedirect() echo.MiddlewareFunc {
- return HTTPSWWWRedirectWithConfig(RedirectHTTPSWWWConfig)
+ return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig)
}
-// HTTPSWWWRedirectWithConfig returns a HTTPS WWW redirect middleware with config or panics on invalid configuration.
+// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
+// See `HTTPSWWWRedirect()`.
func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
- config.redirect = redirectHTTPSWWW
- return toMiddlewareOrPanic(config)
+ return redirect(config, func(scheme, host, uri string) (bool, string) {
+ if scheme != "https" && !strings.HasPrefix(host, www) {
+ return true, "https://www." + host + uri
+ }
+ return false, ""
+ })
}
// HTTPSNonWWWRedirect redirects http requests to https non www.
@@ -78,13 +76,19 @@ func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
//
// Usage `Echo#Pre(HTTPSNonWWWRedirect())`
func HTTPSNonWWWRedirect() echo.MiddlewareFunc {
- return HTTPSNonWWWRedirectWithConfig(RedirectNonHTTPSWWWConfig)
+ return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig)
}
-// HTTPSNonWWWRedirectWithConfig returns a HTTPS Non-WWW redirect middleware with config or panics on invalid configuration.
+// HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
+// See `HTTPSNonWWWRedirect()`.
func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
- config.redirect = redirectNonHTTPSWWW
- return toMiddlewareOrPanic(config)
+ return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
+ if scheme != "https" {
+ host = strings.TrimPrefix(host, www)
+ return true, "https://" + host + uri
+ }
+ return false, ""
+ })
}
// WWWRedirect redirects non www requests to www.
@@ -92,13 +96,18 @@ func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
//
// Usage `Echo#Pre(WWWRedirect())`
func WWWRedirect() echo.MiddlewareFunc {
- return WWWRedirectWithConfig(RedirectWWWConfig)
+ return WWWRedirectWithConfig(DefaultRedirectConfig)
}
-// WWWRedirectWithConfig returns a WWW redirect middleware with config or panics on invalid configuration.
+// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
+// See `WWWRedirect()`.
func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
- config.redirect = redirectWWW
- return toMiddlewareOrPanic(config)
+ return redirect(config, func(scheme, host, uri string) (bool, string) {
+ if !strings.HasPrefix(host, www) {
+ return true, scheme + "://www." + host + uri
+ }
+ return false, ""
+ })
}
// NonWWWRedirect redirects www requests to non www.
@@ -106,79 +115,41 @@ func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
//
// Usage `Echo#Pre(NonWWWRedirect())`
func NonWWWRedirect() echo.MiddlewareFunc {
- return NonWWWRedirectWithConfig(RedirectNonWWWConfig)
+ return NonWWWRedirectWithConfig(DefaultRedirectConfig)
}
-// NonWWWRedirectWithConfig returns a Non-WWW redirect middleware with config or panics on invalid configuration.
+// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
+// See `NonWWWRedirect()`.
func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
- config.redirect = redirectNonWWW
- return toMiddlewareOrPanic(config)
+ return redirect(config, func(scheme, host, uri string) (bool, string) {
+ if strings.HasPrefix(host, www) {
+ return true, scheme + "://" + host[4:] + uri
+ }
+ return false, ""
+ })
}
-// ToMiddleware converts RedirectConfig to middleware or returns an error for invalid configuration
-func (config RedirectConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc {
if config.Skipper == nil {
- config.Skipper = DefaultSkipper
+ config.Skipper = DefaultRedirectConfig.Skipper
}
if config.Code == 0 {
- config.Code = http.StatusMovedPermanently
- }
- if config.redirect == nil {
- return nil, errors.New("redirectConfig is missing redirect function")
+ config.Code = DefaultRedirectConfig.Code
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) error {
+ return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req, scheme := c.Request(), c.Scheme()
host := req.Host
- if ok, url := config.redirect(scheme, host, req.RequestURI); ok {
+ if ok, url := cb(scheme, host, req.RequestURI); ok {
return c.Redirect(config.Code, url)
}
return next(c)
}
- }, nil
-}
-
-var redirectHTTPS = func(scheme, host, uri string) (bool, string) {
- if scheme != "https" {
- return true, "https://" + host + uri
- }
- return false, ""
-}
-
-var redirectHTTPSWWW = func(scheme, host, uri string) (bool, string) {
- // Redirect if not HTTPS OR missing www prefix (needs either fix)
- if scheme != "https" || !strings.HasPrefix(host, www) {
- host = strings.TrimPrefix(host, www) // Remove www if present to avoid duplication
- return true, "https://www." + host + uri
- }
- return false, ""
-}
-
-var redirectNonHTTPSWWW = func(scheme, host, uri string) (ok bool, url string) {
- // Redirect if not HTTPS OR has www prefix (needs either fix)
- if scheme != "https" || strings.HasPrefix(host, www) {
- host = strings.TrimPrefix(host, www)
- return true, "https://" + host + uri
- }
- return false, ""
-}
-
-var redirectWWW = func(scheme, host, uri string) (bool, string) {
- if !strings.HasPrefix(host, www) {
- return true, scheme + "://www." + host + uri
- }
- return false, ""
-}
-
-var redirectNonWWW = func(scheme, host, uri string) (bool, string) {
- if strings.HasPrefix(host, www) {
- return true, scheme + "://" + host[4:] + uri
}
- return false, ""
}
diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go
index a127ca40c..88068ea2e 100644
--- a/middleware/redirect_test.go
+++ b/middleware/redirect_test.go
@@ -8,7 +8,7 @@ import (
"net/http/httptest"
"testing"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
@@ -58,8 +58,8 @@ func TestRedirectHTTPSWWWRedirect(t *testing.T) {
},
{
whenHost: "www.labstack.com",
- expectLocation: "https://www.labstack.com/",
- expectStatusCode: http.StatusMovedPermanently,
+ expectLocation: "",
+ expectStatusCode: http.StatusOK,
},
{
whenHost: "a.com",
@@ -74,12 +74,6 @@ func TestRedirectHTTPSWWWRedirect(t *testing.T) {
{
whenHost: "labstack.com",
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
- expectLocation: "https://www.labstack.com/",
- expectStatusCode: http.StatusMovedPermanently,
- },
- {
- whenHost: "www.labstack.com",
- whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
expectLocation: "",
expectStatusCode: http.StatusOK,
},
@@ -120,12 +114,6 @@ func TestRedirectHTTPSNonWWWRedirect(t *testing.T) {
{
whenHost: "www.labstack.com",
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
- expectLocation: "https://labstack.com/",
- expectStatusCode: http.StatusMovedPermanently,
- },
- {
- whenHost: "labstack.com",
- whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
expectLocation: "",
expectStatusCode: http.StatusOK,
},
@@ -230,7 +218,7 @@ func TestNonWWWRedirectWithConfig(t *testing.T) {
var testCases = []struct {
name string
givenCode int
- givenSkipFunc func(c *echo.Context) bool
+ givenSkipFunc func(c echo.Context) bool
whenHost string
whenHeader http.Header
expectLocation string
@@ -244,7 +232,7 @@ func TestNonWWWRedirectWithConfig(t *testing.T) {
},
{
name: "redirect is skipped",
- givenSkipFunc: func(c *echo.Context) bool {
+ givenSkipFunc: func(c echo.Context) bool {
return true // skip always
},
whenHost: "www.labstack.com",
@@ -278,7 +266,7 @@ func TestNonWWWRedirectWithConfig(t *testing.T) {
func redirectTest(fn middlewareGenerator, host string, header http.Header) *httptest.ResponseRecorder {
e := echo.New()
- next := func(c *echo.Context) (err error) {
+ next := func(c echo.Context) (err error) {
return c.NoContent(http.StatusOK)
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
diff --git a/middleware/request_id.go b/middleware/request_id.go
index b3de40d19..14bd4fd15 100644
--- a/middleware/request_id.go
+++ b/middleware/request_id.go
@@ -4,7 +4,7 @@
package middleware
import (
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
)
// RequestIDConfig defines the config for RequestID middleware.
@@ -13,45 +13,43 @@ type RequestIDConfig struct {
Skipper Skipper
// Generator defines a function to generate an ID.
- // Optional. Default value random.String(32).
+ // Optional. Defaults to generator for random string of length 32.
Generator func() string
// RequestIDHandler defines a function which is executed for a request id.
- RequestIDHandler func(c *echo.Context, requestID string)
+ RequestIDHandler func(echo.Context, string)
- // TargetHeader defines what header to look for to populate the id.
- // Optional. Default value is `X-Request-Id`
+ // TargetHeader defines what header to look for to populate the id
TargetHeader string
}
-// RequestID returns a middleware that reads RequestIDConfig.TargetHeader (`X-Request-ID`) header value or when
-// the header value is empty, generates that value and sets request ID to response
-// as RequestIDConfig.TargetHeader (`X-Request-Id`) value.
-func RequestID() echo.MiddlewareFunc {
- return RequestIDWithConfig(RequestIDConfig{})
+// DefaultRequestIDConfig is the default RequestID middleware config.
+var DefaultRequestIDConfig = RequestIDConfig{
+ Skipper: DefaultSkipper,
+ Generator: generator,
+ TargetHeader: echo.HeaderXRequestID,
}
-// RequestIDWithConfig returns a middleware with given valid config or panics on invalid configuration.
-// The middleware reads RequestIDConfig.TargetHeader (`X-Request-ID`) header value or when the header value is empty,
-// generates that value and sets request ID to response as RequestIDConfig.TargetHeader (`X-Request-Id`) value.
-func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
- return toMiddlewareOrPanic(config)
+// RequestID returns a X-Request-ID middleware.
+func RequestID() echo.MiddlewareFunc {
+ return RequestIDWithConfig(DefaultRequestIDConfig)
}
-// ToMiddleware converts RequestIDConfig to middleware or returns an error for invalid configuration
-func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+// RequestIDWithConfig returns a X-Request-ID middleware with config.
+func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
+ // Defaults
if config.Skipper == nil {
- config.Skipper = DefaultSkipper
+ config.Skipper = DefaultRequestIDConfig.Skipper
}
if config.Generator == nil {
- config.Generator = createRandomStringGenerator(32)
+ config.Generator = generator
}
if config.TargetHeader == "" {
config.TargetHeader = echo.HeaderXRequestID
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) error {
+ return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
@@ -69,5 +67,9 @@ func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
return next(c)
}
- }, nil
+ }
+}
+
+func generator() string {
+ return randomString(32)
}
diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go
index 465e6fc42..4e68b126a 100644
--- a/middleware/request_id_test.go
+++ b/middleware/request_id_test.go
@@ -8,7 +8,7 @@ import (
"net/http/httptest"
"testing"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
@@ -17,108 +17,29 @@ func TestRequestID(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- handler := func(c *echo.Context) error {
+ handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
- rid := RequestID()
- h := rid(handler)
- err := h(c)
- assert.NoError(t, err)
- assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
-}
-
-func TestMustRequestIDWithConfig_skipper(t *testing.T) {
- e := echo.New()
- e.GET("/", func(c *echo.Context) error {
- return c.String(http.StatusTeapot, "test")
- })
-
- generatorCalled := false
- e.Use(RequestIDWithConfig(RequestIDConfig{
- Skipper: func(c *echo.Context) bool {
- return true
- },
- Generator: func() string {
- generatorCalled = true
- return "customGenerator"
- },
- }))
-
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- res := httptest.NewRecorder()
- e.ServeHTTP(res, req)
-
- assert.Equal(t, http.StatusTeapot, res.Code)
- assert.Equal(t, "test", res.Body.String())
-
- assert.Equal(t, res.Header().Get(echo.HeaderXRequestID), "")
- assert.False(t, generatorCalled)
-}
-
-func TestMustRequestIDWithConfig_customGenerator(t *testing.T) {
- e := echo.New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
- handler := func(c *echo.Context) error {
- return c.String(http.StatusOK, "test")
- }
-
- rid := RequestIDWithConfig(RequestIDConfig{
- Generator: func() string { return "customGenerator" },
- })
- h := rid(handler)
- err := h(c)
- assert.NoError(t, err)
- assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
-}
-
-func TestMustRequestIDWithConfig_RequestIDHandler(t *testing.T) {
- e := echo.New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
- handler := func(c *echo.Context) error {
- return c.String(http.StatusOK, "test")
- }
-
- called := false
- rid := RequestIDWithConfig(RequestIDConfig{
- Generator: func() string { return "customGenerator" },
- RequestIDHandler: func(c *echo.Context, s string) {
- called = true
- },
- })
- h := rid(handler)
- err := h(c)
- assert.NoError(t, err)
- assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
- assert.True(t, called)
-}
-
-func TestRequestIDWithConfig(t *testing.T) {
- e := echo.New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
- handler := func(c *echo.Context) error {
- return c.String(http.StatusOK, "test")
- }
-
- rid, err := RequestIDConfig{}.ToMiddleware()
- assert.NoError(t, err)
+ rid := RequestIDWithConfig(RequestIDConfig{})
h := rid(handler)
h(c)
assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
- // Custom generator
+ // Custom generator and handler
+ customID := "customGenerator"
+ calledHandler := false
rid = RequestIDWithConfig(RequestIDConfig{
- Generator: func() string { return "customGenerator" },
+ Generator: func() string { return customID },
+ RequestIDHandler: func(_ echo.Context, id string) {
+ calledHandler = true
+ assert.Equal(t, customID, id)
+ },
})
h = rid(handler)
h(c)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
+ assert.True(t, calledHandler)
}
func TestRequestID_IDNotAltered(t *testing.T) {
@@ -128,7 +49,7 @@ func TestRequestID_IDNotAltered(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- handler := func(c *echo.Context) error {
+ handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
@@ -143,7 +64,7 @@ func TestRequestIDConfigDifferentHeader(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- handler := func(c *echo.Context) error {
+ handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
@@ -158,7 +79,7 @@ func TestRequestIDConfigDifferentHeader(t *testing.T) {
rid = RequestIDWithConfig(RequestIDConfig{
Generator: func() string { return customID },
TargetHeader: echo.HeaderXCorrelationID,
- RequestIDHandler: func(_ *echo.Context, id string) {
+ RequestIDHandler: func(_ echo.Context, id string) {
calledHandler = true
assert.Equal(t, customID, id)
},
diff --git a/middleware/request_logger.go b/middleware/request_logger.go
index 827fd3761..211abf464 100644
--- a/middleware/request_logger.go
+++ b/middleware/request_logger.go
@@ -10,7 +10,7 @@ import (
"net/http"
"time"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
)
// Example for `slog` https://pkg.go.dev/log/slog
@@ -18,8 +18,9 @@ import (
// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
// LogStatus: true,
// LogURI: true,
+// LogError: true,
// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
-// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
+// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// if v.Error == nil {
// logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST",
// slog.String("uri", v.URI),
@@ -40,8 +41,9 @@ import (
// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
// LogStatus: true,
// LogURI: true,
+// LogError: true,
// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
-// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
+// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// if v.Error == nil {
// fmt.Printf("REQUEST: uri: %v, status: %v\n", v.URI, v.Status)
// } else {
@@ -56,8 +58,9 @@ import (
// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
// LogURI: true,
// LogStatus: true,
+// LogError: true,
// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
-// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
+// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// if v.Error == nil {
// logger.Info().
// Str("URI", v.URI).
@@ -79,8 +82,9 @@ import (
// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
// LogURI: true,
// LogStatus: true,
+// LogError: true,
// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
-// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
+// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// if v.Error == nil {
// logger.Info("request",
// zap.String("URI", v.URI),
@@ -102,8 +106,9 @@ import (
// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
// LogURI: true,
// LogStatus: true,
+// LogError: true,
// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
-// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
+// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// if v.Error == nil {
// log.WithFields(logrus.Fields{
// "URI": v.URI,
@@ -126,10 +131,10 @@ type RequestLoggerConfig struct {
Skipper Skipper
// BeforeNextFunc defines a function that is called before next middleware or handler is called in chain.
- BeforeNextFunc func(c *echo.Context)
+ BeforeNextFunc func(c echo.Context)
// LogValuesFunc defines a function that is called with values extracted by logger from request/response.
// Mandatory.
- LogValuesFunc func(c *echo.Context, v RequestLoggerValues) error
+ LogValuesFunc func(c echo.Context, v RequestLoggerValues) error
// HandleError instructs logger to call global error handler when next middleware/handler returns an error.
// This is useful when you have custom error handler that can decide to use different status codes.
@@ -160,9 +165,11 @@ type RequestLoggerConfig struct {
LogReferer bool
// LogUserAgent instructs logger to extract request user agent values.
LogUserAgent bool
- // LogStatus instructs logger to extract response status code. If handler chain returns an error,
- // the status code is extracted from the error satisfying echo.StatusCoder interface.
+ // LogStatus instructs logger to extract response status code. If handler chain returns an echo.HTTPError,
+ // the status code is extracted from the echo.HTTPError returned
LogStatus bool
+ // LogError instructs logger to extract error returned from executed handler chain.
+ LogError bool
// LogContentLength instructs logger to extract content length header value. Note: this value could be different from
// actual request body size as it could be spoofed etc.
LogContentLength bool
@@ -211,7 +218,7 @@ type RequestLoggerValues struct {
Referer string
// UserAgent is request user agent values.
UserAgent string
- // Status is a response status code. When the handler returns an error satisfying echo.StatusCoder interface, then code from it.
+ // Status is response status code. Then handler returns an echo.HTTPError then code from there.
Status int
// Error is error returned from executed handler chain.
Error error
@@ -221,15 +228,15 @@ type RequestLoggerValues struct {
// ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct.
ResponseSize int64
// Headers are list of headers from request. Note: request can contain more than one header with same value so slice
- // of values is what will be returned/logged for each given header.
+ // of values is been logger for each given header.
// Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header
// names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding".
Headers map[string][]string
// QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter
- // with same name so slice of values is what will be returned/logged for each given query param name.
+ // with same name so slice of values is been logger for each given query param name.
QueryParams map[string][]string
// FormValues are list of form values from request body+URI. Note: request can contain more than one form value with
- // same name so slice of values is what will be returned/logged for each given form value name.
+ // same name so slice of values is been logger for each given form value name.
FormValues map[string][]string
}
@@ -242,6 +249,72 @@ func RequestLoggerWithConfig(config RequestLoggerConfig) echo.MiddlewareFunc {
return mw
}
+// RequestLogger returns a RequestLogger middleware with default configuration which
+// uses default slog.slog logger.
+//
+// To customize slog output format replace slog default logger:
+// For JSON format: `slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, nil)))`
+func RequestLogger() echo.MiddlewareFunc {
+ config := RequestLoggerConfig{
+ LogLatency: true,
+ LogProtocol: false,
+ LogRemoteIP: true,
+ LogHost: true,
+ LogMethod: true,
+ LogURI: true,
+ LogURIPath: false,
+ LogRoutePath: false,
+ LogRequestID: true,
+ LogReferer: false,
+ LogUserAgent: true,
+ LogStatus: true,
+ LogError: true,
+ LogContentLength: true,
+ LogResponseSize: true,
+ LogHeaders: nil,
+ LogQueryParams: nil,
+ LogFormValues: nil,
+ HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
+ LogValuesFunc: func(c echo.Context, v RequestLoggerValues) error {
+ if v.Error == nil {
+ slog.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST",
+ slog.String("method", v.Method),
+ slog.String("uri", v.URI),
+ slog.Int("status", v.Status),
+ slog.Duration("latency", v.Latency),
+ slog.String("host", v.Host),
+ slog.String("bytes_in", v.ContentLength),
+ slog.Int64("bytes_out", v.ResponseSize),
+ slog.String("user_agent", v.UserAgent),
+ slog.String("remote_ip", v.RemoteIP),
+ slog.String("request_id", v.RequestID),
+ )
+ } else {
+ slog.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR",
+ slog.String("method", v.Method),
+ slog.String("uri", v.URI),
+ slog.Int("status", v.Status),
+ slog.Duration("latency", v.Latency),
+ slog.String("host", v.Host),
+ slog.String("bytes_in", v.ContentLength),
+ slog.Int64("bytes_out", v.ResponseSize),
+ slog.String("user_agent", v.UserAgent),
+ slog.String("remote_ip", v.RemoteIP),
+ slog.String("request_id", v.RequestID),
+
+ slog.String("error", v.Error.Error()),
+ )
+ }
+ return nil
+ },
+ }
+ mw, err := config.ToMiddleware()
+ if err != nil {
+ panic(err)
+ }
+ return mw
+}
+
// ToMiddleware converts RequestLoggerConfig into middleware or returns an error for invalid configuration.
func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
@@ -266,12 +339,13 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
logFormValues := len(config.LogFormValues) > 0
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) error {
+ return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
+ res := c.Response()
start := now()
if config.BeforeNextFunc != nil {
@@ -279,11 +353,8 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
}
err := next(c)
if err != nil && config.HandleError {
- // When global error handler writes the error to the client the Response gets "committed". This state can be
- // checked with `c.Response().Committed` field.
- c.Echo().HTTPErrorHandler(c, err)
+ c.Error(err)
}
- res := c.Response()
v := RequestLoggerValues{
StartTime: start,
@@ -329,26 +400,26 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.LogUserAgent {
v.UserAgent = req.UserAgent()
}
-
- if config.LogStatus || config.LogResponseSize {
- resp, status := echo.ResolveResponseStatus(res, err)
-
- if config.LogStatus {
- v.Status = status
- }
- if config.LogResponseSize {
- v.ResponseSize = -1
- if resp != nil {
- v.ResponseSize = resp.Size
+ if config.LogStatus {
+ v.Status = res.Status
+ if err != nil && !config.HandleError {
+ // this block should not be executed in case of HandleError=true as the global error handler will decide
+ // the status code. In that case status code could be different from what err contains.
+ var httpErr *echo.HTTPError
+ if errors.As(err, &httpErr) {
+ v.Status = httpErr.Code
}
}
}
- if err != nil {
+ if config.LogError && err != nil {
v.Error = err
}
if config.LogContentLength {
v.ContentLength = req.Header.Get(echo.HeaderContentLength)
}
+ if config.LogResponseSize {
+ v.ResponseSize = res.Size
+ }
if logHeaders {
v.Headers = map[string][]string{}
for _, header := range headers {
@@ -378,69 +449,11 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if errOnLog := config.LogValuesFunc(c, v); errOnLog != nil {
return errOnLog
}
+
// in case of HandleError=true we are returning the error that we already have handled with global error handler
// this is deliberate as this error could be useful for upstream middlewares and default global error handler
// will ignore that error when it bubbles up in middleware chain.
- // Committed response can be checked in custom error handler with following logic
- //
- // if r, _ := echo.UnwrapResponse(c.Response()); r != nil && r.Committed {
- // return
- // }
return err
}
}, nil
}
-
-// RequestLogger creates Request Logger middleware with Echo default settings that uses Context.Logger() as logger.
-func RequestLogger() echo.MiddlewareFunc {
- return RequestLoggerWithConfig(RequestLoggerConfig{
- LogLatency: true,
- LogRemoteIP: true,
- LogHost: true,
- LogMethod: true,
- LogURI: true,
- LogRequestID: true,
- LogUserAgent: true,
- LogStatus: true,
- LogContentLength: true,
- LogResponseSize: true,
- // forwards error to the global error handler, so it can decide appropriate status code.
- // NB: side-effect of that is - request is now "committed" written to the client. Middlewares up in chain can not
- // change Response status code or response body.
- HandleError: true,
- LogValuesFunc: func(c *echo.Context, v RequestLoggerValues) error {
- logger := c.Logger()
- if v.Error == nil {
- logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST",
- slog.String("method", v.Method),
- slog.String("uri", v.URI),
- slog.Int("status", v.Status),
- slog.Duration("latency", v.Latency),
- slog.String("host", v.Host),
- slog.String("bytes_in", v.ContentLength),
- slog.Int64("bytes_out", v.ResponseSize),
- slog.String("user_agent", v.UserAgent),
- slog.String("remote_ip", v.RemoteIP),
- slog.String("request_id", v.RequestID),
- )
- return nil
- }
-
- logger.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR",
- slog.String("method", v.Method),
- slog.String("uri", v.URI),
- slog.Int("status", v.Status),
- slog.Duration("latency", v.Latency),
- slog.String("host", v.Host),
- slog.String("bytes_in", v.ContentLength),
- slog.Int64("bytes_out", v.ResponseSize),
- slog.String("user_agent", v.UserAgent),
- slog.String("remote_ip", v.RemoteIP),
- slog.String("request_id", v.RequestID),
-
- slog.String("error", v.Error.Error()),
- )
- return nil
- },
- })
-}
diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go
index 2232c6f6a..202660287 100644
--- a/middleware/request_logger_test.go
+++ b/middleware/request_logger_test.go
@@ -16,7 +16,7 @@ import (
"testing"
"time"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
@@ -26,13 +26,13 @@ func TestRequestLoggerOK(t *testing.T) {
slog.SetDefault(old)
})
- e := echo.New()
- e.IPExtractor = echo.LegacyIPExtractor()
buf := new(bytes.Buffer)
- e.Logger = slog.New(slog.NewJSONHandler(buf, nil))
+ slog.SetDefault(slog.New(slog.NewJSONHandler(buf, nil)))
+
+ e := echo.New()
e.Use(RequestLogger())
- e.POST("/test", func(c *echo.Context) error {
+ e.POST("/test", func(c echo.Context) error {
return c.String(http.StatusTeapot, "OK")
})
@@ -76,12 +76,13 @@ func TestRequestLoggerError(t *testing.T) {
slog.SetDefault(old)
})
- e := echo.New()
buf := new(bytes.Buffer)
- e.Logger = slog.New(slog.NewJSONHandler(buf, nil))
+ slog.SetDefault(slog.New(slog.NewJSONHandler(buf, nil)))
+
+ e := echo.New()
e.Use(RequestLogger())
- e.GET("/test", func(c *echo.Context) error {
+ e.GET("/test", func(c echo.Context) error {
return errors.New("nope")
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
@@ -120,13 +121,13 @@ func TestRequestLoggerWithConfig(t *testing.T) {
e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
LogRoutePath: true,
LogURI: true,
- LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error {
expect = values
return nil
},
}))
- e.GET("/test", func(c *echo.Context) error {
+ e.GET("/test", func(c echo.Context) error {
return c.String(http.StatusTeapot, "OK")
})
@@ -152,16 +153,16 @@ func TestRequestLogger_skipper(t *testing.T) {
loggerCalled := false
e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
- Skipper: func(c *echo.Context) bool {
+ Skipper: func(c echo.Context) bool {
return true
},
- LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error {
loggerCalled = true
return nil
},
}))
- e.GET("/test", func(c *echo.Context) error {
+ e.GET("/test", func(c echo.Context) error {
return c.String(http.StatusTeapot, "OK")
})
@@ -179,16 +180,16 @@ func TestRequestLogger_beforeNextFunc(t *testing.T) {
var myLoggerInstance int
e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
- BeforeNextFunc: func(c *echo.Context) {
+ BeforeNextFunc: func(c echo.Context) {
c.Set("myLoggerInstance", 42)
},
- LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error {
myLoggerInstance = c.Get("myLoggerInstance").(int)
return nil
},
}))
- e.GET("/test", func(c *echo.Context) error {
+ e.GET("/test", func(c echo.Context) error {
return c.String(http.StatusTeapot, "OK")
})
@@ -206,14 +207,15 @@ func TestRequestLogger_logError(t *testing.T) {
var actual RequestLoggerValues
e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
+ LogError: true,
LogStatus: true,
- LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error {
actual = values
return nil
},
}))
- e.GET("/test", func(c *echo.Context) error {
+ e.GET("/test", func(c echo.Context) error {
return echo.NewHTTPError(http.StatusNotAcceptable, "nope")
})
@@ -236,22 +238,23 @@ func TestRequestLogger_HandleError(t *testing.T) {
return time.Unix(1631045377, 0).UTC()
},
HandleError: true,
+ LogError: true,
LogStatus: true,
- LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error {
actual = values
return nil
},
}))
// to see if "HandleError" works we create custom error handler that uses its own status codes
- e.HTTPErrorHandler = func(c *echo.Context, err error) {
- if r, _ := echo.UnwrapResponse(c.Response()); r != nil && r.Committed {
+ e.HTTPErrorHandler = func(err error, c echo.Context) {
+ if c.Response().Committed {
return
}
c.JSON(http.StatusTeapot, "custom error handler")
}
- e.GET("/test", func(c *echo.Context) error {
+ e.GET("/test", func(c echo.Context) error {
return echo.NewHTTPError(http.StatusForbidden, "nope")
})
@@ -275,14 +278,15 @@ func TestRequestLogger_LogValuesFuncError(t *testing.T) {
var expect RequestLoggerValues
e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
+ LogError: true,
LogStatus: true,
- LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error {
expect = values
return echo.NewHTTPError(http.StatusNotAcceptable, "LogValuesFuncError")
},
}))
- e.GET("/test", func(c *echo.Context) error {
+ e.GET("/test", func(c echo.Context) error {
return c.String(http.StatusTeapot, "OK")
})
@@ -323,13 +327,13 @@ func TestRequestLogger_ID(t *testing.T) {
var expect RequestLoggerValues
e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
LogRequestID: true,
- LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error {
expect = values
return nil
},
}))
- e.GET("/test", func(c *echo.Context) error {
+ e.GET("/test", func(c echo.Context) error {
c.Response().Header().Set(echo.HeaderXRequestID, "321")
return c.String(http.StatusTeapot, "OK")
})
@@ -353,12 +357,12 @@ func TestRequestLogger_headerIsCaseInsensitive(t *testing.T) {
var expect RequestLoggerValues
mw := RequestLoggerWithConfig(RequestLoggerConfig{
- LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error {
expect = values
return nil
},
LogHeaders: []string{"referer", "User-Agent"},
- })(func(c *echo.Context) error {
+ })(func(c echo.Context) error {
c.Request().Header.Set(echo.HeaderXRequestID, "123")
c.FormValue("to force parse form")
return c.String(http.StatusTeapot, "OK")
@@ -383,7 +387,7 @@ func TestRequestLogger_allFields(t *testing.T) {
isFirstNowCall := true
var expect RequestLoggerValues
mw := RequestLoggerWithConfig(RequestLoggerConfig{
- LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error {
expect = values
return nil
},
@@ -399,6 +403,7 @@ func TestRequestLogger_allFields(t *testing.T) {
LogReferer: true,
LogUserAgent: true,
LogStatus: true,
+ LogError: true,
LogContentLength: true,
LogResponseSize: true,
LogHeaders: []string{"accept-encoding", "User-Agent"},
@@ -411,7 +416,7 @@ func TestRequestLogger_allFields(t *testing.T) {
}
return time.Unix(1631045377+10, 0)
},
- })(func(c *echo.Context) error {
+ })(func(c echo.Context) error {
c.Request().Header.Set(echo.HeaderXRequestID, "123")
c.FormValue("to force parse form")
return c.String(http.StatusTeapot, "OK")
@@ -440,7 +445,7 @@ func TestRequestLogger_allFields(t *testing.T) {
assert.Equal(t, time.Unix(1631045377, 0), expect.StartTime)
assert.Equal(t, 10*time.Second, expect.Latency)
assert.Equal(t, "HTTP/1.1", expect.Protocol)
- assert.Equal(t, "192.0.2.1", expect.RemoteIP)
+ assert.Equal(t, "8.8.8.8", expect.RemoteIP)
assert.Equal(t, "example.com", expect.Host)
assert.Equal(t, http.MethodPost, expect.Method)
assert.Equal(t, "/test?lang=en&checked=1&checked=2", expect.URI)
@@ -466,86 +471,12 @@ func TestRequestLogger_allFields(t *testing.T) {
assert.Equal(t, []string{"1", "2"}, expect.FormValues["multiple"])
}
-func TestTestRequestLogger(t *testing.T) {
- var testCases = []struct {
- name string
- whenStatus int
- whenError error
- expectStatus string
- expectError string
- }{
- {
- name: "ok",
- whenStatus: http.StatusTeapot,
- expectStatus: "418",
- },
- {
- name: "error",
- whenError: echo.NewHTTPError(http.StatusBadGateway, "bad gw"),
- expectStatus: "502",
- expectError: `"error":"code=502, message=bad gw"`,
- },
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := echo.New()
- buf := new(bytes.Buffer)
- e.Logger = slog.New(slog.NewJSONHandler(buf, nil))
-
- e.Use(RequestLogger())
- e.POST("/test", func(c *echo.Context) error {
- if tc.whenError != nil {
- return tc.whenError
- }
- return c.String(tc.whenStatus, "OK")
- })
-
- f := make(url.Values)
- f.Set("csrf", "token")
- f.Set("multiple", "1")
- f.Add("multiple", "2")
- reader := strings.NewReader(f.Encode())
- req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", reader)
- req.Header.Set("Referer", "https://echo.labstack.com/")
- req.Header.Set("User-Agent", "curl/7.68.0")
- req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size())))
- req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
- req.Header.Set(echo.HeaderXRealIP, "8.8.8.8")
- req.Header.Set(echo.HeaderXRequestID, "MY_ID")
-
- rec := httptest.NewRecorder()
-
- e.ServeHTTP(rec, req)
-
- rawlog := buf.Bytes()
- if tc.expectError != "" {
- assert.Contains(t, string(rawlog), `"level":"ERROR"`)
- assert.Contains(t, string(rawlog), `"msg":"REQUEST_ERROR"`)
- assert.Contains(t, string(rawlog), tc.expectError)
- } else {
- assert.Contains(t, string(rawlog), `"level":"INFO"`)
- assert.Contains(t, string(rawlog), `"msg":"REQUEST"`)
- }
- assert.Contains(t, string(rawlog), `"status":`+tc.expectStatus)
- assert.Contains(t, string(rawlog), `"method":"POST"`)
- assert.Contains(t, string(rawlog), `"uri":"/test?lang=en&checked=1&checked=2"`)
- assert.Contains(t, string(rawlog), `"latency":`) // this value varies
- assert.Contains(t, string(rawlog), `"request_id":"MY_ID"`)
- assert.Contains(t, string(rawlog), `"remote_ip":"192.0.2.1"`)
- assert.Contains(t, string(rawlog), `"host":"example.com"`)
- assert.Contains(t, string(rawlog), `"user_agent":"curl/7.68.0"`)
- assert.Contains(t, string(rawlog), `"bytes_in":"32"`)
- assert.Contains(t, string(rawlog), `"bytes_out":2`)
- })
- }
-}
-
func BenchmarkRequestLogger_withoutMapFields(b *testing.B) {
e := echo.New()
mw := RequestLoggerWithConfig(RequestLoggerConfig{
Skipper: nil,
- LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error {
return nil
},
LogLatency: true,
@@ -560,9 +491,10 @@ func BenchmarkRequestLogger_withoutMapFields(b *testing.B) {
LogReferer: true,
LogUserAgent: true,
LogStatus: true,
+ LogError: true,
LogContentLength: true,
LogResponseSize: true,
- })(func(c *echo.Context) error {
+ })(func(c echo.Context) error {
c.Request().Header.Set(echo.HeaderXRequestID, "123")
return c.String(http.StatusTeapot, "OK")
})
@@ -585,7 +517,7 @@ func BenchmarkRequestLogger_withMapFields(b *testing.B) {
e := echo.New()
mw := RequestLoggerWithConfig(RequestLoggerConfig{
- LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error {
return nil
},
LogLatency: true,
@@ -600,12 +532,13 @@ func BenchmarkRequestLogger_withMapFields(b *testing.B) {
LogReferer: true,
LogUserAgent: true,
LogStatus: true,
+ LogError: true,
LogContentLength: true,
LogResponseSize: true,
LogHeaders: []string{"accept-encoding", "User-Agent"},
LogQueryParams: []string{"lang", "checked"},
LogFormValues: []string{"csrf", "multiple"},
- })(func(c *echo.Context) error {
+ })(func(c echo.Context) error {
c.Request().Header.Set(echo.HeaderXRequestID, "123")
c.FormValue("to force parse form")
return c.String(http.StatusTeapot, "OK")
diff --git a/middleware/rewrite.go b/middleware/rewrite.go
index 02907ca49..4c19cc1cc 100644
--- a/middleware/rewrite.go
+++ b/middleware/rewrite.go
@@ -4,11 +4,9 @@
package middleware
import (
- "errors"
- "maps"
"regexp"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
)
// RewriteConfig defines the config for Rewrite middleware.
@@ -24,48 +22,51 @@ type RewriteConfig struct {
// "/js/*": "/public/javascripts/$1",
// "/users/*/orders/*": "/user/$1/order/$2",
// Required.
- Rules map[string]string
+ Rules map[string]string `yaml:"rules"`
// RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
// Example:
// "^/old/[0.9]+/": "/new",
// "^/api/.+?/(.*)": "/v2/$1",
- RegexRules map[*regexp.Regexp]string
+ RegexRules map[*regexp.Regexp]string `yaml:"-"`
+}
+
+// DefaultRewriteConfig is the default Rewrite middleware config.
+var DefaultRewriteConfig = RewriteConfig{
+ Skipper: DefaultSkipper,
}
// Rewrite returns a Rewrite middleware.
//
// Rewrite middleware rewrites the URL path based on the provided rules.
func Rewrite(rules map[string]string) echo.MiddlewareFunc {
- c := RewriteConfig{}
+ c := DefaultRewriteConfig
c.Rules = rules
return RewriteWithConfig(c)
}
-// RewriteWithConfig returns a Rewrite middleware or panics on invalid configuration.
-//
-// Rewrite middleware rewrites the URL path based on the provided rules.
+// RewriteWithConfig returns a Rewrite middleware with config.
+// See: `Rewrite()`.
func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
- return toMiddlewareOrPanic(config)
-}
+ // Defaults
+ if config.Rules == nil && config.RegexRules == nil {
+ panic("echo: rewrite middleware requires url path rewrite rules or regex rules")
+ }
-// ToMiddleware converts RewriteConfig to middleware or returns an error for invalid configuration
-func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
- config.Skipper = DefaultSkipper
- }
- if config.Rules == nil && config.RegexRules == nil {
- return nil, errors.New("echo rewrite middleware requires url path rewrite rules or regex rules")
+ config.Skipper = DefaultBodyDumpConfig.Skipper
}
if config.RegexRules == nil {
config.RegexRules = make(map[*regexp.Regexp]string)
}
- maps.Copy(config.RegexRules, rewriteRulesRegex(config.Rules))
+ for k, v := range rewriteRulesRegex(config.Rules) {
+ config.RegexRules[k] = v
+ }
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) (err error) {
+ return func(c echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
@@ -75,5 +76,5 @@ func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
}
return next(c)
}
- }, nil
+ }
}
diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go
index adcc8e9f5..d137b2d13 100644
--- a/middleware/rewrite_test.go
+++ b/middleware/rewrite_test.go
@@ -11,7 +11,7 @@ import (
"regexp"
"testing"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
@@ -26,10 +26,10 @@ func TestRewriteAfterRouting(t *testing.T) {
"/users/*/orders/*": "/user/$1/order/$2",
},
}))
- e.GET("/public/*", func(c *echo.Context) error {
+ e.GET("/public/*", func(c echo.Context) error {
return c.String(http.StatusOK, c.Param("*"))
})
- e.GET("/*", func(c *echo.Context) error {
+ e.GET("/*", func(c echo.Context) error {
return c.String(http.StatusOK, c.Param("*"))
})
@@ -93,74 +93,20 @@ func TestRewriteAfterRouting(t *testing.T) {
}
}
-func TestMustRewriteWithConfig_emptyRulesPanics(t *testing.T) {
- assert.Panics(t, func() {
- RewriteWithConfig(RewriteConfig{})
- })
-}
-
-func TestMustRewriteWithConfig_skipper(t *testing.T) {
- var testCases = []struct {
- name string
- givenSkipper func(c *echo.Context) bool
- whenURL string
- expectURL string
- expectStatus int
- }{
- {
- name: "not skipped",
- whenURL: "/old",
- expectURL: "/new",
- expectStatus: http.StatusOK,
- },
- {
- name: "skipped",
- givenSkipper: func(c *echo.Context) bool {
- return true
- },
- whenURL: "/old",
- expectURL: "/old",
- expectStatus: http.StatusNotFound,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := echo.New()
-
- e.Pre(RewriteWithConfig(
- RewriteConfig{
- Skipper: tc.givenSkipper,
- Rules: map[string]string{"/old": "/new"}},
- ))
-
- e.GET("/new", func(c *echo.Context) error {
- return c.NoContent(http.StatusOK)
- })
-
- req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
- rec := httptest.NewRecorder()
-
- e.ServeHTTP(rec, req)
-
- assert.Equal(t, tc.expectURL, req.URL.EscapedPath())
- assert.Equal(t, tc.expectStatus, rec.Code)
- })
- }
-}
-
// Issue #1086
func TestEchoRewritePreMiddleware(t *testing.T) {
e := echo.New()
+ r := e.Router()
// Rewrite old url to new one
// middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
- e.Pre(RewriteWithConfig(RewriteConfig{
- Rules: map[string]string{"/old": "/new"}}),
- )
+ e.Pre(Rewrite(map[string]string{
+ "/old": "/new",
+ },
+ ))
// Route
- e.Add(http.MethodGet, "/new", func(c *echo.Context) error {
+ r.Add(http.MethodGet, "/new", func(c echo.Context) error {
return c.NoContent(http.StatusOK)
})
@@ -174,6 +120,7 @@ func TestEchoRewritePreMiddleware(t *testing.T) {
// Issue #1143
func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
e := echo.New()
+ r := e.Router()
// middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
e.Pre(RewriteWithConfig(RewriteConfig{
@@ -183,14 +130,14 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
},
}))
- e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c *echo.Context) error {
+ r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error {
return c.String(http.StatusOK, "hosts")
})
- e.Add(http.MethodGet, "/api/:version/eng", func(c *echo.Context) error {
+ r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error {
return c.String(http.StatusOK, "eng")
})
- for range 100 {
+ for i := 0; i < 100; i++ {
req := httptest.NewRequest(http.MethodGet, "/api/v1/mgmt/proj/test/agt", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
diff --git a/middleware/secure.go b/middleware/secure.go
index 022cce4a1..c904abf1a 100644
--- a/middleware/secure.go
+++ b/middleware/secure.go
@@ -6,7 +6,7 @@ package middleware
import (
"fmt"
- "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v4"
)
// SecureConfig defines the config for Secure middleware.
@@ -17,12 +17,12 @@ type SecureConfig struct {
// XSSProtection provides protection against cross-site scripting attack (XSS)
// by setting the `X-XSS-Protection` header.
// Optional. Default value "1; mode=block".
- XSSProtection string
+ XSSProtection string `yaml:"xss_protection"`
// ContentTypeNosniff provides protection against overriding Content-Type
// header by setting the `X-Content-Type-Options` header.
// Optional. Default value "nosniff".
- ContentTypeNosniff string
+ ContentTypeNosniff string `yaml:"content_type_nosniff"`
// XFrameOptions can be used to indicate whether or not a browser should
// be allowed to render a page in a ,