feat: handle range not satisfiable

pull/254/head
zijiren233 9 months ago
parent 1d20475b82
commit 31fcf5eb4e

@ -26,10 +26,9 @@ type Cache interface {
// CacheMetadata stores metadata about a cached response // CacheMetadata stores metadata about a cached response
type CacheMetadata struct { type CacheMetadata struct {
Headers http.Header `json:"headers,omitempty"` Headers http.Header `json:"h,omitempty"`
ContentType string `json:"content_type,omitempty"` ContentType string `json:"ct,omitempty"`
ContentTotalLength int64 `json:"content_total_length,omitempty"` ContentTotalLength int64 `json:"ctl,omitempty"`
NotSupportRange bool `json:"not_support_range,omitempty"`
} }
func (m *CacheMetadata) MarshalBinary() ([]byte, error) { func (m *CacheMetadata) MarshalBinary() ([]byte, error) {

@ -18,24 +18,21 @@ var (
) )
type HttpReadSeekCloser struct { type HttpReadSeekCloser struct {
ctx context.Context ctx context.Context
headHeaders http.Header headHeaders http.Header
currentResp *http.Response currentResp *http.Response
headers http.Header headers http.Header
client *http.Client client *http.Client
contentType string contentType string
method string method string
headMethod string headMethod string
url string url string
allowedContentTypes []string allowedContentTypes []string
notAllowedStatusCodes []int offset int64
allowedStatusCodes []int contentTotalLength int64
offset int64 perLength int64
contentTotalLength int64 currentRespMaxOffset int64
perLength int64 notSupportRange bool
currentRespMaxOffset int64
notSupportRange bool
// if the server does not support range requests, the seek method will be unusable
notSupportSeekWhenNotSupportRange bool notSupportSeekWhenNotSupportRange bool
} }
@ -105,22 +102,6 @@ func AllowedContentTypes(types ...string) HttpReadSeekerConf {
} }
} }
func AllowedStatusCodes(codes ...int) HttpReadSeekerConf {
return func(h *HttpReadSeekCloser) {
if len(codes) > 0 {
h.allowedStatusCodes = slices.Clone(codes)
}
}
}
func NotAllowedStatusCodes(codes ...int) HttpReadSeekerConf {
return func(h *HttpReadSeekCloser) {
if len(codes) > 0 {
h.notAllowedStatusCodes = slices.Clone(codes)
}
}
}
// sets the per length of the request // sets the per length of the request
func WithPerLength(length int64) HttpReadSeekerConf { func WithPerLength(length int64) HttpReadSeekerConf {
return func(h *HttpReadSeekCloser) { return func(h *HttpReadSeekCloser) {
@ -177,9 +158,6 @@ func (h *HttpReadSeekCloser) fix() *HttpReadSeekCloser {
if h.client == nil { if h.client == nil {
h.client = http.DefaultClient h.client = http.DefaultClient
} }
if len(h.notAllowedStatusCodes) == 0 {
h.notAllowedStatusCodes = []int{http.StatusNotFound}
}
if h.perLength <= 0 { if h.perLength <= 0 {
h.perLength = 1024 * 1024 h.perLength = 1024 * 1024
} }
@ -241,80 +219,59 @@ func (h *HttpReadSeekCloser) FetchNextChunk() error {
return fmt.Errorf("failed to execute request: %w", err) return fmt.Errorf("failed to execute request: %w", err)
} }
h.contentType = resp.Header.Get("Content-Type") if resp.StatusCode != http.StatusPartialContent &&
resp.StatusCode != http.StatusOK &&
if resp.StatusCode == http.StatusOK { resp.StatusCode != http.StatusRequestedRangeNotSatisfiable {
if ar := resp.Header.Get("Accept-Ranges"); ar == "" || ar == "none" { resp.Body.Close()
h.notSupportRange = true return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
} }
// if the maximum offset of the current response is less than the content length minus one, it means that the server does not support range requests if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
if h.currentRespMaxOffset < resp.ContentLength-1 { contentTotalLength, err := ParseContentRangeTotalLength(resp.Header.Get("Content-Range"))
h.notSupportRange = true if err == nil && contentTotalLength > 0 {
h.contentTotalLength = contentTotalLength
} }
resp.Body.Close()
return fmt.Errorf("requested range not satisfiable, content total length: %d, offset: %d", h.contentTotalLength, h.offset)
}
if h.notSupportRange { if err := h.checkContentType(resp.Header.Get("Content-Type")); err != nil {
h.contentTotalLength = resp.ContentLength resp.Body.Close()
h.currentRespMaxOffset = h.contentTotalLength - 1 return fmt.Errorf("response validation failed: %w", err)
}
// If offset > 0, read and discard bytes until reaching the desired offset
if h.offset > 0 {
if _, err := io.CopyN(io.Discard, resp.Body, h.offset); err != nil {
resp.Body.Close()
if err == io.EOF {
return io.EOF
}
return fmt.Errorf("failed to discard bytes: %w", err)
}
}
h.currentResp = resp
return nil
}
// if the content length is not known, it may be because the requested length is too long, and a new request is needed
if h.contentTotalLength < 0 {
h.contentTotalLength = resp.ContentLength
resp.Body.Close()
return h.FetchNextChunk()
}
if h.contentTotalLength != resp.ContentLength { h.contentType = resp.Header.Get("Content-Type")
resp.Body.Close()
return fmt.Errorf("content length mismatch: %d != %d", h.contentTotalLength, resp.ContentLength)
}
h.notSupportRange = true if resp.StatusCode == http.StatusOK {
if h.offset > 0 { if h.offset > 0 {
if h.notSupportSeekWhenNotSupportRange {
return fmt.Errorf("not support seek when not support range")
}
if _, err := io.CopyN(io.Discard, resp.Body, h.offset); err != nil { if _, err := io.CopyN(io.Discard, resp.Body, h.offset); err != nil {
resp.Body.Close() resp.Body.Close()
if err == io.EOF {
return io.EOF
}
return fmt.Errorf("failed to discard bytes: %w", err) return fmt.Errorf("failed to discard bytes: %w", err)
} }
} }
h.notSupportRange = true
h.contentTotalLength = resp.ContentLength
h.currentRespMaxOffset = h.contentTotalLength - 1 h.currentRespMaxOffset = h.contentTotalLength - 1
h.currentResp = resp h.currentResp = resp
return nil return nil
} }
if resp.StatusCode != http.StatusPartialContent {
resp.Body.Close()
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
if err := h.checkResponse(resp); err != nil {
resp.Body.Close()
return fmt.Errorf("response validation failed: %w", err)
}
contentTotalLength, err := ParseContentRangeTotalLength(resp.Header.Get("Content-Range")) contentTotalLength, err := ParseContentRangeTotalLength(resp.Header.Get("Content-Range"))
if err == nil && contentTotalLength > 0 { if err == nil && contentTotalLength > 0 {
h.contentTotalLength = contentTotalLength h.contentTotalLength = contentTotalLength
} }
_, end, err := ParseContentRangeStartAndEnd(resp.Header.Get("Content-Range")) start, end, err := ParseContentRangeStartAndEnd(resp.Header.Get("Content-Range"))
if err == nil && end != -1 { if err == nil {
h.currentRespMaxOffset = end if end != -1 {
h.currentRespMaxOffset = end
}
if h.offset != start {
return fmt.Errorf("offset mismatch, expected: %d, got: %d", start, h.offset)
}
} }
h.currentResp = resp h.currentResp = resp
@ -358,13 +315,6 @@ func (h *HttpReadSeekCloser) createRequestWithoutRange() (*http.Request, error)
return req, nil return req, nil
} }
func (h *HttpReadSeekCloser) checkResponse(resp *http.Response) error {
if err := h.checkStatusCode(resp.StatusCode); err != nil {
return err
}
return h.checkContentType(resp.Header.Get("Content-Type"))
}
func (h *HttpReadSeekCloser) closeCurrentResp() { func (h *HttpReadSeekCloser) closeCurrentResp() {
if h.currentResp != nil { if h.currentResp != nil {
h.currentResp.Body.Close() h.currentResp.Body.Close()
@ -381,21 +331,6 @@ func (h *HttpReadSeekCloser) checkContentType(ct string) error {
return nil return nil
} }
func (h *HttpReadSeekCloser) checkStatusCode(code int) error {
if len(h.allowedStatusCodes) != 0 {
if slices.Index(h.allowedStatusCodes, code) == -1 {
return fmt.Errorf("status code %d is not in the list of allowed status codes: %v", code, h.allowedStatusCodes)
}
return nil
}
if len(h.notAllowedStatusCodes) != 0 {
if slices.Index(h.notAllowedStatusCodes, code) != -1 {
return fmt.Errorf("status code %d is in the list of not allowed status codes: %v", code, h.notAllowedStatusCodes)
}
}
return nil
}
func (h *HttpReadSeekCloser) Seek(offset int64, whence int) (int64, error) { func (h *HttpReadSeekCloser) Seek(offset int64, whence int) (int64, error) {
newOffset, err := h.calculateNewOffset(offset, whence) newOffset, err := h.calculateNewOffset(offset, whence)
if err != nil { if err != nil {
@ -451,10 +386,10 @@ func (h *HttpReadSeekCloser) fetchContentLength() error {
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected HTTP status code in HEAD request: %d (expected 200 OK)", resp.StatusCode) return fmt.Errorf("unexpected status code in HEAD request: %d", resp.StatusCode)
} }
if err := h.checkResponse(resp); err != nil { if err := h.checkContentType(resp.Header.Get("Content-Type")); err != nil {
return fmt.Errorf("HEAD response validation failed: %w", err) return fmt.Errorf("HEAD response validation failed: %w", err)
} }
@ -491,10 +426,6 @@ func (h *HttpReadSeekCloser) ContentType() (string, error) {
return "", fmt.Errorf("content type is not available - no successful response received yet") return "", fmt.Errorf("content type is not available - no successful response received yet")
} }
func (h *HttpReadSeekCloser) AcceptRanges() bool {
return !h.notSupportRange
}
func (h *HttpReadSeekCloser) ContentTotalLength() (int64, error) { func (h *HttpReadSeekCloser) ContentTotalLength() (int64, error) {
if h.contentTotalLength > 0 { if h.contentTotalLength > 0 {
return h.contentTotalLength, nil return h.contentTotalLength, nil
@ -502,10 +433,6 @@ func (h *HttpReadSeekCloser) ContentTotalLength() (int64, error) {
return 0, fmt.Errorf("content total length is not available - no successful response received yet") return 0, fmt.Errorf("content total length is not available - no successful response received yet")
} }
func (h *HttpReadSeekCloser) SetContentTotalLength(length int64) {
h.contentTotalLength = length
}
func ParseContentRangeStartAndEnd(contentRange string) (int64, int64, error) { func ParseContentRangeStartAndEnd(contentRange string) (int64, int64, error) {
if contentRange == "" { if contentRange == "" {
return 0, 0, fmt.Errorf("Content-Range header is empty") return 0, 0, fmt.Errorf("Content-Range header is empty")

@ -18,15 +18,10 @@ var mu = ksync.DefaultKmutex()
// Proxy defines the interface for proxy implementations // Proxy defines the interface for proxy implementations
type Proxy interface { type Proxy interface {
io.ReadSeeker io.ReadSeeker
AcceptRanges() bool
ContentTotalLength() (int64, error) ContentTotalLength() (int64, error)
ContentType() (string, error) ContentType() (string, error)
} }
type SetContentTotalLength interface {
SetContentTotalLength(int64)
}
// Headers defines the interface for accessing response headers // Headers defines the interface for accessing response headers
type Headers interface { type Headers interface {
Headers() http.Header Headers() http.Header
@ -114,37 +109,21 @@ func (c *SliceCacheProxy) Proxy(w http.ResponseWriter, r *http.Request) error {
return fmt.Errorf("failed to parse Range header: %w", err) return fmt.Errorf("failed to parse Range header: %w", err)
} }
isRangeRequest := r.Header.Get("Range") != ""
if isRangeRequest {
// avoid the request exceeding the total length of the file due to the large slice size
if st, ok := c.r.(SetContentTotalLength); ok {
cacheItem, ok, err := c.cache.GetAnyWithPrefix(cachePrefix(c.key, c.sliceSize))
if err != nil {
http.Error(w, fmt.Sprintf("Failed to get cache item: %v", err), http.StatusInternalServerError)
return fmt.Errorf("failed to get cache item: %w", err)
}
if ok {
st.SetContentTotalLength(cacheItem.Metadata.ContentTotalLength)
}
}
}
alignedOffset := alignedOffset(byteRange.Start, c.sliceSize) alignedOffset := alignedOffset(byteRange.Start, c.sliceSize)
cacheItem, err := c.getCacheItem(alignedOffset) cacheItem, cached, err := c.getCacheItem(alignedOffset)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("Failed to get cache item: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to get cache item: %v", err), http.StatusInternalServerError)
return fmt.Errorf("failed to get cache item: %w", err) return fmt.Errorf("failed to get cache item: %w", err)
} }
c.setResponseHeaders(w, byteRange, cacheItem, isRangeRequest) c.setResponseHeaders(w, byteRange, cacheItem, cached, r.Header.Get("Range") != "")
if err := c.writeResponse(w, byteRange, alignedOffset, cacheItem); err != nil { if err := c.writeResponse(w, byteRange, alignedOffset, cacheItem); err != nil {
return fmt.Errorf("failed to write response: %w", err) return fmt.Errorf("failed to write response: %w", err)
} }
return nil return nil
} }
func (c *SliceCacheProxy) setResponseHeaders(w http.ResponseWriter, byteRange *ByteRange, cacheItem *CacheItem, isRangeRequest bool) { func (c *SliceCacheProxy) setResponseHeaders(w http.ResponseWriter, byteRange *ByteRange, cacheItem *CacheItem, cached bool, isRangeRequest bool) {
// Copy headers excluding special ones // Copy headers excluding special ones
for k, v := range cacheItem.Metadata.Headers { for k, v := range cacheItem.Metadata.Headers {
switch k { switch k {
@ -155,11 +134,12 @@ func (c *SliceCacheProxy) setResponseHeaders(w http.ResponseWriter, byteRange *B
} }
} }
if !cacheItem.Metadata.NotSupportRange { if cached {
w.Header().Set("Accept-Ranges", "bytes") w.Header().Set("Cache-Status", "HIT")
} else { } else {
w.Header().Set("Accept-Ranges", "none") w.Header().Set("Cache-Status", "MISS")
} }
w.Header().Set("Accept-Ranges", "bytes")
w.Header().Set("Content-Length", fmtContentLength(byteRange.Start, byteRange.End, cacheItem.Metadata.ContentTotalLength)) w.Header().Set("Content-Length", fmtContentLength(byteRange.Start, byteRange.End, cacheItem.Metadata.ContentTotalLength))
w.Header().Set("Content-Type", cacheItem.Metadata.ContentType) w.Header().Set("Content-Type", cacheItem.Metadata.ContentType)
if isRangeRequest { if isRangeRequest {
@ -198,7 +178,7 @@ func (c *SliceCacheProxy) writeResponse(w http.ResponseWriter, byteRange *ByteRa
// Write subsequent slices // Write subsequent slices
currentOffset := alignedOffset + c.sliceSize currentOffset := alignedOffset + c.sliceSize
for remainingLength > 0 { for remainingLength > 0 {
cacheItem, err := c.getCacheItem(currentOffset) cacheItem, _, err := c.getCacheItem(currentOffset)
if err != nil { if err != nil {
return fmt.Errorf("failed to get cache item at offset %d: %w", currentOffset, err) return fmt.Errorf("failed to get cache item at offset %d: %w", currentOffset, err)
} }
@ -219,9 +199,9 @@ func (c *SliceCacheProxy) writeResponse(w http.ResponseWriter, byteRange *ByteRa
return nil return nil
} }
func (c *SliceCacheProxy) getCacheItem(alignedOffset int64) (*CacheItem, error) { func (c *SliceCacheProxy) getCacheItem(alignedOffset int64) (*CacheItem, bool, error) {
if alignedOffset < 0 { if alignedOffset < 0 {
return nil, fmt.Errorf("cache item offset cannot be negative, got: %d", alignedOffset) return nil, false, fmt.Errorf("cache item offset cannot be negative, got: %d", alignedOffset)
} }
cacheKey := cacheKey(c.key, alignedOffset, c.sliceSize) cacheKey := cacheKey(c.key, alignedOffset, c.sliceSize)
@ -231,24 +211,24 @@ func (c *SliceCacheProxy) getCacheItem(alignedOffset int64) (*CacheItem, error)
// Try to get from cache first // Try to get from cache first
slice, ok, err := c.cache.Get(cacheKey) slice, ok, err := c.cache.Get(cacheKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get item from cache: %w", err) return nil, false, fmt.Errorf("failed to get item from cache: %w", err)
} }
if ok { if ok {
return slice, nil return slice, true, nil
} }
// Fetch from source if not in cache // Fetch from source if not in cache
slice, err = c.fetchFromSource(alignedOffset) slice, err = c.fetchFromSource(alignedOffset)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch item from source: %w", err) return nil, false, fmt.Errorf("failed to fetch item from source: %w", err)
} }
// Store in cache // Store in cache
if err = c.cache.Set(cacheKey, slice); err != nil { if err = c.cache.Set(cacheKey, slice); err != nil {
return nil, fmt.Errorf("failed to store item in cache: %w", err) return nil, false, fmt.Errorf("failed to store item in cache: %w", err)
} }
return slice, nil return slice, false, nil
} }
func (c *SliceCacheProxy) contentTotalLength() (int64, error) { func (c *SliceCacheProxy) contentTotalLength() (int64, error) {
@ -311,7 +291,6 @@ func (c *SliceCacheProxy) fetchFromSource(offset int64) (*CacheItem, error) {
Headers: headers, Headers: headers,
ContentTotalLength: total, ContentTotalLength: total,
ContentType: contentType, ContentType: contentType,
NotSupportRange: !c.r.AcceptRanges(),
}, },
Data: buf[:n], Data: buf[:n],
}, nil }, nil

Loading…
Cancel
Save