From 111d17efdecd4f716bc49b3ba8755ff2096d8c2c Mon Sep 17 00:00:00 2001
From: Isabelle COWAN-BERGMAN <Izzette@users.noreply.github.com>
Date: Tue, 12 May 2020 19:41:25 +0200
Subject: [PATCH] Implements --real-client-ip-header option. (#503)

* Implements -real-client-ip-header option.

* The -real-client-ip-header determines what HTTP header is used for
  determining the "real client IP" of the remote client.
* The -real-client-ip-header option supports the following headers:
  X-Forwarded-For X-ProxyUser-IP and X-Real-IP (default).
* Introduces new realClientIPParser interface to allow for multiple
  polymorphic classes to decide how to determine the real client IP.
* TODO: implement the more standard, but more complex `Forwarded` HTTP
  header.

* Corrected order of expected/actual in test cases

* Improved error message in getRemoteIP

* Add tests for getRemoteIP and getClientString

* Add comment explaining splitting of header

* Update documentation on -real-client-ip-header w/o -reverse-proxy

* Add PR number in changelog.

* Fix typo repeated word: "it"

Co-Authored-By: Joel Speed <Joel.speed@hotmail.co.uk>

* Update extended configuration language

* Simplify the language around dependance on -reverse-proxy

Co-Authored-By: Joel Speed <Joel.speed@hotmail.co.uk>

* Added completions

* Reorder real client IP header options

* Update CHANGELOG.md

* Apply suggestions from code review

Co-authored-by: Isabelle COWAN-BERGMAN <Izzette@users.noreply.github.com>

Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk>
Co-authored-by: Henry Jenkins <henry@henryjenkins.name>
---
 CHANGELOG.md                         |   1 +
 contrib/oauth2-proxy_autocomplete.sh |   4 +
 docs/configuration/configuration.md  |   1 +
 main.go                              |   1 +
 oauthproxy.go                        |  14 +--
 options.go                           |  39 +++---
 options_test.go                      |  43 +++++++
 pkg/logger/logger.go                 |  42 +++----
 realclientip.go                      | 102 ++++++++++++++++
 realclientip_test.go                 | 176 +++++++++++++++++++++++++++
 10 files changed, 371 insertions(+), 52 deletions(-)
 create mode 100644 realclientip.go
 create mode 100644 realclientip_test.go

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 401e0e5..91a67ac 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -35,6 +35,7 @@
 
 ## Changes since v5.1.1
 
+- [#503](https://github.com/oauth2-proxy/oauth2-proxy/pull/503) Implements --real-client-ip-header option to select the header from which to obtain a proxied client's IP (@Izzette)
 - [#529](https://github.com/oauth2-proxy/oauth2-proxy/pull/529) Add local test environments for testing changes and new features (@JoelSpeed)
 - [#537](https://github.com/oauth2-proxy/oauth2-proxy/pull/537) Drop Fallback to Email if User not set (@JoelSpeed)
 - [#535](https://github.com/oauth2-proxy/oauth2-proxy/pull/535) Drop support for pre v3.1 cookies (@JoelSpeed)
diff --git a/contrib/oauth2-proxy_autocomplete.sh b/contrib/oauth2-proxy_autocomplete.sh
index bea798b..d4eaa1b 100644
--- a/contrib/oauth2-proxy_autocomplete.sh
+++ b/contrib/oauth2-proxy_autocomplete.sh
@@ -20,6 +20,10 @@ _oauth2_proxy() {
 			COMPREPLY=( $(compgen -W "google azure facebook github keycloak gitlab linkedin login.gov digitalocean" -- ${cur}) )
 			return 0
 			;;
+		--real-client-ip-header)
+			COMPREPLY=( $(compgen -W 'X-Real-IP X-Forwarded-For X-ProxyUser-IP' -- ${cur}) )
+			return 0
+			;;
 		-@(http-address|https-address|redirect-url|upstream|basic-auth-password|skip-auth-regex|flush-interval|extra-jwt-issuers|email-domain|whitelist-domain|keycloak-group|azure-tenant|bitbucket-team|bitbucket-repository|github-org|github-team|github-repo|github-token|gitlab-group|google-group|google-admin-email|google-service-account-json|client-id|client_secret|banner|footer|proxy-prefix|ping-path|cookie-name|cookie-secret|cookie-domain|cookie-path|cookie-expire|cookie-refresh|cookie-samesite|redist-sentinel-master-name|redist-sentinel-connection-urls|redist-cluster-connection-urls|logging-max-size|logging-max-age|logging-max-backups|standard-logging-format|request-logging-format|exclude-logging-paths|auth-logging-format|oidc-issuer-url|oidc-jwks-url|login-url|redeem-url|profile-url|resource|validate-url|scope|approval-prompt|signature-key|acr-values|jwt-key|pubjwk-url))
 			return 0
 			;;
diff --git a/docs/configuration/configuration.md b/docs/configuration/configuration.md
index 90d26e6..69b3cfa 100644
--- a/docs/configuration/configuration.md
+++ b/docs/configuration/configuration.md
@@ -90,6 +90,7 @@ An example [oauth2-proxy.cfg]({{ site.gitweb }}/contrib/oauth2-proxy.cfg.example
 | `--proxy-prefix` | string | the url root path that this proxy should be nested under (e.g. /`<oauth2>/sign_in`) | `"/oauth2"` |
 | `--proxy-websockets` | bool | enables WebSocket proxying | true |
 | `--pubjwk-url` | string | JWK pubkey access endpoint: required by login.gov | |
+| `--real-client-ip-header` | string | Header used to determine the real IP of the client, requires `--reverse-proxy` to be set (one of: X-Forwarded-For, X-Real-IP, or X-ProxyUser-IP) | X-Real-IP |
 | `--redeem-url` | string | Token redemption endpoint | |
 | `--redirect-url` | string | the OAuth Redirect URL. ie: `"https://internalapp.yourcompany.com/oauth2/callback"` | |
 | `--redis-cluster-connection-urls` | string \| list | List of Redis cluster connection URLs (eg redis://HOST[:PORT]). Used in conjunction with `--redis-use-cluster` | |
diff --git a/main.go b/main.go
index 049b6ae..4507d8e 100644
--- a/main.go
+++ b/main.go
@@ -26,6 +26,7 @@ func main() {
 	flagSet.String("http-address", "127.0.0.1:4180", "[http://]<addr>:<port> or unix://<path> to listen on for HTTP clients")
 	flagSet.String("https-address", ":443", "<addr>:<port> to listen on for HTTPS clients")
 	flagSet.Bool("reverse-proxy", false, "are we running behind a reverse proxy, controls whether headers like X-Real-Ip are accepted")
+	flagSet.String("real-client-ip-header", "X-Real-IP", "Header used to determine the real IP of the client (one of: X-Forwarded-For, X-Real-IP, or X-ProxyUser-IP)")
 	flagSet.Bool("force-https", false, "force HTTPS redirect for HTTP requests")
 	flagSet.String("tls-cert-file", "", "path to certificate file")
 	flagSet.String("tls-key-file", "", "path to private key file")
diff --git a/oauthproxy.go b/oauthproxy.go
index 4293423..6a4d699 100644
--- a/oauthproxy.go
+++ b/oauthproxy.go
@@ -112,6 +112,7 @@ type OAuthProxy struct {
 	jwtBearerVerifiers   []*oidc.IDTokenVerifier
 	compiledRegex        []*regexp.Regexp
 	templates            *template.Template
+	realClientIPParser   realClientIPParser
 	Banner               string
 	Footer               string
 }
@@ -308,6 +309,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
 		skipJwtBearerTokens:  opts.SkipJwtBearerTokens,
 		jwtBearerVerifiers:   opts.jwtBearerVerifiers,
 		compiledRegex:        opts.compiledRegex,
+		realClientIPParser:   opts.realClientIPParser,
 		SetXAuthRequest:      opts.SetXAuthRequest,
 		PassBasicAuth:        opts.PassBasicAuth,
 		SetBasicAuth:         opts.SetBasicAuth,
@@ -636,14 +638,6 @@ func (p *OAuthProxy) IsWhitelistedPath(path string) bool {
 	return false
 }
 
-func getRemoteAddr(req *http.Request) (s string) {
-	s = req.RemoteAddr
-	if req.Header.Get("X-Real-IP") != "" {
-		s += fmt.Sprintf(" (%q)", req.Header.Get("X-Real-IP"))
-	}
-	return
-}
-
 // See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
 var noCacheHeaders = map[string]string{
 	"Expires":         time.Unix(0, 0).Format(time.RFC1123),
@@ -766,7 +760,7 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
 // OAuthCallback is the OAuth2 authentication flow callback that finishes the
 // OAuth2 authentication flow
 func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
-	remoteAddr := getRemoteAddr(req)
+	remoteAddr := getClientString(p.realClientIPParser, req, true)
 
 	// finish the oauth cycle
 	err := req.ParseForm()
@@ -894,7 +888,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
 		}
 	}
 
-	remoteAddr := getRemoteAddr(req)
+	remoteAddr := getClientString(p.realClientIPParser, req, true)
 	if session == nil {
 		session, err = p.LoadCookiedSession(req)
 		if err != nil {
diff --git a/options.go b/options.go
index 80b1ff5..7220da8 100644
--- a/options.go
+++ b/options.go
@@ -31,19 +31,20 @@ import (
 // Options holds Configuration Options that can be set by Command Line Flag,
 // or Config File
 type Options struct {
-	ProxyPrefix      string `flag:"proxy-prefix" cfg:"proxy_prefix" env:"OAUTH2_PROXY_PROXY_PREFIX"`
-	PingPath         string `flag:"ping-path" cfg:"ping_path" env:"OAUTH2_PROXY_PING_PATH"`
-	ProxyWebSockets  bool   `flag:"proxy-websockets" cfg:"proxy_websockets" env:"OAUTH2_PROXY_PROXY_WEBSOCKETS"`
-	HTTPAddress      string `flag:"http-address" cfg:"http_address" env:"OAUTH2_PROXY_HTTP_ADDRESS"`
-	HTTPSAddress     string `flag:"https-address" cfg:"https_address" env:"OAUTH2_PROXY_HTTPS_ADDRESS"`
-	ReverseProxy     bool   `flag:"reverse-proxy" cfg:"reverse_proxy" env:"OAUTH2_PROXY_REVERSE_PROXY"`
-	ForceHTTPS       bool   `flag:"force-https" cfg:"force_https" env:"OAUTH2_PROXY_FORCE_HTTPS"`
-	RedirectURL      string `flag:"redirect-url" cfg:"redirect_url" env:"OAUTH2_PROXY_REDIRECT_URL"`
-	ClientID         string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"`
-	ClientSecret     string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"`
-	ClientSecretFile string `flag:"client-secret-file" cfg:"client_secret_file" env:"OAUTH2_PROXY_CLIENT_SECRET_FILE"`
-	TLSCertFile      string `flag:"tls-cert-file" cfg:"tls_cert_file" env:"OAUTH2_PROXY_TLS_CERT_FILE"`
-	TLSKeyFile       string `flag:"tls-key-file" cfg:"tls_key_file" env:"OAUTH2_PROXY_TLS_KEY_FILE"`
+	ProxyPrefix        string `flag:"proxy-prefix" cfg:"proxy_prefix" env:"OAUTH2_PROXY_PROXY_PREFIX"`
+	PingPath           string `flag:"ping-path" cfg:"ping_path" env:"OAUTH2_PROXY_PING_PATH"`
+	ProxyWebSockets    bool   `flag:"proxy-websockets" cfg:"proxy_websockets" env:"OAUTH2_PROXY_PROXY_WEBSOCKETS"`
+	HTTPAddress        string `flag:"http-address" cfg:"http_address" env:"OAUTH2_PROXY_HTTP_ADDRESS"`
+	HTTPSAddress       string `flag:"https-address" cfg:"https_address" env:"OAUTH2_PROXY_HTTPS_ADDRESS"`
+	ReverseProxy       bool   `flag:"reverse-proxy" cfg:"reverse_proxy" env:"OAUTH2_PROXY_REVERSE_PROXY"`
+	RealClientIPHeader string `flag:"real-client-ip-header" cfg:"real_client_ip_header" env:"OAUTH2_PROXY_REAL_CLIENT_IP_HEADER"`
+	ForceHTTPS         bool   `flag:"force-https" cfg:"force_https" env:"OAUTH2_PROXY_FORCE_HTTPS"`
+	RedirectURL        string `flag:"redirect-url" cfg:"redirect_url" env:"OAUTH2_PROXY_REDIRECT_URL"`
+	ClientID           string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"`
+	ClientSecret       string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"`
+	ClientSecretFile   string `flag:"client-secret-file" cfg:"client_secret_file" env:"OAUTH2_PROXY_CLIENT_SECRET_FILE"`
+	TLSCertFile        string `flag:"tls-cert-file" cfg:"tls_cert_file" env:"OAUTH2_PROXY_TLS_CERT_FILE"`
+	TLSKeyFile         string `flag:"tls-key-file" cfg:"tls_key_file" env:"OAUTH2_PROXY_TLS_KEY_FILE"`
 
 	AuthenticatedEmailsFile  string   `flag:"authenticated-emails-file" cfg:"authenticated_emails_file" env:"OAUTH2_PROXY_AUTHENTICATED_EMAILS_FILE"`
 	KeycloakGroup            string   `flag:"keycloak-group" cfg:"keycloak_group" env:"OAUTH2_PROXY_KEYCLOAK_GROUP"`
@@ -139,6 +140,7 @@ type Options struct {
 	signatureData      *SignatureData
 	oidcVerifier       *oidc.IDTokenVerifier
 	jwtBearerVerifiers []*oidc.IDTokenVerifier
+	realClientIPParser realClientIPParser
 }
 
 // SignatureData holds hmacauth signature hash and key
@@ -456,6 +458,13 @@ func (o *Options) Validate() error {
 	msgs = validateCookieName(o, msgs)
 	msgs = setupLogger(o, msgs)
 
+	if o.ReverseProxy {
+		o.realClientIPParser, err = getRealClientIPParser(o.RealClientIPHeader)
+		if err != nil {
+			msgs = append(msgs, fmt.Sprintf("real_client_ip_header (%s) not accepted parameter value: %v", o.RealClientIPHeader, err))
+		}
+	}
+
 	if len(msgs) != 0 {
 		return fmt.Errorf("invalid configuration:\n  %s",
 			strings.Join(msgs, "\n  "))
@@ -695,7 +704,9 @@ func setupLogger(o *Options, msgs []string) []string {
 	logger.SetStandardTemplate(o.StandardLoggingFormat)
 	logger.SetAuthTemplate(o.AuthLoggingFormat)
 	logger.SetReqTemplate(o.RequestLoggingFormat)
-	logger.SetReverseProxy(o.ReverseProxy)
+	logger.SetGetClientFunc(func(r *http.Request) string {
+		return getClientString(o.realClientIPParser, r, false)
+	})
 
 	excludePaths := make([]string, 0)
 	excludePaths = append(excludePaths, strings.Split(o.ExcludeLoggingPaths, ",")...)
diff --git a/options_test.go b/options_test.go
index b14c9e4..ae978b6 100644
--- a/options_test.go
+++ b/options_test.go
@@ -326,3 +326,46 @@ func TestGCPHealthcheck(t *testing.T) {
 	o.GCPHealthChecks = true
 	assert.Equal(t, nil, o.Validate())
 }
+
+func TestRealClientIPHeader(t *testing.T) {
+	var o *Options
+	var err error
+	var expected string
+
+	// Ensure nil if ReverseProxy not set.
+	o = testOptions()
+	o.RealClientIPHeader = "X-Real-IP"
+	assert.Equal(t, nil, o.Validate())
+	assert.Nil(t, o.realClientIPParser)
+
+	// Ensure simple use case works.
+	o = testOptions()
+	o.ReverseProxy = true
+	o.RealClientIPHeader = "X-Forwarded-For"
+	assert.Equal(t, nil, o.Validate())
+	assert.NotNil(t, o.realClientIPParser)
+
+	// Ensure unknown header format process an error.
+	o = testOptions()
+	o.ReverseProxy = true
+	o.RealClientIPHeader = "Forwarded"
+	err = o.Validate()
+	assert.NotEqual(t, nil, err)
+	expected = errorMsg([]string{
+		"real_client_ip_header (Forwarded) not accepted parameter value: the http header key (Forwarded) is either invalid or unsupported",
+	})
+	assert.Equal(t, expected, err.Error())
+	assert.Nil(t, o.realClientIPParser)
+
+	// Ensure invalid header format produces an error.
+	o = testOptions()
+	o.ReverseProxy = true
+	o.RealClientIPHeader = "!934invalidheader-23:"
+	err = o.Validate()
+	assert.NotEqual(t, nil, err)
+	expected = errorMsg([]string{
+		"real_client_ip_header (!934invalidheader-23:) not accepted parameter value: the http header key (!934invalidheader-23:) is either invalid or unsupported",
+	})
+	assert.Equal(t, expected, err.Error())
+	assert.Nil(t, o.realClientIPParser)
+}
diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go
index 0664f84..9bfc2b3 100644
--- a/pkg/logger/logger.go
+++ b/pkg/logger/logger.go
@@ -3,7 +3,6 @@ package logger
 import (
 	"fmt"
 	"io"
-	"net"
 	"net/http"
 	"net/url"
 	"os"
@@ -76,6 +75,9 @@ type reqLogMessageData struct {
 	Username string
 }
 
+// Returns the apparent "real client IP" as a string.
+type GetClientFunc = func(r *http.Request) string
+
 // A Logger represents an active logging object that generates lines of
 // output to an io.Writer passed through a formatter. Each logging
 // operation makes a single call to the Writer's Write method. A Logger
@@ -88,7 +90,7 @@ type Logger struct {
 	stdEnabled     bool
 	authEnabled    bool
 	reqEnabled     bool
-	reverseProxy   bool
+	getClientFunc  GetClientFunc
 	excludePaths   map[string]struct{}
 	stdLogTemplate *template.Template
 	authTemplate   *template.Template
@@ -103,7 +105,7 @@ func New(flag int) *Logger {
 		stdEnabled:     true,
 		authEnabled:    true,
 		reqEnabled:     true,
-		reverseProxy:   false,
+		getClientFunc:  func(r *http.Request) string { return r.RemoteAddr },
 		excludePaths:   nil,
 		stdLogTemplate: template.Must(template.New("std-log").Parse(DefaultStandardLoggingFormat)),
 		authTemplate:   template.Must(template.New("auth-log").Parse(DefaultAuthLoggingFormat)),
@@ -153,7 +155,7 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu
 		username = "-"
 	}
 
-	client := GetClient(req, l.reverseProxy)
+	client := l.getClientFunc(req)
 
 	l.mu.Lock()
 	defer l.mu.Unlock()
@@ -201,7 +203,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url.
 		}
 	}
 
-	client := GetClient(req, l.reverseProxy)
+	client := l.getClientFunc(req)
 
 	l.mu.Lock()
 	defer l.mu.Unlock()
@@ -252,22 +254,6 @@ func (l *Logger) GetFileLineString(calldepth int) string {
 	return fmt.Sprintf("%s:%d", file, line)
 }
 
-// GetClient parses an HTTP request for the client/remote IP address.
-func GetClient(req *http.Request, reverseProxy bool) string {
-	client := req.RemoteAddr
-	if reverseProxy {
-		if ip := req.Header.Get("X-Real-IP"); ip != "" {
-			client = ip
-		}
-	}
-
-	if c, _, err := net.SplitHostPort(client); err == nil {
-		client = c
-	}
-
-	return client
-}
-
 // FormatTimestamp returns a formatted timestamp.
 func (l *Logger) FormatTimestamp(ts time.Time) string {
 	if l.flag&LUTC != 0 {
@@ -312,11 +298,11 @@ func (l *Logger) SetReqEnabled(e bool) {
 	l.reqEnabled = e
 }
 
-// SetReverseProxy controls whether logging will trust headers that can be set by a reverse proxy.
-func (l *Logger) SetReverseProxy(e bool) {
+// SetGetClientFunc sets the function which determines the apparent "real client IP".
+func (l *Logger) SetGetClientFunc(f GetClientFunc) {
 	l.mu.Lock()
 	defer l.mu.Unlock()
-	l.reverseProxy = e
+	l.getClientFunc = f
 }
 
 // SetExcludePaths sets the paths to exclude from logging.
@@ -392,10 +378,10 @@ func SetReqEnabled(e bool) {
 	std.SetReqEnabled(e)
 }
 
-// SetReverseProxy controls whether logging will trust headers that can be set
-// by a reverse proxy for the standard logger.
-func SetReverseProxy(e bool) {
-	std.SetReverseProxy(e)
+// SetGetClientFunc sets the function which determines the apparent IP address
+// set by a reverse proxy for the standard logger.
+func SetGetClientFunc(f GetClientFunc) {
+	std.SetGetClientFunc(f)
 }
 
 // SetExcludePaths sets the path to exclude from logging, eg: health checks
diff --git a/realclientip.go b/realclientip.go
new file mode 100644
index 0000000..b45ef7c
--- /dev/null
+++ b/realclientip.go
@@ -0,0 +1,102 @@
+package main
+
+import (
+	"fmt"
+	"net"
+	"net/http"
+	"strings"
+
+	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
+)
+
+type realClientIPParser interface {
+	GetRealClientIP(http.Header) (net.IP, error)
+}
+
+func getRealClientIPParser(headerKey string) (realClientIPParser, error) {
+	headerKey = http.CanonicalHeaderKey(headerKey)
+
+	switch headerKey {
+	case http.CanonicalHeaderKey("X-Forwarded-For"), http.CanonicalHeaderKey("X-Real-IP"), http.CanonicalHeaderKey("X-ProxyUser-IP"):
+		return &xForwardedForClientIPParser{header: headerKey}, nil
+	}
+
+	// TODO: implement the more standardized but more complex `Forwarded` header.
+	return nil, fmt.Errorf("the http header key (%s) is either invalid or unsupported", headerKey)
+}
+
+type xForwardedForClientIPParser struct {
+	header string
+}
+
+// GetRealClientIP obtain the IP address of the end-user (not proxy).
+// Parses headers sharing the format as specified by:
+// * https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For.
+// Returns the `<client>` portion specified in the above document.
+// Additionally, is capable of parsing IPs with the port included, for v4 in the format "<ip>:<port>" and for v6 in the
+// format "[<ip>]:<port>".  With-port and without-port formats are seamlessly supported concurrently.
+func (p xForwardedForClientIPParser) GetRealClientIP(h http.Header) (net.IP, error) {
+	var ipStr string
+	if realIP := h.Get(p.header); realIP != "" {
+		ipStr = realIP
+	} else {
+		return nil, nil
+	}
+
+	// Each successive proxy may append itself, comma separated, to the end of the X-Forwarded-for header.
+	// Select only the first IP listed, as it is the client IP recorded by the first proxy.
+	if commaIndex := strings.IndexRune(ipStr, ','); commaIndex != -1 {
+		ipStr = ipStr[:commaIndex]
+	}
+	ipStr = strings.TrimSpace(ipStr)
+
+	if ipHost, _, err := net.SplitHostPort(ipStr); err == nil {
+		ipStr = ipHost
+	}
+
+	ip := net.ParseIP(ipStr)
+	if ip == nil {
+		return nil, fmt.Errorf("unable to parse ip (%s) from %s header", ipStr, http.CanonicalHeaderKey(p.header))
+	}
+
+	return ip, nil
+}
+
+// getRemoteIP obtains the IP of the low-level connected network host
+func getRemoteIP(req *http.Request) (net.IP, error) {
+	if ipStr, _, err := net.SplitHostPort(req.RemoteAddr); err != nil {
+		return nil, fmt.Errorf("unable to get ip and port from http.RemoteAddr (%s)", req.RemoteAddr)
+	} else if ip := net.ParseIP(ipStr); ip != nil {
+		return ip, nil
+	} else {
+		return nil, fmt.Errorf("unable to parse ip (%s)", ipStr)
+	}
+}
+
+// getClientString obtains the human readable string of the remote IP and optionally the real client IP if available
+func getClientString(p realClientIPParser, req *http.Request, full bool) (s string) {
+	var realClientIPStr string
+	if p != nil {
+		if realClientIP, err := p.GetRealClientIP(req.Header); err != nil {
+			logger.Printf("Unable to get real client IP: %v", err)
+		} else if realClientIP != nil {
+			realClientIPStr = realClientIP.String()
+		}
+	}
+
+	var remoteIPStr string
+	if remoteIP, err := getRemoteIP(req); err == nil {
+		remoteIPStr = remoteIP.String()
+	} else {
+		// Should not happen, if it does, likely a bug.
+		logger.Printf("Unable to get remote IP(?!?!): %v", err)
+	}
+
+	if !full && realClientIPStr != "" {
+		return realClientIPStr
+	}
+	if full && realClientIPStr != "" {
+		return fmt.Sprintf("%s (%s)", remoteIPStr, realClientIPStr)
+	}
+	return remoteIPStr
+}
diff --git a/realclientip_test.go b/realclientip_test.go
new file mode 100644
index 0000000..0271e2e
--- /dev/null
+++ b/realclientip_test.go
@@ -0,0 +1,176 @@
+package main
+
+import (
+	"net"
+	"net/http"
+	"reflect"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestGetRealClientIPParser(t *testing.T) {
+	forwardedForType := reflect.TypeOf((*xForwardedForClientIPParser)(nil))
+
+	tests := []struct {
+		header     string
+		errString  string
+		parserType reflect.Type
+	}{
+		{"X-Forwarded-For", "", forwardedForType},
+		{"X-REAL-IP", "", forwardedForType},
+		{"x-proxyuser-ip", "", forwardedForType},
+		{"", "the http header key () is either invalid or unsupported", nil},
+		{"Forwarded", "the http header key (Forwarded) is either invalid or unsupported", nil},
+		{"2#* @##$$:kd", "the http header key (2#* @##$$:kd) is either invalid or unsupported", nil},
+	}
+
+	for _, test := range tests {
+		p, err := getRealClientIPParser(test.header)
+
+		if test.errString == "" {
+			assert.Nil(t, err)
+		} else {
+			assert.NotNil(t, err)
+			assert.Equal(t, test.errString, err.Error())
+		}
+
+		if test.parserType == nil {
+			assert.Nil(t, p)
+		} else {
+			assert.NotNil(t, p)
+			assert.Equal(t, test.parserType, reflect.TypeOf(p))
+		}
+
+		if xp, ok := p.(*xForwardedForClientIPParser); ok {
+			assert.Equal(t, http.CanonicalHeaderKey(test.header), xp.header)
+		}
+	}
+}
+
+func TestXForwardedForClientIPParser(t *testing.T) {
+	p := &xForwardedForClientIPParser{header: http.CanonicalHeaderKey("X-Forwarded-For")}
+
+	tests := []struct {
+		headerValue string
+		errString   string
+		expectedIP  net.IP
+	}{
+		{"", "", nil},
+		{"1.2.3.4", "", net.ParseIP("1.2.3.4")},
+		{"10::23", "", net.ParseIP("10::23")},
+		{"::1", "", net.ParseIP("::1")},
+		{"[::1]:1234", "", net.ParseIP("::1")},
+		{"10.0.10.11:1234", "", net.ParseIP("10.0.10.11")},
+		{"192.168.10.50, 10.0.0.1, 1.2.3.4", "", net.ParseIP("192.168.10.50")},
+		{"nil", "unable to parse ip (nil) from X-Forwarded-For header", nil},
+		{"10000.10000.10000.10000", "unable to parse ip (10000.10000.10000.10000) from X-Forwarded-For header", nil},
+	}
+
+	for _, test := range tests {
+		h := http.Header{}
+		h.Add("X-Forwarded-For", test.headerValue)
+
+		ip, err := p.GetRealClientIP(h)
+
+		if test.errString == "" {
+			assert.Nil(t, err)
+		} else {
+			assert.NotNil(t, err)
+			assert.Equal(t, test.errString, err.Error())
+		}
+
+		if test.expectedIP == nil {
+			assert.Nil(t, ip)
+		} else {
+			assert.NotNil(t, ip)
+			assert.Equal(t, test.expectedIP, ip)
+		}
+	}
+}
+
+func TestXForwardedForClientIPParserIgnoresOthers(t *testing.T) {
+	p := &xForwardedForClientIPParser{header: http.CanonicalHeaderKey("X-Forwarded-For")}
+
+	h := http.Header{}
+	expectedIPString := "192.168.10.50"
+	h.Add("X-Real-IP", "10.0.0.1")
+	h.Add("X-ProxyUser-IP", "10.0.0.1")
+	h.Add("X-Forwarded-For", expectedIPString)
+	ip, err := p.GetRealClientIP(h)
+	assert.Nil(t, err)
+	assert.NotNil(t, ip)
+	assert.Equal(t, ip, net.ParseIP(expectedIPString))
+}
+
+func TestGetRemoteIP(t *testing.T) {
+	tests := []struct {
+		remoteAddr string
+		errString  string
+		expectedIP net.IP
+	}{
+		{"", "unable to get ip and port from http.RemoteAddr ()", nil},
+		{"nil", "unable to get ip and port from http.RemoteAddr (nil)", nil},
+		{"235.28.129.186", "unable to get ip and port from http.RemoteAddr (235.28.129.186)", nil},
+		{"90::45", "unable to get ip and port from http.RemoteAddr (90::45)", nil},
+		{"192.168.73.165:14976, 10.4.201.15:18453", "unable to get ip and port from http.RemoteAddr (192.168.73.165:14976, 10.4.201.15:18453)", nil},
+		{"10000.10000.10000.10000:8080", "unable to parse ip (10000.10000.10000.10000)", nil},
+		{"[::1]:48290", "", net.ParseIP("::1")},
+		{"10.254.244.165:62750", "", net.ParseIP("10.254.244.165")},
+	}
+
+	for _, test := range tests {
+		req := &http.Request{RemoteAddr: test.remoteAddr}
+
+		ip, err := getRemoteIP(req)
+
+		if test.errString == "" {
+			assert.Nil(t, err)
+		} else {
+			assert.NotNil(t, err)
+			assert.Equal(t, test.errString, err.Error())
+		}
+
+		if test.expectedIP == nil {
+			assert.Nil(t, ip)
+		} else {
+			assert.NotNil(t, ip)
+			assert.Equal(t, test.expectedIP, ip)
+		}
+	}
+}
+
+func TestGetClientString(t *testing.T) {
+	p := &xForwardedForClientIPParser{header: http.CanonicalHeaderKey("X-Forwarded-For")}
+
+	tests := []struct {
+		parser             realClientIPParser
+		remoteAddr         string
+		headerValue        string
+		expectedClient     string
+		expectedClientFull string
+	}{
+		// Should fail quietly, only printing warnings to the log
+		{nil, "", "", "", ""},
+		{p, "127.0.0.1:11950", "", "127.0.0.1", "127.0.0.1"},
+		{p, "[::1]:28660", "99.103.56.12", "99.103.56.12", "::1 (99.103.56.12)"},
+		{nil, "10.254.244.165:62750", "", "10.254.244.165", "10.254.244.165"},
+		// Parser is nil, the contents of X-Forwarded-For should be ignored in all cases.
+		{nil, "[2001:470:26:307:a5a1:1177:2ae3:e9c3]:48290", "127.0.0.1", "2001:470:26:307:a5a1:1177:2ae3:e9c3", "2001:470:26:307:a5a1:1177:2ae3:e9c3"},
+	}
+
+	for _, test := range tests {
+		h := http.Header{}
+		h.Add("X-Forwarded-For", test.headerValue)
+		req := &http.Request{
+			Header:     h,
+			RemoteAddr: test.remoteAddr,
+		}
+
+		client := getClientString(test.parser, req, false)
+		assert.Equal(t, test.expectedClient, client)
+
+		clientFull := getClientString(test.parser, req, true)
+		assert.Equal(t, test.expectedClientFull, clientFull)
+	}
+}
-- 
GitLab