From 9c15b581d29fa25a40fb457052403dced3aba6af Mon Sep 17 00:00:00 2001 From: zijiren233 Date: Sun, 5 Nov 2023 21:19:59 +0800 Subject: [PATCH] Fix: proxy --- proxy/read_seeker.go | 64 ++++++++++++++++++++++++++++++++-------- server/handlers/movie.go | 33 +++++++++++++++------ 2 files changed, 75 insertions(+), 22 deletions(-) diff --git a/proxy/read_seeker.go b/proxy/read_seeker.go index 6e00bb6..5105c64 100644 --- a/proxy/read_seeker.go +++ b/proxy/read_seeker.go @@ -12,15 +12,17 @@ import ( ) type HttpReadSeeker struct { - offset int64 - url string - contentLength int64 - method string - body io.Reader - client *http.Client - headers map[string]string - ctx context.Context - allowedContentTypes []string + offset int64 + url string + contentLength int64 + method string + body io.Reader + client *http.Client + headers map[string]string + ctx context.Context + allowedContentTypes []string + allowedStatusCodes []int + notAllowedStatusCodes []int } type HttpReadSeekerConf func(h *HttpReadSeeker) @@ -90,6 +92,18 @@ func AllowedContentTypes(types ...string) HttpReadSeekerConf { } } +func AllowedStatusCodes(codes ...int) HttpReadSeekerConf { + return func(h *HttpReadSeeker) { + h.allowedStatusCodes = codes + } +} + +func NotAllowedStatusCodes(codes ...int) HttpReadSeekerConf { + return func(h *HttpReadSeeker) { + h.notAllowedStatusCodes = codes + } +} + func NewHttpReadSeeker(url string, conf ...HttpReadSeekerConf) *HttpReadSeeker { rs := &HttpReadSeeker{ offset: 0, @@ -121,6 +135,9 @@ func (h *HttpReadSeeker) fix() *HttpReadSeeker { if h.client == nil { h.client = http.DefaultClient } + if len(h.notAllowedStatusCodes) == 0 { + h.notAllowedStatusCodes = []int{http.StatusNotFound} + } return h } @@ -138,6 +155,10 @@ func (h *HttpReadSeeker) Read(p []byte) (n int, err error) { return 0, err } defer resp.Body.Close() + + if err := h.checkStatusCode(resp.StatusCode); err != nil { + return 0, err + } if err := h.checkContentType(resp.Header.Get("Content-Type")); err != nil { return 0, err } @@ -148,11 +169,24 @@ func (h *HttpReadSeeker) Read(p []byte) (n int, err error) { } func (h *HttpReadSeeker) checkContentType(ct string) error { - if ct == "" { - return errors.New("content type is empty") + if len(h.allowedContentTypes) != 0 { + if ct == "" || slices.Index(h.allowedContentTypes, ct) == -1 { + return fmt.Errorf("content type `%s` not allowed", ct) + } } - if len(h.allowedContentTypes) != 0 && slices.Index(h.allowedContentTypes, ct) == -1 { - return fmt.Errorf("content type `%s` not allowed", ct) + return nil +} + +func (h *HttpReadSeeker) checkStatusCode(code int) error { + if len(h.allowedStatusCodes) != 0 { + if slices.Index(h.allowedStatusCodes, code) == -1 { + return fmt.Errorf("status code `%d` not allowed", code) + } + } + if len(h.notAllowedStatusCodes) != 0 { + if slices.Index(h.notAllowedStatusCodes, code) != -1 { + return fmt.Errorf("status code `%d` not allowed", code) + } } return nil } @@ -177,6 +211,10 @@ func (h *HttpReadSeeker) Seek(offset int64, whence int) (int64, error) { return 0, err } defer resp.Body.Close() + + if err := h.checkStatusCode(resp.StatusCode); err != nil { + return 0, err + } if err := h.checkContentType(resp.Header.Get("Content-Type")); err != nil { return 0, err } diff --git a/server/handlers/movie.go b/server/handlers/movie.go index a64cc81..cc368b4 100644 --- a/server/handlers/movie.go +++ b/server/handlers/movie.go @@ -6,6 +6,7 @@ import ( "image" "image/color" "image/png" + "io" "math/rand" "net/http" "path" @@ -21,7 +22,6 @@ import ( "github.com/synctv-org/synctv/internal/op" "github.com/synctv-org/synctv/internal/rtmp" pb "github.com/synctv-org/synctv/proto/message" - "github.com/synctv-org/synctv/proxy" "github.com/synctv-org/synctv/server/model" "github.com/synctv-org/synctv/utils" refreshcache "github.com/synctv-org/synctv/utils/refreshCache" @@ -497,16 +497,31 @@ func ProxyMovie(ctx *gin.Context) { } } -func proxyURL(ctx *gin.Context, url string, headers map[string]string) error { - if l, err := utils.ParseURLIsLocalIP(url); err != nil || l { +func proxyURL(ctx *gin.Context, u string, headers map[string]string) error { + if l, err := utils.ParseURLIsLocalIP(u); err != nil || l { return err } - hrs := proxy.NewBufferedHttpReadSeeker(512*1024, url, - proxy.WithContext(ctx), - proxy.WithHeaders(headers), - ) - http.ServeContent(ctx.Writer, ctx.Request, url, time.Now(), hrs) - return nil + r, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + return err + } + for k, v := range headers { + r.Header.Set(k, v) + } + r.Header.Set("Range", ctx.GetHeader("Range")) + resp, err := http.DefaultClient.Do(r) + if err != nil { + return err + } + defer resp.Body.Close() + ctx.Header("Content-Type", resp.Header.Get("Content-Type")) + ctx.Header("Content-Length", resp.Header.Get("Content-Length")) + ctx.Header("Accept-Ranges", resp.Header.Get("Accept-Ranges")) + ctx.Header("Cache-Control", resp.Header.Get("Cache-Control")) + ctx.Header("Content-Range", resp.Header.Get("Content-Range")) + ctx.Status(resp.StatusCode) + _, err = io.Copy(ctx.Writer, resp.Body) + return err } type FormatErrNotSupportFileType string