Pull request: 2305 limit message size

Merge in DNS/adguard-home from 2305-limit-message-size to master

Closes #2305.

Squashed commit of the following:

commit 6edd1e0521277a680f0053308efcf3d9cacc8e62
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon Nov 23 14:03:36 2020 +0300

    aghio: fix final inaccuracies

commit 4dd382aaf25132b31eb269749a2cd36daf0cb792
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon Nov 23 13:59:10 2020 +0300

    all: improve code quality

commit 060f923f6023d0e6f26441559b7023d5e5f96843
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon Nov 23 13:10:57 2020 +0300

    aghio: add validation to constructor

commit f57a2f596f5dc578548241c315c68dce7fc93905
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Nov 20 19:19:26 2020 +0300

    all: fix minor inaccuracies

commit 93462c71725d3d00655a4bd565b77e64451fff60
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Nov 20 19:13:23 2020 +0300

    home: make test name follow convention

commit 4922986ad84481b054479c43b4133a1b97bee86b
Merge: 1f5472abc 046ec13fd
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Nov 20 19:09:01 2020 +0300

    Merge branch 'master' into 2305-limit-message-size

commit 1f5472abcfa7427f389825fc59eb4253e1e2bfb7
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Nov 20 19:08:21 2020 +0300

    aghio: improve readability

commit 60dc706b093fa22bbf62f13b2341934364ddc4df
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Nov 20 18:44:08 2020 +0300

    home: cover middleware with test

commit bedf436b947ca1fa4493af2fc94f1f40beec7c35
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Nov 20 17:10:23 2020 +0300

    aghio: improved error informativeness

commit 682c5da9f21fa330fb3536bb1c112129c91b9990
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Nov 20 13:37:51 2020 +0300

    all: limit readers for ReadAll dealing with miscellanious data.

commit 78c6dd8d90a0a43fe6ee3f9ed4d5fc637b15ba74
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Nov 19 20:07:43 2020 +0300

    all: handle ReadAll calls dealing with request's bodies.

commit bfe1a6faf6468eb44515e2b0ecffa8c51f90b7e8
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Nov 19 17:25:34 2020 +0300

    home: add middlewares

commit bbd1d491b318e6ba07f8af23ad546183383783a8
Merge: 7b77c2cad 62a8fe0b7
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Nov 19 16:44:04 2020 +0300

    Merge branch 'master' into 2305-limit-message-size

commit 7b77c2cad03154177392460982e1d73ee2a30177
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Nov 17 15:33:33 2020 +0300

    aghio: create package
This commit is contained in:
Ainar Garipov
2020-11-23 14:14:08 +03:00
parent 046ec13fdc
commit c129361e55
15 changed files with 413 additions and 64 deletions

View File

@@ -0,0 +1,59 @@
// Package aghio contains extensions for io package's types and methods
package aghio
import (
"fmt"
"io"
)
// LimitReachedError records the limit and the operation that caused it.
type LimitReachedError struct {
Limit int64
}
// Error implements error interface for LimitReachedError.
// TODO(a.garipov): Think about error string format.
func (lre *LimitReachedError) Error() string {
return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit)
}
// limitedReadCloser is a wrapper for io.ReadCloser with limited reader and
// dealing with agherr package.
type limitedReadCloser struct {
limit int64
n int64
rc io.ReadCloser
}
// Read implements Reader interface.
func (lrc *limitedReadCloser) Read(p []byte) (n int, err error) {
if lrc.n == 0 {
return 0, &LimitReachedError{
Limit: lrc.limit,
}
}
if int64(len(p)) > lrc.n {
p = p[0:lrc.n]
}
n, err = lrc.rc.Read(p)
lrc.n -= int64(n)
return n, err
}
// Close implements Closer interface.
func (lrc *limitedReadCloser) Close() error {
return lrc.rc.Close()
}
// LimitReadCloser wraps ReadCloser to make it's Reader stop with
// ErrLimitReached after n bytes read.
func LimitReadCloser(rc io.ReadCloser, n int64) (limited io.ReadCloser, err error) {
if n < 0 {
return nil, fmt.Errorf("aghio: invalid n in LimitReadCloser: %d", n)
}
return &limitedReadCloser{
limit: n,
n: n,
rc: rc,
}, nil
}

