Feat: proxy allow context type

pull/31/head
zijiren233 2 years ago
parent a70c115b23
commit 10d7bc5e46

@ -7,18 +7,20 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"slices"
"strconv" "strconv"
) )
type HttpReadSeeker struct { type HttpReadSeeker struct {
offset int64 offset int64
url string url string
contentLength int64 contentLength int64
method string method string
body io.Reader body io.Reader
client *http.Client client *http.Client
headers map[string]string headers map[string]string
ctx context.Context ctx context.Context
allowedContentTypes []string
} }
type HttpReadSeekerConf func(h *HttpReadSeeker) type HttpReadSeekerConf func(h *HttpReadSeeker)
@ -82,6 +84,12 @@ func WithStartOffset(offset int64) HttpReadSeekerConf {
} }
} }
func AllowedContentTypes(types ...string) HttpReadSeekerConf {
return func(h *HttpReadSeeker) {
h.allowedContentTypes = types
}
}
func NewHttpReadSeeker(url string, conf ...HttpReadSeekerConf) *HttpReadSeeker { func NewHttpReadSeeker(url string, conf ...HttpReadSeekerConf) *HttpReadSeeker {
rs := &HttpReadSeeker{ rs := &HttpReadSeeker{
offset: 0, offset: 0,
@ -130,11 +138,25 @@ func (h *HttpReadSeeker) Read(p []byte) (n int, err error) {
return 0, err return 0, err
} }
defer resp.Body.Close() defer resp.Body.Close()
if err := h.checkContentType(resp.Header.Get("Content-Type")); err != nil {
return 0, err
}
n, err = io.ReadFull(resp.Body, p) n, err = io.ReadFull(resp.Body, p)
h.offset += int64(n) h.offset += int64(n)
return n, err return n, err
} }
func (h *HttpReadSeeker) checkContentType(ct string) error {
if ct == "" {
return errors.New("content type is empty")
}
if len(h.allowedContentTypes) != 0 && slices.Index(h.allowedContentTypes, ct) == -1 {
return fmt.Errorf("content type `%s` not allowed", ct)
}
return nil
}
func (h *HttpReadSeeker) Seek(offset int64, whence int) (int64, error) { func (h *HttpReadSeeker) Seek(offset int64, whence int) (int64, error) {
switch whence { switch whence {
case io.SeekStart: case io.SeekStart:
@ -155,6 +177,9 @@ func (h *HttpReadSeeker) Seek(offset int64, whence int) (int64, error) {
return 0, err return 0, err
} }
defer resp.Body.Close() defer resp.Body.Close()
if err := h.checkContentType(resp.Header.Get("Content-Type")); err != nil {
return 0, err
}
h.contentLength, err = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) h.contentLength, err = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
if err != nil { if err != nil {

Loading…
Cancel
Save