Fix: proxy

This commit is contained in:
zijiren233 2023-11-05 21:19:59 +08:00
parent 3c2cfa5a94
commit 9c15b581d2
2 changed files with 75 additions and 22 deletions

View File

@ -12,15 +12,17 @@ import (
)
type HttpReadSeeker struct {
offset int64
url string
contentLength int64
method string
body io.Reader
client *http.Client
headers map[string]string
ctx context.Context
allowedContentTypes []string
offset int64
url string
contentLength int64
method string
body io.Reader
client *http.Client
headers map[string]string
ctx context.Context
allowedContentTypes []string
allowedStatusCodes []int
notAllowedStatusCodes []int
}
type HttpReadSeekerConf func(h *HttpReadSeeker)
@ -90,6 +92,18 @@ func AllowedContentTypes(types ...string) HttpReadSeekerConf {
}
}
func AllowedStatusCodes(codes ...int) HttpReadSeekerConf {
return func(h *HttpReadSeeker) {
h.allowedStatusCodes = codes
}
}
func NotAllowedStatusCodes(codes ...int) HttpReadSeekerConf {
return func(h *HttpReadSeeker) {
h.notAllowedStatusCodes = codes
}
}
func NewHttpReadSeeker(url string, conf ...HttpReadSeekerConf) *HttpReadSeeker {
rs := &HttpReadSeeker{
offset: 0,
@ -121,6 +135,9 @@ func (h *HttpReadSeeker) fix() *HttpReadSeeker {
if h.client == nil {
h.client = http.DefaultClient
}
if len(h.notAllowedStatusCodes) == 0 {
h.notAllowedStatusCodes = []int{http.StatusNotFound}
}
return h
}
@ -138,6 +155,10 @@ func (h *HttpReadSeeker) Read(p []byte) (n int, err error) {
return 0, err
}
defer resp.Body.Close()
if err := h.checkStatusCode(resp.StatusCode); err != nil {
return 0, err
}
if err := h.checkContentType(resp.Header.Get("Content-Type")); err != nil {
return 0, err
}
@ -148,11 +169,24 @@ func (h *HttpReadSeeker) Read(p []byte) (n int, err error) {
}
func (h *HttpReadSeeker) checkContentType(ct string) error {
if ct == "" {
return errors.New("content type is empty")
if len(h.allowedContentTypes) != 0 {
if ct == "" || slices.Index(h.allowedContentTypes, ct) == -1 {
return fmt.Errorf("content type `%s` not allowed", ct)
}
}
if len(h.allowedContentTypes) != 0 && slices.Index(h.allowedContentTypes, ct) == -1 {
return fmt.Errorf("content type `%s` not allowed", ct)
return nil
}
func (h *HttpReadSeeker) checkStatusCode(code int) error {
if len(h.allowedStatusCodes) != 0 {
if slices.Index(h.allowedStatusCodes, code) == -1 {
return fmt.Errorf("status code `%d` not allowed", code)
}
}
if len(h.notAllowedStatusCodes) != 0 {
if slices.Index(h.notAllowedStatusCodes, code) != -1 {
return fmt.Errorf("status code `%d` not allowed", code)
}
}
return nil
}
@ -177,6 +211,10 @@ func (h *HttpReadSeeker) Seek(offset int64, whence int) (int64, error) {
return 0, err
}
defer resp.Body.Close()
if err := h.checkStatusCode(resp.StatusCode); err != nil {
return 0, err
}
if err := h.checkContentType(resp.Header.Get("Content-Type")); err != nil {
return 0, err
}

View File

@ -6,6 +6,7 @@ import (
"image"
"image/color"
"image/png"
"io"
"math/rand"
"net/http"
"path"
@ -21,7 +22,6 @@ import (
"github.com/synctv-org/synctv/internal/op"
"github.com/synctv-org/synctv/internal/rtmp"
pb "github.com/synctv-org/synctv/proto/message"
"github.com/synctv-org/synctv/proxy"
"github.com/synctv-org/synctv/server/model"
"github.com/synctv-org/synctv/utils"
refreshcache "github.com/synctv-org/synctv/utils/refreshCache"
@ -497,16 +497,31 @@ func ProxyMovie(ctx *gin.Context) {
}
}
func proxyURL(ctx *gin.Context, url string, headers map[string]string) error {
if l, err := utils.ParseURLIsLocalIP(url); err != nil || l {
func proxyURL(ctx *gin.Context, u string, headers map[string]string) error {
if l, err := utils.ParseURLIsLocalIP(u); err != nil || l {
return err
}
hrs := proxy.NewBufferedHttpReadSeeker(512*1024, url,
proxy.WithContext(ctx),
proxy.WithHeaders(headers),
)
http.ServeContent(ctx.Writer, ctx.Request, url, time.Now(), hrs)
return nil
r, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return err
}
for k, v := range headers {
r.Header.Set(k, v)
}
r.Header.Set("Range", ctx.GetHeader("Range"))
resp, err := http.DefaultClient.Do(r)
if err != nil {
return err
}
defer resp.Body.Close()
ctx.Header("Content-Type", resp.Header.Get("Content-Type"))
ctx.Header("Content-Length", resp.Header.Get("Content-Length"))
ctx.Header("Accept-Ranges", resp.Header.Get("Accept-Ranges"))
ctx.Header("Cache-Control", resp.Header.Get("Cache-Control"))
ctx.Header("Content-Range", resp.Header.Get("Content-Range"))
ctx.Status(resp.StatusCode)
_, err = io.Copy(ctx.Writer, resp.Body)
return err
}
type FormatErrNotSupportFileType string