View File

@@ -0,0 +1,108 @@
package aghio
import (
"fmt"
"io"
"io/ioutil"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestLimitReadCloser(t *testing.T) {
testCases := []struct {
name string
n int64
want error
}{{
name: "positive",
n: 1,
want: nil,
}, {
name: "zero",
n: 0,
want: nil,
}, {
name: "negative",
n: -1,
want: fmt.Errorf("aghio: invalid n in LimitReadCloser: -1"),
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := LimitReadCloser(nil, tc.n)
assert.Equal(t, tc.want, err)
})
}
}
func TestLimitedReadCloser_Read(t *testing.T) {
testCases := []struct {
name string
limit int64
rStr string
want int
err error
}{{
name: "perfectly_match",
limit: 3,
rStr: "abc",
want: 3,
err: nil,
}, {
name: "eof",
limit: 3,
rStr: "",
want: 0,
err: io.EOF,
}, {
name: "limit_reached",
limit: 0,
rStr: "abc",
want: 0,
err: &LimitReachedError{
Limit: 0,
},
}, {
name: "truncated",
limit: 2,
rStr: "abc",
want: 2,
err: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
readCloser := ioutil.NopCloser(strings.NewReader(tc.rStr))
buf := make([]byte, tc.limit+1)
lreader, err := LimitReadCloser(readCloser, tc.limit)
assert.Nil(t, err)
n, err := lreader.Read(buf)
assert.Equal(t, n, tc.want)
assert.Equal(t, tc.err, err)
})
}
}
func TestLimitedReadCloser_LimitReachedError(t *testing.T) {
testCases := []struct {
name string
want string
err error
}{{
name: "simplest",
want: "attempted to read more than 0 bytes",
err: &LimitReachedError{
Limit: 0,
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, tc.err.Error())
})
}
}

View File

