mirror of https://github.com/synctv-org/synctv.git
Fix: proxy
This commit is contained in:
parent
3c2cfa5a94
commit
9c15b581d2
|
@ -12,15 +12,17 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type HttpReadSeeker struct {
|
type HttpReadSeeker struct {
|
||||||
offset int64
|
offset int64
|
||||||
url string
|
url string
|
||||||
contentLength int64
|
contentLength int64
|
||||||
method string
|
method string
|
||||||
body io.Reader
|
body io.Reader
|
||||||
client *http.Client
|
client *http.Client
|
||||||
headers map[string]string
|
headers map[string]string
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
allowedContentTypes []string
|
allowedContentTypes []string
|
||||||
|
allowedStatusCodes []int
|
||||||
|
notAllowedStatusCodes []int
|
||||||
}
|
}
|
||||||
|
|
||||||
type HttpReadSeekerConf func(h *HttpReadSeeker)
|
type HttpReadSeekerConf func(h *HttpReadSeeker)
|
||||||
|
@ -90,6 +92,18 @@ func AllowedContentTypes(types ...string) HttpReadSeekerConf {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AllowedStatusCodes(codes ...int) HttpReadSeekerConf {
|
||||||
|
return func(h *HttpReadSeeker) {
|
||||||
|
h.allowedStatusCodes = codes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NotAllowedStatusCodes(codes ...int) HttpReadSeekerConf {
|
||||||
|
return func(h *HttpReadSeeker) {
|
||||||
|
h.notAllowedStatusCodes = codes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func NewHttpReadSeeker(url string, conf ...HttpReadSeekerConf) *HttpReadSeeker {
|
func NewHttpReadSeeker(url string, conf ...HttpReadSeekerConf) *HttpReadSeeker {
|
||||||
rs := &HttpReadSeeker{
|
rs := &HttpReadSeeker{
|
||||||
offset: 0,
|
offset: 0,
|
||||||
|
@ -121,6 +135,9 @@ func (h *HttpReadSeeker) fix() *HttpReadSeeker {
|
||||||
if h.client == nil {
|
if h.client == nil {
|
||||||
h.client = http.DefaultClient
|
h.client = http.DefaultClient
|
||||||
}
|
}
|
||||||
|
if len(h.notAllowedStatusCodes) == 0 {
|
||||||
|
h.notAllowedStatusCodes = []int{http.StatusNotFound}
|
||||||
|
}
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,6 +155,10 @@ func (h *HttpReadSeeker) Read(p []byte) (n int, err error) {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if err := h.checkStatusCode(resp.StatusCode); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
if err := h.checkContentType(resp.Header.Get("Content-Type")); err != nil {
|
if err := h.checkContentType(resp.Header.Get("Content-Type")); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -148,11 +169,24 @@ func (h *HttpReadSeeker) Read(p []byte) (n int, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HttpReadSeeker) checkContentType(ct string) error {
|
func (h *HttpReadSeeker) checkContentType(ct string) error {
|
||||||
if ct == "" {
|
if len(h.allowedContentTypes) != 0 {
|
||||||
return errors.New("content type is empty")
|
if ct == "" || slices.Index(h.allowedContentTypes, ct) == -1 {
|
||||||
|
return fmt.Errorf("content type `%s` not allowed", ct)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if len(h.allowedContentTypes) != 0 && slices.Index(h.allowedContentTypes, ct) == -1 {
|
return nil
|
||||||
return fmt.Errorf("content type `%s` not allowed", ct)
|
}
|
||||||
|
|
||||||
|
func (h *HttpReadSeeker) checkStatusCode(code int) error {
|
||||||
|
if len(h.allowedStatusCodes) != 0 {
|
||||||
|
if slices.Index(h.allowedStatusCodes, code) == -1 {
|
||||||
|
return fmt.Errorf("status code `%d` not allowed", code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(h.notAllowedStatusCodes) != 0 {
|
||||||
|
if slices.Index(h.notAllowedStatusCodes, code) != -1 {
|
||||||
|
return fmt.Errorf("status code `%d` not allowed", code)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -177,6 +211,10 @@ func (h *HttpReadSeeker) Seek(offset int64, whence int) (int64, error) {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if err := h.checkStatusCode(resp.StatusCode); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
if err := h.checkContentType(resp.Header.Get("Content-Type")); err != nil {
|
if err := h.checkContentType(resp.Header.Get("Content-Type")); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"image"
|
"image"
|
||||||
"image/color"
|
"image/color"
|
||||||
"image/png"
|
"image/png"
|
||||||
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
|
@ -21,7 +22,6 @@ import (
|
||||||
"github.com/synctv-org/synctv/internal/op"
|
"github.com/synctv-org/synctv/internal/op"
|
||||||
"github.com/synctv-org/synctv/internal/rtmp"
|
"github.com/synctv-org/synctv/internal/rtmp"
|
||||||
pb "github.com/synctv-org/synctv/proto/message"
|
pb "github.com/synctv-org/synctv/proto/message"
|
||||||
"github.com/synctv-org/synctv/proxy"
|
|
||||||
"github.com/synctv-org/synctv/server/model"
|
"github.com/synctv-org/synctv/server/model"
|
||||||
"github.com/synctv-org/synctv/utils"
|
"github.com/synctv-org/synctv/utils"
|
||||||
refreshcache "github.com/synctv-org/synctv/utils/refreshCache"
|
refreshcache "github.com/synctv-org/synctv/utils/refreshCache"
|
||||||
|
@ -497,16 +497,31 @@ func ProxyMovie(ctx *gin.Context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func proxyURL(ctx *gin.Context, url string, headers map[string]string) error {
|
func proxyURL(ctx *gin.Context, u string, headers map[string]string) error {
|
||||||
if l, err := utils.ParseURLIsLocalIP(url); err != nil || l {
|
if l, err := utils.ParseURLIsLocalIP(u); err != nil || l {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
hrs := proxy.NewBufferedHttpReadSeeker(512*1024, url,
|
r, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||||
proxy.WithContext(ctx),
|
if err != nil {
|
||||||
proxy.WithHeaders(headers),
|
return err
|
||||||
)
|
}
|
||||||
http.ServeContent(ctx.Writer, ctx.Request, url, time.Now(), hrs)
|
for k, v := range headers {
|
||||||
return nil
|
r.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
r.Header.Set("Range", ctx.GetHeader("Range"))
|
||||||
|
resp, err := http.DefaultClient.Do(r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
ctx.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||||
|
ctx.Header("Content-Length", resp.Header.Get("Content-Length"))
|
||||||
|
ctx.Header("Accept-Ranges", resp.Header.Get("Accept-Ranges"))
|
||||||
|
ctx.Header("Cache-Control", resp.Header.Get("Cache-Control"))
|
||||||
|
ctx.Header("Content-Range", resp.Header.Get("Content-Range"))
|
||||||
|
ctx.Status(resp.StatusCode)
|
||||||
|
_, err = io.Copy(ctx.Writer, resp.Body)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
type FormatErrNotSupportFileType string
|
type FormatErrNotSupportFileType string
|
||||||
|
|
Loading…
Reference in New Issue