mirror of
https://github.com/ossrs/srs.git
synced 2025-11-24 03:44:02 +08:00
Migrate proxy to ossrs/proxy-go repository.
This commit is contained in:
4
proxy/.gitignore
vendored
4
proxy/.gitignore
vendored
@@ -1,4 +0,0 @@
|
|||||||
.idea
|
|
||||||
srs-proxy
|
|
||||||
.env
|
|
||||||
.go-formarted
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
.PHONY: all build test fmt clean run
|
|
||||||
|
|
||||||
all: build
|
|
||||||
|
|
||||||
build: fmt ./srs-proxy
|
|
||||||
|
|
||||||
./srs-proxy: *.go
|
|
||||||
go build -o srs-proxy .
|
|
||||||
|
|
||||||
test:
|
|
||||||
go test ./...
|
|
||||||
|
|
||||||
fmt: ./.go-formarted
|
|
||||||
|
|
||||||
./.go-formarted: *.go
|
|
||||||
touch .go-formarted
|
|
||||||
go fmt ./...
|
|
||||||
|
|
||||||
clean:
|
|
||||||
rm -f srs-proxy .go-formarted
|
|
||||||
|
|
||||||
run: fmt
|
|
||||||
go run .
|
|
||||||
6
proxy/README.md
Normal file
6
proxy/README.md
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# Proxy
|
||||||
|
|
||||||
|
Migrated to below repositoties:
|
||||||
|
|
||||||
|
* [proxy-go](https://github.com/ossrs/proxy-go) An common proxy server for any media servers with RTMP/SRT/HLS/HTTP-FLV and WebRTC/WHIP/WHEP protocols support.
|
||||||
|
|
||||||
272
proxy/api.go
272
proxy/api.go
@@ -1,272 +0,0 @@
|
|||||||
// Copyright (c) 2025 Winlin
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"srs-proxy/errors"
|
|
||||||
"srs-proxy/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
// srsHTTPAPIServer is the proxy for SRS HTTP API, to proxy the WebRTC HTTP API like WHIP and WHEP,
|
|
||||||
// to proxy other HTTP API of SRS like the streams and clients, etc.
|
|
||||||
type srsHTTPAPIServer struct {
|
|
||||||
// The underlayer HTTP server.
|
|
||||||
server *http.Server
|
|
||||||
// The WebRTC server.
|
|
||||||
rtc *srsWebRTCServer
|
|
||||||
// The gracefully quit timeout, wait server to quit.
|
|
||||||
gracefulQuitTimeout time.Duration
|
|
||||||
// The wait group for all goroutines.
|
|
||||||
wg sync.WaitGroup
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSRSHTTPAPIServer(opts ...func(*srsHTTPAPIServer)) *srsHTTPAPIServer {
|
|
||||||
v := &srsHTTPAPIServer{}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(v)
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsHTTPAPIServer) Close() error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
|
|
||||||
defer cancel()
|
|
||||||
v.server.Shutdown(ctx)
|
|
||||||
|
|
||||||
v.wg.Wait()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsHTTPAPIServer) Run(ctx context.Context) error {
|
|
||||||
// Parse address to listen.
|
|
||||||
addr := envHttpAPI()
|
|
||||||
if !strings.Contains(addr, ":") {
|
|
||||||
addr = ":" + addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create server and handler.
|
|
||||||
mux := http.NewServeMux()
|
|
||||||
v.server = &http.Server{Addr: addr, Handler: mux}
|
|
||||||
logger.Df(ctx, "HTTP API server listen at %v", addr)
|
|
||||||
|
|
||||||
// Shutdown the server gracefully when quiting.
|
|
||||||
go func() {
|
|
||||||
ctxParent := ctx
|
|
||||||
<-ctxParent.Done()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
v.server.Shutdown(ctx)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// The basic version handler, also can be used as health check API.
|
|
||||||
logger.Df(ctx, "Handle /api/v1/versions by %v", addr)
|
|
||||||
mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
apiResponse(ctx, w, r, map[string]string{
|
|
||||||
"signature": Signature(),
|
|
||||||
"version": Version(),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
// The WebRTC WHIP API handler.
|
|
||||||
logger.Df(ctx, "Handle /rtc/v1/whip/ by %v", addr)
|
|
||||||
mux.HandleFunc("/rtc/v1/whip/", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if err := v.rtc.HandleApiForWHIP(ctx, w, r); err != nil {
|
|
||||||
apiError(ctx, w, r, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// The WebRTC WHEP API handler.
|
|
||||||
logger.Df(ctx, "Handle /rtc/v1/whep/ by %v", addr)
|
|
||||||
mux.HandleFunc("/rtc/v1/whep/", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if err := v.rtc.HandleApiForWHEP(ctx, w, r); err != nil {
|
|
||||||
apiError(ctx, w, r, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Run HTTP API server.
|
|
||||||
v.wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer v.wg.Done()
|
|
||||||
|
|
||||||
err := v.server.ListenAndServe()
|
|
||||||
if err != nil {
|
|
||||||
if ctx.Err() != context.Canceled {
|
|
||||||
// TODO: If HTTP API server closed unexpectedly, we should notice the main loop to quit.
|
|
||||||
logger.Wf(ctx, "HTTP API accept err %+v", err)
|
|
||||||
} else {
|
|
||||||
logger.Df(ctx, "HTTP API server done")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// systemAPI is the system HTTP API of the proxy server, for SRS media server to register the service
|
|
||||||
// to proxy server. It also provides some other system APIs like the status of proxy server, like exporter
|
|
||||||
// for Prometheus metrics.
|
|
||||||
type systemAPI struct {
|
|
||||||
// The underlayer HTTP server.
|
|
||||||
server *http.Server
|
|
||||||
// The gracefully quit timeout, wait server to quit.
|
|
||||||
gracefulQuitTimeout time.Duration
|
|
||||||
// The wait group for all goroutines.
|
|
||||||
wg sync.WaitGroup
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSystemAPI(opts ...func(*systemAPI)) *systemAPI {
|
|
||||||
v := &systemAPI{}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(v)
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *systemAPI) Close() error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
|
|
||||||
defer cancel()
|
|
||||||
v.server.Shutdown(ctx)
|
|
||||||
|
|
||||||
v.wg.Wait()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *systemAPI) Run(ctx context.Context) error {
|
|
||||||
// Parse address to listen.
|
|
||||||
addr := envSystemAPI()
|
|
||||||
if !strings.Contains(addr, ":") {
|
|
||||||
addr = ":" + addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create server and handler.
|
|
||||||
mux := http.NewServeMux()
|
|
||||||
v.server = &http.Server{Addr: addr, Handler: mux}
|
|
||||||
logger.Df(ctx, "System API server listen at %v", addr)
|
|
||||||
|
|
||||||
// Shutdown the server gracefully when quiting.
|
|
||||||
go func() {
|
|
||||||
ctxParent := ctx
|
|
||||||
<-ctxParent.Done()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
v.server.Shutdown(ctx)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// The basic version handler, also can be used as health check API.
|
|
||||||
logger.Df(ctx, "Handle /api/v1/versions by %v", addr)
|
|
||||||
mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
apiResponse(ctx, w, r, map[string]string{
|
|
||||||
"signature": Signature(),
|
|
||||||
"version": Version(),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
// The register service for SRS media servers.
|
|
||||||
logger.Df(ctx, "Handle /api/v1/srs/register by %v", addr)
|
|
||||||
mux.HandleFunc("/api/v1/srs/register", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if err := func() error {
|
|
||||||
var deviceID, ip, serverID, serviceID, pid string
|
|
||||||
var rtmp, stream, api, srt, rtc []string
|
|
||||||
if err := ParseBody(r.Body, &struct {
|
|
||||||
// The IP of SRS, mandatory.
|
|
||||||
IP *string `json:"ip"`
|
|
||||||
// The server id of SRS, store in file, may not change, mandatory.
|
|
||||||
ServerID *string `json:"server"`
|
|
||||||
// The service id of SRS, always change when restarted, mandatory.
|
|
||||||
ServiceID *string `json:"service"`
|
|
||||||
// The process id of SRS, always change when restarted, mandatory.
|
|
||||||
PID *string `json:"pid"`
|
|
||||||
// The RTMP listen endpoints, mandatory.
|
|
||||||
RTMP *[]string `json:"rtmp"`
|
|
||||||
// The HTTP Stream listen endpoints, optional.
|
|
||||||
HTTP *[]string `json:"http"`
|
|
||||||
// The API listen endpoints, optional.
|
|
||||||
API *[]string `json:"api"`
|
|
||||||
// The SRT listen endpoints, optional.
|
|
||||||
SRT *[]string `json:"srt"`
|
|
||||||
// The RTC listen endpoints, optional.
|
|
||||||
RTC *[]string `json:"rtc"`
|
|
||||||
// The device id of SRS, optional.
|
|
||||||
DeviceID *string `json:"device_id"`
|
|
||||||
}{
|
|
||||||
IP: &ip, DeviceID: &deviceID,
|
|
||||||
ServerID: &serverID, ServiceID: &serviceID, PID: &pid,
|
|
||||||
RTMP: &rtmp, HTTP: &stream, API: &api, SRT: &srt, RTC: &rtc,
|
|
||||||
}); err != nil {
|
|
||||||
return errors.Wrapf(err, "parse body")
|
|
||||||
}
|
|
||||||
|
|
||||||
if ip == "" {
|
|
||||||
return errors.Errorf("empty ip")
|
|
||||||
}
|
|
||||||
if serverID == "" {
|
|
||||||
return errors.Errorf("empty server")
|
|
||||||
}
|
|
||||||
if serviceID == "" {
|
|
||||||
return errors.Errorf("empty service")
|
|
||||||
}
|
|
||||||
if pid == "" {
|
|
||||||
return errors.Errorf("empty pid")
|
|
||||||
}
|
|
||||||
if len(rtmp) == 0 {
|
|
||||||
return errors.Errorf("empty rtmp")
|
|
||||||
}
|
|
||||||
|
|
||||||
server := NewSRSServer(func(srs *SRSServer) {
|
|
||||||
srs.IP, srs.DeviceID = ip, deviceID
|
|
||||||
srs.ServerID, srs.ServiceID, srs.PID = serverID, serviceID, pid
|
|
||||||
srs.RTMP, srs.HTTP, srs.API = rtmp, stream, api
|
|
||||||
srs.SRT, srs.RTC = srt, rtc
|
|
||||||
srs.UpdatedAt = time.Now()
|
|
||||||
})
|
|
||||||
if err := srsLoadBalancer.Update(ctx, server); err != nil {
|
|
||||||
return errors.Wrapf(err, "update SRS server %+v", server)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Df(ctx, "Register SRS media server, %+v", server)
|
|
||||||
return nil
|
|
||||||
}(); err != nil {
|
|
||||||
apiError(ctx, w, r, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Response struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
PID string `json:"pid"`
|
|
||||||
}
|
|
||||||
|
|
||||||
apiResponse(ctx, w, r, &Response{
|
|
||||||
Code: 0, PID: fmt.Sprintf("%v", os.Getpid()),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
// Run System API server.
|
|
||||||
v.wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer v.wg.Done()
|
|
||||||
|
|
||||||
err := v.server.ListenAndServe()
|
|
||||||
if err != nil {
|
|
||||||
if ctx.Err() != context.Canceled {
|
|
||||||
// TODO: If System API server closed unexpectedly, we should notice the main loop to quit.
|
|
||||||
logger.Wf(ctx, "System API accept err %+v", err)
|
|
||||||
} else {
|
|
||||||
logger.Df(ctx, "System API server done")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
// Copyright (c) 2025 Winlin
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"srs-proxy/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
func handleGoPprof(ctx context.Context) {
|
|
||||||
if addr := envGoPprof(); addr != "" {
|
|
||||||
go func() {
|
|
||||||
logger.Df(ctx, "Start Go pprof at %v", addr)
|
|
||||||
http.ListenAndServe(addr, nil)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
226
proxy/env.go
226
proxy/env.go
@@ -1,226 +0,0 @@
|
|||||||
// Copyright (c) 2025 Winlin
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"path"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"srs-proxy/errors"
|
|
||||||
"srs-proxy/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
// loadEnvFile loads the environment variables from file. Note that we only use .env file.
|
|
||||||
func loadEnvFile(ctx context.Context) error {
|
|
||||||
workDir, err := os.Getwd()
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "getpwd")
|
|
||||||
}
|
|
||||||
|
|
||||||
envFile := path.Join(workDir, ".env")
|
|
||||||
if _, err := os.Stat(envFile); err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
file, err := os.Open(envFile)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "open %v", envFile)
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
b, err := io.ReadAll(file)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "read %v", envFile)
|
|
||||||
}
|
|
||||||
|
|
||||||
lines := strings.Split(strings.Replace(string(b), "\r\n", "\n", -1), "\n")
|
|
||||||
logger.Df(ctx, "load env file %v, lines=%v", envFile, len(lines))
|
|
||||||
|
|
||||||
for _, line := range lines {
|
|
||||||
if strings.HasPrefix(strings.TrimSpace(line), "#") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if pos := strings.IndexByte(line, '='); pos > 0 {
|
|
||||||
key := strings.TrimSpace(line[:pos])
|
|
||||||
value := strings.TrimSpace(line[pos+1:])
|
|
||||||
if v := os.Getenv(key); v != "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
os.Setenv(key, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildDefaultEnvironmentVariables setups the default environment variables.
|
|
||||||
func buildDefaultEnvironmentVariables(ctx context.Context) {
|
|
||||||
// Whether enable the Go pprof.
|
|
||||||
setEnvDefault("GO_PPROF", "")
|
|
||||||
// Force shutdown timeout.
|
|
||||||
setEnvDefault("PROXY_FORCE_QUIT_TIMEOUT", "30s")
|
|
||||||
// Graceful quit timeout.
|
|
||||||
setEnvDefault("PROXY_GRACE_QUIT_TIMEOUT", "20s")
|
|
||||||
|
|
||||||
// The HTTP API server.
|
|
||||||
setEnvDefault("PROXY_HTTP_API", "11985")
|
|
||||||
// The HTTP web server.
|
|
||||||
setEnvDefault("PROXY_HTTP_SERVER", "18080")
|
|
||||||
// The RTMP media server.
|
|
||||||
setEnvDefault("PROXY_RTMP_SERVER", "11935")
|
|
||||||
// The WebRTC media server, via UDP protocol.
|
|
||||||
setEnvDefault("PROXY_WEBRTC_SERVER", "18000")
|
|
||||||
// The SRT media server, via UDP protocol.
|
|
||||||
setEnvDefault("PROXY_SRT_SERVER", "20080")
|
|
||||||
// The API server of proxy itself.
|
|
||||||
setEnvDefault("PROXY_SYSTEM_API", "12025")
|
|
||||||
// The static directory for web server.
|
|
||||||
setEnvDefault("PROXY_STATIC_FILES", "../trunk/research")
|
|
||||||
|
|
||||||
// The load balancer, use redis or memory.
|
|
||||||
setEnvDefault("PROXY_LOAD_BALANCER_TYPE", "memory")
|
|
||||||
// The redis server host.
|
|
||||||
setEnvDefault("PROXY_REDIS_HOST", "127.0.0.1")
|
|
||||||
// The redis server port.
|
|
||||||
setEnvDefault("PROXY_REDIS_PORT", "6379")
|
|
||||||
// The redis server password.
|
|
||||||
setEnvDefault("PROXY_REDIS_PASSWORD", "")
|
|
||||||
// The redis server db.
|
|
||||||
setEnvDefault("PROXY_REDIS_DB", "0")
|
|
||||||
|
|
||||||
// Whether enable the default backend server, for debugging.
|
|
||||||
setEnvDefault("PROXY_DEFAULT_BACKEND_ENABLED", "off")
|
|
||||||
// Default backend server IP, for debugging.
|
|
||||||
setEnvDefault("PROXY_DEFAULT_BACKEND_IP", "127.0.0.1")
|
|
||||||
// Default backend server port, for debugging.
|
|
||||||
setEnvDefault("PROXY_DEFAULT_BACKEND_RTMP", "1935")
|
|
||||||
// Default backend api port, for debugging.
|
|
||||||
setEnvDefault("PROXY_DEFAULT_BACKEND_API", "1985")
|
|
||||||
// Default backend udp rtc port, for debugging.
|
|
||||||
setEnvDefault("PROXY_DEFAULT_BACKEND_RTC", "8000")
|
|
||||||
// Default backend udp srt port, for debugging.
|
|
||||||
setEnvDefault("PROXY_DEFAULT_BACKEND_SRT", "10080")
|
|
||||||
|
|
||||||
logger.Df(ctx, "load .env as GO_PPROF=%v, "+
|
|
||||||
"PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+
|
|
||||||
"PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+
|
|
||||||
"PROXY_WEBRTC_SERVER=%v, PROXY_SRT_SERVER=%v, "+
|
|
||||||
"PROXY_SYSTEM_API=%v, PROXY_STATIC_FILES=%v, PROXY_DEFAULT_BACKEND_ENABLED=%v, "+
|
|
||||||
"PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_RTMP=%v, "+
|
|
||||||
"PROXY_DEFAULT_BACKEND_HTTP=%v, PROXY_DEFAULT_BACKEND_API=%v, "+
|
|
||||||
"PROXY_DEFAULT_BACKEND_RTC=%v, PROXY_DEFAULT_BACKEND_SRT=%v, "+
|
|
||||||
"PROXY_LOAD_BALANCER_TYPE=%v, PROXY_REDIS_HOST=%v, PROXY_REDIS_PORT=%v, "+
|
|
||||||
"PROXY_REDIS_PASSWORD=%v, PROXY_REDIS_DB=%v",
|
|
||||||
envGoPprof(),
|
|
||||||
envForceQuitTimeout(), envGraceQuitTimeout(),
|
|
||||||
envHttpAPI(), envHttpServer(), envRtmpServer(),
|
|
||||||
envWebRTCServer(), envSRTServer(),
|
|
||||||
envSystemAPI(), envStaticFiles(), envDefaultBackendEnabled(),
|
|
||||||
envDefaultBackendIP(), envDefaultBackendRTMP(),
|
|
||||||
envDefaultBackendHttp(), envDefaultBackendAPI(),
|
|
||||||
envDefaultBackendRTC(), envDefaultBackendSRT(),
|
|
||||||
envLoadBalancerType(), envRedisHost(), envRedisPort(),
|
|
||||||
envRedisPassword(), envRedisDB(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func envStaticFiles() string {
|
|
||||||
return os.Getenv("PROXY_STATIC_FILES")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envDefaultBackendSRT() string {
|
|
||||||
return os.Getenv("PROXY_DEFAULT_BACKEND_SRT")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envDefaultBackendRTC() string {
|
|
||||||
return os.Getenv("PROXY_DEFAULT_BACKEND_RTC")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envDefaultBackendAPI() string {
|
|
||||||
return os.Getenv("PROXY_DEFAULT_BACKEND_API")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envSRTServer() string {
|
|
||||||
return os.Getenv("PROXY_SRT_SERVER")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envWebRTCServer() string {
|
|
||||||
return os.Getenv("PROXY_WEBRTC_SERVER")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envDefaultBackendHttp() string {
|
|
||||||
return os.Getenv("PROXY_DEFAULT_BACKEND_HTTP")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envRedisDB() string {
|
|
||||||
return os.Getenv("PROXY_REDIS_DB")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envRedisPassword() string {
|
|
||||||
return os.Getenv("PROXY_REDIS_PASSWORD")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envRedisPort() string {
|
|
||||||
return os.Getenv("PROXY_REDIS_PORT")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envRedisHost() string {
|
|
||||||
return os.Getenv("PROXY_REDIS_HOST")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envLoadBalancerType() string {
|
|
||||||
return os.Getenv("PROXY_LOAD_BALANCER_TYPE")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envDefaultBackendRTMP() string {
|
|
||||||
return os.Getenv("PROXY_DEFAULT_BACKEND_RTMP")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envDefaultBackendIP() string {
|
|
||||||
return os.Getenv("PROXY_DEFAULT_BACKEND_IP")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envDefaultBackendEnabled() string {
|
|
||||||
return os.Getenv("PROXY_DEFAULT_BACKEND_ENABLED")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envGraceQuitTimeout() string {
|
|
||||||
return os.Getenv("PROXY_GRACE_QUIT_TIMEOUT")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envForceQuitTimeout() string {
|
|
||||||
return os.Getenv("PROXY_FORCE_QUIT_TIMEOUT")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envGoPprof() string {
|
|
||||||
return os.Getenv("GO_PPROF")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envSystemAPI() string {
|
|
||||||
return os.Getenv("PROXY_SYSTEM_API")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envRtmpServer() string {
|
|
||||||
return os.Getenv("PROXY_RTMP_SERVER")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envHttpServer() string {
|
|
||||||
return os.Getenv("PROXY_HTTP_SERVER")
|
|
||||||
}
|
|
||||||
|
|
||||||
func envHttpAPI() string {
|
|
||||||
return os.Getenv("PROXY_HTTP_API")
|
|
||||||
}
|
|
||||||
|
|
||||||
// setEnvDefault set env key=value if not set.
|
|
||||||
func setEnvDefault(key, value string) {
|
|
||||||
if os.Getenv(key) == "" {
|
|
||||||
os.Setenv(key, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,270 +0,0 @@
|
|||||||
// Package errors provides simple error handling primitives.
|
|
||||||
//
|
|
||||||
// The traditional error handling idiom in Go is roughly akin to
|
|
||||||
//
|
|
||||||
// if err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// which applied recursively up the call stack results in error reports
|
|
||||||
// without context or debugging information. The errors package allows
|
|
||||||
// programmers to add context to the failure path in their code in a way
|
|
||||||
// that does not destroy the original value of the error.
|
|
||||||
//
|
|
||||||
// Adding context to an error
|
|
||||||
//
|
|
||||||
// The errors.Wrap function returns a new error that adds context to the
|
|
||||||
// original error by recording a stack trace at the point Wrap is called,
|
|
||||||
// and the supplied message. For example
|
|
||||||
//
|
|
||||||
// _, err := ioutil.ReadAll(r)
|
|
||||||
// if err != nil {
|
|
||||||
// return errors.Wrap(err, "read failed")
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// If additional control is required the errors.WithStack and errors.WithMessage
|
|
||||||
// functions destructure errors.Wrap into its component operations of annotating
|
|
||||||
// an error with a stack trace and an a message, respectively.
|
|
||||||
//
|
|
||||||
// Retrieving the cause of an error
|
|
||||||
//
|
|
||||||
// Using errors.Wrap constructs a stack of errors, adding context to the
|
|
||||||
// preceding error. Depending on the nature of the error it may be necessary
|
|
||||||
// to reverse the operation of errors.Wrap to retrieve the original error
|
|
||||||
// for inspection. Any error value which implements this interface
|
|
||||||
//
|
|
||||||
// type causer interface {
|
|
||||||
// Cause() error
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// can be inspected by errors.Cause. errors.Cause will recursively retrieve
|
|
||||||
// the topmost error which does not implement causer, which is assumed to be
|
|
||||||
// the original cause. For example:
|
|
||||||
//
|
|
||||||
// switch err := errors.Cause(err).(type) {
|
|
||||||
// case *MyError:
|
|
||||||
// // handle specifically
|
|
||||||
// default:
|
|
||||||
// // unknown error
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// causer interface is not exported by this package, but is considered a part
|
|
||||||
// of stable public API.
|
|
||||||
//
|
|
||||||
// Formatted printing of errors
|
|
||||||
//
|
|
||||||
// All error values returned from this package implement fmt.Formatter and can
|
|
||||||
// be formatted by the fmt package. The following verbs are supported
|
|
||||||
//
|
|
||||||
// %s print the error. If the error has a Cause it will be
|
|
||||||
// printed recursively
|
|
||||||
// %v see %s
|
|
||||||
// %+v extended format. Each Frame of the error's StackTrace will
|
|
||||||
// be printed in detail.
|
|
||||||
//
|
|
||||||
// Retrieving the stack trace of an error or wrapper
|
|
||||||
//
|
|
||||||
// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are
|
|
||||||
// invoked. This information can be retrieved with the following interface.
|
|
||||||
//
|
|
||||||
// type stackTracer interface {
|
|
||||||
// StackTrace() errors.StackTrace
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// Where errors.StackTrace is defined as
|
|
||||||
//
|
|
||||||
// type StackTrace []Frame
|
|
||||||
//
|
|
||||||
// The Frame type represents a call site in the stack trace. Frame supports
|
|
||||||
// the fmt.Formatter interface that can be used for printing information about
|
|
||||||
// the stack trace of this error. For example:
|
|
||||||
//
|
|
||||||
// if err, ok := err.(stackTracer); ok {
|
|
||||||
// for _, f := range err.StackTrace() {
|
|
||||||
// fmt.Printf("%+s:%d", f)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// stackTracer interface is not exported by this package, but is considered a part
|
|
||||||
// of stable public API.
|
|
||||||
//
|
|
||||||
// See the documentation for Frame.Format for more details.
|
|
||||||
// Fork from https://github.com/pkg/errors
|
|
||||||
package errors
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
// New returns an error with the supplied message.
|
|
||||||
// New also records the stack trace at the point it was called.
|
|
||||||
func New(message string) error {
|
|
||||||
return &fundamental{
|
|
||||||
msg: message,
|
|
||||||
stack: callers(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Errorf formats according to a format specifier and returns the string
|
|
||||||
// as a value that satisfies error.
|
|
||||||
// Errorf also records the stack trace at the point it was called.
|
|
||||||
func Errorf(format string, args ...interface{}) error {
|
|
||||||
return &fundamental{
|
|
||||||
msg: fmt.Sprintf(format, args...),
|
|
||||||
stack: callers(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// fundamental is an error that has a message and a stack, but no caller.
|
|
||||||
type fundamental struct {
|
|
||||||
msg string
|
|
||||||
*stack
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fundamental) Error() string { return f.msg }
|
|
||||||
|
|
||||||
func (f *fundamental) Format(s fmt.State, verb rune) {
|
|
||||||
switch verb {
|
|
||||||
case 'v':
|
|
||||||
if s.Flag('+') {
|
|
||||||
io.WriteString(s, f.msg)
|
|
||||||
f.stack.Format(s, verb)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fallthrough
|
|
||||||
case 's':
|
|
||||||
io.WriteString(s, f.msg)
|
|
||||||
case 'q':
|
|
||||||
fmt.Fprintf(s, "%q", f.msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithStack annotates err with a stack trace at the point WithStack was called.
|
|
||||||
// If err is nil, WithStack returns nil.
|
|
||||||
func WithStack(err error) error {
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &withStack{
|
|
||||||
err,
|
|
||||||
callers(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type withStack struct {
|
|
||||||
error
|
|
||||||
*stack
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *withStack) Cause() error { return w.error }
|
|
||||||
|
|
||||||
func (w *withStack) Format(s fmt.State, verb rune) {
|
|
||||||
switch verb {
|
|
||||||
case 'v':
|
|
||||||
if s.Flag('+') {
|
|
||||||
fmt.Fprintf(s, "%+v", w.Cause())
|
|
||||||
w.stack.Format(s, verb)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fallthrough
|
|
||||||
case 's':
|
|
||||||
io.WriteString(s, w.Error())
|
|
||||||
case 'q':
|
|
||||||
fmt.Fprintf(s, "%q", w.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wrap returns an error annotating err with a stack trace
|
|
||||||
// at the point Wrap is called, and the supplied message.
|
|
||||||
// If err is nil, Wrap returns nil.
|
|
||||||
func Wrap(err error, message string) error {
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
err = &withMessage{
|
|
||||||
cause: err,
|
|
||||||
msg: message,
|
|
||||||
}
|
|
||||||
return &withStack{
|
|
||||||
err,
|
|
||||||
callers(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wrapf returns an error annotating err with a stack trace
|
|
||||||
// at the point Wrapf is call, and the format specifier.
|
|
||||||
// If err is nil, Wrapf returns nil.
|
|
||||||
func Wrapf(err error, format string, args ...interface{}) error {
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
err = &withMessage{
|
|
||||||
cause: err,
|
|
||||||
msg: fmt.Sprintf(format, args...),
|
|
||||||
}
|
|
||||||
return &withStack{
|
|
||||||
err,
|
|
||||||
callers(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithMessage annotates err with a new message.
|
|
||||||
// If err is nil, WithMessage returns nil.
|
|
||||||
func WithMessage(err error, message string) error {
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &withMessage{
|
|
||||||
cause: err,
|
|
||||||
msg: message,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type withMessage struct {
|
|
||||||
cause error
|
|
||||||
msg string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() }
|
|
||||||
func (w *withMessage) Cause() error { return w.cause }
|
|
||||||
|
|
||||||
func (w *withMessage) Format(s fmt.State, verb rune) {
|
|
||||||
switch verb {
|
|
||||||
case 'v':
|
|
||||||
if s.Flag('+') {
|
|
||||||
fmt.Fprintf(s, "%+v\n", w.Cause())
|
|
||||||
io.WriteString(s, w.msg)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fallthrough
|
|
||||||
case 's', 'q':
|
|
||||||
io.WriteString(s, w.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cause returns the underlying cause of the error, if possible.
|
|
||||||
// An error value has a cause if it implements the following
|
|
||||||
// interface:
|
|
||||||
//
|
|
||||||
// type causer interface {
|
|
||||||
// Cause() error
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// If the error does not implement Cause, the original error will
|
|
||||||
// be returned. If the error is nil, nil will be returned without further
|
|
||||||
// investigation.
|
|
||||||
func Cause(err error) error {
|
|
||||||
type causer interface {
|
|
||||||
Cause() error
|
|
||||||
}
|
|
||||||
|
|
||||||
for err != nil {
|
|
||||||
cause, ok := err.(causer)
|
|
||||||
if !ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
err = cause.Cause()
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
@@ -1,187 +0,0 @@
|
|||||||
// Fork from https://github.com/pkg/errors
|
|
||||||
package errors
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"path"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Frame represents a program counter inside a stack frame.
|
|
||||||
type Frame uintptr
|
|
||||||
|
|
||||||
// pc returns the program counter for this frame;
|
|
||||||
// multiple frames may have the same PC value.
|
|
||||||
func (f Frame) pc() uintptr { return uintptr(f) - 1 }
|
|
||||||
|
|
||||||
// file returns the full path to the file that contains the
|
|
||||||
// function for this Frame's pc.
|
|
||||||
func (f Frame) file() string {
|
|
||||||
fn := runtime.FuncForPC(f.pc())
|
|
||||||
if fn == nil {
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
file, _ := fn.FileLine(f.pc())
|
|
||||||
return file
|
|
||||||
}
|
|
||||||
|
|
||||||
// line returns the line number of source code of the
|
|
||||||
// function for this Frame's pc.
|
|
||||||
func (f Frame) line() int {
|
|
||||||
fn := runtime.FuncForPC(f.pc())
|
|
||||||
if fn == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
_, line := fn.FileLine(f.pc())
|
|
||||||
return line
|
|
||||||
}
|
|
||||||
|
|
||||||
// Format formats the frame according to the fmt.Formatter interface.
|
|
||||||
//
|
|
||||||
// %s source file
|
|
||||||
// %d source line
|
|
||||||
// %n function name
|
|
||||||
// %v equivalent to %s:%d
|
|
||||||
//
|
|
||||||
// Format accepts flags that alter the printing of some verbs, as follows:
|
|
||||||
//
|
|
||||||
// %+s path of source file relative to the compile time GOPATH
|
|
||||||
// %+v equivalent to %+s:%d
|
|
||||||
func (f Frame) Format(s fmt.State, verb rune) {
|
|
||||||
switch verb {
|
|
||||||
case 's':
|
|
||||||
switch {
|
|
||||||
case s.Flag('+'):
|
|
||||||
pc := f.pc()
|
|
||||||
fn := runtime.FuncForPC(pc)
|
|
||||||
if fn == nil {
|
|
||||||
io.WriteString(s, "unknown")
|
|
||||||
} else {
|
|
||||||
file, _ := fn.FileLine(pc)
|
|
||||||
fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
io.WriteString(s, path.Base(f.file()))
|
|
||||||
}
|
|
||||||
case 'd':
|
|
||||||
fmt.Fprintf(s, "%d", f.line())
|
|
||||||
case 'n':
|
|
||||||
name := runtime.FuncForPC(f.pc()).Name()
|
|
||||||
io.WriteString(s, funcname(name))
|
|
||||||
case 'v':
|
|
||||||
f.Format(s, 's')
|
|
||||||
io.WriteString(s, ":")
|
|
||||||
f.Format(s, 'd')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// StackTrace is stack of Frames from innermost (newest) to outermost (oldest).
|
|
||||||
type StackTrace []Frame
|
|
||||||
|
|
||||||
// Format formats the stack of Frames according to the fmt.Formatter interface.
|
|
||||||
//
|
|
||||||
// %s lists source files for each Frame in the stack
|
|
||||||
// %v lists the source file and line number for each Frame in the stack
|
|
||||||
//
|
|
||||||
// Format accepts flags that alter the printing of some verbs, as follows:
|
|
||||||
//
|
|
||||||
// %+v Prints filename, function, and line number for each Frame in the stack.
|
|
||||||
func (st StackTrace) Format(s fmt.State, verb rune) {
|
|
||||||
switch verb {
|
|
||||||
case 'v':
|
|
||||||
switch {
|
|
||||||
case s.Flag('+'):
|
|
||||||
for _, f := range st {
|
|
||||||
fmt.Fprintf(s, "\n%+v", f)
|
|
||||||
}
|
|
||||||
case s.Flag('#'):
|
|
||||||
fmt.Fprintf(s, "%#v", []Frame(st))
|
|
||||||
default:
|
|
||||||
fmt.Fprintf(s, "%v", []Frame(st))
|
|
||||||
}
|
|
||||||
case 's':
|
|
||||||
fmt.Fprintf(s, "%s", []Frame(st))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// stack represents a stack of program counters.
|
|
||||||
type stack []uintptr
|
|
||||||
|
|
||||||
func (s *stack) Format(st fmt.State, verb rune) {
|
|
||||||
switch verb {
|
|
||||||
case 'v':
|
|
||||||
switch {
|
|
||||||
case st.Flag('+'):
|
|
||||||
for _, pc := range *s {
|
|
||||||
f := Frame(pc)
|
|
||||||
fmt.Fprintf(st, "\n%+v", f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stack) StackTrace() StackTrace {
|
|
||||||
f := make([]Frame, len(*s))
|
|
||||||
for i := 0; i < len(f); i++ {
|
|
||||||
f[i] = Frame((*s)[i])
|
|
||||||
}
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
|
|
||||||
func callers() *stack {
|
|
||||||
const depth = 32
|
|
||||||
var pcs [depth]uintptr
|
|
||||||
n := runtime.Callers(3, pcs[:])
|
|
||||||
var st stack = pcs[0:n]
|
|
||||||
return &st
|
|
||||||
}
|
|
||||||
|
|
||||||
// funcname removes the path prefix component of a function's name reported by func.Name().
|
|
||||||
func funcname(name string) string {
|
|
||||||
i := strings.LastIndex(name, "/")
|
|
||||||
name = name[i+1:]
|
|
||||||
i = strings.Index(name, ".")
|
|
||||||
return name[i+1:]
|
|
||||||
}
|
|
||||||
|
|
||||||
func trimGOPATH(name, file string) string {
|
|
||||||
// Here we want to get the source file path relative to the compile time
|
|
||||||
// GOPATH. As of Go 1.6.x there is no direct way to know the compiled
|
|
||||||
// GOPATH at runtime, but we can infer the number of path segments in the
|
|
||||||
// GOPATH. We note that fn.Name() returns the function name qualified by
|
|
||||||
// the import path, which does not include the GOPATH. Thus we can trim
|
|
||||||
// segments from the beginning of the file path until the number of path
|
|
||||||
// separators remaining is one more than the number of path separators in
|
|
||||||
// the function name. For example, given:
|
|
||||||
//
|
|
||||||
// GOPATH /home/user
|
|
||||||
// file /home/user/src/pkg/sub/file.go
|
|
||||||
// fn.Name() pkg/sub.Type.Method
|
|
||||||
//
|
|
||||||
// We want to produce:
|
|
||||||
//
|
|
||||||
// pkg/sub/file.go
|
|
||||||
//
|
|
||||||
// From this we can easily see that fn.Name() has one less path separator
|
|
||||||
// than our desired output. We count separators from the end of the file
|
|
||||||
// path until it finds two more than in the function name and then move
|
|
||||||
// one character forward to preserve the initial path segment without a
|
|
||||||
// leading separator.
|
|
||||||
const sep = "/"
|
|
||||||
goal := strings.Count(name, sep) + 2
|
|
||||||
i := len(file)
|
|
||||||
for n := 0; n < goal; n++ {
|
|
||||||
i = strings.LastIndex(file[:i], sep)
|
|
||||||
if i == -1 {
|
|
||||||
// not enough separators found, set i so that the slice expression
|
|
||||||
// below leaves file unmodified
|
|
||||||
i = -len(sep)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// get back to 0 or trim the leading separator
|
|
||||||
file = file[i+len(sep):]
|
|
||||||
return file
|
|
||||||
}
|
|
||||||
10
proxy/go.mod
10
proxy/go.mod
@@ -1,10 +0,0 @@
|
|||||||
module srs-proxy
|
|
||||||
|
|
||||||
go 1.18
|
|
||||||
|
|
||||||
require github.com/go-redis/redis/v8 v8.11.5
|
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
|
||||||
)
|
|
||||||
15
proxy/go.sum
15
proxy/go.sum
@@ -1,15 +0,0 @@
|
|||||||
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
|
|
||||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
|
||||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
|
||||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
|
||||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
|
||||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
|
||||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
|
||||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
|
|
||||||
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0=
|
|
||||||
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM=
|
|
||||||
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
|
|
||||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
|
||||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
|
||||||
419
proxy/http.go
419
proxy/http.go
@@ -1,419 +0,0 @@
|
|||||||
// Copyright (c) 2025 Winlin
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
stdSync "sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"srs-proxy/errors"
|
|
||||||
"srs-proxy/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
// srsHTTPStreamServer is the proxy server for SRS HTTP stream server, for HTTP-FLV, HTTP-TS,
|
|
||||||
// HLS, etc. The proxy server will figure out which SRS origin server to proxy to, then proxy
|
|
||||||
// the request to the origin server.
|
|
||||||
type srsHTTPStreamServer struct {
|
|
||||||
// The underlayer HTTP server.
|
|
||||||
server *http.Server
|
|
||||||
// The gracefully quit timeout, wait server to quit.
|
|
||||||
gracefulQuitTimeout time.Duration
|
|
||||||
// The wait group for all goroutines.
|
|
||||||
wg stdSync.WaitGroup
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSRSHTTPStreamServer(opts ...func(*srsHTTPStreamServer)) *srsHTTPStreamServer {
|
|
||||||
v := &srsHTTPStreamServer{}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(v)
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsHTTPStreamServer) Close() error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
|
|
||||||
defer cancel()
|
|
||||||
v.server.Shutdown(ctx)
|
|
||||||
|
|
||||||
v.wg.Wait()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsHTTPStreamServer) Run(ctx context.Context) error {
|
|
||||||
// Parse address to listen.
|
|
||||||
addr := envHttpServer()
|
|
||||||
if !strings.Contains(addr, ":") {
|
|
||||||
addr = ":" + addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create server and handler.
|
|
||||||
mux := http.NewServeMux()
|
|
||||||
v.server = &http.Server{Addr: addr, Handler: mux}
|
|
||||||
logger.Df(ctx, "HTTP Stream server listen at %v", addr)
|
|
||||||
|
|
||||||
// Shutdown the server gracefully when quiting.
|
|
||||||
go func() {
|
|
||||||
ctxParent := ctx
|
|
||||||
<-ctxParent.Done()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
v.server.Shutdown(ctx)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// The basic version handler, also can be used as health check API.
|
|
||||||
logger.Df(ctx, "Handle /api/v1/versions by %v", addr)
|
|
||||||
mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
type Response struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
PID string `json:"pid"`
|
|
||||||
Data struct {
|
|
||||||
Major int `json:"major"`
|
|
||||||
Minor int `json:"minor"`
|
|
||||||
Revision int `json:"revision"`
|
|
||||||
Version string `json:"version"`
|
|
||||||
} `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
res := Response{Code: 0, PID: fmt.Sprintf("%v", os.Getpid())}
|
|
||||||
res.Data.Major = VersionMajor()
|
|
||||||
res.Data.Minor = VersionMinor()
|
|
||||||
res.Data.Revision = VersionRevision()
|
|
||||||
res.Data.Version = Version()
|
|
||||||
|
|
||||||
apiResponse(ctx, w, r, &res)
|
|
||||||
})
|
|
||||||
|
|
||||||
// The static web server, for the web pages.
|
|
||||||
var staticServer http.Handler
|
|
||||||
if staticFiles := envStaticFiles(); staticFiles != "" {
|
|
||||||
if _, err := os.Stat(staticFiles); err != nil {
|
|
||||||
return errors.Wrapf(err, "invalid static files %v", staticFiles)
|
|
||||||
}
|
|
||||||
|
|
||||||
staticServer = http.FileServer(http.Dir(staticFiles))
|
|
||||||
logger.Df(ctx, "Handle static files at %v", staticFiles)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The default handler, for both static web server and streaming server.
|
|
||||||
logger.Df(ctx, "Handle / by %v", addr)
|
|
||||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// For HLS streaming, we will proxy the request to the streaming server.
|
|
||||||
if strings.HasSuffix(r.URL.Path, ".m3u8") {
|
|
||||||
unifiedURL, fullURL := convertURLToStreamURL(r)
|
|
||||||
streamURL, err := buildStreamURL(unifiedURL)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("build stream url by %v from %v", unifiedURL, fullURL), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
stream, _ := srsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, NewHLSPlayStream(func(s *HLSPlayStream) {
|
|
||||||
s.SRSProxyBackendHLSID = logger.GenerateContextID()
|
|
||||||
s.StreamURL, s.FullURL = streamURL, fullURL
|
|
||||||
}))
|
|
||||||
|
|
||||||
stream.Initialize(ctx).ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// For HTTP streaming, we will proxy the request to the streaming server.
|
|
||||||
if strings.HasSuffix(r.URL.Path, ".flv") ||
|
|
||||||
strings.HasSuffix(r.URL.Path, ".ts") {
|
|
||||||
// If SPBHID is specified, it must be a HLS stream client.
|
|
||||||
if srsProxyBackendID := r.URL.Query().Get("spbhid"); srsProxyBackendID != "" {
|
|
||||||
if stream, err := srsLoadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("load stream by spbhid %v", srsProxyBackendID), http.StatusBadRequest)
|
|
||||||
} else {
|
|
||||||
stream.Initialize(ctx).ServeHTTP(w, r)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use HTTP pseudo streaming to proxy the request.
|
|
||||||
NewHTTPFlvTsConnection(func(c *HTTPFlvTsConnection) {
|
|
||||||
c.ctx = ctx
|
|
||||||
}).ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Serve by static server.
|
|
||||||
if staticServer != nil {
|
|
||||||
staticServer.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
http.NotFound(w, r)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Run HTTP server.
|
|
||||||
v.wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer v.wg.Done()
|
|
||||||
|
|
||||||
err := v.server.ListenAndServe()
|
|
||||||
if err != nil {
|
|
||||||
if ctx.Err() != context.Canceled {
|
|
||||||
// TODO: If HTTP Stream server closed unexpectedly, we should notice the main loop to quit.
|
|
||||||
logger.Wf(ctx, "HTTP Stream accept err %+v", err)
|
|
||||||
} else {
|
|
||||||
logger.Df(ctx, "HTTP Stream server done")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// HTTPFlvTsConnection is an HTTP pseudo streaming connection, such as an HTTP-FLV or HTTP-TS
|
|
||||||
// connection. There is no state need to be sync between proxy servers.
|
|
||||||
//
|
|
||||||
// When we got an HTTP FLV or TS request, we will parse the stream URL from the HTTP request,
|
|
||||||
// then proxy to the corresponding backend server. All state is in the HTTP request, so this
|
|
||||||
// connection is stateless.
|
|
||||||
type HTTPFlvTsConnection struct {
|
|
||||||
// The context for HTTP streaming.
|
|
||||||
ctx context.Context
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHTTPFlvTsConnection(opts ...func(*HTTPFlvTsConnection)) *HTTPFlvTsConnection {
|
|
||||||
v := &HTTPFlvTsConnection{}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(v)
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *HTTPFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
defer r.Body.Close()
|
|
||||||
ctx := logger.WithContext(v.ctx)
|
|
||||||
|
|
||||||
if err := v.serve(ctx, w, r); err != nil {
|
|
||||||
apiError(ctx, w, r, err)
|
|
||||||
} else {
|
|
||||||
logger.Df(ctx, "HTTP client done")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *HTTPFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
|
|
||||||
// Always allow CORS for all requests.
|
|
||||||
if ok := apiCORS(ctx, w, r); ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build the stream URL in vhost/app/stream schema.
|
|
||||||
unifiedURL, fullURL := convertURLToStreamURL(r)
|
|
||||||
logger.Df(ctx, "Got HTTP client from %v for %v", r.RemoteAddr, fullURL)
|
|
||||||
|
|
||||||
streamURL, err := buildStreamURL(unifiedURL)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "build stream url %v", unifiedURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pick a backend SRS server to proxy the RTMP stream.
|
|
||||||
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "pick backend for %v", streamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = v.serveByBackend(ctx, w, r, backend); err != nil {
|
|
||||||
return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *HTTPFlvTsConnection) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error {
|
|
||||||
// Parse HTTP port from backend.
|
|
||||||
if len(backend.HTTP) == 0 {
|
|
||||||
return errors.Errorf("no http stream server")
|
|
||||||
}
|
|
||||||
|
|
||||||
var httpPort int
|
|
||||||
if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil {
|
|
||||||
return errors.Wrapf(err, "parse http port %v", backend.HTTP[0])
|
|
||||||
} else {
|
|
||||||
httpPort = int(iv)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect to backend SRS server via HTTP client.
|
|
||||||
backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path)
|
|
||||||
req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "create request to %v", backendURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "do request to %v", backendURL)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy all headers from backend to client.
|
|
||||||
w.WriteHeader(resp.StatusCode)
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
for _, vv := range v {
|
|
||||||
w.Header().Add(k, vv)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Df(ctx, "HTTP start streaming")
|
|
||||||
|
|
||||||
// Proxy the stream from backend to client.
|
|
||||||
if _, err := io.Copy(w, resp.Body); err != nil {
|
|
||||||
return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// HLSPlayStream is an HLS stream proxy, which represents the stream level object. This means multiple HLS
|
|
||||||
// clients will share this object, and they do not use the same ctx among proxy servers.
|
|
||||||
//
|
|
||||||
// Unlike the HTTP FLV or TS connection, HLS client may request the m3u8 or ts via different HTTP connections.
|
|
||||||
// Especially for requesting ts, we need to identify the stream URl or backend server for it. So we create
|
|
||||||
// the spbhid which can be seen as the hash of stream URL or backend server. The spbhid enable us to convert
|
|
||||||
// to the stream URL and then query the backend server to serve it.
|
|
||||||
type HLSPlayStream struct {
|
|
||||||
// The context for HLS streaming.
|
|
||||||
ctx context.Context
|
|
||||||
|
|
||||||
// The spbhid, used to identify the backend server.
|
|
||||||
SRSProxyBackendHLSID string `json:"spbhid"`
|
|
||||||
// The stream URL in vhost/app/stream schema.
|
|
||||||
StreamURL string `json:"stream_url"`
|
|
||||||
// The full request URL for HLS streaming
|
|
||||||
FullURL string `json:"full_url"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHLSPlayStream(opts ...func(*HLSPlayStream)) *HLSPlayStream {
|
|
||||||
v := &HLSPlayStream{}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(v)
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *HLSPlayStream) Initialize(ctx context.Context) *HLSPlayStream {
|
|
||||||
if v.ctx == nil {
|
|
||||||
v.ctx = logger.WithContext(ctx)
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *HLSPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
defer r.Body.Close()
|
|
||||||
|
|
||||||
if err := v.serve(v.ctx, w, r); err != nil {
|
|
||||||
apiError(v.ctx, w, r, err)
|
|
||||||
} else {
|
|
||||||
logger.Df(v.ctx, "HLS client %v for %v with %v done",
|
|
||||||
v.SRSProxyBackendHLSID, v.StreamURL, r.URL.Path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *HLSPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
|
|
||||||
ctx, streamURL, fullURL := v.ctx, v.StreamURL, v.FullURL
|
|
||||||
|
|
||||||
// Always allow CORS for all requests.
|
|
||||||
if ok := apiCORS(ctx, w, r); ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pick a backend SRS server to proxy the RTMP stream.
|
|
||||||
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "pick backend for %v", streamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = v.serveByBackend(ctx, w, r, backend); err != nil {
|
|
||||||
return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *HLSPlayStream) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error {
|
|
||||||
// Parse HTTP port from backend.
|
|
||||||
if len(backend.HTTP) == 0 {
|
|
||||||
return errors.Errorf("no rtmp server")
|
|
||||||
}
|
|
||||||
|
|
||||||
var httpPort int
|
|
||||||
if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil {
|
|
||||||
return errors.Wrapf(err, "parse http port %v", backend.HTTP[0])
|
|
||||||
} else {
|
|
||||||
httpPort = int(iv)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect to backend SRS server via HTTP client.
|
|
||||||
backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path)
|
|
||||||
if r.URL.RawQuery != "" {
|
|
||||||
backendURL += "?" + r.URL.RawQuery
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "create request to %v", backendURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Errorf("do request to %v EOF", backendURL)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy all headers from backend to client.
|
|
||||||
w.WriteHeader(resp.StatusCode)
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
for _, vv := range v {
|
|
||||||
w.Header().Add(k, vv)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// For TS file, directly copy it.
|
|
||||||
if !strings.HasSuffix(r.URL.Path, ".m3u8") {
|
|
||||||
if _, err := io.Copy(w, resp.Body); err != nil {
|
|
||||||
return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read all content of m3u8, append the stream ID to ts URL. Note that we only append stream ID to ts
|
|
||||||
// URL, to identify the stream to specified backend server. The spbhid is the SRS Proxy Backend HLS ID.
|
|
||||||
b, err := ioutil.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "read stream from %v", backendURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
m3u8 := string(b)
|
|
||||||
if strings.Contains(m3u8, ".ts?") {
|
|
||||||
m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbhid=%v&&", v.SRSProxyBackendHLSID))
|
|
||||||
} else {
|
|
||||||
m3u8 = strings.ReplaceAll(m3u8, ".ts", fmt.Sprintf(".ts?spbhid=%v", v.SRSProxyBackendHLSID))
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.Copy(w, strings.NewReader(m3u8)); err != nil {
|
|
||||||
return errors.Wrapf(err, "proxy m3u8 client to %v", backendURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
// Copyright (c) 2025 Winlin
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
package logger
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
)
|
|
||||||
|
|
||||||
type key string
|
|
||||||
|
|
||||||
var cidKey key = "cid.proxy.ossrs.org"
|
|
||||||
|
|
||||||
// generateContextID generates a random context id in string.
|
|
||||||
func GenerateContextID() string {
|
|
||||||
randomBytes := make([]byte, 32)
|
|
||||||
_, _ = rand.Read(randomBytes)
|
|
||||||
hash := sha256.Sum256(randomBytes)
|
|
||||||
hashString := hex.EncodeToString(hash[:])
|
|
||||||
cid := hashString[:7]
|
|
||||||
return cid
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithContext creates a new context with cid, which will be used for log.
|
|
||||||
func WithContext(ctx context.Context) context.Context {
|
|
||||||
return WithContextID(ctx, GenerateContextID())
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithContextID creates a new context with cid, which will be used for log.
|
|
||||||
func WithContextID(ctx context.Context, cid string) context.Context {
|
|
||||||
return context.WithValue(ctx, cidKey, cid)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ContextID returns the cid in context, or empty string if not set.
|
|
||||||
func ContextID(ctx context.Context) string {
|
|
||||||
if cid, ok := ctx.Value(cidKey).(string); ok {
|
|
||||||
return cid
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
// Copyright (c) 2025 Winlin
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
package logger
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"io/ioutil"
|
|
||||||
stdLog "log"
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
type logger interface {
|
|
||||||
Printf(ctx context.Context, format string, v ...any)
|
|
||||||
}
|
|
||||||
|
|
||||||
type loggerPlus struct {
|
|
||||||
logger *stdLog.Logger
|
|
||||||
level string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newLoggerPlus(opts ...func(*loggerPlus)) *loggerPlus {
|
|
||||||
v := &loggerPlus{}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(v)
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *loggerPlus) Printf(ctx context.Context, f string, a ...interface{}) {
|
|
||||||
format, args := f, a
|
|
||||||
if cid := ContextID(ctx); cid != "" {
|
|
||||||
format, args = "[%v][%v][%v] "+format, append([]interface{}{v.level, os.Getpid(), cid}, a...)
|
|
||||||
}
|
|
||||||
|
|
||||||
v.logger.Printf(format, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
var verboseLogger logger
|
|
||||||
|
|
||||||
func Vf(ctx context.Context, format string, a ...interface{}) {
|
|
||||||
verboseLogger.Printf(ctx, format, a...)
|
|
||||||
}
|
|
||||||
|
|
||||||
var debugLogger logger
|
|
||||||
|
|
||||||
func Df(ctx context.Context, format string, a ...interface{}) {
|
|
||||||
debugLogger.Printf(ctx, format, a...)
|
|
||||||
}
|
|
||||||
|
|
||||||
var warnLogger logger
|
|
||||||
|
|
||||||
func Wf(ctx context.Context, format string, a ...interface{}) {
|
|
||||||
warnLogger.Printf(ctx, format, a...)
|
|
||||||
}
|
|
||||||
|
|
||||||
var errorLogger logger
|
|
||||||
|
|
||||||
func Ef(ctx context.Context, format string, a ...interface{}) {
|
|
||||||
errorLogger.Printf(ctx, format, a...)
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
logVerboseLabel = "verb"
|
|
||||||
logDebugLabel = "debug"
|
|
||||||
logWarnLabel = "warn"
|
|
||||||
logErrorLabel = "error"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
verboseLogger = newLoggerPlus(func(logger *loggerPlus) {
|
|
||||||
logger.logger = stdLog.New(ioutil.Discard, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
|
|
||||||
logger.level = logVerboseLabel
|
|
||||||
})
|
|
||||||
debugLogger = newLoggerPlus(func(logger *loggerPlus) {
|
|
||||||
logger.logger = stdLog.New(os.Stdout, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
|
|
||||||
logger.level = logDebugLabel
|
|
||||||
})
|
|
||||||
warnLogger = newLoggerPlus(func(logger *loggerPlus) {
|
|
||||||
logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
|
|
||||||
logger.level = logWarnLabel
|
|
||||||
})
|
|
||||||
errorLogger = newLoggerPlus(func(logger *loggerPlus) {
|
|
||||||
logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
|
|
||||||
logger.level = logErrorLabel
|
|
||||||
})
|
|
||||||
}
|
|
||||||
121
proxy/main.go
121
proxy/main.go
@@ -1,121 +0,0 @@
|
|||||||
// Copyright (c) 2025 Winlin
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"srs-proxy/errors"
|
|
||||||
"srs-proxy/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
ctx := logger.WithContext(context.Background())
|
|
||||||
logger.Df(ctx, "%v/%v started", Signature(), Version())
|
|
||||||
|
|
||||||
// Install signals.
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
installSignals(ctx, cancel)
|
|
||||||
|
|
||||||
// Start the main loop, ignore the user cancel error.
|
|
||||||
err := doMain(ctx)
|
|
||||||
if err != nil && ctx.Err() != context.Canceled {
|
|
||||||
logger.Ef(ctx, "main: %+v", err)
|
|
||||||
os.Exit(-1)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Df(ctx, "%v done", Signature())
|
|
||||||
}
|
|
||||||
|
|
||||||
func doMain(ctx context.Context) error {
|
|
||||||
// Setup the environment variables.
|
|
||||||
if err := loadEnvFile(ctx); err != nil {
|
|
||||||
return errors.Wrapf(err, "load env")
|
|
||||||
}
|
|
||||||
|
|
||||||
buildDefaultEnvironmentVariables(ctx)
|
|
||||||
|
|
||||||
// When cancelled, the program is forced to exit due to a timeout. Normally, this doesn't occur
|
|
||||||
// because the main thread exits after the context is cancelled. However, sometimes the main thread
|
|
||||||
// may be blocked for some reason, so a forced exit is necessary to ensure the program terminates.
|
|
||||||
if err := installForceQuit(ctx); err != nil {
|
|
||||||
return errors.Wrapf(err, "install force quit")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the Go pprof if enabled.
|
|
||||||
handleGoPprof(ctx)
|
|
||||||
|
|
||||||
// Initialize SRS load balancers.
|
|
||||||
switch lbType := envLoadBalancerType(); lbType {
|
|
||||||
case "memory":
|
|
||||||
srsLoadBalancer = NewMemoryLoadBalancer()
|
|
||||||
case "redis":
|
|
||||||
srsLoadBalancer = NewRedisLoadBalancer()
|
|
||||||
default:
|
|
||||||
return errors.Errorf("invalid load balancer %v", lbType)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := srsLoadBalancer.Initialize(ctx); err != nil {
|
|
||||||
return errors.Wrapf(err, "initialize srs load balancer")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the gracefully quit timeout.
|
|
||||||
gracefulQuitTimeout, err := parseGracefullyQuitTimeout()
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "parse gracefully quit timeout")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the RTMP server.
|
|
||||||
srsRTMPServer := NewSRSRTMPServer()
|
|
||||||
defer srsRTMPServer.Close()
|
|
||||||
if err := srsRTMPServer.Run(ctx); err != nil {
|
|
||||||
return errors.Wrapf(err, "rtmp server")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the WebRTC server.
|
|
||||||
srsWebRTCServer := NewSRSWebRTCServer()
|
|
||||||
defer srsWebRTCServer.Close()
|
|
||||||
if err := srsWebRTCServer.Run(ctx); err != nil {
|
|
||||||
return errors.Wrapf(err, "rtc server")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the HTTP API server.
|
|
||||||
srsHTTPAPIServer := NewSRSHTTPAPIServer(func(server *srsHTTPAPIServer) {
|
|
||||||
server.gracefulQuitTimeout, server.rtc = gracefulQuitTimeout, srsWebRTCServer
|
|
||||||
})
|
|
||||||
defer srsHTTPAPIServer.Close()
|
|
||||||
if err := srsHTTPAPIServer.Run(ctx); err != nil {
|
|
||||||
return errors.Wrapf(err, "http api server")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the SRT server.
|
|
||||||
srsSRTServer := NewSRSSRTServer()
|
|
||||||
defer srsSRTServer.Close()
|
|
||||||
if err := srsSRTServer.Run(ctx); err != nil {
|
|
||||||
return errors.Wrapf(err, "srt server")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the System API server.
|
|
||||||
systemAPI := NewSystemAPI(func(server *systemAPI) {
|
|
||||||
server.gracefulQuitTimeout = gracefulQuitTimeout
|
|
||||||
})
|
|
||||||
defer systemAPI.Close()
|
|
||||||
if err := systemAPI.Run(ctx); err != nil {
|
|
||||||
return errors.Wrapf(err, "system api server")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the HTTP web server.
|
|
||||||
srsHTTPStreamServer := NewSRSHTTPStreamServer(func(server *srsHTTPStreamServer) {
|
|
||||||
server.gracefulQuitTimeout = gracefulQuitTimeout
|
|
||||||
})
|
|
||||||
defer srsHTTPStreamServer.Close()
|
|
||||||
if err := srsHTTPStreamServer.Run(ctx); err != nil {
|
|
||||||
return errors.Wrapf(err, "http server")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for the main loop to quit.
|
|
||||||
<-ctx.Done()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
515
proxy/rtc.go
515
proxy/rtc.go
@@ -1,515 +0,0 @@
|
|||||||
// Copyright (c) 2025 Winlin
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
stdSync "sync"
|
|
||||||
|
|
||||||
"srs-proxy/errors"
|
|
||||||
"srs-proxy/logger"
|
|
||||||
"srs-proxy/sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// srsWebRTCServer is the proxy for SRS WebRTC server via WHIP or WHEP protocol. It will figure out
|
|
||||||
// which backend server to proxy to. It will also replace the UDP port to the proxy server's in the
|
|
||||||
// SDP answer.
|
|
||||||
type srsWebRTCServer struct {
|
|
||||||
// The UDP listener for WebRTC server.
|
|
||||||
listener *net.UDPConn
|
|
||||||
|
|
||||||
// Fast cache for the username to identify the connection.
|
|
||||||
// The key is username, the value is the UDP address.
|
|
||||||
usernames sync.Map[string, *RTCConnection]
|
|
||||||
// Fast cache for the udp address to identify the connection.
|
|
||||||
// The key is UDP address, the value is the username.
|
|
||||||
// TODO: Support fast earch by uint64 address.
|
|
||||||
addresses sync.Map[string, *RTCConnection]
|
|
||||||
|
|
||||||
// The wait group for server.
|
|
||||||
wg stdSync.WaitGroup
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSRSWebRTCServer(opts ...func(*srsWebRTCServer)) *srsWebRTCServer {
|
|
||||||
v := &srsWebRTCServer{}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(v)
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsWebRTCServer) Close() error {
|
|
||||||
if v.listener != nil {
|
|
||||||
_ = v.listener.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
v.wg.Wait()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsWebRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
|
|
||||||
defer r.Body.Close()
|
|
||||||
ctx = logger.WithContext(ctx)
|
|
||||||
|
|
||||||
// Always allow CORS for all requests.
|
|
||||||
if ok := apiCORS(ctx, w, r); ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read remote SDP offer from body.
|
|
||||||
remoteSDPOffer, err := ioutil.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "read remote sdp offer")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build the stream URL in vhost/app/stream schema.
|
|
||||||
unifiedURL, fullURL := convertURLToStreamURL(r)
|
|
||||||
logger.Df(ctx, "Got WebRTC WHIP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL)
|
|
||||||
|
|
||||||
streamURL, err := buildStreamURL(unifiedURL)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "build stream url %v", unifiedURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pick a backend SRS server to proxy the RTMP stream.
|
|
||||||
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "pick backend for %v", streamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil {
|
|
||||||
return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsWebRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
|
|
||||||
defer r.Body.Close()
|
|
||||||
ctx = logger.WithContext(ctx)
|
|
||||||
|
|
||||||
// Always allow CORS for all requests.
|
|
||||||
if ok := apiCORS(ctx, w, r); ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read remote SDP offer from body.
|
|
||||||
remoteSDPOffer, err := ioutil.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "read remote sdp offer")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build the stream URL in vhost/app/stream schema.
|
|
||||||
unifiedURL, fullURL := convertURLToStreamURL(r)
|
|
||||||
logger.Df(ctx, "Got WebRTC WHEP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL)
|
|
||||||
|
|
||||||
streamURL, err := buildStreamURL(unifiedURL)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "build stream url %v", unifiedURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pick a backend SRS server to proxy the RTMP stream.
|
|
||||||
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "pick backend for %v", streamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil {
|
|
||||||
return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsWebRTCServer) proxyApiToBackend(
|
|
||||||
ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer,
|
|
||||||
remoteSDPOffer string, streamURL string,
|
|
||||||
) error {
|
|
||||||
// Parse HTTP port from backend.
|
|
||||||
if len(backend.API) == 0 {
|
|
||||||
return errors.Errorf("no http api server")
|
|
||||||
}
|
|
||||||
|
|
||||||
var apiPort int
|
|
||||||
if iv, err := strconv.ParseInt(backend.API[0], 10, 64); err != nil {
|
|
||||||
return errors.Wrapf(err, "parse http port %v", backend.API[0])
|
|
||||||
} else {
|
|
||||||
apiPort = int(iv)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect to backend SRS server via HTTP client.
|
|
||||||
backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, apiPort, r.URL.Path)
|
|
||||||
if r.URL.RawQuery != "" {
|
|
||||||
backendURL += "?" + r.URL.RawQuery
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, strings.NewReader(remoteSDPOffer))
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "create request to %v", backendURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Errorf("do request to %v EOF", backendURL)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
|
||||||
return errors.Errorf("proxy api to %v failed, status=%v", backendURL, resp.Status)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy all headers from backend to client.
|
|
||||||
w.WriteHeader(resp.StatusCode)
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
for _, vv := range v {
|
|
||||||
w.Header().Add(k, vv)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the local SDP answer from backend.
|
|
||||||
b, err := ioutil.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "read stream from %v", backendURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace the WebRTC UDP port in answer.
|
|
||||||
localSDPAnswer := string(b)
|
|
||||||
for _, endpoint := range backend.RTC {
|
|
||||||
_, _, port, err := parseListenEndpoint(endpoint)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "parse endpoint %v", endpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
from := fmt.Sprintf(" %v typ host", port)
|
|
||||||
to := fmt.Sprintf(" %v typ host", envWebRTCServer())
|
|
||||||
localSDPAnswer = strings.Replace(localSDPAnswer, from, to, -1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fetch the ice-ufrag and ice-pwd from local SDP answer.
|
|
||||||
remoteICEUfrag, remoteICEPwd, err := parseIceUfragPwd(remoteSDPOffer)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "parse remote sdp offer")
|
|
||||||
}
|
|
||||||
|
|
||||||
localICEUfrag, localICEPwd, err := parseIceUfragPwd(localSDPAnswer)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "parse local sdp answer")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save the new WebRTC connection to LB.
|
|
||||||
icePair := &RTCICEPair{
|
|
||||||
RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd,
|
|
||||||
LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd,
|
|
||||||
}
|
|
||||||
if err := srsLoadBalancer.StoreWebRTC(ctx, streamURL, NewRTCConnection(func(c *RTCConnection) {
|
|
||||||
c.StreamURL, c.Ufrag = streamURL, icePair.Ufrag()
|
|
||||||
c.Initialize(ctx, v.listener)
|
|
||||||
|
|
||||||
// Cache the connection for fast search by username.
|
|
||||||
v.usernames.Store(c.Ufrag, c)
|
|
||||||
})); err != nil {
|
|
||||||
return errors.Wrapf(err, "load or store webrtc %v", streamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Response client with local answer.
|
|
||||||
if _, err = w.Write([]byte(localSDPAnswer)); err != nil {
|
|
||||||
return errors.Wrapf(err, "write local sdp answer %v", localSDPAnswer)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Df(ctx, "Create WebRTC connection with local answer %vB with ice-ufrag=%v, ice-pwd=%vB",
|
|
||||||
len(localSDPAnswer), localICEUfrag, len(localICEPwd))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsWebRTCServer) Run(ctx context.Context) error {
|
|
||||||
// Parse address to listen.
|
|
||||||
endpoint := envWebRTCServer()
|
|
||||||
if !strings.Contains(endpoint, ":") {
|
|
||||||
endpoint = fmt.Sprintf(":%v", endpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
saddr, err := net.ResolveUDPAddr("udp", endpoint)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "resolve udp addr %v", endpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
listener, err := net.ListenUDP("udp", saddr)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "listen udp %v", saddr)
|
|
||||||
}
|
|
||||||
v.listener = listener
|
|
||||||
logger.Df(ctx, "WebRTC server listen at %v", saddr)
|
|
||||||
|
|
||||||
// Consume all messages from UDP media transport.
|
|
||||||
v.wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer v.wg.Done()
|
|
||||||
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
buf := make([]byte, 4096)
|
|
||||||
n, caddr, err := listener.ReadFromUDP(buf)
|
|
||||||
if err != nil {
|
|
||||||
// TODO: If WebRTC server closed unexpectedly, we should notice the main loop to quit.
|
|
||||||
logger.Wf(ctx, "read from udp failed, err=%+v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil {
|
|
||||||
logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsWebRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error {
|
|
||||||
var connection *RTCConnection
|
|
||||||
|
|
||||||
// If STUN binding request, parse the ufrag and identify the connection.
|
|
||||||
if err := func() error {
|
|
||||||
if rtcIsRTPOrRTCP(data) || !rtcIsSTUN(data) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var pkt RTCStunPacket
|
|
||||||
if err := pkt.UnmarshalBinary(data); err != nil {
|
|
||||||
return errors.Wrapf(err, "unmarshal stun packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Search the connection in fast cache.
|
|
||||||
if s, ok := v.usernames.Load(pkt.Username); ok {
|
|
||||||
connection = s
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load connection by username.
|
|
||||||
if s, err := srsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil {
|
|
||||||
return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username)
|
|
||||||
} else {
|
|
||||||
connection = s.Initialize(ctx, v.listener)
|
|
||||||
logger.Df(ctx, "Create WebRTC connection by ufrag=%v, stream=%v", pkt.Username, connection.StreamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cache connection for fast search.
|
|
||||||
if connection != nil {
|
|
||||||
v.usernames.Store(pkt.Username, connection)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Search the connection by addr.
|
|
||||||
if s, ok := v.addresses.Load(addr.String()); ok {
|
|
||||||
connection = s
|
|
||||||
} else if connection != nil {
|
|
||||||
// Cache the address for fast search.
|
|
||||||
v.addresses.Store(addr.String(), connection)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If connection is not found, ignore the packet.
|
|
||||||
if connection == nil {
|
|
||||||
// TODO: Should logging the dropped packet, only logging the first one for each address.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Proxy the packet to backend.
|
|
||||||
if err := connection.HandlePacket(addr, data); err != nil {
|
|
||||||
return errors.Wrapf(err, "proxy %vB for %v", len(data), connection.StreamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RTCConnection is a WebRTC connection proxy, for both WHIP and WHEP. It represents a WebRTC
|
|
||||||
// connection, identify by the ufrag in sdp offer/answer and ICE binding request.
|
|
||||||
//
|
|
||||||
// It's not like RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is
|
|
||||||
// in the client request. The RTCConnection is stateful, and need to sync the ufrag between
|
|
||||||
// proxy servers.
|
|
||||||
//
|
|
||||||
// The media transport is UDP, which is also a special thing for WebRTC. So if the client switch
|
|
||||||
// to another UDP address, it may connect to another WebRTC proxy, then we should discover the
|
|
||||||
// RTCConnection by the ufrag from the ICE binding request.
|
|
||||||
type RTCConnection struct {
|
|
||||||
// The stream context for WebRTC streaming.
|
|
||||||
ctx context.Context
|
|
||||||
|
|
||||||
// The stream URL in vhost/app/stream schema.
|
|
||||||
StreamURL string `json:"stream_url"`
|
|
||||||
// The ufrag for this WebRTC connection.
|
|
||||||
Ufrag string `json:"ufrag"`
|
|
||||||
|
|
||||||
// The UDP connection proxy to backend.
|
|
||||||
backendUDP *net.UDPConn
|
|
||||||
// The client UDP address. Note that it may change.
|
|
||||||
clientUDP *net.UDPAddr
|
|
||||||
// The listener UDP connection, used to send messages to client.
|
|
||||||
listenerUDP *net.UDPConn
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRTCConnection(opts ...func(*RTCConnection)) *RTCConnection {
|
|
||||||
v := &RTCConnection{}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(v)
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *RTCConnection) Initialize(ctx context.Context, listener *net.UDPConn) *RTCConnection {
|
|
||||||
if v.ctx == nil {
|
|
||||||
v.ctx = logger.WithContext(ctx)
|
|
||||||
}
|
|
||||||
if listener != nil {
|
|
||||||
v.listenerUDP = listener
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *RTCConnection) HandlePacket(addr *net.UDPAddr, data []byte) error {
|
|
||||||
ctx := v.ctx
|
|
||||||
|
|
||||||
// Update the current UDP address.
|
|
||||||
v.clientUDP = addr
|
|
||||||
|
|
||||||
// Start the UDP proxy to backend.
|
|
||||||
if err := v.connectBackend(ctx); err != nil {
|
|
||||||
return errors.Wrapf(err, "connect backend for %v", v.StreamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Proxy client message to backend.
|
|
||||||
if v.backendUDP == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Proxy all messages from backend to client.
|
|
||||||
go func() {
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
buf := make([]byte, 4096)
|
|
||||||
n, _, err := v.backendUDP.ReadFromUDP(buf)
|
|
||||||
if err != nil {
|
|
||||||
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
|
|
||||||
logger.Wf(ctx, "read from backend failed, err=%v", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err = v.listenerUDP.WriteToUDP(buf[:n], v.clientUDP); err != nil {
|
|
||||||
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
|
|
||||||
logger.Wf(ctx, "write to client failed, err=%v", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if _, err := v.backendUDP.Write(data); err != nil {
|
|
||||||
return errors.Wrapf(err, "write to backend %v", v.StreamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *RTCConnection) connectBackend(ctx context.Context) error {
|
|
||||||
if v.backendUDP != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pick a backend SRS server to proxy the RTC stream.
|
|
||||||
backend, err := srsLoadBalancer.Pick(ctx, v.StreamURL)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "pick backend")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse UDP port from backend.
|
|
||||||
if len(backend.RTC) == 0 {
|
|
||||||
return errors.Errorf("no udp server")
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, udpPort, err := parseListenEndpoint(backend.RTC[0])
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.RTC[0], backend, v.StreamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect to backend SRS server via UDP client.
|
|
||||||
// TODO: FIXME: Support close the connection when timeout or DTLS alert.
|
|
||||||
backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)}
|
|
||||||
if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil {
|
|
||||||
return errors.Wrapf(err, "dial udp to %v", backendAddr)
|
|
||||||
} else {
|
|
||||||
v.backendUDP = backendUDP
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type RTCICEPair struct {
|
|
||||||
// The remote ufrag, used for ICE username and session id.
|
|
||||||
RemoteICEUfrag string `json:"remote_ufrag"`
|
|
||||||
// The remote pwd, used for ICE password.
|
|
||||||
RemoteICEPwd string `json:"remote_pwd"`
|
|
||||||
// The local ufrag, used for ICE username and session id.
|
|
||||||
LocalICEUfrag string `json:"local_ufrag"`
|
|
||||||
// The local pwd, used for ICE password.
|
|
||||||
LocalICEPwd string `json:"local_pwd"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate the ICE ufrag for the WebRTC streaming, format is remote-ufrag:local-ufrag.
|
|
||||||
func (v *RTCICEPair) Ufrag() string {
|
|
||||||
return fmt.Sprintf("%v:%v", v.LocalICEUfrag, v.RemoteICEUfrag)
|
|
||||||
}
|
|
||||||
|
|
||||||
type RTCStunPacket struct {
|
|
||||||
// The stun message type.
|
|
||||||
MessageType uint16
|
|
||||||
// The stun username, or ufrag.
|
|
||||||
Username string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *RTCStunPacket) UnmarshalBinary(data []byte) error {
|
|
||||||
if len(data) < 20 {
|
|
||||||
return errors.Errorf("stun packet too short %v", len(data))
|
|
||||||
}
|
|
||||||
|
|
||||||
p := data
|
|
||||||
v.MessageType = binary.BigEndian.Uint16(p)
|
|
||||||
messageLen := binary.BigEndian.Uint16(p[2:])
|
|
||||||
//magicCookie := p[:8]
|
|
||||||
//transactionID := p[:20]
|
|
||||||
p = p[20:]
|
|
||||||
|
|
||||||
if len(p) != int(messageLen) {
|
|
||||||
return errors.Errorf("stun packet length invalid %v != %v", len(data), messageLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
for len(p) > 0 {
|
|
||||||
typ := binary.BigEndian.Uint16(p)
|
|
||||||
length := binary.BigEndian.Uint16(p[2:])
|
|
||||||
p = p[4:]
|
|
||||||
|
|
||||||
if len(p) < int(length) {
|
|
||||||
return errors.Errorf("stun attribute length invalid %v < %v", len(p), length)
|
|
||||||
}
|
|
||||||
|
|
||||||
value := p[:length]
|
|
||||||
p = p[length:]
|
|
||||||
|
|
||||||
if length%4 != 0 {
|
|
||||||
p = p[4-length%4:]
|
|
||||||
}
|
|
||||||
|
|
||||||
switch typ {
|
|
||||||
case 0x0006:
|
|
||||||
v.Username = string(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
655
proxy/rtmp.go
655
proxy/rtmp.go
@@ -1,655 +0,0 @@
|
|||||||
// Copyright (c) 2025 Winlin
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"srs-proxy/errors"
|
|
||||||
"srs-proxy/logger"
|
|
||||||
"srs-proxy/rtmp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// srsRTMPServer is the proxy for SRS RTMP server, to proxy the RTMP stream to backend SRS
|
|
||||||
// server. It will figure out the backend server to proxy to. Unlike the edge server, it will
|
|
||||||
// not cache the stream, but just proxy the stream to backend.
|
|
||||||
type srsRTMPServer struct {
|
|
||||||
// The TCP listener for RTMP server.
|
|
||||||
listener *net.TCPListener
|
|
||||||
// The random number generator.
|
|
||||||
rd *rand.Rand
|
|
||||||
// The wait group for all goroutines.
|
|
||||||
wg sync.WaitGroup
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSRSRTMPServer(opts ...func(*srsRTMPServer)) *srsRTMPServer {
|
|
||||||
v := &srsRTMPServer{
|
|
||||||
rd: rand.New(rand.NewSource(time.Now().UnixNano())),
|
|
||||||
}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(v)
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsRTMPServer) Close() error {
|
|
||||||
if v.listener != nil {
|
|
||||||
v.listener.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
v.wg.Wait()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsRTMPServer) Run(ctx context.Context) error {
|
|
||||||
endpoint := envRtmpServer()
|
|
||||||
if !strings.Contains(endpoint, ":") {
|
|
||||||
endpoint = ":" + endpoint
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, err := net.ResolveTCPAddr("tcp", endpoint)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "resolve rtmp addr %v", endpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
listener, err := net.ListenTCP("tcp", addr)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "listen rtmp addr %v", addr)
|
|
||||||
}
|
|
||||||
v.listener = listener
|
|
||||||
logger.Df(ctx, "RTMP server listen at %v", addr)
|
|
||||||
|
|
||||||
v.wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer v.wg.Done()
|
|
||||||
|
|
||||||
for {
|
|
||||||
conn, err := v.listener.AcceptTCP()
|
|
||||||
if err != nil {
|
|
||||||
if ctx.Err() != context.Canceled {
|
|
||||||
// TODO: If RTMP server closed unexpectedly, we should notice the main loop to quit.
|
|
||||||
logger.Wf(ctx, "RTMP server accept err %+v", err)
|
|
||||||
} else {
|
|
||||||
logger.Df(ctx, "RTMP server done")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
v.wg.Add(1)
|
|
||||||
go func(ctx context.Context, conn *net.TCPConn) {
|
|
||||||
defer v.wg.Done()
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
handleErr := func(err error) {
|
|
||||||
if isPeerClosedError(err) {
|
|
||||||
logger.Df(ctx, "RTMP peer is closed")
|
|
||||||
} else {
|
|
||||||
logger.Wf(ctx, "RTMP serve err %+v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rc := NewRTMPConnection(func(client *RTMPConnection) {
|
|
||||||
client.rd = v.rd
|
|
||||||
})
|
|
||||||
if err := rc.serve(ctx, conn); err != nil {
|
|
||||||
handleErr(err)
|
|
||||||
} else {
|
|
||||||
logger.Df(ctx, "RTMP client done")
|
|
||||||
}
|
|
||||||
}(logger.WithContext(ctx), conn)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RTMPConnection is an RTMP streaming connection. There is no state need to be sync between
|
|
||||||
// proxy servers.
|
|
||||||
//
|
|
||||||
// When we got an RTMP request, we will parse the stream URL from the RTMP publish or play request,
|
|
||||||
// then proxy to the corresponding backend server. All state is in the RTMP request, so this
|
|
||||||
// connection is stateless.
|
|
||||||
type RTMPConnection struct {
|
|
||||||
// The random number generator.
|
|
||||||
rd *rand.Rand
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRTMPConnection(opts ...func(*RTMPConnection)) *RTMPConnection {
|
|
||||||
v := &RTMPConnection{}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(v)
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
|
|
||||||
logger.Df(ctx, "Got RTMP client from %v", conn.RemoteAddr())
|
|
||||||
|
|
||||||
// If any goroutine quit, cancel another one.
|
|
||||||
parentCtx := ctx
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
var backend *RTMPClientToBackend
|
|
||||||
if true {
|
|
||||||
go func() {
|
|
||||||
<-ctx.Done()
|
|
||||||
conn.Close()
|
|
||||||
if backend != nil {
|
|
||||||
backend.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simple handshake with client.
|
|
||||||
hs := rtmp.NewHandshake(v.rd)
|
|
||||||
if _, err := hs.ReadC0S0(conn); err != nil {
|
|
||||||
return errors.Wrapf(err, "read c0")
|
|
||||||
}
|
|
||||||
if _, err := hs.ReadC1S1(conn); err != nil {
|
|
||||||
return errors.Wrapf(err, "read c1")
|
|
||||||
}
|
|
||||||
if err := hs.WriteC0S0(conn); err != nil {
|
|
||||||
return errors.Wrapf(err, "write s1")
|
|
||||||
}
|
|
||||||
if err := hs.WriteC1S1(conn); err != nil {
|
|
||||||
return errors.Wrapf(err, "write s1")
|
|
||||||
}
|
|
||||||
if err := hs.WriteC2S2(conn, hs.C1S1()); err != nil {
|
|
||||||
return errors.Wrapf(err, "write s2")
|
|
||||||
}
|
|
||||||
if _, err := hs.ReadC2S2(conn); err != nil {
|
|
||||||
return errors.Wrapf(err, "read c2")
|
|
||||||
}
|
|
||||||
|
|
||||||
client := rtmp.NewProtocol(conn)
|
|
||||||
logger.Df(ctx, "RTMP simple handshake done")
|
|
||||||
|
|
||||||
// Expect RTMP connect command with tcUrl.
|
|
||||||
var connectReq *rtmp.ConnectAppPacket
|
|
||||||
if _, err := rtmp.ExpectPacket(ctx, client, &connectReq); err != nil {
|
|
||||||
return errors.Wrapf(err, "expect connect req")
|
|
||||||
}
|
|
||||||
|
|
||||||
if true {
|
|
||||||
ack := rtmp.NewWindowAcknowledgementSize()
|
|
||||||
ack.AckSize = 2500000
|
|
||||||
if err := client.WritePacket(ctx, ack, 0); err != nil {
|
|
||||||
return errors.Wrapf(err, "write set ack size")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if true {
|
|
||||||
chunk := rtmp.NewSetChunkSize()
|
|
||||||
chunk.ChunkSize = 128
|
|
||||||
if err := client.WritePacket(ctx, chunk, 0); err != nil {
|
|
||||||
return errors.Wrapf(err, "write set chunk size")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
connectRes := rtmp.NewConnectAppResPacket(connectReq.TransactionID)
|
|
||||||
connectRes.CommandObject.Set("fmsVer", rtmp.NewAmf0String("FMS/3,5,3,888"))
|
|
||||||
connectRes.CommandObject.Set("capabilities", rtmp.NewAmf0Number(127))
|
|
||||||
connectRes.CommandObject.Set("mode", rtmp.NewAmf0Number(1))
|
|
||||||
connectRes.Args.Set("level", rtmp.NewAmf0String("status"))
|
|
||||||
connectRes.Args.Set("code", rtmp.NewAmf0String("NetConnection.Connect.Success"))
|
|
||||||
connectRes.Args.Set("description", rtmp.NewAmf0String("Connection succeeded"))
|
|
||||||
connectRes.Args.Set("objectEncoding", rtmp.NewAmf0Number(0))
|
|
||||||
connectResData := rtmp.NewAmf0EcmaArray()
|
|
||||||
connectResData.Set("version", rtmp.NewAmf0String("3,5,3,888"))
|
|
||||||
connectResData.Set("srs_version", rtmp.NewAmf0String(Version()))
|
|
||||||
connectResData.Set("srs_id", rtmp.NewAmf0String(logger.ContextID(ctx)))
|
|
||||||
connectRes.Args.Set("data", connectResData)
|
|
||||||
if err := client.WritePacket(ctx, connectRes, 0); err != nil {
|
|
||||||
return errors.Wrapf(err, "write connect res")
|
|
||||||
}
|
|
||||||
|
|
||||||
tcUrl := connectReq.TcUrl()
|
|
||||||
logger.Df(ctx, "RTMP connect app %v", tcUrl)
|
|
||||||
|
|
||||||
// Expect RTMP command to identify the client, a publisher or viewer.
|
|
||||||
var currentStreamID, nextStreamID int
|
|
||||||
var streamName string
|
|
||||||
var clientType RTMPClientType
|
|
||||||
for clientType == "" {
|
|
||||||
var identifyReq rtmp.Packet
|
|
||||||
if _, err := rtmp.ExpectPacket(ctx, client, &identifyReq); err != nil {
|
|
||||||
return errors.Wrapf(err, "expect identify req")
|
|
||||||
}
|
|
||||||
|
|
||||||
var response rtmp.Packet
|
|
||||||
switch pkt := identifyReq.(type) {
|
|
||||||
case *rtmp.CallPacket:
|
|
||||||
if pkt.CommandName == "createStream" {
|
|
||||||
identifyRes := rtmp.NewCreateStreamResPacket(pkt.TransactionID)
|
|
||||||
response = identifyRes
|
|
||||||
|
|
||||||
nextStreamID = 1
|
|
||||||
identifyRes.StreamID = *rtmp.NewAmf0Number(float64(nextStreamID))
|
|
||||||
} else if pkt.CommandName == "getStreamLength" {
|
|
||||||
// Ignore and do not reply these packets.
|
|
||||||
} else {
|
|
||||||
// For releaseStream, FCPublish, etc.
|
|
||||||
identifyRes := rtmp.NewCallPacket()
|
|
||||||
response = identifyRes
|
|
||||||
|
|
||||||
identifyRes.TransactionID = pkt.TransactionID
|
|
||||||
identifyRes.CommandName = "_result"
|
|
||||||
identifyRes.CommandObject = rtmp.NewAmf0Null()
|
|
||||||
identifyRes.Args = rtmp.NewAmf0Undefined()
|
|
||||||
}
|
|
||||||
case *rtmp.PublishPacket:
|
|
||||||
streamName = string(pkt.StreamName)
|
|
||||||
clientType = RTMPClientTypePublisher
|
|
||||||
|
|
||||||
identifyRes := rtmp.NewCallPacket()
|
|
||||||
response = identifyRes
|
|
||||||
|
|
||||||
identifyRes.CommandName = "onFCPublish"
|
|
||||||
identifyRes.CommandObject = rtmp.NewAmf0Null()
|
|
||||||
|
|
||||||
data := rtmp.NewAmf0Object()
|
|
||||||
data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start"))
|
|
||||||
data.Set("description", rtmp.NewAmf0String("Started publishing stream."))
|
|
||||||
identifyRes.Args = data
|
|
||||||
case *rtmp.PlayPacket:
|
|
||||||
streamName = string(pkt.StreamName)
|
|
||||||
clientType = RTMPClientTypeViewer
|
|
||||||
|
|
||||||
identifyRes := rtmp.NewCallPacket()
|
|
||||||
response = identifyRes
|
|
||||||
|
|
||||||
identifyRes.CommandName = "onStatus"
|
|
||||||
identifyRes.CommandObject = rtmp.NewAmf0Null()
|
|
||||||
|
|
||||||
data := rtmp.NewAmf0Object()
|
|
||||||
data.Set("level", rtmp.NewAmf0String("status"))
|
|
||||||
data.Set("code", rtmp.NewAmf0String("NetStream.Play.Reset"))
|
|
||||||
data.Set("description", rtmp.NewAmf0String("Playing and resetting stream."))
|
|
||||||
data.Set("details", rtmp.NewAmf0String("stream"))
|
|
||||||
data.Set("clientid", rtmp.NewAmf0String("ASAICiss"))
|
|
||||||
identifyRes.Args = data
|
|
||||||
}
|
|
||||||
|
|
||||||
if response != nil {
|
|
||||||
if err := client.WritePacket(ctx, response, currentStreamID); err != nil {
|
|
||||||
return errors.Wrapf(err, "write identify res for req=%v, stream=%v",
|
|
||||||
identifyReq, currentStreamID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the stream ID for next request.
|
|
||||||
currentStreamID = nextStreamID
|
|
||||||
}
|
|
||||||
logger.Df(ctx, "RTMP identify tcUrl=%v, stream=%v, id=%v, type=%v",
|
|
||||||
tcUrl, streamName, currentStreamID, clientType)
|
|
||||||
|
|
||||||
// Find a backend SRS server to proxy the RTMP stream.
|
|
||||||
backend = NewRTMPClientToBackend(func(client *RTMPClientToBackend) {
|
|
||||||
client.rd, client.typ = v.rd, clientType
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
if err := backend.Connect(ctx, tcUrl, streamName); err != nil {
|
|
||||||
return errors.Wrapf(err, "connect backend, tcUrl=%v, stream=%v", tcUrl, streamName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the streaming.
|
|
||||||
if clientType == RTMPClientTypePublisher {
|
|
||||||
identifyRes := rtmp.NewCallPacket()
|
|
||||||
|
|
||||||
identifyRes.CommandName = "onStatus"
|
|
||||||
identifyRes.CommandObject = rtmp.NewAmf0Null()
|
|
||||||
|
|
||||||
data := rtmp.NewAmf0Object()
|
|
||||||
data.Set("level", rtmp.NewAmf0String("status"))
|
|
||||||
data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start"))
|
|
||||||
data.Set("description", rtmp.NewAmf0String("Started publishing stream."))
|
|
||||||
data.Set("clientid", rtmp.NewAmf0String("ASAICiss"))
|
|
||||||
identifyRes.Args = data
|
|
||||||
|
|
||||||
if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil {
|
|
||||||
return errors.Wrapf(err, "start publish")
|
|
||||||
}
|
|
||||||
} else if clientType == RTMPClientTypeViewer {
|
|
||||||
identifyRes := rtmp.NewCallPacket()
|
|
||||||
|
|
||||||
identifyRes.CommandName = "onStatus"
|
|
||||||
identifyRes.CommandObject = rtmp.NewAmf0Null()
|
|
||||||
|
|
||||||
data := rtmp.NewAmf0Object()
|
|
||||||
data.Set("level", rtmp.NewAmf0String("status"))
|
|
||||||
data.Set("code", rtmp.NewAmf0String("NetStream.Play.Start"))
|
|
||||||
data.Set("description", rtmp.NewAmf0String("Started playing stream."))
|
|
||||||
data.Set("details", rtmp.NewAmf0String("stream"))
|
|
||||||
data.Set("clientid", rtmp.NewAmf0String("ASAICiss"))
|
|
||||||
identifyRes.Args = data
|
|
||||||
|
|
||||||
if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil {
|
|
||||||
return errors.Wrapf(err, "start play")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.Df(ctx, "RTMP start streaming")
|
|
||||||
|
|
||||||
// For all proxy goroutines.
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
defer wg.Wait()
|
|
||||||
|
|
||||||
// Proxy all message from backend to client.
|
|
||||||
wg.Add(1)
|
|
||||||
var r0 error
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
r0 = func() error {
|
|
||||||
for {
|
|
||||||
m, err := backend.client.ReadMessage(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "read message")
|
|
||||||
}
|
|
||||||
//logger.Df(ctx, "client<- %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload))
|
|
||||||
|
|
||||||
// TODO: Update the stream ID if not the same.
|
|
||||||
if err := client.WriteMessage(ctx, m); err != nil {
|
|
||||||
return errors.Wrapf(err, "write message")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Proxy all messages from client to backend.
|
|
||||||
wg.Add(1)
|
|
||||||
var r1 error
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
r1 = func() error {
|
|
||||||
for {
|
|
||||||
m, err := client.ReadMessage(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "read message")
|
|
||||||
}
|
|
||||||
//logger.Df(ctx, "client-> %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload))
|
|
||||||
|
|
||||||
// TODO: Update the stream ID if not the same.
|
|
||||||
if err := backend.client.WriteMessage(ctx, m); err != nil {
|
|
||||||
return errors.Wrapf(err, "write message")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Wait until all goroutine quit.
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
// Reset the error if caused by another goroutine.
|
|
||||||
if r0 != nil {
|
|
||||||
return errors.Wrapf(r0, "proxy backend->client")
|
|
||||||
}
|
|
||||||
if r1 != nil {
|
|
||||||
return errors.Wrapf(r1, "proxy client->backend")
|
|
||||||
}
|
|
||||||
|
|
||||||
return parentCtx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
type RTMPClientType string
|
|
||||||
|
|
||||||
const (
|
|
||||||
RTMPClientTypePublisher RTMPClientType = "publisher"
|
|
||||||
RTMPClientTypeViewer RTMPClientType = "viewer"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RTMPClientToBackend is a RTMP client to proxy the RTMP stream to backend.
|
|
||||||
type RTMPClientToBackend struct {
|
|
||||||
// The random number generator.
|
|
||||||
rd *rand.Rand
|
|
||||||
// The underlayer tcp client.
|
|
||||||
tcpConn *net.TCPConn
|
|
||||||
// The RTMP protocol client.
|
|
||||||
client *rtmp.Protocol
|
|
||||||
// The stream type.
|
|
||||||
typ RTMPClientType
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRTMPClientToBackend(opts ...func(*RTMPClientToBackend)) *RTMPClientToBackend {
|
|
||||||
v := &RTMPClientToBackend{}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(v)
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *RTMPClientToBackend) Close() error {
|
|
||||||
if v.tcpConn != nil {
|
|
||||||
v.tcpConn.Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *RTMPClientToBackend) Connect(ctx context.Context, tcUrl, streamName string) error {
|
|
||||||
// Build the stream URL in vhost/app/stream schema.
|
|
||||||
streamURL, err := buildStreamURL(fmt.Sprintf("%v/%v", tcUrl, streamName))
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "build stream url %v/%v", tcUrl, streamName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pick a backend SRS server to proxy the RTMP stream.
|
|
||||||
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "pick backend for %v", streamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse RTMP port from backend.
|
|
||||||
if len(backend.RTMP) == 0 {
|
|
||||||
return errors.Errorf("no rtmp server %+v for %v", backend, streamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
var rtmpPort int
|
|
||||||
if iv, err := strconv.ParseInt(backend.RTMP[0], 10, 64); err != nil {
|
|
||||||
return errors.Wrapf(err, "parse backend %+v rtmp port %v", backend, backend.RTMP[0])
|
|
||||||
} else {
|
|
||||||
rtmpPort = int(iv)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect to backend SRS server via TCP client.
|
|
||||||
addr := &net.TCPAddr{IP: net.ParseIP(backend.IP), Port: rtmpPort}
|
|
||||||
c, err := net.DialTCP("tcp", nil, addr)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "dial backend addr=%v, srs=%v", addr, backend)
|
|
||||||
}
|
|
||||||
v.tcpConn = c
|
|
||||||
|
|
||||||
hs := rtmp.NewHandshake(v.rd)
|
|
||||||
client := rtmp.NewProtocol(c)
|
|
||||||
v.client = client
|
|
||||||
|
|
||||||
// Simple RTMP handshake with server.
|
|
||||||
if err := hs.WriteC0S0(c); err != nil {
|
|
||||||
return errors.Wrapf(err, "write c0")
|
|
||||||
}
|
|
||||||
if err := hs.WriteC1S1(c); err != nil {
|
|
||||||
return errors.Wrapf(err, "write c1")
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err = hs.ReadC0S0(c); err != nil {
|
|
||||||
return errors.Wrapf(err, "read s0")
|
|
||||||
}
|
|
||||||
if _, err := hs.ReadC1S1(c); err != nil {
|
|
||||||
return errors.Wrapf(err, "read s1")
|
|
||||||
}
|
|
||||||
if _, err = hs.ReadC2S2(c); err != nil {
|
|
||||||
return errors.Wrapf(err, "read c2")
|
|
||||||
}
|
|
||||||
logger.Df(ctx, "backend simple handshake done, server=%v", addr)
|
|
||||||
|
|
||||||
if err := hs.WriteC2S2(c, hs.C1S1()); err != nil {
|
|
||||||
return errors.Wrapf(err, "write c2")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect RTMP app on tcUrl with server.
|
|
||||||
if true {
|
|
||||||
connectApp := rtmp.NewConnectAppPacket()
|
|
||||||
connectApp.CommandObject.Set("tcUrl", rtmp.NewAmf0String(tcUrl))
|
|
||||||
if err := client.WritePacket(ctx, connectApp, 1); err != nil {
|
|
||||||
return errors.Wrapf(err, "write connect app")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if true {
|
|
||||||
var connectAppRes *rtmp.ConnectAppResPacket
|
|
||||||
if _, err := rtmp.ExpectPacket(ctx, client, &connectAppRes); err != nil {
|
|
||||||
return errors.Wrapf(err, "expect connect app res")
|
|
||||||
}
|
|
||||||
logger.Df(ctx, "backend connect RTMP app, tcUrl=%v, id=%v", tcUrl, connectAppRes.SrsID())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Play or view RTMP stream with server.
|
|
||||||
if v.typ == RTMPClientTypeViewer {
|
|
||||||
return v.play(ctx, client, streamName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Publish RTMP stream with server.
|
|
||||||
return v.publish(ctx, client, streamName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol, streamName string) error {
|
|
||||||
if true {
|
|
||||||
identifyReq := rtmp.NewCallPacket()
|
|
||||||
identifyReq.CommandName = "releaseStream"
|
|
||||||
identifyReq.TransactionID = 2
|
|
||||||
identifyReq.CommandObject = rtmp.NewAmf0Null()
|
|
||||||
identifyReq.Args = rtmp.NewAmf0String(streamName)
|
|
||||||
if err := client.WritePacket(ctx, identifyReq, 0); err != nil {
|
|
||||||
return errors.Wrapf(err, "releaseStream")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
var identifyRes *rtmp.CallPacket
|
|
||||||
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
|
|
||||||
return errors.Wrapf(err, "expect releaseStream res")
|
|
||||||
}
|
|
||||||
if identifyRes.CommandName == "_result" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if true {
|
|
||||||
identifyReq := rtmp.NewCallPacket()
|
|
||||||
identifyReq.CommandName = "FCPublish"
|
|
||||||
identifyReq.TransactionID = 3
|
|
||||||
identifyReq.CommandObject = rtmp.NewAmf0Null()
|
|
||||||
identifyReq.Args = rtmp.NewAmf0String(streamName)
|
|
||||||
if err := client.WritePacket(ctx, identifyReq, 0); err != nil {
|
|
||||||
return errors.Wrapf(err, "FCPublish")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
var identifyRes *rtmp.CallPacket
|
|
||||||
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
|
|
||||||
return errors.Wrapf(err, "expect FCPublish res")
|
|
||||||
}
|
|
||||||
if identifyRes.CommandName == "_result" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var currentStreamID int
|
|
||||||
if true {
|
|
||||||
createStream := rtmp.NewCreateStreamPacket()
|
|
||||||
createStream.TransactionID = 4
|
|
||||||
createStream.CommandObject = rtmp.NewAmf0Null()
|
|
||||||
if err := client.WritePacket(ctx, createStream, 0); err != nil {
|
|
||||||
return errors.Wrapf(err, "createStream")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
var identifyRes *rtmp.CreateStreamResPacket
|
|
||||||
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
|
|
||||||
return errors.Wrapf(err, "expect createStream res")
|
|
||||||
}
|
|
||||||
if sid := identifyRes.StreamID; sid != 0 {
|
|
||||||
currentStreamID = int(sid)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if true {
|
|
||||||
publishStream := rtmp.NewPublishPacket()
|
|
||||||
publishStream.TransactionID = 5
|
|
||||||
publishStream.CommandObject = rtmp.NewAmf0Null()
|
|
||||||
publishStream.StreamName = *rtmp.NewAmf0String(streamName)
|
|
||||||
publishStream.StreamType = *rtmp.NewAmf0String("live")
|
|
||||||
if err := client.WritePacket(ctx, publishStream, currentStreamID); err != nil {
|
|
||||||
return errors.Wrapf(err, "publish")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
var identifyRes *rtmp.CallPacket
|
|
||||||
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
|
|
||||||
return errors.Wrapf(err, "expect publish res")
|
|
||||||
}
|
|
||||||
// Ignore onFCPublish, expect onStatus(NetStream.Publish.Start).
|
|
||||||
if identifyRes.CommandName == "onStatus" {
|
|
||||||
if data := rtmp.NewAmf0Converter(identifyRes.Args).ToObject(); data == nil {
|
|
||||||
return errors.Errorf("onStatus args not object")
|
|
||||||
} else if code := rtmp.NewAmf0Converter(data.Get("code")).ToString(); code == nil {
|
|
||||||
return errors.Errorf("onStatus code not string")
|
|
||||||
} else if *code != "NetStream.Publish.Start" {
|
|
||||||
return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", *code)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.Df(ctx, "backend publish stream=%v, sid=%v", streamName, currentStreamID)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, streamName string) error {
|
|
||||||
var currentStreamID int
|
|
||||||
if true {
|
|
||||||
createStream := rtmp.NewCreateStreamPacket()
|
|
||||||
createStream.TransactionID = 4
|
|
||||||
createStream.CommandObject = rtmp.NewAmf0Null()
|
|
||||||
if err := client.WritePacket(ctx, createStream, 0); err != nil {
|
|
||||||
return errors.Wrapf(err, "createStream")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
var identifyRes *rtmp.CreateStreamResPacket
|
|
||||||
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
|
|
||||||
return errors.Wrapf(err, "expect createStream res")
|
|
||||||
}
|
|
||||||
if sid := identifyRes.StreamID; sid != 0 {
|
|
||||||
currentStreamID = int(sid)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
playStream := rtmp.NewPlayPacket()
|
|
||||||
playStream.StreamName = *rtmp.NewAmf0String(streamName)
|
|
||||||
if err := client.WritePacket(ctx, playStream, currentStreamID); err != nil {
|
|
||||||
return errors.Wrapf(err, "play")
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
var identifyRes *rtmp.CallPacket
|
|
||||||
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
|
|
||||||
return errors.Wrapf(err, "expect releaseStream res")
|
|
||||||
}
|
|
||||||
if identifyRes.CommandName == "onStatus" && identifyRes.ArgsCode() == "NetStream.Play.Start" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,771 +0,0 @@
|
|||||||
// Copyright (c) 2025 Winlin
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
package rtmp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"srs-proxy/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Please read @doc amf0_spec_121207.pdf, @page 4, @section 2.1 Types Overview
|
|
||||||
type amf0Marker uint8
|
|
||||||
|
|
||||||
const (
|
|
||||||
amf0MarkerNumber amf0Marker = iota // 0
|
|
||||||
amf0MarkerBoolean // 1
|
|
||||||
amf0MarkerString // 2
|
|
||||||
amf0MarkerObject // 3
|
|
||||||
amf0MarkerMovieClip // 4
|
|
||||||
amf0MarkerNull // 5
|
|
||||||
amf0MarkerUndefined // 6
|
|
||||||
amf0MarkerReference // 7
|
|
||||||
amf0MarkerEcmaArray // 8
|
|
||||||
amf0MarkerObjectEnd // 9
|
|
||||||
amf0MarkerStrictArray // 10
|
|
||||||
amf0MarkerDate // 11
|
|
||||||
amf0MarkerLongString // 12
|
|
||||||
amf0MarkerUnsupported // 13
|
|
||||||
amf0MarkerRecordSet // 14
|
|
||||||
amf0MarkerXmlDocument // 15
|
|
||||||
amf0MarkerTypedObject // 16
|
|
||||||
amf0MarkerAvmPlusObject // 17
|
|
||||||
|
|
||||||
amf0MarkerForbidden amf0Marker = 0xff
|
|
||||||
)
|
|
||||||
|
|
||||||
func (v amf0Marker) String() string {
|
|
||||||
switch v {
|
|
||||||
case amf0MarkerNumber:
|
|
||||||
return "Amf0Number"
|
|
||||||
case amf0MarkerBoolean:
|
|
||||||
return "amf0Boolean"
|
|
||||||
case amf0MarkerString:
|
|
||||||
return "Amf0String"
|
|
||||||
case amf0MarkerObject:
|
|
||||||
return "Amf0Object"
|
|
||||||
case amf0MarkerNull:
|
|
||||||
return "Null"
|
|
||||||
case amf0MarkerUndefined:
|
|
||||||
return "Undefined"
|
|
||||||
case amf0MarkerReference:
|
|
||||||
return "Reference"
|
|
||||||
case amf0MarkerEcmaArray:
|
|
||||||
return "EcmaArray"
|
|
||||||
case amf0MarkerObjectEnd:
|
|
||||||
return "ObjectEnd"
|
|
||||||
case amf0MarkerStrictArray:
|
|
||||||
return "StrictArray"
|
|
||||||
case amf0MarkerDate:
|
|
||||||
return "Date"
|
|
||||||
case amf0MarkerLongString:
|
|
||||||
return "LongString"
|
|
||||||
case amf0MarkerUnsupported:
|
|
||||||
return "Unsupported"
|
|
||||||
case amf0MarkerXmlDocument:
|
|
||||||
return "XmlDocument"
|
|
||||||
case amf0MarkerTypedObject:
|
|
||||||
return "TypedObject"
|
|
||||||
case amf0MarkerAvmPlusObject:
|
|
||||||
return "AvmPlusObject"
|
|
||||||
case amf0MarkerMovieClip:
|
|
||||||
return "MovieClip"
|
|
||||||
case amf0MarkerRecordSet:
|
|
||||||
return "RecordSet"
|
|
||||||
default:
|
|
||||||
return "Forbidden"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// For utest to mock it.
|
|
||||||
type amf0Buffer interface {
|
|
||||||
Bytes() []byte
|
|
||||||
WriteByte(c byte) error
|
|
||||||
Write(p []byte) (n int, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
var createBuffer = func() amf0Buffer {
|
|
||||||
return &bytes.Buffer{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// All AMF0 things.
|
|
||||||
type amf0Any interface {
|
|
||||||
// Binary marshaler and unmarshaler.
|
|
||||||
encoding.BinaryUnmarshaler
|
|
||||||
encoding.BinaryMarshaler
|
|
||||||
// Get the size of bytes to marshal this object.
|
|
||||||
Size() int
|
|
||||||
|
|
||||||
// Get the Marker of any AMF0 stuff.
|
|
||||||
amf0Marker() amf0Marker
|
|
||||||
}
|
|
||||||
|
|
||||||
type amf0Converter struct {
|
|
||||||
from amf0Any
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAmf0Converter(from amf0Any) *amf0Converter {
|
|
||||||
return &amf0Converter{from: from}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Converter) ToNumber() *amf0Number {
|
|
||||||
return amf0AnyTo[*amf0Number](v.from)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Converter) ToBoolean() *amf0Boolean {
|
|
||||||
return amf0AnyTo[*amf0Boolean](v.from)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Converter) ToString() *amf0String {
|
|
||||||
return amf0AnyTo[*amf0String](v.from)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Converter) ToObject() *amf0Object {
|
|
||||||
return amf0AnyTo[*amf0Object](v.from)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Converter) ToNull() *amf0Null {
|
|
||||||
return amf0AnyTo[*amf0Null](v.from)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Converter) ToUndefined() *amf0Undefined {
|
|
||||||
return amf0AnyTo[*amf0Undefined](v.from)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Converter) ToEcmaArray() *amf0EcmaArray {
|
|
||||||
return amf0AnyTo[*amf0EcmaArray](v.from)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Converter) ToStrictArray() *amf0StrictArray {
|
|
||||||
return amf0AnyTo[*amf0StrictArray](v.from)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert any to specified object.
|
|
||||||
func amf0AnyTo[T amf0Any](a amf0Any) T {
|
|
||||||
var to T
|
|
||||||
if a != nil {
|
|
||||||
if v, ok := a.(T); ok {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return to
|
|
||||||
}
|
|
||||||
|
|
||||||
// Discovery the amf0 object from the bytes b.
|
|
||||||
func Amf0Discovery(p []byte) (a amf0Any, err error) {
|
|
||||||
if len(p) < 1 {
|
|
||||||
return nil, errors.Errorf("require 1 bytes only %v", len(p))
|
|
||||||
}
|
|
||||||
m := amf0Marker(p[0])
|
|
||||||
|
|
||||||
switch m {
|
|
||||||
case amf0MarkerNumber:
|
|
||||||
return NewAmf0Number(0), nil
|
|
||||||
case amf0MarkerBoolean:
|
|
||||||
return NewAmf0Boolean(false), nil
|
|
||||||
case amf0MarkerString:
|
|
||||||
return NewAmf0String(""), nil
|
|
||||||
case amf0MarkerObject:
|
|
||||||
return NewAmf0Object(), nil
|
|
||||||
case amf0MarkerNull:
|
|
||||||
return NewAmf0Null(), nil
|
|
||||||
case amf0MarkerUndefined:
|
|
||||||
return NewAmf0Undefined(), nil
|
|
||||||
case amf0MarkerReference:
|
|
||||||
case amf0MarkerEcmaArray:
|
|
||||||
return NewAmf0EcmaArray(), nil
|
|
||||||
case amf0MarkerObjectEnd:
|
|
||||||
return &amf0ObjectEOF{}, nil
|
|
||||||
case amf0MarkerStrictArray:
|
|
||||||
return NewAmf0StrictArray(), nil
|
|
||||||
case amf0MarkerDate, amf0MarkerLongString, amf0MarkerUnsupported, amf0MarkerXmlDocument,
|
|
||||||
amf0MarkerTypedObject, amf0MarkerAvmPlusObject, amf0MarkerForbidden, amf0MarkerMovieClip,
|
|
||||||
amf0MarkerRecordSet:
|
|
||||||
return nil, errors.Errorf("Marker %v is not supported", m)
|
|
||||||
}
|
|
||||||
return nil, errors.Errorf("Marker %v is invalid", m)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The UTF8 string, please read @doc amf0_spec_121207.pdf, @page 3, @section 1.3.1 Strings and UTF-8
|
|
||||||
type amf0UTF8 string
|
|
||||||
|
|
||||||
func (v *amf0UTF8) Size() int {
|
|
||||||
return 2 + len(string(*v))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0UTF8) UnmarshalBinary(data []byte) (err error) {
|
|
||||||
var p []byte
|
|
||||||
if p = data; len(p) < 2 {
|
|
||||||
return errors.Errorf("require 2 bytes only %v", len(p))
|
|
||||||
}
|
|
||||||
size := uint16(p[0])<<8 | uint16(p[1])
|
|
||||||
|
|
||||||
if p = data[2:]; len(p) < int(size) {
|
|
||||||
return errors.Errorf("require %v bytes only %v", int(size), len(p))
|
|
||||||
}
|
|
||||||
*v = amf0UTF8(string(p[:size]))
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0UTF8) MarshalBinary() (data []byte, err error) {
|
|
||||||
data = make([]byte, v.Size())
|
|
||||||
|
|
||||||
size := uint16(len(string(*v)))
|
|
||||||
data[0] = byte(size >> 8)
|
|
||||||
data[1] = byte(size)
|
|
||||||
|
|
||||||
if size > 0 {
|
|
||||||
copy(data[2:], []byte(*v))
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// The number object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.2 Number Type
|
|
||||||
type amf0Number float64
|
|
||||||
|
|
||||||
func NewAmf0Number(f float64) *amf0Number {
|
|
||||||
v := amf0Number(f)
|
|
||||||
return &v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Number) amf0Marker() amf0Marker {
|
|
||||||
return amf0MarkerNumber
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Number) Size() int {
|
|
||||||
return 1 + 8
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Number) UnmarshalBinary(data []byte) (err error) {
|
|
||||||
var p []byte
|
|
||||||
if p = data; len(p) < 9 {
|
|
||||||
return errors.Errorf("require 9 bytes only %v", len(p))
|
|
||||||
}
|
|
||||||
if m := amf0Marker(p[0]); m != amf0MarkerNumber {
|
|
||||||
return errors.Errorf("Amf0Number amf0Marker %v is illegal", m)
|
|
||||||
}
|
|
||||||
|
|
||||||
f := binary.BigEndian.Uint64(p[1:])
|
|
||||||
*v = amf0Number(math.Float64frombits(f))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Number) MarshalBinary() (data []byte, err error) {
|
|
||||||
data = make([]byte, 9)
|
|
||||||
data[0] = byte(amf0MarkerNumber)
|
|
||||||
f := math.Float64bits(float64(*v))
|
|
||||||
binary.BigEndian.PutUint64(data[1:], f)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// The string objet, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.4 String Type
|
|
||||||
type amf0String string
|
|
||||||
|
|
||||||
func NewAmf0String(s string) *amf0String {
|
|
||||||
v := amf0String(s)
|
|
||||||
return &v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0String) amf0Marker() amf0Marker {
|
|
||||||
return amf0MarkerString
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0String) Size() int {
|
|
||||||
u := amf0UTF8(*v)
|
|
||||||
return 1 + u.Size()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0String) UnmarshalBinary(data []byte) (err error) {
|
|
||||||
var p []byte
|
|
||||||
if p = data; len(p) < 1 {
|
|
||||||
return errors.Errorf("require 1 bytes only %v", len(p))
|
|
||||||
}
|
|
||||||
if m := amf0Marker(p[0]); m != amf0MarkerString {
|
|
||||||
return errors.Errorf("Amf0String amf0Marker %v is illegal", m)
|
|
||||||
}
|
|
||||||
|
|
||||||
var sv amf0UTF8
|
|
||||||
if err = sv.UnmarshalBinary(p[1:]); err != nil {
|
|
||||||
return errors.WithMessage(err, "utf8")
|
|
||||||
}
|
|
||||||
*v = amf0String(string(sv))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0String) MarshalBinary() (data []byte, err error) {
|
|
||||||
u := amf0UTF8(*v)
|
|
||||||
|
|
||||||
var pb []byte
|
|
||||||
if pb, err = u.MarshalBinary(); err != nil {
|
|
||||||
return nil, errors.WithMessage(err, "utf8")
|
|
||||||
}
|
|
||||||
|
|
||||||
data = append([]byte{byte(amf0MarkerString)}, pb...)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// The AMF0 object end type, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.11 Object End Type
|
|
||||||
type amf0ObjectEOF struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0ObjectEOF) amf0Marker() amf0Marker {
|
|
||||||
return amf0MarkerObjectEnd
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0ObjectEOF) Size() int {
|
|
||||||
return 3
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0ObjectEOF) UnmarshalBinary(data []byte) (err error) {
|
|
||||||
p := data
|
|
||||||
|
|
||||||
if len(p) < 3 {
|
|
||||||
return errors.Errorf("require 3 bytes only %v", len(p))
|
|
||||||
}
|
|
||||||
|
|
||||||
if p[0] != 0 || p[1] != 0 || p[2] != 9 {
|
|
||||||
return errors.Errorf("EOF amf0Marker %v is illegal", p[0:3])
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0ObjectEOF) MarshalBinary() (data []byte, err error) {
|
|
||||||
return []byte{0, 0, 9}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use array for object and ecma array, to keep the original order.
|
|
||||||
type amf0Property struct {
|
|
||||||
key amf0UTF8
|
|
||||||
value amf0Any
|
|
||||||
}
|
|
||||||
|
|
||||||
// The object-like AMF0 structure, like object and ecma array and strict array.
|
|
||||||
type amf0ObjectBase struct {
|
|
||||||
properties []*amf0Property
|
|
||||||
lock sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0ObjectBase) Size() int {
|
|
||||||
v.lock.Lock()
|
|
||||||
defer v.lock.Unlock()
|
|
||||||
|
|
||||||
var size int
|
|
||||||
|
|
||||||
for _, p := range v.properties {
|
|
||||||
key, value := p.key, p.value
|
|
||||||
size += key.Size() + value.Size()
|
|
||||||
}
|
|
||||||
|
|
||||||
return size
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0ObjectBase) Get(key string) amf0Any {
|
|
||||||
v.lock.Lock()
|
|
||||||
defer v.lock.Unlock()
|
|
||||||
|
|
||||||
for _, p := range v.properties {
|
|
||||||
if string(p.key) == key {
|
|
||||||
return p.value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0ObjectBase) Set(key string, value amf0Any) *amf0ObjectBase {
|
|
||||||
v.lock.Lock()
|
|
||||||
defer v.lock.Unlock()
|
|
||||||
|
|
||||||
prop := &amf0Property{key: amf0UTF8(key), value: value}
|
|
||||||
|
|
||||||
var ok bool
|
|
||||||
for i, p := range v.properties {
|
|
||||||
if string(p.key) == key {
|
|
||||||
v.properties[i] = prop
|
|
||||||
ok = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
v.properties = append(v.properties, prop)
|
|
||||||
}
|
|
||||||
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) {
|
|
||||||
// if no eof, elems specified by maxElems.
|
|
||||||
if !eof && maxElems < 0 {
|
|
||||||
return errors.Errorf("maxElems=%v without eof", maxElems)
|
|
||||||
}
|
|
||||||
// if eof, maxElems must be -1.
|
|
||||||
if eof && maxElems != -1 {
|
|
||||||
return errors.Errorf("maxElems=%v with eof", maxElems)
|
|
||||||
}
|
|
||||||
|
|
||||||
readOne := func() (amf0UTF8, amf0Any, error) {
|
|
||||||
var u amf0UTF8
|
|
||||||
if err = u.UnmarshalBinary(p); err != nil {
|
|
||||||
return "", nil, errors.WithMessage(err, "prop name")
|
|
||||||
}
|
|
||||||
|
|
||||||
p = p[u.Size():]
|
|
||||||
var a amf0Any
|
|
||||||
if a, err = Amf0Discovery(p); err != nil {
|
|
||||||
return "", nil, errors.WithMessage(err, fmt.Sprintf("discover prop %v", string(u)))
|
|
||||||
}
|
|
||||||
return u, a, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
pushOne := func(u amf0UTF8, a amf0Any) error {
|
|
||||||
// For object property, consume the whole bytes.
|
|
||||||
if err = a.UnmarshalBinary(p); err != nil {
|
|
||||||
return errors.WithMessage(err, fmt.Sprintf("unmarshal prop %v", string(u)))
|
|
||||||
}
|
|
||||||
|
|
||||||
v.Set(string(u), a)
|
|
||||||
p = p[a.Size():]
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for eof {
|
|
||||||
u, a, err := readOne()
|
|
||||||
if err != nil {
|
|
||||||
return errors.WithMessage(err, "read")
|
|
||||||
}
|
|
||||||
|
|
||||||
// For object EOF, we should only consume total 3bytes.
|
|
||||||
if u.Size() == 2 && a.amf0Marker() == amf0MarkerObjectEnd {
|
|
||||||
// 2 bytes is consumed by u(name), the a(eof) should only consume 1 byte.
|
|
||||||
p = p[1:]
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := pushOne(u, a); err != nil {
|
|
||||||
return errors.WithMessage(err, "push")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for len(v.properties) < maxElems {
|
|
||||||
u, a, err := readOne()
|
|
||||||
if err != nil {
|
|
||||||
return errors.WithMessage(err, "read")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := pushOne(u, a); err != nil {
|
|
||||||
return errors.WithMessage(err, "push")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0ObjectBase) marshal(b amf0Buffer) (err error) {
|
|
||||||
v.lock.Lock()
|
|
||||||
defer v.lock.Unlock()
|
|
||||||
|
|
||||||
var pb []byte
|
|
||||||
for _, p := range v.properties {
|
|
||||||
key, value := p.key, p.value
|
|
||||||
|
|
||||||
if pb, err = key.MarshalBinary(); err != nil {
|
|
||||||
return errors.WithMessage(err, fmt.Sprintf("marshal %v", string(key)))
|
|
||||||
}
|
|
||||||
if _, err = b.Write(pb); err != nil {
|
|
||||||
return errors.Wrapf(err, "write %v", string(key))
|
|
||||||
}
|
|
||||||
|
|
||||||
if pb, err = value.MarshalBinary(); err != nil {
|
|
||||||
return errors.WithMessage(err, fmt.Sprintf("marshal value for %v", string(key)))
|
|
||||||
}
|
|
||||||
if _, err = b.Write(pb); err != nil {
|
|
||||||
return errors.Wrapf(err, "marshal value for %v", string(key))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// The AMF0 object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.5 Object Type
|
|
||||||
type amf0Object struct {
|
|
||||||
amf0ObjectBase
|
|
||||||
eof amf0ObjectEOF
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAmf0Object() *amf0Object {
|
|
||||||
v := &amf0Object{}
|
|
||||||
v.properties = []*amf0Property{}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Object) amf0Marker() amf0Marker {
|
|
||||||
return amf0MarkerObject
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Object) Size() int {
|
|
||||||
return int(1) + v.eof.Size() + v.amf0ObjectBase.Size()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Object) UnmarshalBinary(data []byte) (err error) {
|
|
||||||
var p []byte
|
|
||||||
if p = data; len(p) < 1 {
|
|
||||||
return errors.Errorf("require 1 byte only %v", len(p))
|
|
||||||
}
|
|
||||||
if m := amf0Marker(p[0]); m != amf0MarkerObject {
|
|
||||||
return errors.Errorf("Amf0Object amf0Marker %v is illegal", m)
|
|
||||||
}
|
|
||||||
p = p[1:]
|
|
||||||
|
|
||||||
if err = v.unmarshal(p, true, -1); err != nil {
|
|
||||||
return errors.WithMessage(err, "unmarshal")
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Object) MarshalBinary() (data []byte, err error) {
|
|
||||||
b := createBuffer()
|
|
||||||
|
|
||||||
if err = b.WriteByte(byte(amf0MarkerObject)); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "marshal")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = v.marshal(b); err != nil {
|
|
||||||
return nil, errors.WithMessage(err, "marshal")
|
|
||||||
}
|
|
||||||
|
|
||||||
var pb []byte
|
|
||||||
if pb, err = v.eof.MarshalBinary(); err != nil {
|
|
||||||
return nil, errors.WithMessage(err, "marshal")
|
|
||||||
}
|
|
||||||
if _, err = b.Write(pb); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "marshal")
|
|
||||||
}
|
|
||||||
|
|
||||||
return b.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// The AMF0 ecma array, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.10 ECMA Array Type
|
|
||||||
type amf0EcmaArray struct {
|
|
||||||
amf0ObjectBase
|
|
||||||
count uint32
|
|
||||||
eof amf0ObjectEOF
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAmf0EcmaArray() *amf0EcmaArray {
|
|
||||||
v := &amf0EcmaArray{}
|
|
||||||
v.properties = []*amf0Property{}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0EcmaArray) amf0Marker() amf0Marker {
|
|
||||||
return amf0MarkerEcmaArray
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0EcmaArray) Size() int {
|
|
||||||
return int(1) + 4 + v.eof.Size() + v.amf0ObjectBase.Size()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0EcmaArray) UnmarshalBinary(data []byte) (err error) {
|
|
||||||
var p []byte
|
|
||||||
if p = data; len(p) < 5 {
|
|
||||||
return errors.Errorf("require 5 bytes only %v", len(p))
|
|
||||||
}
|
|
||||||
if m := amf0Marker(p[0]); m != amf0MarkerEcmaArray {
|
|
||||||
return errors.Errorf("EcmaArray amf0Marker %v is illegal", m)
|
|
||||||
}
|
|
||||||
v.count = binary.BigEndian.Uint32(p[1:])
|
|
||||||
p = p[5:]
|
|
||||||
|
|
||||||
if err = v.unmarshal(p, true, -1); err != nil {
|
|
||||||
return errors.WithMessage(err, "unmarshal")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) {
|
|
||||||
b := createBuffer()
|
|
||||||
|
|
||||||
if err = b.WriteByte(byte(amf0MarkerEcmaArray)); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "marshal")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = binary.Write(b, binary.BigEndian, v.count); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "marshal")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = v.marshal(b); err != nil {
|
|
||||||
return nil, errors.WithMessage(err, "marshal")
|
|
||||||
}
|
|
||||||
|
|
||||||
var pb []byte
|
|
||||||
if pb, err = v.eof.MarshalBinary(); err != nil {
|
|
||||||
return nil, errors.WithMessage(err, "marshal")
|
|
||||||
}
|
|
||||||
if _, err = b.Write(pb); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "marshal")
|
|
||||||
}
|
|
||||||
|
|
||||||
return b.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// The AMF0 strict array, please read @doc amf0_spec_121207.pdf, @page 7, @section 2.12 Strict Array Type
|
|
||||||
type amf0StrictArray struct {
|
|
||||||
amf0ObjectBase
|
|
||||||
count uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAmf0StrictArray() *amf0StrictArray {
|
|
||||||
v := &amf0StrictArray{}
|
|
||||||
v.properties = []*amf0Property{}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0StrictArray) amf0Marker() amf0Marker {
|
|
||||||
return amf0MarkerStrictArray
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0StrictArray) Size() int {
|
|
||||||
return int(1) + 4 + v.amf0ObjectBase.Size()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0StrictArray) UnmarshalBinary(data []byte) (err error) {
|
|
||||||
var p []byte
|
|
||||||
if p = data; len(p) < 5 {
|
|
||||||
return errors.Errorf("require 5 bytes only %v", len(p))
|
|
||||||
}
|
|
||||||
if m := amf0Marker(p[0]); m != amf0MarkerStrictArray {
|
|
||||||
return errors.Errorf("StrictArray amf0Marker %v is illegal", m)
|
|
||||||
}
|
|
||||||
v.count = binary.BigEndian.Uint32(p[1:])
|
|
||||||
p = p[5:]
|
|
||||||
|
|
||||||
if int(v.count) <= 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = v.unmarshal(p, false, int(v.count)); err != nil {
|
|
||||||
return errors.WithMessage(err, "unmarshal")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0StrictArray) MarshalBinary() (data []byte, err error) {
|
|
||||||
b := createBuffer()
|
|
||||||
|
|
||||||
if err = b.WriteByte(byte(amf0MarkerStrictArray)); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "marshal")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = binary.Write(b, binary.BigEndian, v.count); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "marshal")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = v.marshal(b); err != nil {
|
|
||||||
return nil, errors.WithMessage(err, "marshal")
|
|
||||||
}
|
|
||||||
|
|
||||||
return b.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// The single amf0Marker object, for all AMF0 which only has the amf0Marker, like null and undefined.
|
|
||||||
type amf0SingleMarkerObject struct {
|
|
||||||
target amf0Marker
|
|
||||||
}
|
|
||||||
|
|
||||||
func newAmf0SingleMarkerObject(m amf0Marker) amf0SingleMarkerObject {
|
|
||||||
return amf0SingleMarkerObject{target: m}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0SingleMarkerObject) amf0Marker() amf0Marker {
|
|
||||||
return v.target
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0SingleMarkerObject) Size() int {
|
|
||||||
return int(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0SingleMarkerObject) UnmarshalBinary(data []byte) (err error) {
|
|
||||||
var p []byte
|
|
||||||
if p = data; len(p) < 1 {
|
|
||||||
return errors.Errorf("require 1 byte only %v", len(p))
|
|
||||||
}
|
|
||||||
if m := amf0Marker(p[0]); m != v.target {
|
|
||||||
return errors.Errorf("%v amf0Marker %v is illegal", v.target, m)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0SingleMarkerObject) MarshalBinary() (data []byte, err error) {
|
|
||||||
return []byte{byte(v.target)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// The AMF0 null, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.7 null Type
|
|
||||||
type amf0Null struct {
|
|
||||||
amf0SingleMarkerObject
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAmf0Null() *amf0Null {
|
|
||||||
v := amf0Null{}
|
|
||||||
v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerNull)
|
|
||||||
return &v
|
|
||||||
}
|
|
||||||
|
|
||||||
// The AMF0 undefined, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.8 undefined Type
|
|
||||||
type amf0Undefined struct {
|
|
||||||
amf0SingleMarkerObject
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAmf0Undefined() amf0Any {
|
|
||||||
v := amf0Undefined{}
|
|
||||||
v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerUndefined)
|
|
||||||
return &v
|
|
||||||
}
|
|
||||||
|
|
||||||
// The AMF0 boolean, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.3 Boolean Type
|
|
||||||
type amf0Boolean bool
|
|
||||||
|
|
||||||
func NewAmf0Boolean(b bool) amf0Any {
|
|
||||||
v := amf0Boolean(b)
|
|
||||||
return &v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Boolean) amf0Marker() amf0Marker {
|
|
||||||
return amf0MarkerBoolean
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Boolean) Size() int {
|
|
||||||
return int(2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Boolean) UnmarshalBinary(data []byte) (err error) {
|
|
||||||
var p []byte
|
|
||||||
if p = data; len(p) < 2 {
|
|
||||||
return errors.Errorf("require 2 bytes only %v", len(p))
|
|
||||||
}
|
|
||||||
if m := amf0Marker(p[0]); m != amf0MarkerBoolean {
|
|
||||||
return errors.Errorf("BOOL amf0Marker %v is illegal", m)
|
|
||||||
}
|
|
||||||
if p[1] == 0 {
|
|
||||||
*v = false
|
|
||||||
} else {
|
|
||||||
*v = true
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *amf0Boolean) MarshalBinary() (data []byte, err error) {
|
|
||||||
var b byte
|
|
||||||
if *v {
|
|
||||||
b = 1
|
|
||||||
}
|
|
||||||
return []byte{byte(amf0MarkerBoolean), b}, nil
|
|
||||||
}
|
|
||||||
1792
proxy/rtmp/rtmp.go
1792
proxy/rtmp/rtmp.go
File diff suppressed because it is too large
Load Diff
@@ -1,44 +0,0 @@
|
|||||||
// Copyright (c) 2025 Winlin
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"srs-proxy/errors"
|
|
||||||
"srs-proxy/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
func installSignals(ctx context.Context, cancel context.CancelFunc) {
|
|
||||||
sc := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for s := range sc {
|
|
||||||
logger.Df(ctx, "Got signal %v", s)
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func installForceQuit(ctx context.Context) error {
|
|
||||||
var forceTimeout time.Duration
|
|
||||||
if t, err := time.ParseDuration(envForceQuitTimeout()); err != nil {
|
|
||||||
return errors.Wrapf(err, "parse force timeout %v", envForceQuitTimeout())
|
|
||||||
} else {
|
|
||||||
forceTimeout = t
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-ctx.Done()
|
|
||||||
time.Sleep(forceTimeout)
|
|
||||||
logger.Wf(ctx, "Force to exit by timeout")
|
|
||||||
os.Exit(1)
|
|
||||||
}()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
553
proxy/srs.go
553
proxy/srs.go
@@ -1,553 +0,0 @@
|
|||||||
// Copyright (c) 2025 Winlin
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
// Use v8 because we use Go 1.16+, while v9 requires Go 1.18+
|
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
|
|
||||||
"srs-proxy/errors"
|
|
||||||
"srs-proxy/logger"
|
|
||||||
"srs-proxy/sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// If server heartbeat in this duration, it's alive.
|
|
||||||
const srsServerAliveDuration = 300 * time.Second
|
|
||||||
|
|
||||||
// If HLS streaming update in this duration, it's alive.
|
|
||||||
const srsHLSAliveDuration = 120 * time.Second
|
|
||||||
|
|
||||||
// If WebRTC streaming update in this duration, it's alive.
|
|
||||||
const srsRTCAliveDuration = 120 * time.Second
|
|
||||||
|
|
||||||
type SRSServer struct {
|
|
||||||
// The server IP.
|
|
||||||
IP string `json:"ip,omitempty"`
|
|
||||||
// The server device ID, configured by user.
|
|
||||||
DeviceID string `json:"device_id,omitempty"`
|
|
||||||
// The server id of SRS, store in file, may not change, mandatory.
|
|
||||||
ServerID string `json:"server_id,omitempty"`
|
|
||||||
// The service id of SRS, always change when restarted, mandatory.
|
|
||||||
ServiceID string `json:"service_id,omitempty"`
|
|
||||||
// The process id of SRS, always change when restarted, mandatory.
|
|
||||||
PID string `json:"pid,omitempty"`
|
|
||||||
// The RTMP listen endpoints.
|
|
||||||
RTMP []string `json:"rtmp,omitempty"`
|
|
||||||
// The HTTP Stream listen endpoints.
|
|
||||||
HTTP []string `json:"http,omitempty"`
|
|
||||||
// The HTTP API listen endpoints.
|
|
||||||
API []string `json:"api,omitempty"`
|
|
||||||
// The SRT server listen endpoints.
|
|
||||||
SRT []string `json:"srt,omitempty"`
|
|
||||||
// The RTC server listen endpoints.
|
|
||||||
RTC []string `json:"rtc,omitempty"`
|
|
||||||
// Last update time.
|
|
||||||
UpdatedAt time.Time `json:"update_at,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *SRSServer) ID() string {
|
|
||||||
return fmt.Sprintf("%v-%v-%v", v.ServerID, v.ServiceID, v.PID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *SRSServer) String() string {
|
|
||||||
return fmt.Sprintf("%v", v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *SRSServer) Format(f fmt.State, c rune) {
|
|
||||||
switch c {
|
|
||||||
case 'v', 's':
|
|
||||||
if f.Flag('+') {
|
|
||||||
var sb strings.Builder
|
|
||||||
sb.WriteString(fmt.Sprintf("pid=%v, server=%v, service=%v", v.PID, v.ServerID, v.ServiceID))
|
|
||||||
if v.DeviceID != "" {
|
|
||||||
sb.WriteString(fmt.Sprintf(", device=%v", v.DeviceID))
|
|
||||||
}
|
|
||||||
if len(v.RTMP) > 0 {
|
|
||||||
sb.WriteString(fmt.Sprintf(", rtmp=[%v]", strings.Join(v.RTMP, ",")))
|
|
||||||
}
|
|
||||||
if len(v.HTTP) > 0 {
|
|
||||||
sb.WriteString(fmt.Sprintf(", http=[%v]", strings.Join(v.HTTP, ",")))
|
|
||||||
}
|
|
||||||
if len(v.API) > 0 {
|
|
||||||
sb.WriteString(fmt.Sprintf(", api=[%v]", strings.Join(v.API, ",")))
|
|
||||||
}
|
|
||||||
if len(v.SRT) > 0 {
|
|
||||||
sb.WriteString(fmt.Sprintf(", srt=[%v]", strings.Join(v.SRT, ",")))
|
|
||||||
}
|
|
||||||
if len(v.RTC) > 0 {
|
|
||||||
sb.WriteString(fmt.Sprintf(", rtc=[%v]", strings.Join(v.RTC, ",")))
|
|
||||||
}
|
|
||||||
sb.WriteString(fmt.Sprintf(", update=%v", v.UpdatedAt.Format("2006-01-02 15:04:05.999")))
|
|
||||||
fmt.Fprintf(f, "SRS ip=%v, id=%v, %v", v.IP, v.ID(), sb.String())
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(f, "SRS ip=%v, id=%v", v.IP, v.ID())
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
fmt.Fprintf(f, "%v, fmt=%%%c", v, c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSRSServer(opts ...func(*SRSServer)) *SRSServer {
|
|
||||||
v := &SRSServer{}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(v)
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDefaultSRSForDebugging initialize the default SRS media server, for debugging only.
|
|
||||||
func NewDefaultSRSForDebugging() (*SRSServer, error) {
|
|
||||||
if envDefaultBackendEnabled() != "on" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if envDefaultBackendIP() == "" {
|
|
||||||
return nil, fmt.Errorf("empty default backend ip")
|
|
||||||
}
|
|
||||||
if envDefaultBackendRTMP() == "" {
|
|
||||||
return nil, fmt.Errorf("empty default backend rtmp")
|
|
||||||
}
|
|
||||||
|
|
||||||
server := NewSRSServer(func(srs *SRSServer) {
|
|
||||||
srs.IP = envDefaultBackendIP()
|
|
||||||
srs.RTMP = []string{envDefaultBackendRTMP()}
|
|
||||||
srs.ServerID = fmt.Sprintf("default-%v", logger.GenerateContextID())
|
|
||||||
srs.ServiceID = logger.GenerateContextID()
|
|
||||||
srs.PID = fmt.Sprintf("%v", os.Getpid())
|
|
||||||
srs.UpdatedAt = time.Now()
|
|
||||||
})
|
|
||||||
|
|
||||||
if envDefaultBackendHttp() != "" {
|
|
||||||
server.HTTP = []string{envDefaultBackendHttp()}
|
|
||||||
}
|
|
||||||
if envDefaultBackendAPI() != "" {
|
|
||||||
server.API = []string{envDefaultBackendAPI()}
|
|
||||||
}
|
|
||||||
if envDefaultBackendRTC() != "" {
|
|
||||||
server.RTC = []string{envDefaultBackendRTC()}
|
|
||||||
}
|
|
||||||
if envDefaultBackendSRT() != "" {
|
|
||||||
server.SRT = []string{envDefaultBackendSRT()}
|
|
||||||
}
|
|
||||||
return server, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SRSLoadBalancer is the interface to load balance the SRS servers.
|
|
||||||
type SRSLoadBalancer interface {
|
|
||||||
// Initialize the load balancer.
|
|
||||||
Initialize(ctx context.Context) error
|
|
||||||
// Update the backer server.
|
|
||||||
Update(ctx context.Context, server *SRSServer) error
|
|
||||||
// Pick a backend server for the specified stream URL.
|
|
||||||
Pick(ctx context.Context, streamURL string) (*SRSServer, error)
|
|
||||||
// Load or store the HLS streaming for the specified stream URL.
|
|
||||||
LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error)
|
|
||||||
// Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID.
|
|
||||||
LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error)
|
|
||||||
// Store the WebRTC streaming for the specified stream URL.
|
|
||||||
StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error
|
|
||||||
// Load the WebRTC streaming by ufrag, the ICE username.
|
|
||||||
LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// srsLoadBalancer is the global SRS load balancer.
|
|
||||||
var srsLoadBalancer SRSLoadBalancer
|
|
||||||
|
|
||||||
// srsMemoryLoadBalancer stores state in memory.
|
|
||||||
type srsMemoryLoadBalancer struct {
|
|
||||||
// All available SRS servers, key is server ID.
|
|
||||||
servers sync.Map[string, *SRSServer]
|
|
||||||
// The picked server to servce client by specified stream URL, key is stream url.
|
|
||||||
picked sync.Map[string, *SRSServer]
|
|
||||||
// The HLS streaming, key is stream URL.
|
|
||||||
hlsStreamURL sync.Map[string, *HLSPlayStream]
|
|
||||||
// The HLS streaming, key is SPBHID.
|
|
||||||
hlsSPBHID sync.Map[string, *HLSPlayStream]
|
|
||||||
// The WebRTC streaming, key is stream URL.
|
|
||||||
rtcStreamURL sync.Map[string, *RTCConnection]
|
|
||||||
// The WebRTC streaming, key is ufrag.
|
|
||||||
rtcUfrag sync.Map[string, *RTCConnection]
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMemoryLoadBalancer() SRSLoadBalancer {
|
|
||||||
return &srsMemoryLoadBalancer{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsMemoryLoadBalancer) Initialize(ctx context.Context) error {
|
|
||||||
if server, err := NewDefaultSRSForDebugging(); err != nil {
|
|
||||||
return errors.Wrapf(err, "initialize default SRS")
|
|
||||||
} else if server != nil {
|
|
||||||
if err := v.Update(ctx, server); err != nil {
|
|
||||||
return errors.Wrapf(err, "update default SRS %+v", server)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Keep alive.
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-time.After(30 * time.Second):
|
|
||||||
if err := v.Update(ctx, server); err != nil {
|
|
||||||
logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
logger.Df(ctx, "MemoryLB: Initialize default SRS media server, %+v", server)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsMemoryLoadBalancer) Update(ctx context.Context, server *SRSServer) error {
|
|
||||||
v.servers.Store(server.ID(), server)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsMemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) {
|
|
||||||
// Always proxy to the same server for the same stream URL.
|
|
||||||
if server, ok := v.picked.Load(streamURL); ok {
|
|
||||||
return server, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gather all servers that were alive within the last few seconds.
|
|
||||||
var servers []*SRSServer
|
|
||||||
v.servers.Range(func(key string, server *SRSServer) bool {
|
|
||||||
if time.Since(server.UpdatedAt) < srsServerAliveDuration {
|
|
||||||
servers = append(servers, server)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
|
|
||||||
// If no servers available, use all possible servers.
|
|
||||||
if len(servers) == 0 {
|
|
||||||
v.servers.Range(func(key string, server *SRSServer) bool {
|
|
||||||
servers = append(servers, server)
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// No server found, failed.
|
|
||||||
if len(servers) == 0 {
|
|
||||||
return nil, fmt.Errorf("no server available for %v", streamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pick a server randomly from servers.
|
|
||||||
server := servers[rand.Intn(len(servers))]
|
|
||||||
v.picked.Store(streamURL, server)
|
|
||||||
return server, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsMemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) {
|
|
||||||
// Load the HLS streaming for the SPBHID, for TS files.
|
|
||||||
if actual, ok := v.hlsSPBHID.Load(spbhid); !ok {
|
|
||||||
return nil, errors.Errorf("no HLS streaming for SPBHID %v", spbhid)
|
|
||||||
} else {
|
|
||||||
return actual, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) {
|
|
||||||
// Update the HLS streaming for the stream URL, for M3u8.
|
|
||||||
actual, _ := v.hlsStreamURL.LoadOrStore(streamURL, value)
|
|
||||||
if actual == nil {
|
|
||||||
return nil, errors.Errorf("load or store HLS streaming for %v failed", streamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the HLS streaming for the SPBHID, for TS files.
|
|
||||||
v.hlsSPBHID.Store(value.SRSProxyBackendHLSID, actual)
|
|
||||||
|
|
||||||
return actual, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsMemoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error {
|
|
||||||
// Update the WebRTC streaming for the stream URL.
|
|
||||||
v.rtcStreamURL.Store(streamURL, value)
|
|
||||||
|
|
||||||
// Update the WebRTC streaming for the ufrag.
|
|
||||||
v.rtcUfrag.Store(value.Ufrag, value)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsMemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) {
|
|
||||||
if actual, ok := v.rtcUfrag.Load(ufrag); !ok {
|
|
||||||
return nil, errors.Errorf("no WebRTC streaming for ufrag %v", ufrag)
|
|
||||||
} else {
|
|
||||||
return actual, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type srsRedisLoadBalancer struct {
|
|
||||||
// The redis client sdk.
|
|
||||||
rdb *redis.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRedisLoadBalancer() SRSLoadBalancer {
|
|
||||||
return &srsRedisLoadBalancer{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsRedisLoadBalancer) Initialize(ctx context.Context) error {
|
|
||||||
redisDatabase, err := strconv.Atoi(envRedisDB())
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", envRedisDB())
|
|
||||||
}
|
|
||||||
|
|
||||||
rdb := redis.NewClient(&redis.Options{
|
|
||||||
Addr: fmt.Sprintf("%v:%v", envRedisHost(), envRedisPort()),
|
|
||||||
Password: envRedisPassword(),
|
|
||||||
DB: redisDatabase,
|
|
||||||
})
|
|
||||||
v.rdb = rdb
|
|
||||||
|
|
||||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
|
||||||
return errors.Wrapf(err, "unable to connect to redis %v", rdb.String())
|
|
||||||
}
|
|
||||||
logger.Df(ctx, "RedisLB: connected to redis %v ok", rdb.String())
|
|
||||||
|
|
||||||
if server, err := NewDefaultSRSForDebugging(); err != nil {
|
|
||||||
return errors.Wrapf(err, "initialize default SRS")
|
|
||||||
} else if server != nil {
|
|
||||||
if err := v.Update(ctx, server); err != nil {
|
|
||||||
return errors.Wrapf(err, "update default SRS %+v", server)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Keep alive.
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-time.After(30 * time.Second):
|
|
||||||
if err := v.Update(ctx, server); err != nil {
|
|
||||||
logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
logger.Df(ctx, "RedisLB: Initialize default SRS media server, %+v", server)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsRedisLoadBalancer) Update(ctx context.Context, server *SRSServer) error {
|
|
||||||
b, err := json.Marshal(server)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "marshal server %+v", server)
|
|
||||||
}
|
|
||||||
|
|
||||||
key := v.redisKeyServer(server.ID())
|
|
||||||
if err = v.rdb.Set(ctx, key, b, srsServerAliveDuration).Err(); err != nil {
|
|
||||||
return errors.Wrapf(err, "set key=%v server %+v", key, server)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query all servers from redis, in json string.
|
|
||||||
var serverKeys []string
|
|
||||||
if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil {
|
|
||||||
if err := json.Unmarshal(b, &serverKeys); err != nil {
|
|
||||||
return errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check each server expiration, if not exists in redis, remove from servers.
|
|
||||||
for i := len(serverKeys) - 1; i >= 0; i-- {
|
|
||||||
if _, err := v.rdb.Get(ctx, serverKeys[i]).Bytes(); err != nil {
|
|
||||||
serverKeys = append(serverKeys[:i], serverKeys[i+1:]...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add server to servers if not exists.
|
|
||||||
var found bool
|
|
||||||
for _, serverKey := range serverKeys {
|
|
||||||
if serverKey == key {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
serverKeys = append(serverKeys, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update all servers to redis.
|
|
||||||
b, err = json.Marshal(serverKeys)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "marshal servers %+v", serverKeys)
|
|
||||||
}
|
|
||||||
if err = v.rdb.Set(ctx, v.redisKeyServers(), b, 0).Err(); err != nil {
|
|
||||||
return errors.Wrapf(err, "set key=%v servers %+v", v.redisKeyServers(), serverKeys)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsRedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) {
|
|
||||||
key := fmt.Sprintf("srs-proxy-url:%v", streamURL)
|
|
||||||
|
|
||||||
// Always proxy to the same server for the same stream URL.
|
|
||||||
if serverKey, err := v.rdb.Get(ctx, key).Result(); err == nil {
|
|
||||||
// If server not exists, ignore and pick another server for the stream URL.
|
|
||||||
if b, err := v.rdb.Get(ctx, serverKey).Bytes(); err == nil && len(b) > 0 {
|
|
||||||
var server SRSServer
|
|
||||||
if err := json.Unmarshal(b, &server); err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "unmarshal key=%v server %v", key, string(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: If server fail, we should migrate the streams to another server.
|
|
||||||
return &server, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query all servers from redis, in json string.
|
|
||||||
var serverKeys []string
|
|
||||||
if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil {
|
|
||||||
if err := json.Unmarshal(b, &serverKeys); err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// No server found, failed.
|
|
||||||
if len(serverKeys) == 0 {
|
|
||||||
return nil, fmt.Errorf("no server available for %v", streamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// All server should be alive, if not, should have been removed by redis. So we only
|
|
||||||
// random pick one that is always available.
|
|
||||||
var serverKey string
|
|
||||||
var server SRSServer
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
tryServerKey := serverKeys[rand.Intn(len(serverKeys))]
|
|
||||||
b, err := v.rdb.Get(ctx, tryServerKey).Bytes()
|
|
||||||
if err == nil && len(b) > 0 {
|
|
||||||
if err := json.Unmarshal(b, &server); err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "unmarshal key=%v server %v", serverKey, string(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
serverKey = tryServerKey
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if serverKey == "" {
|
|
||||||
return nil, errors.Errorf("no server available in %v for %v", serverKeys, streamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the picked server for the stream URL.
|
|
||||||
if err := v.rdb.Set(ctx, key, []byte(serverKey), 0).Err(); err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "set key=%v server %v", key, serverKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &server, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsRedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) {
|
|
||||||
key := v.redisKeySPBHID(spbhid)
|
|
||||||
|
|
||||||
b, err := v.rdb.Get(ctx, key).Bytes()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "get key=%v HLS", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
var actual HLSPlayStream
|
|
||||||
if err := json.Unmarshal(b, &actual); err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
return &actual, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) {
|
|
||||||
b, err := json.Marshal(value)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "marshal HLS %v", value)
|
|
||||||
}
|
|
||||||
|
|
||||||
key := v.redisKeyHLS(streamURL)
|
|
||||||
if err = v.rdb.Set(ctx, key, b, srsHLSAliveDuration).Err(); err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "set key=%v HLS %v", key, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
key2 := v.redisKeySPBHID(value.SRSProxyBackendHLSID)
|
|
||||||
if err := v.rdb.Set(ctx, key2, b, srsHLSAliveDuration).Err(); err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "set key=%v HLS %v", key2, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query the HLS streaming from redis.
|
|
||||||
b2, err := v.rdb.Get(ctx, key).Bytes()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "get key=%v HLS", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
var actual HLSPlayStream
|
|
||||||
if err := json.Unmarshal(b2, &actual); err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b2))
|
|
||||||
}
|
|
||||||
|
|
||||||
return &actual, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsRedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error {
|
|
||||||
b, err := json.Marshal(value)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "marshal WebRTC %v", value)
|
|
||||||
}
|
|
||||||
|
|
||||||
key := v.redisKeyRTC(streamURL)
|
|
||||||
if err = v.rdb.Set(ctx, key, b, srsRTCAliveDuration).Err(); err != nil {
|
|
||||||
return errors.Wrapf(err, "set key=%v WebRTC %v", key, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
key2 := v.redisKeyUfrag(value.Ufrag)
|
|
||||||
if err := v.rdb.Set(ctx, key2, b, srsRTCAliveDuration).Err(); err != nil {
|
|
||||||
return errors.Wrapf(err, "set key=%v WebRTC %v", key2, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) {
|
|
||||||
key := v.redisKeyUfrag(ufrag)
|
|
||||||
|
|
||||||
b, err := v.rdb.Get(ctx, key).Bytes()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "get key=%v WebRTC", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
var actual RTCConnection
|
|
||||||
if err := json.Unmarshal(b, &actual); err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "unmarshal key=%v WebRTC %v", key, string(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
return &actual, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsRedisLoadBalancer) redisKeyUfrag(ufrag string) string {
|
|
||||||
return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsRedisLoadBalancer) redisKeyRTC(streamURL string) string {
|
|
||||||
return fmt.Sprintf("srs-proxy-rtc:%v", streamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsRedisLoadBalancer) redisKeySPBHID(spbhid string) string {
|
|
||||||
return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsRedisLoadBalancer) redisKeyHLS(streamURL string) string {
|
|
||||||
return fmt.Sprintf("srs-proxy-hls:%v", streamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsRedisLoadBalancer) redisKeyServer(serverID string) string {
|
|
||||||
return fmt.Sprintf("srs-proxy-server:%v", serverID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsRedisLoadBalancer) redisKeyServers() string {
|
|
||||||
return fmt.Sprintf("srs-proxy-all-servers")
|
|
||||||
}
|
|
||||||
574
proxy/srt.go
574
proxy/srt.go
@@ -1,574 +0,0 @@
|
|||||||
// Copyright (c) 2025 Winlin
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
stdSync "sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"srs-proxy/errors"
|
|
||||||
"srs-proxy/logger"
|
|
||||||
"srs-proxy/sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// srsSRTServer is the proxy for SRS server via SRT. It will figure out which backend server to
|
|
||||||
// proxy to. It only parses the SRT handshake messages, parses the stream id, and proxy to the
|
|
||||||
// backend server.
|
|
||||||
type srsSRTServer struct {
|
|
||||||
// The UDP listener for SRT server.
|
|
||||||
listener *net.UDPConn
|
|
||||||
|
|
||||||
// The SRT connections, identify by the socket ID.
|
|
||||||
sockets sync.Map[uint32, *SRTConnection]
|
|
||||||
// The system start time.
|
|
||||||
start time.Time
|
|
||||||
|
|
||||||
// The wait group for server.
|
|
||||||
wg stdSync.WaitGroup
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSRSSRTServer(opts ...func(*srsSRTServer)) *srsSRTServer {
|
|
||||||
v := &srsSRTServer{
|
|
||||||
start: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(v)
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsSRTServer) Close() error {
|
|
||||||
if v.listener != nil {
|
|
||||||
v.listener.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
v.wg.Wait()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsSRTServer) Run(ctx context.Context) error {
|
|
||||||
// Parse address to listen.
|
|
||||||
endpoint := envSRTServer()
|
|
||||||
if !strings.Contains(endpoint, ":") {
|
|
||||||
endpoint = ":" + endpoint
|
|
||||||
}
|
|
||||||
|
|
||||||
saddr, err := net.ResolveUDPAddr("udp", endpoint)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "resolve udp addr %v", endpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
listener, err := net.ListenUDP("udp", saddr)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "listen udp %v", saddr)
|
|
||||||
}
|
|
||||||
v.listener = listener
|
|
||||||
logger.Df(ctx, "SRT server listen at %v", saddr)
|
|
||||||
|
|
||||||
// Consume all messages from UDP media transport.
|
|
||||||
v.wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer v.wg.Done()
|
|
||||||
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
buf := make([]byte, 4096)
|
|
||||||
n, caddr, err := v.listener.ReadFromUDP(buf)
|
|
||||||
if err != nil {
|
|
||||||
// TODO: If SRT server closed unexpectedly, we should notice the main loop to quit.
|
|
||||||
logger.Wf(ctx, "read from udp failed, err=%+v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil {
|
|
||||||
logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *srsSRTServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error {
|
|
||||||
socketID := srtParseSocketID(data)
|
|
||||||
|
|
||||||
var pkt *SRTHandshakePacket
|
|
||||||
if srtIsHandshake(data) {
|
|
||||||
pkt = &SRTHandshakePacket{}
|
|
||||||
if err := pkt.UnmarshalBinary(data); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if socketID == 0 {
|
|
||||||
socketID = pkt.SRTSocketID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, ok := v.sockets.LoadOrStore(socketID, NewSRTConnection(func(c *SRTConnection) {
|
|
||||||
c.ctx = logger.WithContext(ctx)
|
|
||||||
c.listenerUDP, c.socketID = v.listener, socketID
|
|
||||||
c.start = v.start
|
|
||||||
}))
|
|
||||||
|
|
||||||
ctx = conn.ctx
|
|
||||||
if !ok {
|
|
||||||
logger.Df(ctx, "Create new SRT connection skt=%v", socketID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if newSocketID, err := conn.HandlePacket(pkt, addr, data); err != nil {
|
|
||||||
return errors.Wrapf(err, "handle packet")
|
|
||||||
} else if newSocketID != 0 && newSocketID != socketID {
|
|
||||||
// The connection may use a new socket ID.
|
|
||||||
// TODO: FIXME: Should cleanup the dead SRT connection.
|
|
||||||
v.sockets.Store(newSocketID, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SRTConnection is an SRT connection proxy, for both caller and listener. It represents an SRT
|
|
||||||
// connection, identify by the socket ID.
|
|
||||||
//
|
|
||||||
// It's similar to RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is in
|
|
||||||
// the client request. The SRTConnection is stateless, and no need to sync between proxy servers.
|
|
||||||
//
|
|
||||||
// Unlike the WebRTC connection, SRTConnection does not support address changes. This means the
|
|
||||||
// client should never switch to another network or port. If this occurs, the client may be served
|
|
||||||
// by a different proxy server and fail because the other proxy server cannot identify the client.
|
|
||||||
type SRTConnection struct {
|
|
||||||
// The stream context for SRT connection.
|
|
||||||
ctx context.Context
|
|
||||||
|
|
||||||
// The current socket ID.
|
|
||||||
socketID uint32
|
|
||||||
|
|
||||||
// The UDP connection proxy to backend.
|
|
||||||
backendUDP *net.UDPConn
|
|
||||||
// The listener UDP connection, used to send messages to client.
|
|
||||||
listenerUDP *net.UDPConn
|
|
||||||
|
|
||||||
// Listener start time.
|
|
||||||
start time.Time
|
|
||||||
|
|
||||||
// Handshake packets with client.
|
|
||||||
handshake0 *SRTHandshakePacket
|
|
||||||
handshake1 *SRTHandshakePacket
|
|
||||||
handshake2 *SRTHandshakePacket
|
|
||||||
handshake3 *SRTHandshakePacket
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSRTConnection(opts ...func(*SRTConnection)) *SRTConnection {
|
|
||||||
v := &SRTConnection{}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(v)
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *SRTConnection) HandlePacket(pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) (uint32, error) {
|
|
||||||
ctx := v.ctx
|
|
||||||
|
|
||||||
// If not handshake, try to proxy to backend directly.
|
|
||||||
if pkt == nil {
|
|
||||||
// Proxy client message to backend.
|
|
||||||
if v.backendUDP != nil {
|
|
||||||
if _, err := v.backendUDP.Write(data); err != nil {
|
|
||||||
return v.socketID, errors.Wrapf(err, "write to backend")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return v.socketID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle handshake messages.
|
|
||||||
if err := v.handleHandshake(ctx, pkt, addr, data); err != nil {
|
|
||||||
return v.socketID, errors.Wrapf(err, "handle handshake %v", pkt)
|
|
||||||
}
|
|
||||||
|
|
||||||
return v.socketID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *SRTConnection) handleHandshake(ctx context.Context, pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) error {
|
|
||||||
// Handle handshake 0 and 1 messages.
|
|
||||||
if pkt.SynCookie == 0 {
|
|
||||||
// Save handshake 0 packet.
|
|
||||||
v.handshake0 = pkt
|
|
||||||
logger.Df(ctx, "SRT Handshake 0: %v", v.handshake0)
|
|
||||||
|
|
||||||
// Response handshake 1.
|
|
||||||
v.handshake1 = &SRTHandshakePacket{
|
|
||||||
ControlFlag: pkt.ControlFlag,
|
|
||||||
ControlType: 0,
|
|
||||||
SubType: 0,
|
|
||||||
AdditionalInfo: 0,
|
|
||||||
Timestamp: uint32(time.Since(v.start).Microseconds()),
|
|
||||||
SocketID: pkt.SRTSocketID,
|
|
||||||
Version: 5,
|
|
||||||
EncryptionField: 0,
|
|
||||||
ExtensionField: 0x4A17,
|
|
||||||
InitSequence: pkt.InitSequence,
|
|
||||||
MTU: pkt.MTU,
|
|
||||||
FlowWindow: pkt.FlowWindow,
|
|
||||||
HandshakeType: 1,
|
|
||||||
SRTSocketID: pkt.SRTSocketID,
|
|
||||||
SynCookie: 0x418d5e4e,
|
|
||||||
PeerIP: net.ParseIP("127.0.0.1"),
|
|
||||||
}
|
|
||||||
logger.Df(ctx, "SRT Handshake 1: %v", v.handshake1)
|
|
||||||
|
|
||||||
if b, err := v.handshake1.MarshalBinary(); err != nil {
|
|
||||||
return errors.Wrapf(err, "marshal handshake 1")
|
|
||||||
} else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil {
|
|
||||||
return errors.Wrapf(err, "write handshake 1")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle handshake 2 and 3 messages.
|
|
||||||
// Parse stream id from packet.
|
|
||||||
streamID, err := pkt.StreamID()
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "parse stream id")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save handshake packet.
|
|
||||||
v.handshake2 = pkt
|
|
||||||
logger.Df(ctx, "SRT Handshake 2: %v, sid=%v", v.handshake2, streamID)
|
|
||||||
|
|
||||||
// Start the UDP proxy to backend.
|
|
||||||
if err := v.connectBackend(ctx, streamID); err != nil {
|
|
||||||
return errors.Wrapf(err, "connect backend for %v", streamID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Proxy client message to backend.
|
|
||||||
if v.backendUDP == nil {
|
|
||||||
return errors.Errorf("no backend for %v", streamID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Proxy handshake 0 to backend server.
|
|
||||||
if b, err := v.handshake0.MarshalBinary(); err != nil {
|
|
||||||
return errors.Wrapf(err, "marshal handshake 0")
|
|
||||||
} else if _, err = v.backendUDP.Write(b); err != nil {
|
|
||||||
return errors.Wrapf(err, "write handshake 0")
|
|
||||||
}
|
|
||||||
logger.Df(ctx, "Proxy send handshake 0: %v", v.handshake0)
|
|
||||||
|
|
||||||
// Read handshake 1 from backend server.
|
|
||||||
b := make([]byte, 4096)
|
|
||||||
handshake1p := &SRTHandshakePacket{}
|
|
||||||
if nn, err := v.backendUDP.Read(b); err != nil {
|
|
||||||
return errors.Wrapf(err, "read handshake 1")
|
|
||||||
} else if err := handshake1p.UnmarshalBinary(b[:nn]); err != nil {
|
|
||||||
return errors.Wrapf(err, "unmarshal handshake 1")
|
|
||||||
}
|
|
||||||
logger.Df(ctx, "Proxy got handshake 1: %v", handshake1p)
|
|
||||||
|
|
||||||
// Proxy handshake 2 to backend server.
|
|
||||||
handshake2p := *v.handshake2
|
|
||||||
handshake2p.SynCookie = handshake1p.SynCookie
|
|
||||||
if b, err := handshake2p.MarshalBinary(); err != nil {
|
|
||||||
return errors.Wrapf(err, "marshal handshake 2")
|
|
||||||
} else if _, err = v.backendUDP.Write(b); err != nil {
|
|
||||||
return errors.Wrapf(err, "write handshake 2")
|
|
||||||
}
|
|
||||||
logger.Df(ctx, "Proxy send handshake 2: %v", handshake2p)
|
|
||||||
|
|
||||||
// Read handshake 3 from backend server.
|
|
||||||
handshake3p := &SRTHandshakePacket{}
|
|
||||||
if nn, err := v.backendUDP.Read(b); err != nil {
|
|
||||||
return errors.Wrapf(err, "read handshake 3")
|
|
||||||
} else if err := handshake3p.UnmarshalBinary(b[:nn]); err != nil {
|
|
||||||
return errors.Wrapf(err, "unmarshal handshake 3")
|
|
||||||
}
|
|
||||||
logger.Df(ctx, "Proxy got handshake 3: %v", handshake3p)
|
|
||||||
|
|
||||||
// Response handshake 3 to client.
|
|
||||||
v.handshake3 = &*handshake3p
|
|
||||||
v.handshake3.SynCookie = v.handshake1.SynCookie
|
|
||||||
v.socketID = handshake3p.SRTSocketID
|
|
||||||
logger.Df(ctx, "Handshake 3: %v", v.handshake3)
|
|
||||||
|
|
||||||
if b, err := v.handshake3.MarshalBinary(); err != nil {
|
|
||||||
return errors.Wrapf(err, "marshal handshake 3")
|
|
||||||
} else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil {
|
|
||||||
return errors.Wrapf(err, "write handshake 3")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start a goroutine to proxy message from backend to client.
|
|
||||||
// TODO: FIXME: Support close the connection when timeout or client disconnected.
|
|
||||||
go func() {
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
nn, err := v.backendUDP.Read(b)
|
|
||||||
if err != nil {
|
|
||||||
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
|
|
||||||
logger.Wf(ctx, "read from backend failed, err=%v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, err = v.listenerUDP.WriteToUDP(b[:nn], addr); err != nil {
|
|
||||||
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
|
|
||||||
logger.Wf(ctx, "write to client failed, err=%v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *SRTConnection) connectBackend(ctx context.Context, streamID string) error {
|
|
||||||
if v.backendUDP != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse stream id to host and resource.
|
|
||||||
host, resource, err := parseSRTStreamID(streamID)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "parse stream id %v", streamID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if host == "" {
|
|
||||||
host = "localhost"
|
|
||||||
}
|
|
||||||
|
|
||||||
streamURL, err := buildStreamURL(fmt.Sprintf("srt://%v/%v", host, resource))
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "build stream url %v", streamID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pick a backend SRS server to proxy the SRT stream.
|
|
||||||
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "pick backend for %v", streamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse UDP port from backend.
|
|
||||||
if len(backend.SRT) == 0 {
|
|
||||||
return errors.Errorf("no udp server %v for %v", backend, streamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, udpPort, err := parseListenEndpoint(backend.SRT[0])
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.SRT[0], backend, streamURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect to backend SRS server via UDP client.
|
|
||||||
// TODO: FIXME: Support close the connection when timeout or client disconnected.
|
|
||||||
backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)}
|
|
||||||
if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil {
|
|
||||||
return errors.Wrapf(err, "dial udp to %v of %v for %v", backendAddr, backend, streamURL)
|
|
||||||
} else {
|
|
||||||
v.backendUDP = backendUDP
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2
|
|
||||||
// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2.1
|
|
||||||
type SRTHandshakePacket struct {
|
|
||||||
// F: 1 bit. Packet Type Flag. The control packet has this flag set to
|
|
||||||
// "1". The data packet has this flag set to "0".
|
|
||||||
ControlFlag uint8
|
|
||||||
// Control Type: 15 bits. Control Packet Type. The use of these bits
|
|
||||||
// is determined by the control packet type definition.
|
|
||||||
// Handshake control packets (Control Type = 0x0000) are used to
|
|
||||||
// exchange peer configurations, to agree on connection parameters, and
|
|
||||||
// to establish a connection.
|
|
||||||
ControlType uint16
|
|
||||||
// Subtype: 16 bits. This field specifies an additional subtype for
|
|
||||||
// specific packets.
|
|
||||||
SubType uint16
|
|
||||||
// Type-specific Information: 32 bits. The use of this field depends on
|
|
||||||
// the particular control packet type. Handshake packets do not use
|
|
||||||
// this field.
|
|
||||||
AdditionalInfo uint32
|
|
||||||
// Timestamp: 32 bits.
|
|
||||||
Timestamp uint32
|
|
||||||
// Destination Socket ID: 32 bits.
|
|
||||||
SocketID uint32
|
|
||||||
|
|
||||||
// Version: 32 bits. A base protocol version number. Currently used
|
|
||||||
// values are 4 and 5. Values greater than 5 are reserved for future
|
|
||||||
// use.
|
|
||||||
Version uint32
|
|
||||||
// Encryption Field: 16 bits. Block cipher family and key size. The
|
|
||||||
// values of this field are described in Table 2. The default value
|
|
||||||
// is AES-128.
|
|
||||||
// 0 | No Encryption Advertised
|
|
||||||
// 2 | AES-128
|
|
||||||
// 3 | AES-192
|
|
||||||
// 4 | AES-256
|
|
||||||
EncryptionField uint16
|
|
||||||
// Extension Field: 16 bits. This field is message specific extension
|
|
||||||
// related to Handshake Type field. The value MUST be set to 0
|
|
||||||
// except for the following cases. (1) If the handshake control
|
|
||||||
// packet is the INDUCTION message, this field is sent back by the
|
|
||||||
// Listener. (2) In the case of a CONCLUSION message, this field
|
|
||||||
// value should contain a combination of Extension Type values.
|
|
||||||
// 0x00000001 | HSREQ
|
|
||||||
// 0x00000002 | KMREQ
|
|
||||||
// 0x00000004 | CONFIG
|
|
||||||
// 0x4A17 if HandshakeType is INDUCTION, see https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-4.3.1.1
|
|
||||||
ExtensionField uint16
|
|
||||||
// Initial Packet Sequence Number: 32 bits. The sequence number of the
|
|
||||||
// very first data packet to be sent.
|
|
||||||
InitSequence uint32
|
|
||||||
// Maximum Transmission Unit Size: 32 bits. This value is typically set
|
|
||||||
// to 1500, which is the default Maximum Transmission Unit (MTU) size
|
|
||||||
// for Ethernet, but can be less.
|
|
||||||
MTU uint32
|
|
||||||
// Maximum Flow Window Size: 32 bits. The value of this field is the
|
|
||||||
// maximum number of data packets allowed to be "in flight" (i.e. the
|
|
||||||
// number of sent packets for which an ACK control packet has not yet
|
|
||||||
// been received).
|
|
||||||
FlowWindow uint32
|
|
||||||
// Handshake Type: 32 bits. This field indicates the handshake packet
|
|
||||||
// type.
|
|
||||||
// 0xFFFFFFFD | DONE
|
|
||||||
// 0xFFFFFFFE | AGREEMENT
|
|
||||||
// 0xFFFFFFFF | CONCLUSION
|
|
||||||
// 0x00000000 | WAVEHAND
|
|
||||||
// 0x00000001 | INDUCTION
|
|
||||||
HandshakeType uint32
|
|
||||||
// SRT Socket ID: 32 bits. This field holds the ID of the source SRT
|
|
||||||
// socket from which a handshake packet is issued.
|
|
||||||
SRTSocketID uint32
|
|
||||||
// SYN Cookie: 32 bits. Randomized value for processing a handshake.
|
|
||||||
// The value of this field is specified by the handshake message
|
|
||||||
// type.
|
|
||||||
SynCookie uint32
|
|
||||||
// Peer IP Address: 128 bits. IPv4 or IPv6 address of the packet's
|
|
||||||
// sender. The value consists of four 32-bit fields.
|
|
||||||
PeerIP net.IP
|
|
||||||
// Extensions.
|
|
||||||
// Extension Type: 16 bits. The value of this field is used to process
|
|
||||||
// an integrated handshake. Each extension can have a pair of
|
|
||||||
// request and response types.
|
|
||||||
// Extension Length: 16 bits. The length of the Extension Contents
|
|
||||||
// field in four-byte blocks.
|
|
||||||
// Extension Contents: variable length. The payload of the extension.
|
|
||||||
ExtraData []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *SRTHandshakePacket) IsData() bool {
|
|
||||||
return v.ControlFlag == 0x00
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *SRTHandshakePacket) IsControl() bool {
|
|
||||||
return v.ControlFlag == 0x80
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *SRTHandshakePacket) IsHandshake() bool {
|
|
||||||
return v.IsControl() && v.ControlType == 0x00 && v.SubType == 0x00
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *SRTHandshakePacket) StreamID() (string, error) {
|
|
||||||
p := v.ExtraData
|
|
||||||
for {
|
|
||||||
if len(p) < 2 {
|
|
||||||
return "", errors.Errorf("Require 2 bytes, actual=%v, extra=%v", len(p), len(v.ExtraData))
|
|
||||||
}
|
|
||||||
|
|
||||||
extType := binary.BigEndian.Uint16(p)
|
|
||||||
extSize := binary.BigEndian.Uint16(p[2:])
|
|
||||||
p = p[4:]
|
|
||||||
|
|
||||||
if len(p) < int(extSize*4) {
|
|
||||||
return "", errors.Errorf("Require %v bytes, actual=%v, extra=%v", extSize*4, len(p), len(v.ExtraData))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ignore other packets except stream id.
|
|
||||||
if extType != 0x05 {
|
|
||||||
p = p[extSize*4:]
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// We must copy it, because we will decode the stream id.
|
|
||||||
data := append([]byte{}, p[:extSize*4]...)
|
|
||||||
|
|
||||||
// Reverse the stream id encoded in little-endian to big-endian.
|
|
||||||
for i := 0; i < len(data); i += 4 {
|
|
||||||
value := binary.LittleEndian.Uint32(data[i:])
|
|
||||||
binary.BigEndian.PutUint32(data[i:], value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Trim the trailing zero bytes.
|
|
||||||
data = bytes.TrimRight(data, "\x00")
|
|
||||||
return string(data), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *SRTHandshakePacket) String() string {
|
|
||||||
return fmt.Sprintf("Control=%v, CType=%v, SType=%v, Timestamp=%v, SocketID=%v, Version=%v, Encrypt=%v, Extension=%v, InitSequence=%v, MTU=%v, FlowWnd=%v, HSType=%v, SRTSocketID=%v, Cookie=%v, Peer=%vB, Extra=%vB",
|
|
||||||
v.IsControl(), v.ControlType, v.SubType, v.Timestamp, v.SocketID, v.Version, v.EncryptionField, v.ExtensionField, v.InitSequence, v.MTU, v.FlowWindow, v.HandshakeType, v.SRTSocketID, v.SynCookie, len(v.PeerIP), len(v.ExtraData))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *SRTHandshakePacket) UnmarshalBinary(b []byte) error {
|
|
||||||
if len(b) < 4 {
|
|
||||||
return errors.Errorf("Invalid packet length %v", len(b))
|
|
||||||
}
|
|
||||||
v.ControlFlag = b[0] & 0x80
|
|
||||||
v.ControlType = binary.BigEndian.Uint16(b[0:2]) & 0x7fff
|
|
||||||
v.SubType = binary.BigEndian.Uint16(b[2:4])
|
|
||||||
|
|
||||||
if len(b) < 64 {
|
|
||||||
return errors.Errorf("Invalid packet length %v", len(b))
|
|
||||||
}
|
|
||||||
v.AdditionalInfo = binary.BigEndian.Uint32(b[4:])
|
|
||||||
v.Timestamp = binary.BigEndian.Uint32(b[8:])
|
|
||||||
v.SocketID = binary.BigEndian.Uint32(b[12:])
|
|
||||||
v.Version = binary.BigEndian.Uint32(b[16:])
|
|
||||||
v.EncryptionField = binary.BigEndian.Uint16(b[20:])
|
|
||||||
v.ExtensionField = binary.BigEndian.Uint16(b[22:])
|
|
||||||
v.InitSequence = binary.BigEndian.Uint32(b[24:])
|
|
||||||
v.MTU = binary.BigEndian.Uint32(b[28:])
|
|
||||||
v.FlowWindow = binary.BigEndian.Uint32(b[32:])
|
|
||||||
v.HandshakeType = binary.BigEndian.Uint32(b[36:])
|
|
||||||
v.SRTSocketID = binary.BigEndian.Uint32(b[40:])
|
|
||||||
v.SynCookie = binary.BigEndian.Uint32(b[44:])
|
|
||||||
|
|
||||||
// Only support IPv4.
|
|
||||||
v.PeerIP = net.IPv4(b[51], b[50], b[49], b[48])
|
|
||||||
|
|
||||||
v.ExtraData = b[64:]
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *SRTHandshakePacket) MarshalBinary() ([]byte, error) {
|
|
||||||
b := make([]byte, 64+len(v.ExtraData))
|
|
||||||
binary.BigEndian.PutUint16(b, uint16(v.ControlFlag)<<8|v.ControlType)
|
|
||||||
binary.BigEndian.PutUint16(b[2:], v.SubType)
|
|
||||||
binary.BigEndian.PutUint32(b[4:], v.AdditionalInfo)
|
|
||||||
binary.BigEndian.PutUint32(b[8:], v.Timestamp)
|
|
||||||
binary.BigEndian.PutUint32(b[12:], v.SocketID)
|
|
||||||
binary.BigEndian.PutUint32(b[16:], v.Version)
|
|
||||||
binary.BigEndian.PutUint16(b[20:], v.EncryptionField)
|
|
||||||
binary.BigEndian.PutUint16(b[22:], v.ExtensionField)
|
|
||||||
binary.BigEndian.PutUint32(b[24:], v.InitSequence)
|
|
||||||
binary.BigEndian.PutUint32(b[28:], v.MTU)
|
|
||||||
binary.BigEndian.PutUint32(b[32:], v.FlowWindow)
|
|
||||||
binary.BigEndian.PutUint32(b[36:], v.HandshakeType)
|
|
||||||
binary.BigEndian.PutUint32(b[40:], v.SRTSocketID)
|
|
||||||
binary.BigEndian.PutUint32(b[44:], v.SynCookie)
|
|
||||||
|
|
||||||
// Only support IPv4.
|
|
||||||
ip := v.PeerIP.To4()
|
|
||||||
b[48] = ip[3]
|
|
||||||
b[49] = ip[2]
|
|
||||||
b[50] = ip[1]
|
|
||||||
b[51] = ip[0]
|
|
||||||
|
|
||||||
if len(v.ExtraData) > 0 {
|
|
||||||
copy(b[64:], v.ExtraData)
|
|
||||||
}
|
|
||||||
|
|
||||||
return b, nil
|
|
||||||
}
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
// Copyright (c) 2025 Winlin
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
package sync
|
|
||||||
|
|
||||||
import "sync"
|
|
||||||
|
|
||||||
type Map[K comparable, V any] struct {
|
|
||||||
m sync.Map
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Map[K, V]) Delete(key K) {
|
|
||||||
m.m.Delete(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Map[K, V]) Load(key K) (value V, ok bool) {
|
|
||||||
v, ok := m.m.Load(key)
|
|
||||||
if !ok {
|
|
||||||
return value, ok
|
|
||||||
}
|
|
||||||
return v.(V), ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
|
|
||||||
v, loaded := m.m.LoadAndDelete(key)
|
|
||||||
if !loaded {
|
|
||||||
return value, loaded
|
|
||||||
}
|
|
||||||
return v.(V), loaded
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
|
|
||||||
a, loaded := m.m.LoadOrStore(key, value)
|
|
||||||
return a.(V), loaded
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Map[K, V]) Range(f func(key K, value V) bool) {
|
|
||||||
m.m.Range(func(key, value any) bool {
|
|
||||||
return f(key.(K), value.(V))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Map[K, V]) Store(key K, value V) {
|
|
||||||
m.m.Store(key, value)
|
|
||||||
}
|
|
||||||
276
proxy/utils.go
276
proxy/utils.go
@@ -1,276 +0,0 @@
|
|||||||
// Copyright (c) 2025 Winlin
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
|
||||||
stdErr "errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"path"
|
|
||||||
"reflect"
|
|
||||||
"regexp"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"srs-proxy/errors"
|
|
||||||
"srs-proxy/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, data any) {
|
|
||||||
w.Header().Set("Server", fmt.Sprintf("%v/%v", Signature(), Version()))
|
|
||||||
|
|
||||||
b, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
|
||||||
apiError(ctx, w, r, errors.Wrapf(err, "marshal %v %v", reflect.TypeOf(data), data))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func apiError(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) {
|
|
||||||
logger.Wf(ctx, "HTTP API error %+v", err)
|
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
fmt.Fprintln(w, fmt.Sprintf("%v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
func apiCORS(ctx context.Context, w http.ResponseWriter, r *http.Request) bool {
|
|
||||||
// Always support CORS. Note that browser may send origin header for m3u8, but no origin header
|
|
||||||
// for ts. So we always response CORS header.
|
|
||||||
if true {
|
|
||||||
// SRS does not need cookie or credentials, so we disable CORS credentials, and use * for CORS origin,
|
|
||||||
// headers, expose headers and methods.
|
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
||||||
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
|
|
||||||
w.Header().Set("Access-Control-Allow-Headers", "*")
|
|
||||||
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
|
|
||||||
w.Header().Set("Access-Control-Allow-Methods", "*")
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Method == http.MethodOptions {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseGracefullyQuitTimeout() (time.Duration, error) {
|
|
||||||
if t, err := time.ParseDuration(envGraceQuitTimeout()); err != nil {
|
|
||||||
return 0, errors.Wrapf(err, "parse duration %v", envGraceQuitTimeout())
|
|
||||||
} else {
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseBody read the body from r, and unmarshal JSON to v.
|
|
||||||
func ParseBody(r io.ReadCloser, v interface{}) error {
|
|
||||||
b, err := ioutil.ReadAll(r)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "read body")
|
|
||||||
}
|
|
||||||
defer r.Close()
|
|
||||||
|
|
||||||
if len(b) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(b, v); err != nil {
|
|
||||||
return errors.Wrapf(err, "json unmarshal %v", string(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildStreamURL build as vhost/app/stream for stream URL r.
|
|
||||||
func buildStreamURL(r string) (string, error) {
|
|
||||||
u, err := url.Parse(r)
|
|
||||||
if err != nil {
|
|
||||||
return "", errors.Wrapf(err, "parse url %v", r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If not domain or ip in hostname, it's __defaultVhost__.
|
|
||||||
defaultVhost := !strings.Contains(u.Hostname(), ".")
|
|
||||||
|
|
||||||
// If hostname is actually an IP address, it's __defaultVhost__.
|
|
||||||
if ip := net.ParseIP(u.Hostname()); ip.To4() != nil {
|
|
||||||
defaultVhost = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if defaultVhost {
|
|
||||||
return fmt.Sprintf("__defaultVhost__%v", u.Path), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ignore port, only use hostname as vhost.
|
|
||||||
return fmt.Sprintf("%v%v", u.Hostname(), u.Path), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// isPeerClosedError indicates whether peer object closed the connection.
|
|
||||||
func isPeerClosedError(err error) bool {
|
|
||||||
causeErr := errors.Cause(err)
|
|
||||||
|
|
||||||
if stdErr.Is(causeErr, io.EOF) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if stdErr.Is(causeErr, syscall.EPIPE) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if netErr, ok := causeErr.(*net.OpError); ok {
|
|
||||||
if sysErr, ok := netErr.Err.(*os.SyscallError); ok {
|
|
||||||
if stdErr.Is(sysErr.Err, syscall.ECONNRESET) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// convertURLToStreamURL convert the URL in HTTP request to special URLs. The unifiedURL is the URL
|
|
||||||
// in unified, foramt as scheme://vhost/app/stream without extensions. While the fullURL is the unifiedURL
|
|
||||||
// with extension.
|
|
||||||
func convertURLToStreamURL(r *http.Request) (unifiedURL, fullURL string) {
|
|
||||||
scheme := "http"
|
|
||||||
if r.TLS != nil {
|
|
||||||
scheme = "https"
|
|
||||||
}
|
|
||||||
|
|
||||||
hostname := "__defaultVhost__"
|
|
||||||
if strings.Contains(r.Host, ":") {
|
|
||||||
if v, _, err := net.SplitHostPort(r.Host); err == nil {
|
|
||||||
hostname = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var appStream, streamExt string
|
|
||||||
|
|
||||||
// Parse app/stream from query string.
|
|
||||||
q := r.URL.Query()
|
|
||||||
if app := q.Get("app"); app != "" {
|
|
||||||
appStream = "/" + app
|
|
||||||
}
|
|
||||||
if stream := q.Get("stream"); stream != "" {
|
|
||||||
appStream = fmt.Sprintf("%v/%v", appStream, stream)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse app/stream from path.
|
|
||||||
if appStream == "" {
|
|
||||||
streamExt = path.Ext(r.URL.Path)
|
|
||||||
appStream = strings.TrimSuffix(r.URL.Path, streamExt)
|
|
||||||
}
|
|
||||||
|
|
||||||
unifiedURL = fmt.Sprintf("%v://%v%v", scheme, hostname, appStream)
|
|
||||||
fullURL = fmt.Sprintf("%v%v", unifiedURL, streamExt)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// rtcIsSTUN returns true if data of UDP payload is a STUN packet.
|
|
||||||
func rtcIsSTUN(data []byte) bool {
|
|
||||||
return len(data) > 0 && (data[0] == 0 || data[0] == 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// rtcIsRTPOrRTCP returns true if data of UDP payload is a RTP or RTCP packet.
|
|
||||||
func rtcIsRTPOrRTCP(data []byte) bool {
|
|
||||||
return len(data) >= 12 && (data[0]&0xC0) == 0x80
|
|
||||||
}
|
|
||||||
|
|
||||||
// srtIsHandshake returns true if data of UDP payload is a SRT handshake packet.
|
|
||||||
func srtIsHandshake(data []byte) bool {
|
|
||||||
return len(data) >= 4 && binary.BigEndian.Uint32(data) == 0x80000000
|
|
||||||
}
|
|
||||||
|
|
||||||
// srtParseSocketID parse the socket id from the SRT packet.
|
|
||||||
func srtParseSocketID(data []byte) uint32 {
|
|
||||||
if len(data) >= 16 {
|
|
||||||
return binary.BigEndian.Uint32(data[12:])
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseIceUfragPwd parse the ice-ufrag and ice-pwd from the SDP.
|
|
||||||
func parseIceUfragPwd(sdp string) (ufrag, pwd string, err error) {
|
|
||||||
if true {
|
|
||||||
ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`)
|
|
||||||
ufragMatch := ufragRe.FindStringSubmatch(sdp)
|
|
||||||
if len(ufragMatch) <= 1 {
|
|
||||||
return "", "", errors.Errorf("no ice-ufrag in sdp %v", sdp)
|
|
||||||
}
|
|
||||||
ufrag = ufragMatch[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
if true {
|
|
||||||
pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`)
|
|
||||||
pwdMatch := pwdRe.FindStringSubmatch(sdp)
|
|
||||||
if len(pwdMatch) <= 1 {
|
|
||||||
return "", "", errors.Errorf("no ice-pwd in sdp %v", sdp)
|
|
||||||
}
|
|
||||||
pwd = pwdMatch[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
return ufrag, pwd, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseSRTStreamID parse the SRT stream id to host(optional) and resource(required).
|
|
||||||
// See https://ossrs.io/lts/en-us/docs/v7/doc/srt#srt-url
|
|
||||||
func parseSRTStreamID(sid string) (host, resource string, err error) {
|
|
||||||
if true {
|
|
||||||
hostRe := regexp.MustCompile(`h=([^,]+)`)
|
|
||||||
hostMatch := hostRe.FindStringSubmatch(sid)
|
|
||||||
if len(hostMatch) > 1 {
|
|
||||||
host = hostMatch[1]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if true {
|
|
||||||
resourceRe := regexp.MustCompile(`r=([^,]+)`)
|
|
||||||
resourceMatch := resourceRe.FindStringSubmatch(sid)
|
|
||||||
if len(resourceMatch) <= 1 {
|
|
||||||
return "", "", errors.Errorf("no resource in sid %v", sid)
|
|
||||||
}
|
|
||||||
resource = resourceMatch[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
return host, resource, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseListenEndpoint parse the listen endpoint as:
|
|
||||||
// port The tcp listen port, like 1935.
|
|
||||||
// protocol://ip:port The listen endpoint, like tcp://:1935 or tcp://0.0.0.0:1935
|
|
||||||
func parseListenEndpoint(ep string) (protocol string, ip net.IP, port uint16, err error) {
|
|
||||||
// If no colon in ep, it's port in string.
|
|
||||||
if !strings.Contains(ep, ":") {
|
|
||||||
if p, err := strconv.Atoi(ep); err != nil {
|
|
||||||
return "", nil, 0, errors.Wrapf(err, "parse port %v", ep)
|
|
||||||
} else {
|
|
||||||
return "tcp", nil, uint16(p), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Must be protocol://ip:port schema.
|
|
||||||
parts := strings.Split(ep, ":")
|
|
||||||
if len(parts) != 3 {
|
|
||||||
return "", nil, 0, errors.Errorf("invalid endpoint %v", ep)
|
|
||||||
}
|
|
||||||
|
|
||||||
if p, err := strconv.Atoi(parts[2]); err != nil {
|
|
||||||
return "", nil, 0, errors.Wrapf(err, "parse port %v", parts[2])
|
|
||||||
} else {
|
|
||||||
return parts[0], net.ParseIP(parts[1]), uint16(p), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
// Copyright (c) 2025 Winlin
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
package main
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
func VersionMajor() int {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// VersionMinor specifies the typical version of SRS we adapt to.
|
|
||||||
func VersionMinor() int {
|
|
||||||
return 5
|
|
||||||
}
|
|
||||||
|
|
||||||
func VersionRevision() int {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func Version() string {
|
|
||||||
return fmt.Sprintf("%v.%v.%v", VersionMajor(), VersionMinor(), VersionRevision())
|
|
||||||
}
|
|
||||||
|
|
||||||
func Signature() string {
|
|
||||||
return "SRSProxy"
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user