@@ -299,6 +299,7 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
// . Check if a static IP is configured for the network interface
// Respond with results
func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
// This use of ReadAll is safe, because request's body is now limited.
body, err := ioutil.ReadAll(r.Body)
if err != nil {
msg := fmt.Sprintf("failed to read request body: %s", err)

View File

@@ -10,6 +10,7 @@ import (
"time"
"unsafe"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/log"
)
@@ -18,8 +19,10 @@ var GLMode bool
var glFilePrefix = "/tmp/gl_token_"
const glTokenTimeoutSeconds = 3600
const glCookieName = "Admin-Token"
const (
glTokenTimeoutSeconds = 3600
glCookieName = "Admin-Token"
)
func glProcessRedirect(w http.ResponseWriter, r *http.Request) bool {
if !GLMode {
@@ -71,14 +74,28 @@ func archIsLittleEndian() bool {
return (b == 0x04)
}
// MaxFileSize is a maximum file length in bytes.
const MaxFileSize = 1024 * 1024
func glGetTokenDate(file string) uint32 {
f, err := os.Open(file)
if err != nil {
log.Error("os.Open: %s", err)
return 0
}
defer f.Close()
fileReadCloser, err := aghio.LimitReadCloser(f, MaxFileSize)
if err != nil {
log.Error("LimitReadCloser: %s", err)
return 0
}
defer fileReadCloser.Close()
var dateToken uint32
bs, err := ioutil.ReadAll(f)
// This use of ReadAll is now safe, because we limited reader.
bs, err := ioutil.ReadAll(fileReadCloser)
if err != nil {
log.Error("ioutil.ReadAll: %s", err)
return 0

View File

@@ -3,7 +3,6 @@ package home
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
)
@@ -150,16 +149,11 @@ func clientHostToJSON(ip string, ch ClientHost) clientJSON {
// Add a new client
func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
return
}
cj := clientJSON{}
err = json.Unmarshal(body, &cj)
err := json.NewDecoder(r.Body).Decode(&cj)
if err != nil {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
httpError(w, http.StatusBadRequest, "failed to process request body: %s", err)
return
}
@@ -183,16 +177,17 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
// Remove client
func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
cj := clientJSON{}
err := json.NewDecoder(r.Body).Decode(&cj)
if err != nil {
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
httpError(w, http.StatusBadRequest, "failed to process request body: %s", err)
return
}
cj := clientJSON{}
err = json.Unmarshal(body, &cj)
if err != nil || len(cj.Name) == 0 {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
if len(cj.Name) == 0 {
httpError(w, http.StatusBadRequest, "client's name must be non-empty")
return
}
@@ -211,18 +206,14 @@ type updateJSON struct {
// Update client's properties
func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
dj := updateJSON{}
err := json.NewDecoder(r.Body).Decode(&dj)
if err != nil {
httpError(w, http.StatusBadRequest, "failed to read request body: %s", err)
httpError(w, http.StatusBadRequest, "failed to process request body: %s", err)
return
}
var dj updateJSON
err = json.Unmarshal(body, &dj)
if err != nil {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
return
}
if len(dj.Name) == 0 {
httpError(w, http.StatusBadRequest, "Invalid request")
return

View File

@@ -214,6 +214,7 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
}
func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
// This use of ReadAll is safe, because request's body is now limited.
body, err := ioutil.ReadAll(r.Body)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err)

View File

@@ -66,6 +66,7 @@ func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
}
func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
// This use of ReadAll is safe, because request's body is now limited.
body, err := ioutil.ReadAll(r.Body)
if err != nil {
msg := fmt.Sprintf("failed to read request body: %s", err)

View File

@@ -0,0 +1,59 @@
package home
import (
"net/http"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/log"
)
// middlerware is a wrapper function signature.
type middleware func(http.Handler) http.Handler
// withMiddlewares consequently wraps h with all the middlewares.
func withMiddlewares(h http.Handler, middlewares ...middleware) (wrapped http.Handler) {
wrapped = h
for _, mw := range middlewares {
wrapped = mw(wrapped)
}
return wrapped
}
// RequestBodySizeLimit is maximum request body length in bytes.
const RequestBodySizeLimit = 64 * 1024
// limitRequestBody wraps underlying handler h, making it's request's body Read
// method limited.
func limitRequestBody(h http.Handler) (limited http.Handler) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var err error
r.Body, err = aghio.LimitReadCloser(r.Body, RequestBodySizeLimit)
if err != nil {
log.Error("limitRequestBody: %s", err)
return
}
h.ServeHTTP(w, r)
})
}
// TODO(a.garipov): We currently have to use this, because everything registers
// its HTTP handlers in http.DefaultServeMux. In the future, refactor our HTTP
// API initialization process and stop using the gosh darn http.DefaultServeMux
// for anything at all. Gosh darn global variables.
func filterPProf(h http.Handler) (filtered http.Handler) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/debug/pprof") {
http.NotFound(w, r)
return
}
h.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,64 @@
package home
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/stretchr/testify/assert"
)
func TestLimitRequestBody(t *testing.T) {
errReqLimitReached := &aghio.LimitReachedError{
Limit: RequestBodySizeLimit,
}
testCases := []struct {
name string
body string
want []byte
wantErr error
}{{
name: "not_so_big",
body: "somestr",
want: []byte("somestr"),
wantErr: nil,
}, {
name: "so_big",
body: string(make([]byte, RequestBodySizeLimit+1)),
want: make([]byte, RequestBodySizeLimit),
wantErr: errReqLimitReached,
}, {
name: "empty",
body: "",
want: []byte(nil),
wantErr: nil,
}}
makeHandler := func(err *error) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var b []byte
b, *err = ioutil.ReadAll(r.Body)
w.Write(b)
})
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var err error
handler := makeHandler(&err)
lim := limitRequestBody(handler)
req := httptest.NewRequest(http.MethodPost, "https://www.example.com", strings.NewReader(tc.body))
res := httptest.NewRecorder()
lim.ServeHTTP(res, req)
assert.Equal(t, tc.want, res.Body.Bytes())
assert.Equal(t, tc.wantErr, err)
})
}
}

View File

