mirror of https://github.com/synctv-org/synctv.git
Fix: proxy
This commit is contained in:
parent
3c2cfa5a94
commit
9c15b581d2
|
@ -21,6 +21,8 @@ type HttpReadSeeker struct {
|
|||
headers map[string]string
|
||||
ctx context.Context
|
||||
allowedContentTypes []string
|
||||
allowedStatusCodes []int
|
||||
notAllowedStatusCodes []int
|
||||
}
|
||||
|
||||
type HttpReadSeekerConf func(h *HttpReadSeeker)
|
||||
|
@ -90,6 +92,18 @@ func AllowedContentTypes(types ...string) HttpReadSeekerConf {
|
|||
}
|
||||
}
|
||||
|
||||
func AllowedStatusCodes(codes ...int) HttpReadSeekerConf {
|
||||
return func(h *HttpReadSeeker) {
|
||||
h.allowedStatusCodes = codes
|
||||
}
|
||||
}
|
||||
|
||||
func NotAllowedStatusCodes(codes ...int) HttpReadSeekerConf {
|
||||
return func(h *HttpReadSeeker) {
|
||||
h.notAllowedStatusCodes = codes
|
||||
}
|
||||
}
|
||||
|
||||
func NewHttpReadSeeker(url string, conf ...HttpReadSeekerConf) *HttpReadSeeker {
|
||||
rs := &HttpReadSeeker{
|
||||
offset: 0,
|
||||
|
@ -121,6 +135,9 @@ func (h *HttpReadSeeker) fix() *HttpReadSeeker {
|
|||
if h.client == nil {
|
||||
h.client = http.DefaultClient
|
||||
}
|
||||
if len(h.notAllowedStatusCodes) == 0 {
|
||||
h.notAllowedStatusCodes = []int{http.StatusNotFound}
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
|
@ -138,6 +155,10 @@ func (h *HttpReadSeeker) Read(p []byte) (n int, err error) {
|
|||
return 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if err := h.checkStatusCode(resp.StatusCode); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := h.checkContentType(resp.Header.Get("Content-Type")); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -148,12 +169,25 @@ func (h *HttpReadSeeker) Read(p []byte) (n int, err error) {
|
|||
}
|
||||
|
||||
func (h *HttpReadSeeker) checkContentType(ct string) error {
|
||||
if ct == "" {
|
||||
return errors.New("content type is empty")
|
||||
}
|
||||
if len(h.allowedContentTypes) != 0 && slices.Index(h.allowedContentTypes, ct) == -1 {
|
||||
if len(h.allowedContentTypes) != 0 {
|
||||
if ct == "" || slices.Index(h.allowedContentTypes, ct) == -1 {
|
||||
return fmt.Errorf("content type `%s` not allowed", ct)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HttpReadSeeker) checkStatusCode(code int) error {
|
||||
if len(h.allowedStatusCodes) != 0 {
|
||||
if slices.Index(h.allowedStatusCodes, code) == -1 {
|
||||
return fmt.Errorf("status code `%d` not allowed", code)
|
||||
}
|
||||
}
|
||||
if len(h.notAllowedStatusCodes) != 0 {
|
||||
if slices.Index(h.notAllowedStatusCodes, code) != -1 {
|
||||
return fmt.Errorf("status code `%d` not allowed", code)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -177,6 +211,10 @@ func (h *HttpReadSeeker) Seek(offset int64, whence int) (int64, error) {
|
|||
return 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if err := h.checkStatusCode(resp.StatusCode); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := h.checkContentType(resp.Header.Get("Content-Type")); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"image"
|
||||
"image/color"
|
||||
"image/png"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"path"
|
||||
|
@ -21,7 +22,6 @@ import (
|
|||
"github.com/synctv-org/synctv/internal/op"
|
||||
"github.com/synctv-org/synctv/internal/rtmp"
|
||||
pb "github.com/synctv-org/synctv/proto/message"
|
||||
"github.com/synctv-org/synctv/proxy"
|
||||
"github.com/synctv-org/synctv/server/model"
|
||||
"github.com/synctv-org/synctv/utils"
|
||||
refreshcache "github.com/synctv-org/synctv/utils/refreshCache"
|
||||
|
@ -497,16 +497,31 @@ func ProxyMovie(ctx *gin.Context) {
|
|||
}
|
||||
}
|
||||
|
||||
func proxyURL(ctx *gin.Context, url string, headers map[string]string) error {
|
||||
if l, err := utils.ParseURLIsLocalIP(url); err != nil || l {
|
||||
func proxyURL(ctx *gin.Context, u string, headers map[string]string) error {
|
||||
if l, err := utils.ParseURLIsLocalIP(u); err != nil || l {
|
||||
return err
|
||||
}
|
||||
hrs := proxy.NewBufferedHttpReadSeeker(512*1024, url,
|
||||
proxy.WithContext(ctx),
|
||||
proxy.WithHeaders(headers),
|
||||
)
|
||||
http.ServeContent(ctx.Writer, ctx.Request, url, time.Now(), hrs)
|
||||
return nil
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for k, v := range headers {
|
||||
r.Header.Set(k, v)
|
||||
}
|
||||
r.Header.Set("Range", ctx.GetHeader("Range"))
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
ctx.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||
ctx.Header("Content-Length", resp.Header.Get("Content-Length"))
|
||||
ctx.Header("Accept-Ranges", resp.Header.Get("Accept-Ranges"))
|
||||
ctx.Header("Cache-Control", resp.Header.Get("Cache-Control"))
|
||||
ctx.Header("Content-Range", resp.Header.Get("Content-Range"))
|
||||
ctx.Status(resp.StatusCode)
|
||||
_, err = io.Copy(ctx.Writer, resp.Body)
|
||||
return err
|
||||
}
|
||||
|
||||
type FormatErrNotSupportFileType string
|
||||
|
|
Loading…
Reference in New Issue