fix(httpgetter): prevent DNS rebinding in link metadata fetch

pull/5947/head
boojack 3 weeks ago
parent 4a1e401bd9
commit 078488ca81

@ -1,6 +1,7 @@
package httpgetter
import (
"context"
"fmt"
"io"
"net"
@ -17,17 +18,112 @@ var ErrInternalIP = errors.New("internal IP addresses are not allowed")
const maxHTMLMetaBytes = 512 * 1024
var httpClient = &http.Client{
Timeout: 5 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if err := validateURL(req.URL.String()); err != nil {
return errors.Wrap(err, "redirect to internal IP")
var (
lookupIPAddr = net.DefaultResolver.LookupIPAddr
dialContext = (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext
httpClient = newHTTPClient()
)
func newHTTPClient() *http.Client {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.Proxy = nil
transport.DialContext = secureDialContext
return &http.Client{
Transport: transport,
Timeout: 5 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if err := validateURL(req.URL.String()); err != nil {
return errors.Wrap(err, "redirect to internal IP")
}
if len(via) >= 10 {
return errors.New("too many redirects")
}
return nil
},
}
}
func secureDialContext(ctx context.Context, network, address string) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, errors.Wrap(err, "invalid address")
}
ips, err := resolveAllowedIPs(ctx, host)
if err != nil {
return nil, err
}
var dialErr error
for _, ip := range ips {
conn, err := dialContext(ctx, network, net.JoinHostPort(ip.String(), port))
if err == nil {
return conn, nil
}
dialErr = err
}
return nil, dialErr
}
func resolveAllowedIPs(ctx context.Context, host string) ([]net.IP, error) {
if ip := net.ParseIP(host); ip != nil {
if isInternalIP(ip) {
return nil, errors.Wrap(ErrInternalIP, ip.String())
}
if len(via) >= 10 {
return errors.New("too many redirects")
return []net.IP{ip}, nil
}
addrs, err := lookupIPAddr(ctx, host)
if err != nil {
return nil, errors.Errorf("failed to resolve hostname: %v", err)
}
ips := make([]net.IP, 0, len(addrs))
for _, addr := range addrs {
ip := addr.IP
if ip == nil {
continue
}
return nil
},
if isInternalIP(ip) {
return nil, errors.Wrapf(ErrInternalIP, "host=%s, ip=%s", host, ip.String())
}
ips = append(ips, ip)
}
if len(ips) == 0 {
return nil, errors.New("hostname resolved to no addresses")
}
return ips, nil
}
func isInternalIP(ip net.IP) bool {
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsUnspecified()
}
func validateURL(urlStr string) error {
u, err := url.Parse(urlStr)
if err != nil {
return errors.New("invalid URL format")
}
if u.Scheme != "http" && u.Scheme != "https" {
return errors.New("only http/https protocols are allowed")
}
host := u.Hostname()
if host == "" {
return errors.New("empty hostname")
}
if ip := net.ParseIP(host); ip != nil && isInternalIP(ip) {
return errors.Wrap(ErrInternalIP, ip.String())
}
return nil
}
type HTMLMeta struct {
@ -118,44 +214,6 @@ func extractMetaProperty(token html.Token, prop string) (content string, ok bool
return content, ok
}
func validateURL(urlStr string) error {
u, err := url.Parse(urlStr)
if err != nil {
return errors.New("invalid URL format")
}
if u.Scheme != "http" && u.Scheme != "https" {
return errors.New("only http/https protocols are allowed")
}
host := u.Hostname()
if host == "" {
return errors.New("empty hostname")
}
// check if the hostname is an IP
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() {
return errors.Wrap(ErrInternalIP, ip.String())
}
return nil
}
// check if it's a hostname, resolve it and check all returned IPs
ips, err := net.LookupIP(host)
if err != nil {
return errors.Errorf("failed to resolve hostname: %v", err)
}
for _, ip := range ips {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() {
return errors.Wrapf(ErrInternalIP, "host=%s, ip=%s", host, ip.String())
}
}
return nil
}
func enrichSiteMeta(url *url.URL, meta *HTMLMeta) {
if url.Hostname() == "www.youtube.com" {
if url.Path == "/watch" {

@ -1,8 +1,10 @@
package httpgetter
import (
"context"
"errors"
"io"
"net"
"net/http"
"strings"
"testing"
@ -68,3 +70,52 @@ func TestGetHTMLMetaForInternal(t *testing.T) {
func TestHTTPClientHasTimeout(t *testing.T) {
require.NotZero(t, httpClient.Timeout)
}
func TestSecureDialContextRejectsResolvedInternalIP(t *testing.T) {
originalLookupIPAddr := lookupIPAddr
originalDialContext := dialContext
t.Cleanup(func() {
lookupIPAddr = originalLookupIPAddr
dialContext = originalDialContext
})
lookupIPAddr = func(context.Context, string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP("127.0.0.1")}}, nil
}
dialContext = func(context.Context, string, string) (net.Conn, error) {
t.Fatal("internal IP should be rejected before dialing")
return nil, nil
}
_, err := secureDialContext(context.Background(), "tcp", "rebind.example:80")
require.ErrorIs(t, err, ErrInternalIP)
}
func TestSecureDialContextDialsResolvedIP(t *testing.T) {
originalLookupIPAddr := lookupIPAddr
originalDialContext := dialContext
t.Cleanup(func() {
lookupIPAddr = originalLookupIPAddr
dialContext = originalDialContext
})
lookupIPAddr = func(context.Context, string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
}
var dialedAddress string
dialContext = func(_ context.Context, _ string, address string) (net.Conn, error) {
dialedAddress = address
clientConn, serverConn := net.Pipe()
t.Cleanup(func() {
clientConn.Close()
serverConn.Close()
})
return clientConn, nil
}
conn, err := secureDialContext(context.Background(), "tcp", "rebind.example:80")
require.NoError(t, err)
require.NotNil(t, conn)
require.Equal(t, "93.184.216.34:80", dialedAddress)
}

Loading…
Cancel
Save