diff --git a/CHANGELOG.md b/CHANGELOG.md index 401e0e5b3bf978c1ca95675408b8a54632f47bfd..91a67acf60ea745109ea91fb3c3de95b88517fcc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ ## Changes since v5.1.1 +- [#503](https://github.com/oauth2-proxy/oauth2-proxy/pull/503) Implements --real-client-ip-header option to select the header from which to obtain a proxied client's IP (@Izzette) - [#529](https://github.com/oauth2-proxy/oauth2-proxy/pull/529) Add local test environments for testing changes and new features (@JoelSpeed) - [#537](https://github.com/oauth2-proxy/oauth2-proxy/pull/537) Drop Fallback to Email if User not set (@JoelSpeed) - [#535](https://github.com/oauth2-proxy/oauth2-proxy/pull/535) Drop support for pre v3.1 cookies (@JoelSpeed) diff --git a/contrib/oauth2-proxy_autocomplete.sh b/contrib/oauth2-proxy_autocomplete.sh index bea798b9574f6891eff3bdddd78e85b35811c07b..d4eaa1b1e76c7575db02675f525bf5a8eedc1ecb 100644 --- a/contrib/oauth2-proxy_autocomplete.sh +++ b/contrib/oauth2-proxy_autocomplete.sh @@ -20,6 +20,10 @@ _oauth2_proxy() { COMPREPLY=( $(compgen -W "google azure facebook github keycloak gitlab linkedin login.gov digitalocean" -- ${cur}) ) return 0 ;; + --real-client-ip-header) + COMPREPLY=( $(compgen -W 'X-Real-IP X-Forwarded-For X-ProxyUser-IP' -- ${cur}) ) + return 0 + ;; -@(http-address|https-address|redirect-url|upstream|basic-auth-password|skip-auth-regex|flush-interval|extra-jwt-issuers|email-domain|whitelist-domain|keycloak-group|azure-tenant|bitbucket-team|bitbucket-repository|github-org|github-team|github-repo|github-token|gitlab-group|google-group|google-admin-email|google-service-account-json|client-id|client_secret|banner|footer|proxy-prefix|ping-path|cookie-name|cookie-secret|cookie-domain|cookie-path|cookie-expire|cookie-refresh|cookie-samesite|redist-sentinel-master-name|redist-sentinel-connection-urls|redist-cluster-connection-urls|logging-max-size|logging-max-age|logging-max-backups|standard-logging-format|request-logging-format|exclude-logging-paths|auth-logging-format|oidc-issuer-url|oidc-jwks-url|login-url|redeem-url|profile-url|resource|validate-url|scope|approval-prompt|signature-key|acr-values|jwt-key|pubjwk-url)) return 0 ;; diff --git a/docs/configuration/configuration.md b/docs/configuration/configuration.md index 90d26e6ca6faf7526cf1c11fb27986da4bf9204d..69b3cfa018a54bc69cede1a3b2033d5aa95e5113 100644 --- a/docs/configuration/configuration.md +++ b/docs/configuration/configuration.md @@ -90,6 +90,7 @@ An example [oauth2-proxy.cfg]({{ site.gitweb }}/contrib/oauth2-proxy.cfg.example | `--proxy-prefix` | string | the url root path that this proxy should be nested under (e.g. /`<oauth2>/sign_in`) | `"/oauth2"` | | `--proxy-websockets` | bool | enables WebSocket proxying | true | | `--pubjwk-url` | string | JWK pubkey access endpoint: required by login.gov | | +| `--real-client-ip-header` | string | Header used to determine the real IP of the client, requires `--reverse-proxy` to be set (one of: X-Forwarded-For, X-Real-IP, or X-ProxyUser-IP) | X-Real-IP | | `--redeem-url` | string | Token redemption endpoint | | | `--redirect-url` | string | the OAuth Redirect URL. ie: `"https://internalapp.yourcompany.com/oauth2/callback"` | | | `--redis-cluster-connection-urls` | string \| list | List of Redis cluster connection URLs (eg redis://HOST[:PORT]). Used in conjunction with `--redis-use-cluster` | | diff --git a/main.go b/main.go index 049b6ae8d6bdc2ef4addd91cc7ee0c9fae949299..4507d8ee1b683265c057133a7fc43f9ecf874022 100644 --- a/main.go +++ b/main.go @@ -26,6 +26,7 @@ func main() { flagSet.String("http-address", "127.0.0.1:4180", "[http://]<addr>:<port> or unix://<path> to listen on for HTTP clients") flagSet.String("https-address", ":443", "<addr>:<port> to listen on for HTTPS clients") flagSet.Bool("reverse-proxy", false, "are we running behind a reverse proxy, controls whether headers like X-Real-Ip are accepted") + flagSet.String("real-client-ip-header", "X-Real-IP", "Header used to determine the real IP of the client (one of: X-Forwarded-For, X-Real-IP, or X-ProxyUser-IP)") flagSet.Bool("force-https", false, "force HTTPS redirect for HTTP requests") flagSet.String("tls-cert-file", "", "path to certificate file") flagSet.String("tls-key-file", "", "path to private key file") diff --git a/oauthproxy.go b/oauthproxy.go index 42934232b9edf56131116d800b704d59f0484058..6a4d699f8ce745963f3900e8f78f449e48c0824f 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -112,6 +112,7 @@ type OAuthProxy struct { jwtBearerVerifiers []*oidc.IDTokenVerifier compiledRegex []*regexp.Regexp templates *template.Template + realClientIPParser realClientIPParser Banner string Footer string } @@ -308,6 +309,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { skipJwtBearerTokens: opts.SkipJwtBearerTokens, jwtBearerVerifiers: opts.jwtBearerVerifiers, compiledRegex: opts.compiledRegex, + realClientIPParser: opts.realClientIPParser, SetXAuthRequest: opts.SetXAuthRequest, PassBasicAuth: opts.PassBasicAuth, SetBasicAuth: opts.SetBasicAuth, @@ -636,14 +638,6 @@ func (p *OAuthProxy) IsWhitelistedPath(path string) bool { return false } -func getRemoteAddr(req *http.Request) (s string) { - s = req.RemoteAddr - if req.Header.Get("X-Real-IP") != "" { - s += fmt.Sprintf(" (%q)", req.Header.Get("X-Real-IP")) - } - return -} - // See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en var noCacheHeaders = map[string]string{ "Expires": time.Unix(0, 0).Format(time.RFC1123), @@ -766,7 +760,7 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { // OAuthCallback is the OAuth2 authentication flow callback that finishes the // OAuth2 authentication flow func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { - remoteAddr := getRemoteAddr(req) + remoteAddr := getClientString(p.realClientIPParser, req, true) // finish the oauth cycle err := req.ParseForm() @@ -894,7 +888,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R } } - remoteAddr := getRemoteAddr(req) + remoteAddr := getClientString(p.realClientIPParser, req, true) if session == nil { session, err = p.LoadCookiedSession(req) if err != nil { diff --git a/options.go b/options.go index 80b1ff576fcf6667d2ab329a0fd14f808226fcc0..7220da849a61ac4efe06134af41511cdc366e8b6 100644 --- a/options.go +++ b/options.go @@ -31,19 +31,20 @@ import ( // Options holds Configuration Options that can be set by Command Line Flag, // or Config File type Options struct { - ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy_prefix" env:"OAUTH2_PROXY_PROXY_PREFIX"` - PingPath string `flag:"ping-path" cfg:"ping_path" env:"OAUTH2_PROXY_PING_PATH"` - ProxyWebSockets bool `flag:"proxy-websockets" cfg:"proxy_websockets" env:"OAUTH2_PROXY_PROXY_WEBSOCKETS"` - HTTPAddress string `flag:"http-address" cfg:"http_address" env:"OAUTH2_PROXY_HTTP_ADDRESS"` - HTTPSAddress string `flag:"https-address" cfg:"https_address" env:"OAUTH2_PROXY_HTTPS_ADDRESS"` - ReverseProxy bool `flag:"reverse-proxy" cfg:"reverse_proxy" env:"OAUTH2_PROXY_REVERSE_PROXY"` - ForceHTTPS bool `flag:"force-https" cfg:"force_https" env:"OAUTH2_PROXY_FORCE_HTTPS"` - RedirectURL string `flag:"redirect-url" cfg:"redirect_url" env:"OAUTH2_PROXY_REDIRECT_URL"` - ClientID string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"` - ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"` - ClientSecretFile string `flag:"client-secret-file" cfg:"client_secret_file" env:"OAUTH2_PROXY_CLIENT_SECRET_FILE"` - TLSCertFile string `flag:"tls-cert-file" cfg:"tls_cert_file" env:"OAUTH2_PROXY_TLS_CERT_FILE"` - TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file" env:"OAUTH2_PROXY_TLS_KEY_FILE"` + ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy_prefix" env:"OAUTH2_PROXY_PROXY_PREFIX"` + PingPath string `flag:"ping-path" cfg:"ping_path" env:"OAUTH2_PROXY_PING_PATH"` + ProxyWebSockets bool `flag:"proxy-websockets" cfg:"proxy_websockets" env:"OAUTH2_PROXY_PROXY_WEBSOCKETS"` + HTTPAddress string `flag:"http-address" cfg:"http_address" env:"OAUTH2_PROXY_HTTP_ADDRESS"` + HTTPSAddress string `flag:"https-address" cfg:"https_address" env:"OAUTH2_PROXY_HTTPS_ADDRESS"` + ReverseProxy bool `flag:"reverse-proxy" cfg:"reverse_proxy" env:"OAUTH2_PROXY_REVERSE_PROXY"` + RealClientIPHeader string `flag:"real-client-ip-header" cfg:"real_client_ip_header" env:"OAUTH2_PROXY_REAL_CLIENT_IP_HEADER"` + ForceHTTPS bool `flag:"force-https" cfg:"force_https" env:"OAUTH2_PROXY_FORCE_HTTPS"` + RedirectURL string `flag:"redirect-url" cfg:"redirect_url" env:"OAUTH2_PROXY_REDIRECT_URL"` + ClientID string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"` + ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"` + ClientSecretFile string `flag:"client-secret-file" cfg:"client_secret_file" env:"OAUTH2_PROXY_CLIENT_SECRET_FILE"` + TLSCertFile string `flag:"tls-cert-file" cfg:"tls_cert_file" env:"OAUTH2_PROXY_TLS_CERT_FILE"` + TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file" env:"OAUTH2_PROXY_TLS_KEY_FILE"` AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file" env:"OAUTH2_PROXY_AUTHENTICATED_EMAILS_FILE"` KeycloakGroup string `flag:"keycloak-group" cfg:"keycloak_group" env:"OAUTH2_PROXY_KEYCLOAK_GROUP"` @@ -139,6 +140,7 @@ type Options struct { signatureData *SignatureData oidcVerifier *oidc.IDTokenVerifier jwtBearerVerifiers []*oidc.IDTokenVerifier + realClientIPParser realClientIPParser } // SignatureData holds hmacauth signature hash and key @@ -456,6 +458,13 @@ func (o *Options) Validate() error { msgs = validateCookieName(o, msgs) msgs = setupLogger(o, msgs) + if o.ReverseProxy { + o.realClientIPParser, err = getRealClientIPParser(o.RealClientIPHeader) + if err != nil { + msgs = append(msgs, fmt.Sprintf("real_client_ip_header (%s) not accepted parameter value: %v", o.RealClientIPHeader, err)) + } + } + if len(msgs) != 0 { return fmt.Errorf("invalid configuration:\n %s", strings.Join(msgs, "\n ")) @@ -695,7 +704,9 @@ func setupLogger(o *Options, msgs []string) []string { logger.SetStandardTemplate(o.StandardLoggingFormat) logger.SetAuthTemplate(o.AuthLoggingFormat) logger.SetReqTemplate(o.RequestLoggingFormat) - logger.SetReverseProxy(o.ReverseProxy) + logger.SetGetClientFunc(func(r *http.Request) string { + return getClientString(o.realClientIPParser, r, false) + }) excludePaths := make([]string, 0) excludePaths = append(excludePaths, strings.Split(o.ExcludeLoggingPaths, ",")...) diff --git a/options_test.go b/options_test.go index b14c9e465a2d8ab55faba00225ceff3eea7484b2..ae978b6af561d1a30af4f2f8540aca9393bb80b2 100644 --- a/options_test.go +++ b/options_test.go @@ -326,3 +326,46 @@ func TestGCPHealthcheck(t *testing.T) { o.GCPHealthChecks = true assert.Equal(t, nil, o.Validate()) } + +func TestRealClientIPHeader(t *testing.T) { + var o *Options + var err error + var expected string + + // Ensure nil if ReverseProxy not set. + o = testOptions() + o.RealClientIPHeader = "X-Real-IP" + assert.Equal(t, nil, o.Validate()) + assert.Nil(t, o.realClientIPParser) + + // Ensure simple use case works. + o = testOptions() + o.ReverseProxy = true + o.RealClientIPHeader = "X-Forwarded-For" + assert.Equal(t, nil, o.Validate()) + assert.NotNil(t, o.realClientIPParser) + + // Ensure unknown header format process an error. + o = testOptions() + o.ReverseProxy = true + o.RealClientIPHeader = "Forwarded" + err = o.Validate() + assert.NotEqual(t, nil, err) + expected = errorMsg([]string{ + "real_client_ip_header (Forwarded) not accepted parameter value: the http header key (Forwarded) is either invalid or unsupported", + }) + assert.Equal(t, expected, err.Error()) + assert.Nil(t, o.realClientIPParser) + + // Ensure invalid header format produces an error. + o = testOptions() + o.ReverseProxy = true + o.RealClientIPHeader = "!934invalidheader-23:" + err = o.Validate() + assert.NotEqual(t, nil, err) + expected = errorMsg([]string{ + "real_client_ip_header (!934invalidheader-23:) not accepted parameter value: the http header key (!934invalidheader-23:) is either invalid or unsupported", + }) + assert.Equal(t, expected, err.Error()) + assert.Nil(t, o.realClientIPParser) +} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 0664f8434d3d3b2b12011c0904ea74741978b66f..9bfc2b3a6a7351f43cf43edd1f828903f7073907 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -3,7 +3,6 @@ package logger import ( "fmt" "io" - "net" "net/http" "net/url" "os" @@ -76,6 +75,9 @@ type reqLogMessageData struct { Username string } +// Returns the apparent "real client IP" as a string. +type GetClientFunc = func(r *http.Request) string + // A Logger represents an active logging object that generates lines of // output to an io.Writer passed through a formatter. Each logging // operation makes a single call to the Writer's Write method. A Logger @@ -88,7 +90,7 @@ type Logger struct { stdEnabled bool authEnabled bool reqEnabled bool - reverseProxy bool + getClientFunc GetClientFunc excludePaths map[string]struct{} stdLogTemplate *template.Template authTemplate *template.Template @@ -103,7 +105,7 @@ func New(flag int) *Logger { stdEnabled: true, authEnabled: true, reqEnabled: true, - reverseProxy: false, + getClientFunc: func(r *http.Request) string { return r.RemoteAddr }, excludePaths: nil, stdLogTemplate: template.Must(template.New("std-log").Parse(DefaultStandardLoggingFormat)), authTemplate: template.Must(template.New("auth-log").Parse(DefaultAuthLoggingFormat)), @@ -153,7 +155,7 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu username = "-" } - client := GetClient(req, l.reverseProxy) + client := l.getClientFunc(req) l.mu.Lock() defer l.mu.Unlock() @@ -201,7 +203,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url. } } - client := GetClient(req, l.reverseProxy) + client := l.getClientFunc(req) l.mu.Lock() defer l.mu.Unlock() @@ -252,22 +254,6 @@ func (l *Logger) GetFileLineString(calldepth int) string { return fmt.Sprintf("%s:%d", file, line) } -// GetClient parses an HTTP request for the client/remote IP address. -func GetClient(req *http.Request, reverseProxy bool) string { - client := req.RemoteAddr - if reverseProxy { - if ip := req.Header.Get("X-Real-IP"); ip != "" { - client = ip - } - } - - if c, _, err := net.SplitHostPort(client); err == nil { - client = c - } - - return client -} - // FormatTimestamp returns a formatted timestamp. func (l *Logger) FormatTimestamp(ts time.Time) string { if l.flag&LUTC != 0 { @@ -312,11 +298,11 @@ func (l *Logger) SetReqEnabled(e bool) { l.reqEnabled = e } -// SetReverseProxy controls whether logging will trust headers that can be set by a reverse proxy. -func (l *Logger) SetReverseProxy(e bool) { +// SetGetClientFunc sets the function which determines the apparent "real client IP". +func (l *Logger) SetGetClientFunc(f GetClientFunc) { l.mu.Lock() defer l.mu.Unlock() - l.reverseProxy = e + l.getClientFunc = f } // SetExcludePaths sets the paths to exclude from logging. @@ -392,10 +378,10 @@ func SetReqEnabled(e bool) { std.SetReqEnabled(e) } -// SetReverseProxy controls whether logging will trust headers that can be set -// by a reverse proxy for the standard logger. -func SetReverseProxy(e bool) { - std.SetReverseProxy(e) +// SetGetClientFunc sets the function which determines the apparent IP address +// set by a reverse proxy for the standard logger. +func SetGetClientFunc(f GetClientFunc) { + std.SetGetClientFunc(f) } // SetExcludePaths sets the path to exclude from logging, eg: health checks diff --git a/realclientip.go b/realclientip.go new file mode 100644 index 0000000000000000000000000000000000000000..b45ef7c29365e53adcccc8b185877b9c888b3267 --- /dev/null +++ b/realclientip.go @@ -0,0 +1,102 @@ +package main + +import ( + "fmt" + "net" + "net/http" + "strings" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" +) + +type realClientIPParser interface { + GetRealClientIP(http.Header) (net.IP, error) +} + +func getRealClientIPParser(headerKey string) (realClientIPParser, error) { + headerKey = http.CanonicalHeaderKey(headerKey) + + switch headerKey { + case http.CanonicalHeaderKey("X-Forwarded-For"), http.CanonicalHeaderKey("X-Real-IP"), http.CanonicalHeaderKey("X-ProxyUser-IP"): + return &xForwardedForClientIPParser{header: headerKey}, nil + } + + // TODO: implement the more standardized but more complex `Forwarded` header. + return nil, fmt.Errorf("the http header key (%s) is either invalid or unsupported", headerKey) +} + +type xForwardedForClientIPParser struct { + header string +} + +// GetRealClientIP obtain the IP address of the end-user (not proxy). +// Parses headers sharing the format as specified by: +// * https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For. +// Returns the `<client>` portion specified in the above document. +// Additionally, is capable of parsing IPs with the port included, for v4 in the format "<ip>:<port>" and for v6 in the +// format "[<ip>]:<port>". With-port and without-port formats are seamlessly supported concurrently. +func (p xForwardedForClientIPParser) GetRealClientIP(h http.Header) (net.IP, error) { + var ipStr string + if realIP := h.Get(p.header); realIP != "" { + ipStr = realIP + } else { + return nil, nil + } + + // Each successive proxy may append itself, comma separated, to the end of the X-Forwarded-for header. + // Select only the first IP listed, as it is the client IP recorded by the first proxy. + if commaIndex := strings.IndexRune(ipStr, ','); commaIndex != -1 { + ipStr = ipStr[:commaIndex] + } + ipStr = strings.TrimSpace(ipStr) + + if ipHost, _, err := net.SplitHostPort(ipStr); err == nil { + ipStr = ipHost + } + + ip := net.ParseIP(ipStr) + if ip == nil { + return nil, fmt.Errorf("unable to parse ip (%s) from %s header", ipStr, http.CanonicalHeaderKey(p.header)) + } + + return ip, nil +} + +// getRemoteIP obtains the IP of the low-level connected network host +func getRemoteIP(req *http.Request) (net.IP, error) { + if ipStr, _, err := net.SplitHostPort(req.RemoteAddr); err != nil { + return nil, fmt.Errorf("unable to get ip and port from http.RemoteAddr (%s)", req.RemoteAddr) + } else if ip := net.ParseIP(ipStr); ip != nil { + return ip, nil + } else { + return nil, fmt.Errorf("unable to parse ip (%s)", ipStr) + } +} + +// getClientString obtains the human readable string of the remote IP and optionally the real client IP if available +func getClientString(p realClientIPParser, req *http.Request, full bool) (s string) { + var realClientIPStr string + if p != nil { + if realClientIP, err := p.GetRealClientIP(req.Header); err != nil { + logger.Printf("Unable to get real client IP: %v", err) + } else if realClientIP != nil { + realClientIPStr = realClientIP.String() + } + } + + var remoteIPStr string + if remoteIP, err := getRemoteIP(req); err == nil { + remoteIPStr = remoteIP.String() + } else { + // Should not happen, if it does, likely a bug. + logger.Printf("Unable to get remote IP(?!?!): %v", err) + } + + if !full && realClientIPStr != "" { + return realClientIPStr + } + if full && realClientIPStr != "" { + return fmt.Sprintf("%s (%s)", remoteIPStr, realClientIPStr) + } + return remoteIPStr +} diff --git a/realclientip_test.go b/realclientip_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0271e2e331c344119df37b5ee47a2f6d4ef20196 --- /dev/null +++ b/realclientip_test.go @@ -0,0 +1,176 @@ +package main + +import ( + "net" + "net/http" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetRealClientIPParser(t *testing.T) { + forwardedForType := reflect.TypeOf((*xForwardedForClientIPParser)(nil)) + + tests := []struct { + header string + errString string + parserType reflect.Type + }{ + {"X-Forwarded-For", "", forwardedForType}, + {"X-REAL-IP", "", forwardedForType}, + {"x-proxyuser-ip", "", forwardedForType}, + {"", "the http header key () is either invalid or unsupported", nil}, + {"Forwarded", "the http header key (Forwarded) is either invalid or unsupported", nil}, + {"2#* @##$$:kd", "the http header key (2#* @##$$:kd) is either invalid or unsupported", nil}, + } + + for _, test := range tests { + p, err := getRealClientIPParser(test.header) + + if test.errString == "" { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + assert.Equal(t, test.errString, err.Error()) + } + + if test.parserType == nil { + assert.Nil(t, p) + } else { + assert.NotNil(t, p) + assert.Equal(t, test.parserType, reflect.TypeOf(p)) + } + + if xp, ok := p.(*xForwardedForClientIPParser); ok { + assert.Equal(t, http.CanonicalHeaderKey(test.header), xp.header) + } + } +} + +func TestXForwardedForClientIPParser(t *testing.T) { + p := &xForwardedForClientIPParser{header: http.CanonicalHeaderKey("X-Forwarded-For")} + + tests := []struct { + headerValue string + errString string + expectedIP net.IP + }{ + {"", "", nil}, + {"1.2.3.4", "", net.ParseIP("1.2.3.4")}, + {"10::23", "", net.ParseIP("10::23")}, + {"::1", "", net.ParseIP("::1")}, + {"[::1]:1234", "", net.ParseIP("::1")}, + {"10.0.10.11:1234", "", net.ParseIP("10.0.10.11")}, + {"192.168.10.50, 10.0.0.1, 1.2.3.4", "", net.ParseIP("192.168.10.50")}, + {"nil", "unable to parse ip (nil) from X-Forwarded-For header", nil}, + {"10000.10000.10000.10000", "unable to parse ip (10000.10000.10000.10000) from X-Forwarded-For header", nil}, + } + + for _, test := range tests { + h := http.Header{} + h.Add("X-Forwarded-For", test.headerValue) + + ip, err := p.GetRealClientIP(h) + + if test.errString == "" { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + assert.Equal(t, test.errString, err.Error()) + } + + if test.expectedIP == nil { + assert.Nil(t, ip) + } else { + assert.NotNil(t, ip) + assert.Equal(t, test.expectedIP, ip) + } + } +} + +func TestXForwardedForClientIPParserIgnoresOthers(t *testing.T) { + p := &xForwardedForClientIPParser{header: http.CanonicalHeaderKey("X-Forwarded-For")} + + h := http.Header{} + expectedIPString := "192.168.10.50" + h.Add("X-Real-IP", "10.0.0.1") + h.Add("X-ProxyUser-IP", "10.0.0.1") + h.Add("X-Forwarded-For", expectedIPString) + ip, err := p.GetRealClientIP(h) + assert.Nil(t, err) + assert.NotNil(t, ip) + assert.Equal(t, ip, net.ParseIP(expectedIPString)) +} + +func TestGetRemoteIP(t *testing.T) { + tests := []struct { + remoteAddr string + errString string + expectedIP net.IP + }{ + {"", "unable to get ip and port from http.RemoteAddr ()", nil}, + {"nil", "unable to get ip and port from http.RemoteAddr (nil)", nil}, + {"235.28.129.186", "unable to get ip and port from http.RemoteAddr (235.28.129.186)", nil}, + {"90::45", "unable to get ip and port from http.RemoteAddr (90::45)", nil}, + {"192.168.73.165:14976, 10.4.201.15:18453", "unable to get ip and port from http.RemoteAddr (192.168.73.165:14976, 10.4.201.15:18453)", nil}, + {"10000.10000.10000.10000:8080", "unable to parse ip (10000.10000.10000.10000)", nil}, + {"[::1]:48290", "", net.ParseIP("::1")}, + {"10.254.244.165:62750", "", net.ParseIP("10.254.244.165")}, + } + + for _, test := range tests { + req := &http.Request{RemoteAddr: test.remoteAddr} + + ip, err := getRemoteIP(req) + + if test.errString == "" { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + assert.Equal(t, test.errString, err.Error()) + } + + if test.expectedIP == nil { + assert.Nil(t, ip) + } else { + assert.NotNil(t, ip) + assert.Equal(t, test.expectedIP, ip) + } + } +} + +func TestGetClientString(t *testing.T) { + p := &xForwardedForClientIPParser{header: http.CanonicalHeaderKey("X-Forwarded-For")} + + tests := []struct { + parser realClientIPParser + remoteAddr string + headerValue string + expectedClient string + expectedClientFull string + }{ + // Should fail quietly, only printing warnings to the log + {nil, "", "", "", ""}, + {p, "127.0.0.1:11950", "", "127.0.0.1", "127.0.0.1"}, + {p, "[::1]:28660", "99.103.56.12", "99.103.56.12", "::1 (99.103.56.12)"}, + {nil, "10.254.244.165:62750", "", "10.254.244.165", "10.254.244.165"}, + // Parser is nil, the contents of X-Forwarded-For should be ignored in all cases. + {nil, "[2001:470:26:307:a5a1:1177:2ae3:e9c3]:48290", "127.0.0.1", "2001:470:26:307:a5a1:1177:2ae3:e9c3", "2001:470:26:307:a5a1:1177:2ae3:e9c3"}, + } + + for _, test := range tests { + h := http.Header{} + h.Add("X-Forwarded-For", test.headerValue) + req := &http.Request{ + Header: h, + RemoteAddr: test.remoteAddr, + } + + client := getClientString(test.parser, req, false) + assert.Equal(t, test.expectedClient, client) + + clientFull := getClientString(test.parser, req, true) + assert.Equal(t, test.expectedClientFull, clientFull) + } +}