diff --git a/proxy/read_seeker.go b/proxy/read_seeker.go index 3bcf422..6e00bb6 100644 --- a/proxy/read_seeker.go +++ b/proxy/read_seeker.go @@ -7,18 +7,20 @@ import ( "fmt" "io" "net/http" + "slices" "strconv" ) type HttpReadSeeker struct { - offset int64 - url string - contentLength int64 - method string - body io.Reader - client *http.Client - headers map[string]string - ctx context.Context + offset int64 + url string + contentLength int64 + method string + body io.Reader + client *http.Client + headers map[string]string + ctx context.Context + allowedContentTypes []string } type HttpReadSeekerConf func(h *HttpReadSeeker) @@ -82,6 +84,12 @@ func WithStartOffset(offset int64) HttpReadSeekerConf { } } +func AllowedContentTypes(types ...string) HttpReadSeekerConf { + return func(h *HttpReadSeeker) { + h.allowedContentTypes = types + } +} + func NewHttpReadSeeker(url string, conf ...HttpReadSeekerConf) *HttpReadSeeker { rs := &HttpReadSeeker{ offset: 0, @@ -130,11 +138,25 @@ func (h *HttpReadSeeker) Read(p []byte) (n int, err error) { return 0, err } defer resp.Body.Close() + if err := h.checkContentType(resp.Header.Get("Content-Type")); err != nil { + return 0, err + } + n, err = io.ReadFull(resp.Body, p) h.offset += int64(n) return n, err } +func (h *HttpReadSeeker) checkContentType(ct string) error { + if ct == "" { + return errors.New("content type is empty") + } + if len(h.allowedContentTypes) != 0 && slices.Index(h.allowedContentTypes, ct) == -1 { + return fmt.Errorf("content type `%s` not allowed", ct) + } + return nil +} + func (h *HttpReadSeeker) Seek(offset int64, whence int) (int64, error) { switch whence { case io.SeekStart: @@ -155,6 +177,9 @@ func (h *HttpReadSeeker) Seek(offset int64, whence int) (int64, error) { return 0, err } defer resp.Body.Close() + if err := h.checkContentType(resp.Header.Get("Content-Type")); err != nil { + return 0, err + } h.contentLength, err = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) if err != nil {