@@ -7,7 +7,6 @@ import (
"net"
"net/http"
"strconv"
"strings"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/util"
@@ -142,7 +141,7 @@ func (web *Web) Start() {
web.httpServer = &http.Server{
ErrorLog: web.errLogger,
Addr: address,
Handler: filterPPROF(http.DefaultServeMux),
Handler: withMiddlewares(http.DefaultServeMux, filterPProf, limitRequestBody),
}
err := web.httpServer.ListenAndServe()
if err != http.ErrServerClosed {
@@ -153,22 +152,6 @@ func (web *Web) Start() {
}
}
// TODO(a.garipov): We currently have to use this, because everything registers
// its HTTP handlers in http.DefaultServeMux. In the future, refactor our HTTP
// API initialization process and stop using the gosh darn http.DefaultServeMux
// for anything at all. Gosh darn global variables.
func filterPPROF(h http.Handler) (filtered http.Handler) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/debug/pprof") {
http.NotFound(w, r)
return
}
h.ServeHTTP(w, r)
})
}
// Close - stop HTTP server, possibly waiting for all active connections to be closed
func (web *Web) Close() {
log.Info("Stopping HTTP server...")

View File

@@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/cache"
@@ -115,6 +116,9 @@ func whoisParse(data string) map[string]string {
return m
}
// MaxConnReadSize is an upper limit in bytes for reading from net.Conn.
const MaxConnReadSize = 64 * 1024
// Send request to a server and receive the response
func (w *Whois) query(target, serverAddr string) (string, error) {
addr, _, _ := net.SplitHostPort(serverAddr)
@@ -127,13 +131,20 @@ func (w *Whois) query(target, serverAddr string) (string, error) {
}
defer conn.Close()
connReadCloser, err := aghio.LimitReadCloser(conn, MaxConnReadSize)
if err != nil {
return "", err
}
defer connReadCloser.Close()
_ = conn.SetReadDeadline(time.Now().Add(time.Duration(w.timeoutMsec) * time.Millisecond))
_, err = conn.Write([]byte(target + "\r\n"))
if err != nil {
return "", err
}
data, err := ioutil.ReadAll(conn)
// This use of ReadAll is now safe, because we limited the conn Reader.
data, err := ioutil.ReadAll(connReadCloser)
if err != nil {
return "", err
}

View File

@@ -6,6 +6,8 @@ import (
"io/ioutil"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
)
const versionCheckPeriod = 8 * 60 * 60
@@ -19,6 +21,9 @@ type VersionInfo struct {
CanAutoUpdate bool // If true - we can auto-update
}
// MaxResponseSize is responses on server's requests maximum length in bytes.
const MaxResponseSize = 64 * 1024
// GetVersionResponse - downloads version.json (if needed) and deserializes it
func (u *Updater) GetVersionResponse(forceRecheck bool) (VersionInfo, error) {
if !forceRecheck &&
@@ -27,14 +32,19 @@ func (u *Updater) GetVersionResponse(forceRecheck bool) (VersionInfo, error) {
}
resp, err := u.Client.Get(u.VersionURL)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
if err != nil {
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", u.VersionURL, err)
}
defer resp.Body.Close()
resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxResponseSize)
if err != nil {
return VersionInfo{}, fmt.Errorf("updater: LimitReadCloser: %w", err)
}
defer resp.Body.Close()
// This use of ReadAll is safe, because we just limited the appropriate
// ReadCloser.
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", u.VersionURL, err)

View File

@@ -14,6 +14,7 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/log"
)
@@ -217,17 +218,27 @@ func (u *Updater) clean() {
_ = os.RemoveAll(u.updateDir)
}
// MaxPackageFileSize is a maximum package file length in bytes. The largest
// package whose size is limited by this constant currently has the size of
// approximately 9 MiB.
const MaxPackageFileSize = 32 * 1024 * 1024
// Download package file and save it to disk
func (u *Updater) downloadPackageFile(url string, filename string) error {
resp, err := u.Client.Get(url)
if err != nil {
return fmt.Errorf("http request failed: %w", err)
}
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
defer resp.Body.Close()
resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxPackageFileSize)
if err != nil {
return fmt.Errorf("http request failed: %w", err)
}
defer resp.Body.Close()
log.Debug("updater: reading HTTP body")
// This use of ReadAll is now safe, because we limited body's Reader.
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("ioutil.ReadAll() failed: %w", err)