Migrate proxy to ossrs/proxy-go repository.

This commit is contained in:
winlin
2025-06-23 10:39:57 -04:00
parent 8538575915
commit acb9b88566
24 changed files with 6 additions and 6949 deletions

4
proxy/.gitignore vendored
View File

@@ -1,4 +0,0 @@
.idea
srs-proxy
.env
.go-formarted

View File

@@ -1,23 +0,0 @@
.PHONY: all build test fmt clean run
all: build
build: fmt ./srs-proxy
./srs-proxy: *.go
go build -o srs-proxy .
test:
go test ./...
fmt: ./.go-formarted
./.go-formarted: *.go
touch .go-formarted
go fmt ./...
clean:
rm -f srs-proxy .go-formarted
run: fmt
go run .

6
proxy/README.md Normal file
View File

@@ -0,0 +1,6 @@
# Proxy
Migrated to below repositoties:
* [proxy-go](https://github.com/ossrs/proxy-go) An common proxy server for any media servers with RTMP/SRT/HLS/HTTP-FLV and WebRTC/WHIP/WHEP protocols support.

View File

@@ -1,272 +0,0 @@
// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"fmt"
"net/http"
"os"
"strings"
"sync"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
)
// srsHTTPAPIServer is the proxy for SRS HTTP API, to proxy the WebRTC HTTP API like WHIP and WHEP,
// to proxy other HTTP API of SRS like the streams and clients, etc.
type srsHTTPAPIServer struct {
// The underlayer HTTP server.
server *http.Server
// The WebRTC server.
rtc *srsWebRTCServer
// The gracefully quit timeout, wait server to quit.
gracefulQuitTimeout time.Duration
// The wait group for all goroutines.
wg sync.WaitGroup
}
func NewSRSHTTPAPIServer(opts ...func(*srsHTTPAPIServer)) *srsHTTPAPIServer {
v := &srsHTTPAPIServer{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *srsHTTPAPIServer) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
v.wg.Wait()
return nil
}
func (v *srsHTTPAPIServer) Run(ctx context.Context) error {
// Parse address to listen.
addr := envHttpAPI()
if !strings.Contains(addr, ":") {
addr = ":" + addr
}
// Create server and handler.
mux := http.NewServeMux()
v.server = &http.Server{Addr: addr, Handler: mux}
logger.Df(ctx, "HTTP API server listen at %v", addr)
// Shutdown the server gracefully when quiting.
go func() {
ctxParent := ctx
<-ctxParent.Done()
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
}()
// The basic version handler, also can be used as health check API.
logger.Df(ctx, "Handle /api/v1/versions by %v", addr)
mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) {
apiResponse(ctx, w, r, map[string]string{
"signature": Signature(),
"version": Version(),
})
})
// The WebRTC WHIP API handler.
logger.Df(ctx, "Handle /rtc/v1/whip/ by %v", addr)
mux.HandleFunc("/rtc/v1/whip/", func(w http.ResponseWriter, r *http.Request) {
if err := v.rtc.HandleApiForWHIP(ctx, w, r); err != nil {
apiError(ctx, w, r, err)
}
})
// The WebRTC WHEP API handler.
logger.Df(ctx, "Handle /rtc/v1/whep/ by %v", addr)
mux.HandleFunc("/rtc/v1/whep/", func(w http.ResponseWriter, r *http.Request) {
if err := v.rtc.HandleApiForWHEP(ctx, w, r); err != nil {
apiError(ctx, w, r, err)
}
})
// Run HTTP API server.
v.wg.Add(1)
go func() {
defer v.wg.Done()
err := v.server.ListenAndServe()
if err != nil {
if ctx.Err() != context.Canceled {
// TODO: If HTTP API server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "HTTP API accept err %+v", err)
} else {
logger.Df(ctx, "HTTP API server done")
}
}
}()
return nil
}
// systemAPI is the system HTTP API of the proxy server, for SRS media server to register the service
// to proxy server. It also provides some other system APIs like the status of proxy server, like exporter
// for Prometheus metrics.
type systemAPI struct {
// The underlayer HTTP server.
server *http.Server
// The gracefully quit timeout, wait server to quit.
gracefulQuitTimeout time.Duration
// The wait group for all goroutines.
wg sync.WaitGroup
}
func NewSystemAPI(opts ...func(*systemAPI)) *systemAPI {
v := &systemAPI{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *systemAPI) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
v.wg.Wait()
return nil
}
func (v *systemAPI) Run(ctx context.Context) error {
// Parse address to listen.
addr := envSystemAPI()
if !strings.Contains(addr, ":") {
addr = ":" + addr
}
// Create server and handler.
mux := http.NewServeMux()
v.server = &http.Server{Addr: addr, Handler: mux}
logger.Df(ctx, "System API server listen at %v", addr)
// Shutdown the server gracefully when quiting.
go func() {
ctxParent := ctx
<-ctxParent.Done()
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
}()
// The basic version handler, also can be used as health check API.
logger.Df(ctx, "Handle /api/v1/versions by %v", addr)
mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) {
apiResponse(ctx, w, r, map[string]string{
"signature": Signature(),
"version": Version(),
})
})
// The register service for SRS media servers.
logger.Df(ctx, "Handle /api/v1/srs/register by %v", addr)
mux.HandleFunc("/api/v1/srs/register", func(w http.ResponseWriter, r *http.Request) {
if err := func() error {
var deviceID, ip, serverID, serviceID, pid string
var rtmp, stream, api, srt, rtc []string
if err := ParseBody(r.Body, &struct {
// The IP of SRS, mandatory.
IP *string `json:"ip"`
// The server id of SRS, store in file, may not change, mandatory.
ServerID *string `json:"server"`
// The service id of SRS, always change when restarted, mandatory.
ServiceID *string `json:"service"`
// The process id of SRS, always change when restarted, mandatory.
PID *string `json:"pid"`
// The RTMP listen endpoints, mandatory.
RTMP *[]string `json:"rtmp"`
// The HTTP Stream listen endpoints, optional.
HTTP *[]string `json:"http"`
// The API listen endpoints, optional.
API *[]string `json:"api"`
// The SRT listen endpoints, optional.
SRT *[]string `json:"srt"`
// The RTC listen endpoints, optional.
RTC *[]string `json:"rtc"`
// The device id of SRS, optional.
DeviceID *string `json:"device_id"`
}{
IP: &ip, DeviceID: &deviceID,
ServerID: &serverID, ServiceID: &serviceID, PID: &pid,
RTMP: &rtmp, HTTP: &stream, API: &api, SRT: &srt, RTC: &rtc,
}); err != nil {
return errors.Wrapf(err, "parse body")
}
if ip == "" {
return errors.Errorf("empty ip")
}
if serverID == "" {
return errors.Errorf("empty server")
}
if serviceID == "" {
return errors.Errorf("empty service")
}
if pid == "" {
return errors.Errorf("empty pid")
}
if len(rtmp) == 0 {
return errors.Errorf("empty rtmp")
}
server := NewSRSServer(func(srs *SRSServer) {
srs.IP, srs.DeviceID = ip, deviceID
srs.ServerID, srs.ServiceID, srs.PID = serverID, serviceID, pid
srs.RTMP, srs.HTTP, srs.API = rtmp, stream, api
srs.SRT, srs.RTC = srt, rtc
srs.UpdatedAt = time.Now()
})
if err := srsLoadBalancer.Update(ctx, server); err != nil {
return errors.Wrapf(err, "update SRS server %+v", server)
}
logger.Df(ctx, "Register SRS media server, %+v", server)
return nil
}(); err != nil {
apiError(ctx, w, r, err)
}
type Response struct {
Code int `json:"code"`
PID string `json:"pid"`
}
apiResponse(ctx, w, r, &Response{
Code: 0, PID: fmt.Sprintf("%v", os.Getpid()),
})
})
// Run System API server.
v.wg.Add(1)
go func() {
defer v.wg.Done()
err := v.server.ListenAndServe()
if err != nil {
if ctx.Err() != context.Canceled {
// TODO: If System API server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "System API accept err %+v", err)
} else {
logger.Df(ctx, "System API server done")
}
}
}()
return nil
}

View File

@@ -1,20 +0,0 @@
// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"net/http"
"srs-proxy/logger"
)
func handleGoPprof(ctx context.Context) {
if addr := envGoPprof(); addr != "" {
go func() {
logger.Df(ctx, "Start Go pprof at %v", addr)
http.ListenAndServe(addr, nil)
}()
}
}

View File

@@ -1,226 +0,0 @@
// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"io"
"os"
"path"
"strings"
"srs-proxy/errors"
"srs-proxy/logger"
)
// loadEnvFile loads the environment variables from file. Note that we only use .env file.
func loadEnvFile(ctx context.Context) error {
workDir, err := os.Getwd()
if err != nil {
return errors.Wrapf(err, "getpwd")
}
envFile := path.Join(workDir, ".env")
if _, err := os.Stat(envFile); err != nil {
return nil
}
file, err := os.Open(envFile)
if err != nil {
return errors.Wrapf(err, "open %v", envFile)
}
defer file.Close()
b, err := io.ReadAll(file)
if err != nil {
return errors.Wrapf(err, "read %v", envFile)
}
lines := strings.Split(strings.Replace(string(b), "\r\n", "\n", -1), "\n")
logger.Df(ctx, "load env file %v, lines=%v", envFile, len(lines))
for _, line := range lines {
if strings.HasPrefix(strings.TrimSpace(line), "#") {
continue
}
if pos := strings.IndexByte(line, '='); pos > 0 {
key := strings.TrimSpace(line[:pos])
value := strings.TrimSpace(line[pos+1:])
if v := os.Getenv(key); v != "" {
continue
}
os.Setenv(key, value)
}
}
return nil
}
// buildDefaultEnvironmentVariables setups the default environment variables.
func buildDefaultEnvironmentVariables(ctx context.Context) {
// Whether enable the Go pprof.
setEnvDefault("GO_PPROF", "")
// Force shutdown timeout.
setEnvDefault("PROXY_FORCE_QUIT_TIMEOUT", "30s")
// Graceful quit timeout.
setEnvDefault("PROXY_GRACE_QUIT_TIMEOUT", "20s")
// The HTTP API server.
setEnvDefault("PROXY_HTTP_API", "11985")
// The HTTP web server.
setEnvDefault("PROXY_HTTP_SERVER", "18080")
// The RTMP media server.
setEnvDefault("PROXY_RTMP_SERVER", "11935")
// The WebRTC media server, via UDP protocol.
setEnvDefault("PROXY_WEBRTC_SERVER", "18000")
// The SRT media server, via UDP protocol.
setEnvDefault("PROXY_SRT_SERVER", "20080")
// The API server of proxy itself.
setEnvDefault("PROXY_SYSTEM_API", "12025")
// The static directory for web server.
setEnvDefault("PROXY_STATIC_FILES", "../trunk/research")
// The load balancer, use redis or memory.
setEnvDefault("PROXY_LOAD_BALANCER_TYPE", "memory")
// The redis server host.
setEnvDefault("PROXY_REDIS_HOST", "127.0.0.1")
// The redis server port.
setEnvDefault("PROXY_REDIS_PORT", "6379")
// The redis server password.
setEnvDefault("PROXY_REDIS_PASSWORD", "")
// The redis server db.
setEnvDefault("PROXY_REDIS_DB", "0")
// Whether enable the default backend server, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_ENABLED", "off")
// Default backend server IP, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_IP", "127.0.0.1")
// Default backend server port, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_RTMP", "1935")
// Default backend api port, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_API", "1985")
// Default backend udp rtc port, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_RTC", "8000")
// Default backend udp srt port, for debugging.
setEnvDefault("PROXY_DEFAULT_BACKEND_SRT", "10080")
logger.Df(ctx, "load .env as GO_PPROF=%v, "+
"PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+
"PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+
"PROXY_WEBRTC_SERVER=%v, PROXY_SRT_SERVER=%v, "+
"PROXY_SYSTEM_API=%v, PROXY_STATIC_FILES=%v, PROXY_DEFAULT_BACKEND_ENABLED=%v, "+
"PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_RTMP=%v, "+
"PROXY_DEFAULT_BACKEND_HTTP=%v, PROXY_DEFAULT_BACKEND_API=%v, "+
"PROXY_DEFAULT_BACKEND_RTC=%v, PROXY_DEFAULT_BACKEND_SRT=%v, "+
"PROXY_LOAD_BALANCER_TYPE=%v, PROXY_REDIS_HOST=%v, PROXY_REDIS_PORT=%v, "+
"PROXY_REDIS_PASSWORD=%v, PROXY_REDIS_DB=%v",
envGoPprof(),
envForceQuitTimeout(), envGraceQuitTimeout(),
envHttpAPI(), envHttpServer(), envRtmpServer(),
envWebRTCServer(), envSRTServer(),
envSystemAPI(), envStaticFiles(), envDefaultBackendEnabled(),
envDefaultBackendIP(), envDefaultBackendRTMP(),
envDefaultBackendHttp(), envDefaultBackendAPI(),
envDefaultBackendRTC(), envDefaultBackendSRT(),
envLoadBalancerType(), envRedisHost(), envRedisPort(),
envRedisPassword(), envRedisDB(),
)
}
func envStaticFiles() string {
return os.Getenv("PROXY_STATIC_FILES")
}
func envDefaultBackendSRT() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_SRT")
}
func envDefaultBackendRTC() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_RTC")
}
func envDefaultBackendAPI() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_API")
}
func envSRTServer() string {
return os.Getenv("PROXY_SRT_SERVER")
}
func envWebRTCServer() string {
return os.Getenv("PROXY_WEBRTC_SERVER")
}
func envDefaultBackendHttp() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_HTTP")
}
func envRedisDB() string {
return os.Getenv("PROXY_REDIS_DB")
}
func envRedisPassword() string {
return os.Getenv("PROXY_REDIS_PASSWORD")
}
func envRedisPort() string {
return os.Getenv("PROXY_REDIS_PORT")
}
func envRedisHost() string {
return os.Getenv("PROXY_REDIS_HOST")
}
func envLoadBalancerType() string {
return os.Getenv("PROXY_LOAD_BALANCER_TYPE")
}
func envDefaultBackendRTMP() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_RTMP")
}
func envDefaultBackendIP() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_IP")
}
func envDefaultBackendEnabled() string {
return os.Getenv("PROXY_DEFAULT_BACKEND_ENABLED")
}
func envGraceQuitTimeout() string {
return os.Getenv("PROXY_GRACE_QUIT_TIMEOUT")
}
func envForceQuitTimeout() string {
return os.Getenv("PROXY_FORCE_QUIT_TIMEOUT")
}
func envGoPprof() string {
return os.Getenv("GO_PPROF")
}
func envSystemAPI() string {
return os.Getenv("PROXY_SYSTEM_API")
}
func envRtmpServer() string {
return os.Getenv("PROXY_RTMP_SERVER")
}
func envHttpServer() string {
return os.Getenv("PROXY_HTTP_SERVER")
}
func envHttpAPI() string {
return os.Getenv("PROXY_HTTP_API")
}
// setEnvDefault set env key=value if not set.
func setEnvDefault(key, value string) {
if os.Getenv(key) == "" {
os.Setenv(key, value)
}
}

View File

@@ -1,270 +0,0 @@
// Package errors provides simple error handling primitives.
//
// The traditional error handling idiom in Go is roughly akin to
//
// if err != nil {
// return err
// }
//
// which applied recursively up the call stack results in error reports
// without context or debugging information. The errors package allows
// programmers to add context to the failure path in their code in a way
// that does not destroy the original value of the error.
//
// Adding context to an error
//
// The errors.Wrap function returns a new error that adds context to the
// original error by recording a stack trace at the point Wrap is called,
// and the supplied message. For example
//
// _, err := ioutil.ReadAll(r)
// if err != nil {
// return errors.Wrap(err, "read failed")
// }
//
// If additional control is required the errors.WithStack and errors.WithMessage
// functions destructure errors.Wrap into its component operations of annotating
// an error with a stack trace and an a message, respectively.
//
// Retrieving the cause of an error
//
// Using errors.Wrap constructs a stack of errors, adding context to the
// preceding error. Depending on the nature of the error it may be necessary
// to reverse the operation of errors.Wrap to retrieve the original error
// for inspection. Any error value which implements this interface
//
// type causer interface {
// Cause() error
// }
//
// can be inspected by errors.Cause. errors.Cause will recursively retrieve
// the topmost error which does not implement causer, which is assumed to be
// the original cause. For example:
//
// switch err := errors.Cause(err).(type) {
// case *MyError:
// // handle specifically
// default:
// // unknown error
// }
//
// causer interface is not exported by this package, but is considered a part
// of stable public API.
//
// Formatted printing of errors
//
// All error values returned from this package implement fmt.Formatter and can
// be formatted by the fmt package. The following verbs are supported
//
// %s print the error. If the error has a Cause it will be
// printed recursively
// %v see %s
// %+v extended format. Each Frame of the error's StackTrace will
// be printed in detail.
//
// Retrieving the stack trace of an error or wrapper
//
// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are
// invoked. This information can be retrieved with the following interface.
//
// type stackTracer interface {
// StackTrace() errors.StackTrace
// }
//
// Where errors.StackTrace is defined as
//
// type StackTrace []Frame
//
// The Frame type represents a call site in the stack trace. Frame supports
// the fmt.Formatter interface that can be used for printing information about
// the stack trace of this error. For example:
//
// if err, ok := err.(stackTracer); ok {
// for _, f := range err.StackTrace() {
// fmt.Printf("%+s:%d", f)
// }
// }
//
// stackTracer interface is not exported by this package, but is considered a part
// of stable public API.
//
// See the documentation for Frame.Format for more details.
// Fork from https://github.com/pkg/errors
package errors
import (
"fmt"
"io"
)
// New returns an error with the supplied message.
// New also records the stack trace at the point it was called.
func New(message string) error {
return &fundamental{
msg: message,
stack: callers(),
}
}
// Errorf formats according to a format specifier and returns the string
// as a value that satisfies error.
// Errorf also records the stack trace at the point it was called.
func Errorf(format string, args ...interface{}) error {
return &fundamental{
msg: fmt.Sprintf(format, args...),
stack: callers(),
}
}
// fundamental is an error that has a message and a stack, but no caller.
type fundamental struct {
msg string
*stack
}
func (f *fundamental) Error() string { return f.msg }
func (f *fundamental) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
io.WriteString(s, f.msg)
f.stack.Format(s, verb)
return
}
fallthrough
case 's':
io.WriteString(s, f.msg)
case 'q':
fmt.Fprintf(s, "%q", f.msg)
}
}
// WithStack annotates err with a stack trace at the point WithStack was called.
// If err is nil, WithStack returns nil.
func WithStack(err error) error {
if err == nil {
return nil
}
return &withStack{
err,
callers(),
}
}
type withStack struct {
error
*stack
}
func (w *withStack) Cause() error { return w.error }
func (w *withStack) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
fmt.Fprintf(s, "%+v", w.Cause())
w.stack.Format(s, verb)
return
}
fallthrough
case 's':
io.WriteString(s, w.Error())
case 'q':
fmt.Fprintf(s, "%q", w.Error())
}
}
// Wrap returns an error annotating err with a stack trace
// at the point Wrap is called, and the supplied message.
// If err is nil, Wrap returns nil.
func Wrap(err error, message string) error {
if err == nil {
return nil
}
err = &withMessage{
cause: err,
msg: message,
}
return &withStack{
err,
callers(),
}
}
// Wrapf returns an error annotating err with a stack trace
// at the point Wrapf is call, and the format specifier.
// If err is nil, Wrapf returns nil.
func Wrapf(err error, format string, args ...interface{}) error {
if err == nil {
return nil
}
err = &withMessage{
cause: err,
msg: fmt.Sprintf(format, args...),
}
return &withStack{
err,
callers(),
}
}
// WithMessage annotates err with a new message.
// If err is nil, WithMessage returns nil.
func WithMessage(err error, message string) error {
if err == nil {
return nil
}
return &withMessage{
cause: err,
msg: message,
}
}
type withMessage struct {
cause error
msg string
}
func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() }
func (w *withMessage) Cause() error { return w.cause }
func (w *withMessage) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
fmt.Fprintf(s, "%+v\n", w.Cause())
io.WriteString(s, w.msg)
return
}
fallthrough
case 's', 'q':
io.WriteString(s, w.Error())
}
}
// Cause returns the underlying cause of the error, if possible.
// An error value has a cause if it implements the following
// interface:
//
// type causer interface {
// Cause() error
// }
//
// If the error does not implement Cause, the original error will
// be returned. If the error is nil, nil will be returned without further
// investigation.
func Cause(err error) error {
type causer interface {
Cause() error
}
for err != nil {
cause, ok := err.(causer)
if !ok {
break
}
err = cause.Cause()
}
return err
}

View File

@@ -1,187 +0,0 @@
// Fork from https://github.com/pkg/errors
package errors
import (
"fmt"
"io"
"path"
"runtime"
"strings"
)
// Frame represents a program counter inside a stack frame.
type Frame uintptr
// pc returns the program counter for this frame;
// multiple frames may have the same PC value.
func (f Frame) pc() uintptr { return uintptr(f) - 1 }
// file returns the full path to the file that contains the
// function for this Frame's pc.
func (f Frame) file() string {
fn := runtime.FuncForPC(f.pc())
if fn == nil {
return "unknown"
}
file, _ := fn.FileLine(f.pc())
return file
}
// line returns the line number of source code of the
// function for this Frame's pc.
func (f Frame) line() int {
fn := runtime.FuncForPC(f.pc())
if fn == nil {
return 0
}
_, line := fn.FileLine(f.pc())
return line
}
// Format formats the frame according to the fmt.Formatter interface.
//
// %s source file
// %d source line
// %n function name
// %v equivalent to %s:%d
//
// Format accepts flags that alter the printing of some verbs, as follows:
//
// %+s path of source file relative to the compile time GOPATH
// %+v equivalent to %+s:%d
func (f Frame) Format(s fmt.State, verb rune) {
switch verb {
case 's':
switch {
case s.Flag('+'):
pc := f.pc()
fn := runtime.FuncForPC(pc)
if fn == nil {
io.WriteString(s, "unknown")
} else {
file, _ := fn.FileLine(pc)
fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file)
}
default:
io.WriteString(s, path.Base(f.file()))
}
case 'd':
fmt.Fprintf(s, "%d", f.line())
case 'n':
name := runtime.FuncForPC(f.pc()).Name()
io.WriteString(s, funcname(name))
case 'v':
f.Format(s, 's')
io.WriteString(s, ":")
f.Format(s, 'd')
}
}
// StackTrace is stack of Frames from innermost (newest) to outermost (oldest).
type StackTrace []Frame
// Format formats the stack of Frames according to the fmt.Formatter interface.
//
// %s lists source files for each Frame in the stack
// %v lists the source file and line number for each Frame in the stack
//
// Format accepts flags that alter the printing of some verbs, as follows:
//
// %+v Prints filename, function, and line number for each Frame in the stack.
func (st StackTrace) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
switch {
case s.Flag('+'):
for _, f := range st {
fmt.Fprintf(s, "\n%+v", f)
}
case s.Flag('#'):
fmt.Fprintf(s, "%#v", []Frame(st))
default:
fmt.Fprintf(s, "%v", []Frame(st))
}
case 's':
fmt.Fprintf(s, "%s", []Frame(st))
}
}
// stack represents a stack of program counters.
type stack []uintptr
func (s *stack) Format(st fmt.State, verb rune) {
switch verb {
case 'v':
switch {
case st.Flag('+'):
for _, pc := range *s {
f := Frame(pc)
fmt.Fprintf(st, "\n%+v", f)
}
}
}
}
func (s *stack) StackTrace() StackTrace {
f := make([]Frame, len(*s))
for i := 0; i < len(f); i++ {
f[i] = Frame((*s)[i])
}
return f
}
func callers() *stack {
const depth = 32
var pcs [depth]uintptr
n := runtime.Callers(3, pcs[:])
var st stack = pcs[0:n]
return &st
}
// funcname removes the path prefix component of a function's name reported by func.Name().
func funcname(name string) string {
i := strings.LastIndex(name, "/")
name = name[i+1:]
i = strings.Index(name, ".")
return name[i+1:]
}
func trimGOPATH(name, file string) string {
// Here we want to get the source file path relative to the compile time
// GOPATH. As of Go 1.6.x there is no direct way to know the compiled
// GOPATH at runtime, but we can infer the number of path segments in the
// GOPATH. We note that fn.Name() returns the function name qualified by
// the import path, which does not include the GOPATH. Thus we can trim
// segments from the beginning of the file path until the number of path
// separators remaining is one more than the number of path separators in
// the function name. For example, given:
//
// GOPATH /home/user
// file /home/user/src/pkg/sub/file.go
// fn.Name() pkg/sub.Type.Method
//
// We want to produce:
//
// pkg/sub/file.go
//
// From this we can easily see that fn.Name() has one less path separator
// than our desired output. We count separators from the end of the file
// path until it finds two more than in the function name and then move
// one character forward to preserve the initial path segment without a
// leading separator.
const sep = "/"
goal := strings.Count(name, sep) + 2
i := len(file)
for n := 0; n < goal; n++ {
i = strings.LastIndex(file[:i], sep)
if i == -1 {
// not enough separators found, set i so that the slice expression
// below leaves file unmodified
i = -len(sep)
break
}
}
// get back to 0 or trim the leading separator
file = file[i+len(sep):]
return file
}

View File

@@ -1,10 +0,0 @@
module srs-proxy
go 1.18
require github.com/go-redis/redis/v8 v8.11.5
require (
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
)

View File

@@ -1,15 +0,0 @@
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM=
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=

View File

@@ -1,419 +0,0 @@
// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"strconv"
"strings"
stdSync "sync"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
)
// srsHTTPStreamServer is the proxy server for SRS HTTP stream server, for HTTP-FLV, HTTP-TS,
// HLS, etc. The proxy server will figure out which SRS origin server to proxy to, then proxy
// the request to the origin server.
type srsHTTPStreamServer struct {
// The underlayer HTTP server.
server *http.Server
// The gracefully quit timeout, wait server to quit.
gracefulQuitTimeout time.Duration
// The wait group for all goroutines.
wg stdSync.WaitGroup
}
func NewSRSHTTPStreamServer(opts ...func(*srsHTTPStreamServer)) *srsHTTPStreamServer {
v := &srsHTTPStreamServer{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *srsHTTPStreamServer) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
v.wg.Wait()
return nil
}
func (v *srsHTTPStreamServer) Run(ctx context.Context) error {
// Parse address to listen.
addr := envHttpServer()
if !strings.Contains(addr, ":") {
addr = ":" + addr
}
// Create server and handler.
mux := http.NewServeMux()
v.server = &http.Server{Addr: addr, Handler: mux}
logger.Df(ctx, "HTTP Stream server listen at %v", addr)
// Shutdown the server gracefully when quiting.
go func() {
ctxParent := ctx
<-ctxParent.Done()
ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout)
defer cancel()
v.server.Shutdown(ctx)
}()
// The basic version handler, also can be used as health check API.
logger.Df(ctx, "Handle /api/v1/versions by %v", addr)
mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) {
type Response struct {
Code int `json:"code"`
PID string `json:"pid"`
Data struct {
Major int `json:"major"`
Minor int `json:"minor"`
Revision int `json:"revision"`
Version string `json:"version"`
} `json:"data"`
}
res := Response{Code: 0, PID: fmt.Sprintf("%v", os.Getpid())}
res.Data.Major = VersionMajor()
res.Data.Minor = VersionMinor()
res.Data.Revision = VersionRevision()
res.Data.Version = Version()
apiResponse(ctx, w, r, &res)
})
// The static web server, for the web pages.
var staticServer http.Handler
if staticFiles := envStaticFiles(); staticFiles != "" {
if _, err := os.Stat(staticFiles); err != nil {
return errors.Wrapf(err, "invalid static files %v", staticFiles)
}
staticServer = http.FileServer(http.Dir(staticFiles))
logger.Df(ctx, "Handle static files at %v", staticFiles)
}
// The default handler, for both static web server and streaming server.
logger.Df(ctx, "Handle / by %v", addr)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// For HLS streaming, we will proxy the request to the streaming server.
if strings.HasSuffix(r.URL.Path, ".m3u8") {
unifiedURL, fullURL := convertURLToStreamURL(r)
streamURL, err := buildStreamURL(unifiedURL)
if err != nil {
http.Error(w, fmt.Sprintf("build stream url by %v from %v", unifiedURL, fullURL), http.StatusBadRequest)
return
}
stream, _ := srsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, NewHLSPlayStream(func(s *HLSPlayStream) {
s.SRSProxyBackendHLSID = logger.GenerateContextID()
s.StreamURL, s.FullURL = streamURL, fullURL
}))
stream.Initialize(ctx).ServeHTTP(w, r)
return
}
// For HTTP streaming, we will proxy the request to the streaming server.
if strings.HasSuffix(r.URL.Path, ".flv") ||
strings.HasSuffix(r.URL.Path, ".ts") {
// If SPBHID is specified, it must be a HLS stream client.
if srsProxyBackendID := r.URL.Query().Get("spbhid"); srsProxyBackendID != "" {
if stream, err := srsLoadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil {
http.Error(w, fmt.Sprintf("load stream by spbhid %v", srsProxyBackendID), http.StatusBadRequest)
} else {
stream.Initialize(ctx).ServeHTTP(w, r)
}
return
}
// Use HTTP pseudo streaming to proxy the request.
NewHTTPFlvTsConnection(func(c *HTTPFlvTsConnection) {
c.ctx = ctx
}).ServeHTTP(w, r)
return
}
// Serve by static server.
if staticServer != nil {
staticServer.ServeHTTP(w, r)
return
}
http.NotFound(w, r)
})
// Run HTTP server.
v.wg.Add(1)
go func() {
defer v.wg.Done()
err := v.server.ListenAndServe()
if err != nil {
if ctx.Err() != context.Canceled {
// TODO: If HTTP Stream server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "HTTP Stream accept err %+v", err)
} else {
logger.Df(ctx, "HTTP Stream server done")
}
}
}()
return nil
}
// HTTPFlvTsConnection is an HTTP pseudo streaming connection, such as an HTTP-FLV or HTTP-TS
// connection. There is no state need to be sync between proxy servers.
//
// When we got an HTTP FLV or TS request, we will parse the stream URL from the HTTP request,
// then proxy to the corresponding backend server. All state is in the HTTP request, so this
// connection is stateless.
type HTTPFlvTsConnection struct {
// The context for HTTP streaming.
ctx context.Context
}
func NewHTTPFlvTsConnection(opts ...func(*HTTPFlvTsConnection)) *HTTPFlvTsConnection {
v := &HTTPFlvTsConnection{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *HTTPFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
ctx := logger.WithContext(v.ctx)
if err := v.serve(ctx, w, r); err != nil {
apiError(ctx, w, r, err)
} else {
logger.Df(ctx, "HTTP client done")
}
}
func (v *HTTPFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
// Always allow CORS for all requests.
if ok := apiCORS(ctx, w, r); ok {
return nil
}
// Build the stream URL in vhost/app/stream schema.
unifiedURL, fullURL := convertURLToStreamURL(r)
logger.Df(ctx, "Got HTTP client from %v for %v", r.RemoteAddr, fullURL)
streamURL, err := buildStreamURL(unifiedURL)
if err != nil {
return errors.Wrapf(err, "build stream url %v", unifiedURL)
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
if err = v.serveByBackend(ctx, w, r, backend); err != nil {
return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend)
}
return nil
}
func (v *HTTPFlvTsConnection) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error {
// Parse HTTP port from backend.
if len(backend.HTTP) == 0 {
return errors.Errorf("no http stream server")
}
var httpPort int
if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil {
return errors.Wrapf(err, "parse http port %v", backend.HTTP[0])
} else {
httpPort = int(iv)
}
// Connect to backend SRS server via HTTP client.
backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path)
req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil)
if err != nil {
return errors.Wrapf(err, "create request to %v", backendURL)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return errors.Wrapf(err, "do request to %v", backendURL)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status)
}
// Copy all headers from backend to client.
w.WriteHeader(resp.StatusCode)
for k, v := range resp.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
logger.Df(ctx, "HTTP start streaming")
// Proxy the stream from backend to client.
if _, err := io.Copy(w, resp.Body); err != nil {
return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL)
}
return nil
}
// HLSPlayStream is an HLS stream proxy, which represents the stream level object. This means multiple HLS
// clients will share this object, and they do not use the same ctx among proxy servers.
//
// Unlike the HTTP FLV or TS connection, HLS client may request the m3u8 or ts via different HTTP connections.
// Especially for requesting ts, we need to identify the stream URl or backend server for it. So we create
// the spbhid which can be seen as the hash of stream URL or backend server. The spbhid enable us to convert
// to the stream URL and then query the backend server to serve it.
type HLSPlayStream struct {
// The context for HLS streaming.
ctx context.Context
// The spbhid, used to identify the backend server.
SRSProxyBackendHLSID string `json:"spbhid"`
// The stream URL in vhost/app/stream schema.
StreamURL string `json:"stream_url"`
// The full request URL for HLS streaming
FullURL string `json:"full_url"`
}
func NewHLSPlayStream(opts ...func(*HLSPlayStream)) *HLSPlayStream {
v := &HLSPlayStream{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *HLSPlayStream) Initialize(ctx context.Context) *HLSPlayStream {
if v.ctx == nil {
v.ctx = logger.WithContext(ctx)
}
return v
}
func (v *HLSPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if err := v.serve(v.ctx, w, r); err != nil {
apiError(v.ctx, w, r, err)
} else {
logger.Df(v.ctx, "HLS client %v for %v with %v done",
v.SRSProxyBackendHLSID, v.StreamURL, r.URL.Path)
}
}
func (v *HLSPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
ctx, streamURL, fullURL := v.ctx, v.StreamURL, v.FullURL
// Always allow CORS for all requests.
if ok := apiCORS(ctx, w, r); ok {
return nil
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
if err = v.serveByBackend(ctx, w, r, backend); err != nil {
return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend)
}
return nil
}
func (v *HLSPlayStream) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error {
// Parse HTTP port from backend.
if len(backend.HTTP) == 0 {
return errors.Errorf("no rtmp server")
}
var httpPort int
if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil {
return errors.Wrapf(err, "parse http port %v", backend.HTTP[0])
} else {
httpPort = int(iv)
}
// Connect to backend SRS server via HTTP client.
backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path)
if r.URL.RawQuery != "" {
backendURL += "?" + r.URL.RawQuery
}
req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil)
if err != nil {
return errors.Wrapf(err, "create request to %v", backendURL)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return errors.Errorf("do request to %v EOF", backendURL)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status)
}
// Copy all headers from backend to client.
w.WriteHeader(resp.StatusCode)
for k, v := range resp.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
// For TS file, directly copy it.
if !strings.HasSuffix(r.URL.Path, ".m3u8") {
if _, err := io.Copy(w, resp.Body); err != nil {
return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL)
}
return nil
}
// Read all content of m3u8, append the stream ID to ts URL. Note that we only append stream ID to ts
// URL, to identify the stream to specified backend server. The spbhid is the SRS Proxy Backend HLS ID.
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return errors.Wrapf(err, "read stream from %v", backendURL)
}
m3u8 := string(b)
if strings.Contains(m3u8, ".ts?") {
m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbhid=%v&&", v.SRSProxyBackendHLSID))
} else {
m3u8 = strings.ReplaceAll(m3u8, ".ts", fmt.Sprintf(".ts?spbhid=%v", v.SRSProxyBackendHLSID))
}
if _, err := io.Copy(w, strings.NewReader(m3u8)); err != nil {
return errors.Wrapf(err, "proxy m3u8 client to %v", backendURL)
}
return nil
}

View File

@@ -1,43 +0,0 @@
// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package logger
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
)
type key string
var cidKey key = "cid.proxy.ossrs.org"
// generateContextID generates a random context id in string.
func GenerateContextID() string {
randomBytes := make([]byte, 32)
_, _ = rand.Read(randomBytes)
hash := sha256.Sum256(randomBytes)
hashString := hex.EncodeToString(hash[:])
cid := hashString[:7]
return cid
}
// WithContext creates a new context with cid, which will be used for log.
func WithContext(ctx context.Context) context.Context {
return WithContextID(ctx, GenerateContextID())
}
// WithContextID creates a new context with cid, which will be used for log.
func WithContextID(ctx context.Context, cid string) context.Context {
return context.WithValue(ctx, cidKey, cid)
}
// ContextID returns the cid in context, or empty string if not set.
func ContextID(ctx context.Context) string {
if cid, ok := ctx.Value(cidKey).(string); ok {
return cid
}
return ""
}

View File

@@ -1,87 +0,0 @@
// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package logger
import (
"context"
"io/ioutil"
stdLog "log"
"os"
)
type logger interface {
Printf(ctx context.Context, format string, v ...any)
}
type loggerPlus struct {
logger *stdLog.Logger
level string
}
func newLoggerPlus(opts ...func(*loggerPlus)) *loggerPlus {
v := &loggerPlus{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *loggerPlus) Printf(ctx context.Context, f string, a ...interface{}) {
format, args := f, a
if cid := ContextID(ctx); cid != "" {
format, args = "[%v][%v][%v] "+format, append([]interface{}{v.level, os.Getpid(), cid}, a...)
}
v.logger.Printf(format, args...)
}
var verboseLogger logger
func Vf(ctx context.Context, format string, a ...interface{}) {
verboseLogger.Printf(ctx, format, a...)
}
var debugLogger logger
func Df(ctx context.Context, format string, a ...interface{}) {
debugLogger.Printf(ctx, format, a...)
}
var warnLogger logger
func Wf(ctx context.Context, format string, a ...interface{}) {
warnLogger.Printf(ctx, format, a...)
}
var errorLogger logger
func Ef(ctx context.Context, format string, a ...interface{}) {
errorLogger.Printf(ctx, format, a...)
}
const (
logVerboseLabel = "verb"
logDebugLabel = "debug"
logWarnLabel = "warn"
logErrorLabel = "error"
)
func init() {
verboseLogger = newLoggerPlus(func(logger *loggerPlus) {
logger.logger = stdLog.New(ioutil.Discard, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
logger.level = logVerboseLabel
})
debugLogger = newLoggerPlus(func(logger *loggerPlus) {
logger.logger = stdLog.New(os.Stdout, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
logger.level = logDebugLabel
})
warnLogger = newLoggerPlus(func(logger *loggerPlus) {
logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
logger.level = logWarnLabel
})
errorLogger = newLoggerPlus(func(logger *loggerPlus) {
logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds)
logger.level = logErrorLabel
})
}

View File

@@ -1,121 +0,0 @@
// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"os"
"srs-proxy/errors"
"srs-proxy/logger"
)
func main() {
ctx := logger.WithContext(context.Background())
logger.Df(ctx, "%v/%v started", Signature(), Version())
// Install signals.
ctx, cancel := context.WithCancel(ctx)
installSignals(ctx, cancel)
// Start the main loop, ignore the user cancel error.
err := doMain(ctx)
if err != nil && ctx.Err() != context.Canceled {
logger.Ef(ctx, "main: %+v", err)
os.Exit(-1)
}
logger.Df(ctx, "%v done", Signature())
}
func doMain(ctx context.Context) error {
// Setup the environment variables.
if err := loadEnvFile(ctx); err != nil {
return errors.Wrapf(err, "load env")
}
buildDefaultEnvironmentVariables(ctx)
// When cancelled, the program is forced to exit due to a timeout. Normally, this doesn't occur
// because the main thread exits after the context is cancelled. However, sometimes the main thread
// may be blocked for some reason, so a forced exit is necessary to ensure the program terminates.
if err := installForceQuit(ctx); err != nil {
return errors.Wrapf(err, "install force quit")
}
// Start the Go pprof if enabled.
handleGoPprof(ctx)
// Initialize SRS load balancers.
switch lbType := envLoadBalancerType(); lbType {
case "memory":
srsLoadBalancer = NewMemoryLoadBalancer()
case "redis":
srsLoadBalancer = NewRedisLoadBalancer()
default:
return errors.Errorf("invalid load balancer %v", lbType)
}
if err := srsLoadBalancer.Initialize(ctx); err != nil {
return errors.Wrapf(err, "initialize srs load balancer")
}
// Parse the gracefully quit timeout.
gracefulQuitTimeout, err := parseGracefullyQuitTimeout()
if err != nil {
return errors.Wrapf(err, "parse gracefully quit timeout")
}
// Start the RTMP server.
srsRTMPServer := NewSRSRTMPServer()
defer srsRTMPServer.Close()
if err := srsRTMPServer.Run(ctx); err != nil {
return errors.Wrapf(err, "rtmp server")
}
// Start the WebRTC server.
srsWebRTCServer := NewSRSWebRTCServer()
defer srsWebRTCServer.Close()
if err := srsWebRTCServer.Run(ctx); err != nil {
return errors.Wrapf(err, "rtc server")
}
// Start the HTTP API server.
srsHTTPAPIServer := NewSRSHTTPAPIServer(func(server *srsHTTPAPIServer) {
server.gracefulQuitTimeout, server.rtc = gracefulQuitTimeout, srsWebRTCServer
})
defer srsHTTPAPIServer.Close()
if err := srsHTTPAPIServer.Run(ctx); err != nil {
return errors.Wrapf(err, "http api server")
}
// Start the SRT server.
srsSRTServer := NewSRSSRTServer()
defer srsSRTServer.Close()
if err := srsSRTServer.Run(ctx); err != nil {
return errors.Wrapf(err, "srt server")
}
// Start the System API server.
systemAPI := NewSystemAPI(func(server *systemAPI) {
server.gracefulQuitTimeout = gracefulQuitTimeout
})
defer systemAPI.Close()
if err := systemAPI.Run(ctx); err != nil {
return errors.Wrapf(err, "system api server")
}
// Start the HTTP web server.
srsHTTPStreamServer := NewSRSHTTPStreamServer(func(server *srsHTTPStreamServer) {
server.gracefulQuitTimeout = gracefulQuitTimeout
})
defer srsHTTPStreamServer.Close()
if err := srsHTTPStreamServer.Run(ctx); err != nil {
return errors.Wrapf(err, "http server")
}
// Wait for the main loop to quit.
<-ctx.Done()
return nil
}

View File

@@ -1,515 +0,0 @@
// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"encoding/binary"
"fmt"
"io/ioutil"
"net"
"net/http"
"strconv"
"strings"
stdSync "sync"
"srs-proxy/errors"
"srs-proxy/logger"
"srs-proxy/sync"
)
// srsWebRTCServer is the proxy for SRS WebRTC server via WHIP or WHEP protocol. It will figure out
// which backend server to proxy to. It will also replace the UDP port to the proxy server's in the
// SDP answer.
type srsWebRTCServer struct {
// The UDP listener for WebRTC server.
listener *net.UDPConn
// Fast cache for the username to identify the connection.
// The key is username, the value is the UDP address.
usernames sync.Map[string, *RTCConnection]
// Fast cache for the udp address to identify the connection.
// The key is UDP address, the value is the username.
// TODO: Support fast earch by uint64 address.
addresses sync.Map[string, *RTCConnection]
// The wait group for server.
wg stdSync.WaitGroup
}
func NewSRSWebRTCServer(opts ...func(*srsWebRTCServer)) *srsWebRTCServer {
v := &srsWebRTCServer{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *srsWebRTCServer) Close() error {
if v.listener != nil {
_ = v.listener.Close()
}
v.wg.Wait()
return nil
}
func (v *srsWebRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
defer r.Body.Close()
ctx = logger.WithContext(ctx)
// Always allow CORS for all requests.
if ok := apiCORS(ctx, w, r); ok {
return nil
}
// Read remote SDP offer from body.
remoteSDPOffer, err := ioutil.ReadAll(r.Body)
if err != nil {
return errors.Wrapf(err, "read remote sdp offer")
}
// Build the stream URL in vhost/app/stream schema.
unifiedURL, fullURL := convertURLToStreamURL(r)
logger.Df(ctx, "Got WebRTC WHIP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL)
streamURL, err := buildStreamURL(unifiedURL)
if err != nil {
return errors.Wrapf(err, "build stream url %v", unifiedURL)
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil {
return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend)
}
return nil
}
func (v *srsWebRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
defer r.Body.Close()
ctx = logger.WithContext(ctx)
// Always allow CORS for all requests.
if ok := apiCORS(ctx, w, r); ok {
return nil
}
// Read remote SDP offer from body.
remoteSDPOffer, err := ioutil.ReadAll(r.Body)
if err != nil {
return errors.Wrapf(err, "read remote sdp offer")
}
// Build the stream URL in vhost/app/stream schema.
unifiedURL, fullURL := convertURLToStreamURL(r)
logger.Df(ctx, "Got WebRTC WHEP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL)
streamURL, err := buildStreamURL(unifiedURL)
if err != nil {
return errors.Wrapf(err, "build stream url %v", unifiedURL)
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil {
return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend)
}
return nil
}
func (v *srsWebRTCServer) proxyApiToBackend(
ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer,
remoteSDPOffer string, streamURL string,
) error {
// Parse HTTP port from backend.
if len(backend.API) == 0 {
return errors.Errorf("no http api server")
}
var apiPort int
if iv, err := strconv.ParseInt(backend.API[0], 10, 64); err != nil {
return errors.Wrapf(err, "parse http port %v", backend.API[0])
} else {
apiPort = int(iv)
}
// Connect to backend SRS server via HTTP client.
backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, apiPort, r.URL.Path)
if r.URL.RawQuery != "" {
backendURL += "?" + r.URL.RawQuery
}
req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, strings.NewReader(remoteSDPOffer))
if err != nil {
return errors.Wrapf(err, "create request to %v", backendURL)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return errors.Errorf("do request to %v EOF", backendURL)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return errors.Errorf("proxy api to %v failed, status=%v", backendURL, resp.Status)
}
// Copy all headers from backend to client.
w.WriteHeader(resp.StatusCode)
for k, v := range resp.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
// Parse the local SDP answer from backend.
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return errors.Wrapf(err, "read stream from %v", backendURL)
}
// Replace the WebRTC UDP port in answer.
localSDPAnswer := string(b)
for _, endpoint := range backend.RTC {
_, _, port, err := parseListenEndpoint(endpoint)
if err != nil {
return errors.Wrapf(err, "parse endpoint %v", endpoint)
}
from := fmt.Sprintf(" %v typ host", port)
to := fmt.Sprintf(" %v typ host", envWebRTCServer())
localSDPAnswer = strings.Replace(localSDPAnswer, from, to, -1)
}
// Fetch the ice-ufrag and ice-pwd from local SDP answer.
remoteICEUfrag, remoteICEPwd, err := parseIceUfragPwd(remoteSDPOffer)
if err != nil {
return errors.Wrapf(err, "parse remote sdp offer")
}
localICEUfrag, localICEPwd, err := parseIceUfragPwd(localSDPAnswer)
if err != nil {
return errors.Wrapf(err, "parse local sdp answer")
}
// Save the new WebRTC connection to LB.
icePair := &RTCICEPair{
RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd,
LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd,
}
if err := srsLoadBalancer.StoreWebRTC(ctx, streamURL, NewRTCConnection(func(c *RTCConnection) {
c.StreamURL, c.Ufrag = streamURL, icePair.Ufrag()
c.Initialize(ctx, v.listener)
// Cache the connection for fast search by username.
v.usernames.Store(c.Ufrag, c)
})); err != nil {
return errors.Wrapf(err, "load or store webrtc %v", streamURL)
}
// Response client with local answer.
if _, err = w.Write([]byte(localSDPAnswer)); err != nil {
return errors.Wrapf(err, "write local sdp answer %v", localSDPAnswer)
}
logger.Df(ctx, "Create WebRTC connection with local answer %vB with ice-ufrag=%v, ice-pwd=%vB",
len(localSDPAnswer), localICEUfrag, len(localICEPwd))
return nil
}
func (v *srsWebRTCServer) Run(ctx context.Context) error {
// Parse address to listen.
endpoint := envWebRTCServer()
if !strings.Contains(endpoint, ":") {
endpoint = fmt.Sprintf(":%v", endpoint)
}
saddr, err := net.ResolveUDPAddr("udp", endpoint)
if err != nil {
return errors.Wrapf(err, "resolve udp addr %v", endpoint)
}
listener, err := net.ListenUDP("udp", saddr)
if err != nil {
return errors.Wrapf(err, "listen udp %v", saddr)
}
v.listener = listener
logger.Df(ctx, "WebRTC server listen at %v", saddr)
// Consume all messages from UDP media transport.
v.wg.Add(1)
go func() {
defer v.wg.Done()
for ctx.Err() == nil {
buf := make([]byte, 4096)
n, caddr, err := listener.ReadFromUDP(buf)
if err != nil {
// TODO: If WebRTC server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "read from udp failed, err=%+v", err)
continue
}
if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil {
logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err)
}
}
}()
return nil
}
func (v *srsWebRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error {
var connection *RTCConnection
// If STUN binding request, parse the ufrag and identify the connection.
if err := func() error {
if rtcIsRTPOrRTCP(data) || !rtcIsSTUN(data) {
return nil
}
var pkt RTCStunPacket
if err := pkt.UnmarshalBinary(data); err != nil {
return errors.Wrapf(err, "unmarshal stun packet")
}
// Search the connection in fast cache.
if s, ok := v.usernames.Load(pkt.Username); ok {
connection = s
return nil
}
// Load connection by username.
if s, err := srsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil {
return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username)
} else {
connection = s.Initialize(ctx, v.listener)
logger.Df(ctx, "Create WebRTC connection by ufrag=%v, stream=%v", pkt.Username, connection.StreamURL)
}
// Cache connection for fast search.
if connection != nil {
v.usernames.Store(pkt.Username, connection)
}
return nil
}(); err != nil {
return err
}
// Search the connection by addr.
if s, ok := v.addresses.Load(addr.String()); ok {
connection = s
} else if connection != nil {
// Cache the address for fast search.
v.addresses.Store(addr.String(), connection)
}
// If connection is not found, ignore the packet.
if connection == nil {
// TODO: Should logging the dropped packet, only logging the first one for each address.
return nil
}
// Proxy the packet to backend.
if err := connection.HandlePacket(addr, data); err != nil {
return errors.Wrapf(err, "proxy %vB for %v", len(data), connection.StreamURL)
}
return nil
}
// RTCConnection is a WebRTC connection proxy, for both WHIP and WHEP. It represents a WebRTC
// connection, identify by the ufrag in sdp offer/answer and ICE binding request.
//
// It's not like RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is
// in the client request. The RTCConnection is stateful, and need to sync the ufrag between
// proxy servers.
//
// The media transport is UDP, which is also a special thing for WebRTC. So if the client switch
// to another UDP address, it may connect to another WebRTC proxy, then we should discover the
// RTCConnection by the ufrag from the ICE binding request.
type RTCConnection struct {
// The stream context for WebRTC streaming.
ctx context.Context
// The stream URL in vhost/app/stream schema.
StreamURL string `json:"stream_url"`
// The ufrag for this WebRTC connection.
Ufrag string `json:"ufrag"`
// The UDP connection proxy to backend.
backendUDP *net.UDPConn
// The client UDP address. Note that it may change.
clientUDP *net.UDPAddr
// The listener UDP connection, used to send messages to client.
listenerUDP *net.UDPConn
}
func NewRTCConnection(opts ...func(*RTCConnection)) *RTCConnection {
v := &RTCConnection{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *RTCConnection) Initialize(ctx context.Context, listener *net.UDPConn) *RTCConnection {
if v.ctx == nil {
v.ctx = logger.WithContext(ctx)
}
if listener != nil {
v.listenerUDP = listener
}
return v
}
func (v *RTCConnection) HandlePacket(addr *net.UDPAddr, data []byte) error {
ctx := v.ctx
// Update the current UDP address.
v.clientUDP = addr
// Start the UDP proxy to backend.
if err := v.connectBackend(ctx); err != nil {
return errors.Wrapf(err, "connect backend for %v", v.StreamURL)
}
// Proxy client message to backend.
if v.backendUDP == nil {
return nil
}
// Proxy all messages from backend to client.
go func() {
for ctx.Err() == nil {
buf := make([]byte, 4096)
n, _, err := v.backendUDP.ReadFromUDP(buf)
if err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Wf(ctx, "read from backend failed, err=%v", err)
break
}
if _, err = v.listenerUDP.WriteToUDP(buf[:n], v.clientUDP); err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Wf(ctx, "write to client failed, err=%v", err)
break
}
}
}()
if _, err := v.backendUDP.Write(data); err != nil {
return errors.Wrapf(err, "write to backend %v", v.StreamURL)
}
return nil
}
func (v *RTCConnection) connectBackend(ctx context.Context) error {
if v.backendUDP != nil {
return nil
}
// Pick a backend SRS server to proxy the RTC stream.
backend, err := srsLoadBalancer.Pick(ctx, v.StreamURL)
if err != nil {
return errors.Wrapf(err, "pick backend")
}
// Parse UDP port from backend.
if len(backend.RTC) == 0 {
return errors.Errorf("no udp server")
}
_, _, udpPort, err := parseListenEndpoint(backend.RTC[0])
if err != nil {
return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.RTC[0], backend, v.StreamURL)
}
// Connect to backend SRS server via UDP client.
// TODO: FIXME: Support close the connection when timeout or DTLS alert.
backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)}
if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil {
return errors.Wrapf(err, "dial udp to %v", backendAddr)
} else {
v.backendUDP = backendUDP
}
return nil
}
type RTCICEPair struct {
// The remote ufrag, used for ICE username and session id.
RemoteICEUfrag string `json:"remote_ufrag"`
// The remote pwd, used for ICE password.
RemoteICEPwd string `json:"remote_pwd"`
// The local ufrag, used for ICE username and session id.
LocalICEUfrag string `json:"local_ufrag"`
// The local pwd, used for ICE password.
LocalICEPwd string `json:"local_pwd"`
}
// Generate the ICE ufrag for the WebRTC streaming, format is remote-ufrag:local-ufrag.
func (v *RTCICEPair) Ufrag() string {
return fmt.Sprintf("%v:%v", v.LocalICEUfrag, v.RemoteICEUfrag)
}
type RTCStunPacket struct {
// The stun message type.
MessageType uint16
// The stun username, or ufrag.
Username string
}
func (v *RTCStunPacket) UnmarshalBinary(data []byte) error {
if len(data) < 20 {
return errors.Errorf("stun packet too short %v", len(data))
}
p := data
v.MessageType = binary.BigEndian.Uint16(p)
messageLen := binary.BigEndian.Uint16(p[2:])
//magicCookie := p[:8]
//transactionID := p[:20]
p = p[20:]
if len(p) != int(messageLen) {
return errors.Errorf("stun packet length invalid %v != %v", len(data), messageLen)
}
for len(p) > 0 {
typ := binary.BigEndian.Uint16(p)
length := binary.BigEndian.Uint16(p[2:])
p = p[4:]
if len(p) < int(length) {
return errors.Errorf("stun attribute length invalid %v < %v", len(p), length)
}
value := p[:length]
p = p[length:]
if length%4 != 0 {
p = p[4-length%4:]
}
switch typ {
case 0x0006:
v.Username = string(value)
}
}
return nil
}

View File

@@ -1,655 +0,0 @@
// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"fmt"
"math/rand"
"net"
"strconv"
"strings"
"sync"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
"srs-proxy/rtmp"
)
// srsRTMPServer is the proxy for SRS RTMP server, to proxy the RTMP stream to backend SRS
// server. It will figure out the backend server to proxy to. Unlike the edge server, it will
// not cache the stream, but just proxy the stream to backend.
type srsRTMPServer struct {
// The TCP listener for RTMP server.
listener *net.TCPListener
// The random number generator.
rd *rand.Rand
// The wait group for all goroutines.
wg sync.WaitGroup
}
func NewSRSRTMPServer(opts ...func(*srsRTMPServer)) *srsRTMPServer {
v := &srsRTMPServer{
rd: rand.New(rand.NewSource(time.Now().UnixNano())),
}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *srsRTMPServer) Close() error {
if v.listener != nil {
v.listener.Close()
}
v.wg.Wait()
return nil
}
func (v *srsRTMPServer) Run(ctx context.Context) error {
endpoint := envRtmpServer()
if !strings.Contains(endpoint, ":") {
endpoint = ":" + endpoint
}
addr, err := net.ResolveTCPAddr("tcp", endpoint)
if err != nil {
return errors.Wrapf(err, "resolve rtmp addr %v", endpoint)
}
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
return errors.Wrapf(err, "listen rtmp addr %v", addr)
}
v.listener = listener
logger.Df(ctx, "RTMP server listen at %v", addr)
v.wg.Add(1)
go func() {
defer v.wg.Done()
for {
conn, err := v.listener.AcceptTCP()
if err != nil {
if ctx.Err() != context.Canceled {
// TODO: If RTMP server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "RTMP server accept err %+v", err)
} else {
logger.Df(ctx, "RTMP server done")
}
return
}
v.wg.Add(1)
go func(ctx context.Context, conn *net.TCPConn) {
defer v.wg.Done()
defer conn.Close()
handleErr := func(err error) {
if isPeerClosedError(err) {
logger.Df(ctx, "RTMP peer is closed")
} else {
logger.Wf(ctx, "RTMP serve err %+v", err)
}
}
rc := NewRTMPConnection(func(client *RTMPConnection) {
client.rd = v.rd
})
if err := rc.serve(ctx, conn); err != nil {
handleErr(err)
} else {
logger.Df(ctx, "RTMP client done")
}
}(logger.WithContext(ctx), conn)
}
}()
return nil
}
// RTMPConnection is an RTMP streaming connection. There is no state need to be sync between
// proxy servers.
//
// When we got an RTMP request, we will parse the stream URL from the RTMP publish or play request,
// then proxy to the corresponding backend server. All state is in the RTMP request, so this
// connection is stateless.
type RTMPConnection struct {
// The random number generator.
rd *rand.Rand
}
func NewRTMPConnection(opts ...func(*RTMPConnection)) *RTMPConnection {
v := &RTMPConnection{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
logger.Df(ctx, "Got RTMP client from %v", conn.RemoteAddr())
// If any goroutine quit, cancel another one.
parentCtx := ctx
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var backend *RTMPClientToBackend
if true {
go func() {
<-ctx.Done()
conn.Close()
if backend != nil {
backend.Close()
}
}()
}
// Simple handshake with client.
hs := rtmp.NewHandshake(v.rd)
if _, err := hs.ReadC0S0(conn); err != nil {
return errors.Wrapf(err, "read c0")
}
if _, err := hs.ReadC1S1(conn); err != nil {
return errors.Wrapf(err, "read c1")
}
if err := hs.WriteC0S0(conn); err != nil {
return errors.Wrapf(err, "write s1")
}
if err := hs.WriteC1S1(conn); err != nil {
return errors.Wrapf(err, "write s1")
}
if err := hs.WriteC2S2(conn, hs.C1S1()); err != nil {
return errors.Wrapf(err, "write s2")
}
if _, err := hs.ReadC2S2(conn); err != nil {
return errors.Wrapf(err, "read c2")
}
client := rtmp.NewProtocol(conn)
logger.Df(ctx, "RTMP simple handshake done")
// Expect RTMP connect command with tcUrl.
var connectReq *rtmp.ConnectAppPacket
if _, err := rtmp.ExpectPacket(ctx, client, &connectReq); err != nil {
return errors.Wrapf(err, "expect connect req")
}
if true {
ack := rtmp.NewWindowAcknowledgementSize()
ack.AckSize = 2500000
if err := client.WritePacket(ctx, ack, 0); err != nil {
return errors.Wrapf(err, "write set ack size")
}
}
if true {
chunk := rtmp.NewSetChunkSize()
chunk.ChunkSize = 128
if err := client.WritePacket(ctx, chunk, 0); err != nil {
return errors.Wrapf(err, "write set chunk size")
}
}
connectRes := rtmp.NewConnectAppResPacket(connectReq.TransactionID)
connectRes.CommandObject.Set("fmsVer", rtmp.NewAmf0String("FMS/3,5,3,888"))
connectRes.CommandObject.Set("capabilities", rtmp.NewAmf0Number(127))
connectRes.CommandObject.Set("mode", rtmp.NewAmf0Number(1))
connectRes.Args.Set("level", rtmp.NewAmf0String("status"))
connectRes.Args.Set("code", rtmp.NewAmf0String("NetConnection.Connect.Success"))
connectRes.Args.Set("description", rtmp.NewAmf0String("Connection succeeded"))
connectRes.Args.Set("objectEncoding", rtmp.NewAmf0Number(0))
connectResData := rtmp.NewAmf0EcmaArray()
connectResData.Set("version", rtmp.NewAmf0String("3,5,3,888"))
connectResData.Set("srs_version", rtmp.NewAmf0String(Version()))
connectResData.Set("srs_id", rtmp.NewAmf0String(logger.ContextID(ctx)))
connectRes.Args.Set("data", connectResData)
if err := client.WritePacket(ctx, connectRes, 0); err != nil {
return errors.Wrapf(err, "write connect res")
}
tcUrl := connectReq.TcUrl()
logger.Df(ctx, "RTMP connect app %v", tcUrl)
// Expect RTMP command to identify the client, a publisher or viewer.
var currentStreamID, nextStreamID int
var streamName string
var clientType RTMPClientType
for clientType == "" {
var identifyReq rtmp.Packet
if _, err := rtmp.ExpectPacket(ctx, client, &identifyReq); err != nil {
return errors.Wrapf(err, "expect identify req")
}
var response rtmp.Packet
switch pkt := identifyReq.(type) {
case *rtmp.CallPacket:
if pkt.CommandName == "createStream" {
identifyRes := rtmp.NewCreateStreamResPacket(pkt.TransactionID)
response = identifyRes
nextStreamID = 1
identifyRes.StreamID = *rtmp.NewAmf0Number(float64(nextStreamID))
} else if pkt.CommandName == "getStreamLength" {
// Ignore and do not reply these packets.
} else {
// For releaseStream, FCPublish, etc.
identifyRes := rtmp.NewCallPacket()
response = identifyRes
identifyRes.TransactionID = pkt.TransactionID
identifyRes.CommandName = "_result"
identifyRes.CommandObject = rtmp.NewAmf0Null()
identifyRes.Args = rtmp.NewAmf0Undefined()
}
case *rtmp.PublishPacket:
streamName = string(pkt.StreamName)
clientType = RTMPClientTypePublisher
identifyRes := rtmp.NewCallPacket()
response = identifyRes
identifyRes.CommandName = "onFCPublish"
identifyRes.CommandObject = rtmp.NewAmf0Null()
data := rtmp.NewAmf0Object()
data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start"))
data.Set("description", rtmp.NewAmf0String("Started publishing stream."))
identifyRes.Args = data
case *rtmp.PlayPacket:
streamName = string(pkt.StreamName)
clientType = RTMPClientTypeViewer
identifyRes := rtmp.NewCallPacket()
response = identifyRes
identifyRes.CommandName = "onStatus"
identifyRes.CommandObject = rtmp.NewAmf0Null()
data := rtmp.NewAmf0Object()
data.Set("level", rtmp.NewAmf0String("status"))
data.Set("code", rtmp.NewAmf0String("NetStream.Play.Reset"))
data.Set("description", rtmp.NewAmf0String("Playing and resetting stream."))
data.Set("details", rtmp.NewAmf0String("stream"))
data.Set("clientid", rtmp.NewAmf0String("ASAICiss"))
identifyRes.Args = data
}
if response != nil {
if err := client.WritePacket(ctx, response, currentStreamID); err != nil {
return errors.Wrapf(err, "write identify res for req=%v, stream=%v",
identifyReq, currentStreamID)
}
}
// Update the stream ID for next request.
currentStreamID = nextStreamID
}
logger.Df(ctx, "RTMP identify tcUrl=%v, stream=%v, id=%v, type=%v",
tcUrl, streamName, currentStreamID, clientType)
// Find a backend SRS server to proxy the RTMP stream.
backend = NewRTMPClientToBackend(func(client *RTMPClientToBackend) {
client.rd, client.typ = v.rd, clientType
})
defer backend.Close()
if err := backend.Connect(ctx, tcUrl, streamName); err != nil {
return errors.Wrapf(err, "connect backend, tcUrl=%v, stream=%v", tcUrl, streamName)
}
// Start the streaming.
if clientType == RTMPClientTypePublisher {
identifyRes := rtmp.NewCallPacket()
identifyRes.CommandName = "onStatus"
identifyRes.CommandObject = rtmp.NewAmf0Null()
data := rtmp.NewAmf0Object()
data.Set("level", rtmp.NewAmf0String("status"))
data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start"))
data.Set("description", rtmp.NewAmf0String("Started publishing stream."))
data.Set("clientid", rtmp.NewAmf0String("ASAICiss"))
identifyRes.Args = data
if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil {
return errors.Wrapf(err, "start publish")
}
} else if clientType == RTMPClientTypeViewer {
identifyRes := rtmp.NewCallPacket()
identifyRes.CommandName = "onStatus"
identifyRes.CommandObject = rtmp.NewAmf0Null()
data := rtmp.NewAmf0Object()
data.Set("level", rtmp.NewAmf0String("status"))
data.Set("code", rtmp.NewAmf0String("NetStream.Play.Start"))
data.Set("description", rtmp.NewAmf0String("Started playing stream."))
data.Set("details", rtmp.NewAmf0String("stream"))
data.Set("clientid", rtmp.NewAmf0String("ASAICiss"))
identifyRes.Args = data
if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil {
return errors.Wrapf(err, "start play")
}
}
logger.Df(ctx, "RTMP start streaming")
// For all proxy goroutines.
var wg sync.WaitGroup
defer wg.Wait()
// Proxy all message from backend to client.
wg.Add(1)
var r0 error
go func() {
defer wg.Done()
defer cancel()
r0 = func() error {
for {
m, err := backend.client.ReadMessage(ctx)
if err != nil {
return errors.Wrapf(err, "read message")
}
//logger.Df(ctx, "client<- %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload))
// TODO: Update the stream ID if not the same.
if err := client.WriteMessage(ctx, m); err != nil {
return errors.Wrapf(err, "write message")
}
}
}()
}()
// Proxy all messages from client to backend.
wg.Add(1)
var r1 error
go func() {
defer wg.Done()
defer cancel()
r1 = func() error {
for {
m, err := client.ReadMessage(ctx)
if err != nil {
return errors.Wrapf(err, "read message")
}
//logger.Df(ctx, "client-> %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload))
// TODO: Update the stream ID if not the same.
if err := backend.client.WriteMessage(ctx, m); err != nil {
return errors.Wrapf(err, "write message")
}
}
}()
}()
// Wait until all goroutine quit.
wg.Wait()
// Reset the error if caused by another goroutine.
if r0 != nil {
return errors.Wrapf(r0, "proxy backend->client")
}
if r1 != nil {
return errors.Wrapf(r1, "proxy client->backend")
}
return parentCtx.Err()
}
type RTMPClientType string
const (
RTMPClientTypePublisher RTMPClientType = "publisher"
RTMPClientTypeViewer RTMPClientType = "viewer"
)
// RTMPClientToBackend is a RTMP client to proxy the RTMP stream to backend.
type RTMPClientToBackend struct {
// The random number generator.
rd *rand.Rand
// The underlayer tcp client.
tcpConn *net.TCPConn
// The RTMP protocol client.
client *rtmp.Protocol
// The stream type.
typ RTMPClientType
}
func NewRTMPClientToBackend(opts ...func(*RTMPClientToBackend)) *RTMPClientToBackend {
v := &RTMPClientToBackend{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *RTMPClientToBackend) Close() error {
if v.tcpConn != nil {
v.tcpConn.Close()
}
return nil
}
func (v *RTMPClientToBackend) Connect(ctx context.Context, tcUrl, streamName string) error {
// Build the stream URL in vhost/app/stream schema.
streamURL, err := buildStreamURL(fmt.Sprintf("%v/%v", tcUrl, streamName))
if err != nil {
return errors.Wrapf(err, "build stream url %v/%v", tcUrl, streamName)
}
// Pick a backend SRS server to proxy the RTMP stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
// Parse RTMP port from backend.
if len(backend.RTMP) == 0 {
return errors.Errorf("no rtmp server %+v for %v", backend, streamURL)
}
var rtmpPort int
if iv, err := strconv.ParseInt(backend.RTMP[0], 10, 64); err != nil {
return errors.Wrapf(err, "parse backend %+v rtmp port %v", backend, backend.RTMP[0])
} else {
rtmpPort = int(iv)
}
// Connect to backend SRS server via TCP client.
addr := &net.TCPAddr{IP: net.ParseIP(backend.IP), Port: rtmpPort}
c, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return errors.Wrapf(err, "dial backend addr=%v, srs=%v", addr, backend)
}
v.tcpConn = c
hs := rtmp.NewHandshake(v.rd)
client := rtmp.NewProtocol(c)
v.client = client
// Simple RTMP handshake with server.
if err := hs.WriteC0S0(c); err != nil {
return errors.Wrapf(err, "write c0")
}
if err := hs.WriteC1S1(c); err != nil {
return errors.Wrapf(err, "write c1")
}
if _, err = hs.ReadC0S0(c); err != nil {
return errors.Wrapf(err, "read s0")
}
if _, err := hs.ReadC1S1(c); err != nil {
return errors.Wrapf(err, "read s1")
}
if _, err = hs.ReadC2S2(c); err != nil {
return errors.Wrapf(err, "read c2")
}
logger.Df(ctx, "backend simple handshake done, server=%v", addr)
if err := hs.WriteC2S2(c, hs.C1S1()); err != nil {
return errors.Wrapf(err, "write c2")
}
// Connect RTMP app on tcUrl with server.
if true {
connectApp := rtmp.NewConnectAppPacket()
connectApp.CommandObject.Set("tcUrl", rtmp.NewAmf0String(tcUrl))
if err := client.WritePacket(ctx, connectApp, 1); err != nil {
return errors.Wrapf(err, "write connect app")
}
}
if true {
var connectAppRes *rtmp.ConnectAppResPacket
if _, err := rtmp.ExpectPacket(ctx, client, &connectAppRes); err != nil {
return errors.Wrapf(err, "expect connect app res")
}
logger.Df(ctx, "backend connect RTMP app, tcUrl=%v, id=%v", tcUrl, connectAppRes.SrsID())
}
// Play or view RTMP stream with server.
if v.typ == RTMPClientTypeViewer {
return v.play(ctx, client, streamName)
}
// Publish RTMP stream with server.
return v.publish(ctx, client, streamName)
}
func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol, streamName string) error {
if true {
identifyReq := rtmp.NewCallPacket()
identifyReq.CommandName = "releaseStream"
identifyReq.TransactionID = 2
identifyReq.CommandObject = rtmp.NewAmf0Null()
identifyReq.Args = rtmp.NewAmf0String(streamName)
if err := client.WritePacket(ctx, identifyReq, 0); err != nil {
return errors.Wrapf(err, "releaseStream")
}
}
for {
var identifyRes *rtmp.CallPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect releaseStream res")
}
if identifyRes.CommandName == "_result" {
break
}
}
if true {
identifyReq := rtmp.NewCallPacket()
identifyReq.CommandName = "FCPublish"
identifyReq.TransactionID = 3
identifyReq.CommandObject = rtmp.NewAmf0Null()
identifyReq.Args = rtmp.NewAmf0String(streamName)
if err := client.WritePacket(ctx, identifyReq, 0); err != nil {
return errors.Wrapf(err, "FCPublish")
}
}
for {
var identifyRes *rtmp.CallPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect FCPublish res")
}
if identifyRes.CommandName == "_result" {
break
}
}
var currentStreamID int
if true {
createStream := rtmp.NewCreateStreamPacket()
createStream.TransactionID = 4
createStream.CommandObject = rtmp.NewAmf0Null()
if err := client.WritePacket(ctx, createStream, 0); err != nil {
return errors.Wrapf(err, "createStream")
}
}
for {
var identifyRes *rtmp.CreateStreamResPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect createStream res")
}
if sid := identifyRes.StreamID; sid != 0 {
currentStreamID = int(sid)
break
}
}
if true {
publishStream := rtmp.NewPublishPacket()
publishStream.TransactionID = 5
publishStream.CommandObject = rtmp.NewAmf0Null()
publishStream.StreamName = *rtmp.NewAmf0String(streamName)
publishStream.StreamType = *rtmp.NewAmf0String("live")
if err := client.WritePacket(ctx, publishStream, currentStreamID); err != nil {
return errors.Wrapf(err, "publish")
}
}
for {
var identifyRes *rtmp.CallPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect publish res")
}
// Ignore onFCPublish, expect onStatus(NetStream.Publish.Start).
if identifyRes.CommandName == "onStatus" {
if data := rtmp.NewAmf0Converter(identifyRes.Args).ToObject(); data == nil {
return errors.Errorf("onStatus args not object")
} else if code := rtmp.NewAmf0Converter(data.Get("code")).ToString(); code == nil {
return errors.Errorf("onStatus code not string")
} else if *code != "NetStream.Publish.Start" {
return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", *code)
}
break
}
}
logger.Df(ctx, "backend publish stream=%v, sid=%v", streamName, currentStreamID)
return nil
}
func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, streamName string) error {
var currentStreamID int
if true {
createStream := rtmp.NewCreateStreamPacket()
createStream.TransactionID = 4
createStream.CommandObject = rtmp.NewAmf0Null()
if err := client.WritePacket(ctx, createStream, 0); err != nil {
return errors.Wrapf(err, "createStream")
}
}
for {
var identifyRes *rtmp.CreateStreamResPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect createStream res")
}
if sid := identifyRes.StreamID; sid != 0 {
currentStreamID = int(sid)
break
}
}
playStream := rtmp.NewPlayPacket()
playStream.StreamName = *rtmp.NewAmf0String(streamName)
if err := client.WritePacket(ctx, playStream, currentStreamID); err != nil {
return errors.Wrapf(err, "play")
}
for {
var identifyRes *rtmp.CallPacket
if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
return errors.Wrapf(err, "expect releaseStream res")
}
if identifyRes.CommandName == "onStatus" && identifyRes.ArgsCode() == "NetStream.Play.Start" {
break
}
}
return nil
}

View File

@@ -1,771 +0,0 @@
// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package rtmp
import (
"bytes"
"encoding"
"encoding/binary"
"fmt"
"math"
"sync"
"srs-proxy/errors"
)
// Please read @doc amf0_spec_121207.pdf, @page 4, @section 2.1 Types Overview
type amf0Marker uint8
const (
amf0MarkerNumber amf0Marker = iota // 0
amf0MarkerBoolean // 1
amf0MarkerString // 2
amf0MarkerObject // 3
amf0MarkerMovieClip // 4
amf0MarkerNull // 5
amf0MarkerUndefined // 6
amf0MarkerReference // 7
amf0MarkerEcmaArray // 8
amf0MarkerObjectEnd // 9
amf0MarkerStrictArray // 10
amf0MarkerDate // 11
amf0MarkerLongString // 12
amf0MarkerUnsupported // 13
amf0MarkerRecordSet // 14
amf0MarkerXmlDocument // 15
amf0MarkerTypedObject // 16
amf0MarkerAvmPlusObject // 17
amf0MarkerForbidden amf0Marker = 0xff
)
func (v amf0Marker) String() string {
switch v {
case amf0MarkerNumber:
return "Amf0Number"
case amf0MarkerBoolean:
return "amf0Boolean"
case amf0MarkerString:
return "Amf0String"
case amf0MarkerObject:
return "Amf0Object"
case amf0MarkerNull:
return "Null"
case amf0MarkerUndefined:
return "Undefined"
case amf0MarkerReference:
return "Reference"
case amf0MarkerEcmaArray:
return "EcmaArray"
case amf0MarkerObjectEnd:
return "ObjectEnd"
case amf0MarkerStrictArray:
return "StrictArray"
case amf0MarkerDate:
return "Date"
case amf0MarkerLongString:
return "LongString"
case amf0MarkerUnsupported:
return "Unsupported"
case amf0MarkerXmlDocument:
return "XmlDocument"
case amf0MarkerTypedObject:
return "TypedObject"
case amf0MarkerAvmPlusObject:
return "AvmPlusObject"
case amf0MarkerMovieClip:
return "MovieClip"
case amf0MarkerRecordSet:
return "RecordSet"
default:
return "Forbidden"
}
}
// For utest to mock it.
type amf0Buffer interface {
Bytes() []byte
WriteByte(c byte) error
Write(p []byte) (n int, err error)
}
var createBuffer = func() amf0Buffer {
return &bytes.Buffer{}
}
// All AMF0 things.
type amf0Any interface {
// Binary marshaler and unmarshaler.
encoding.BinaryUnmarshaler
encoding.BinaryMarshaler
// Get the size of bytes to marshal this object.
Size() int
// Get the Marker of any AMF0 stuff.
amf0Marker() amf0Marker
}
type amf0Converter struct {
from amf0Any
}
func NewAmf0Converter(from amf0Any) *amf0Converter {
return &amf0Converter{from: from}
}
func (v *amf0Converter) ToNumber() *amf0Number {
return amf0AnyTo[*amf0Number](v.from)
}
func (v *amf0Converter) ToBoolean() *amf0Boolean {
return amf0AnyTo[*amf0Boolean](v.from)
}
func (v *amf0Converter) ToString() *amf0String {
return amf0AnyTo[*amf0String](v.from)
}
func (v *amf0Converter) ToObject() *amf0Object {
return amf0AnyTo[*amf0Object](v.from)
}
func (v *amf0Converter) ToNull() *amf0Null {
return amf0AnyTo[*amf0Null](v.from)
}
func (v *amf0Converter) ToUndefined() *amf0Undefined {
return amf0AnyTo[*amf0Undefined](v.from)
}
func (v *amf0Converter) ToEcmaArray() *amf0EcmaArray {
return amf0AnyTo[*amf0EcmaArray](v.from)
}
func (v *amf0Converter) ToStrictArray() *amf0StrictArray {
return amf0AnyTo[*amf0StrictArray](v.from)
}
// Convert any to specified object.
func amf0AnyTo[T amf0Any](a amf0Any) T {
var to T
if a != nil {
if v, ok := a.(T); ok {
return v
}
}
return to
}
// Discovery the amf0 object from the bytes b.
func Amf0Discovery(p []byte) (a amf0Any, err error) {
if len(p) < 1 {
return nil, errors.Errorf("require 1 bytes only %v", len(p))
}
m := amf0Marker(p[0])
switch m {
case amf0MarkerNumber:
return NewAmf0Number(0), nil
case amf0MarkerBoolean:
return NewAmf0Boolean(false), nil
case amf0MarkerString:
return NewAmf0String(""), nil
case amf0MarkerObject:
return NewAmf0Object(), nil
case amf0MarkerNull:
return NewAmf0Null(), nil
case amf0MarkerUndefined:
return NewAmf0Undefined(), nil
case amf0MarkerReference:
case amf0MarkerEcmaArray:
return NewAmf0EcmaArray(), nil
case amf0MarkerObjectEnd:
return &amf0ObjectEOF{}, nil
case amf0MarkerStrictArray:
return NewAmf0StrictArray(), nil
case amf0MarkerDate, amf0MarkerLongString, amf0MarkerUnsupported, amf0MarkerXmlDocument,
amf0MarkerTypedObject, amf0MarkerAvmPlusObject, amf0MarkerForbidden, amf0MarkerMovieClip,
amf0MarkerRecordSet:
return nil, errors.Errorf("Marker %v is not supported", m)
}
return nil, errors.Errorf("Marker %v is invalid", m)
}
// The UTF8 string, please read @doc amf0_spec_121207.pdf, @page 3, @section 1.3.1 Strings and UTF-8
type amf0UTF8 string
func (v *amf0UTF8) Size() int {
return 2 + len(string(*v))
}
func (v *amf0UTF8) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 2 {
return errors.Errorf("require 2 bytes only %v", len(p))
}
size := uint16(p[0])<<8 | uint16(p[1])
if p = data[2:]; len(p) < int(size) {
return errors.Errorf("require %v bytes only %v", int(size), len(p))
}
*v = amf0UTF8(string(p[:size]))
return
}
func (v *amf0UTF8) MarshalBinary() (data []byte, err error) {
data = make([]byte, v.Size())
size := uint16(len(string(*v)))
data[0] = byte(size >> 8)
data[1] = byte(size)
if size > 0 {
copy(data[2:], []byte(*v))
}
return
}
// The number object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.2 Number Type
type amf0Number float64
func NewAmf0Number(f float64) *amf0Number {
v := amf0Number(f)
return &v
}
func (v *amf0Number) amf0Marker() amf0Marker {
return amf0MarkerNumber
}
func (v *amf0Number) Size() int {
return 1 + 8
}
func (v *amf0Number) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 9 {
return errors.Errorf("require 9 bytes only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerNumber {
return errors.Errorf("Amf0Number amf0Marker %v is illegal", m)
}
f := binary.BigEndian.Uint64(p[1:])
*v = amf0Number(math.Float64frombits(f))
return
}
func (v *amf0Number) MarshalBinary() (data []byte, err error) {
data = make([]byte, 9)
data[0] = byte(amf0MarkerNumber)
f := math.Float64bits(float64(*v))
binary.BigEndian.PutUint64(data[1:], f)
return
}
// The string objet, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.4 String Type
type amf0String string
func NewAmf0String(s string) *amf0String {
v := amf0String(s)
return &v
}
func (v *amf0String) amf0Marker() amf0Marker {
return amf0MarkerString
}
func (v *amf0String) Size() int {
u := amf0UTF8(*v)
return 1 + u.Size()
}
func (v *amf0String) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 1 {
return errors.Errorf("require 1 bytes only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerString {
return errors.Errorf("Amf0String amf0Marker %v is illegal", m)
}
var sv amf0UTF8
if err = sv.UnmarshalBinary(p[1:]); err != nil {
return errors.WithMessage(err, "utf8")
}
*v = amf0String(string(sv))
return
}
func (v *amf0String) MarshalBinary() (data []byte, err error) {
u := amf0UTF8(*v)
var pb []byte
if pb, err = u.MarshalBinary(); err != nil {
return nil, errors.WithMessage(err, "utf8")
}
data = append([]byte{byte(amf0MarkerString)}, pb...)
return
}
// The AMF0 object end type, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.11 Object End Type
type amf0ObjectEOF struct {
}
func (v *amf0ObjectEOF) amf0Marker() amf0Marker {
return amf0MarkerObjectEnd
}
func (v *amf0ObjectEOF) Size() int {
return 3
}
func (v *amf0ObjectEOF) UnmarshalBinary(data []byte) (err error) {
p := data
if len(p) < 3 {
return errors.Errorf("require 3 bytes only %v", len(p))
}
if p[0] != 0 || p[1] != 0 || p[2] != 9 {
return errors.Errorf("EOF amf0Marker %v is illegal", p[0:3])
}
return
}
func (v *amf0ObjectEOF) MarshalBinary() (data []byte, err error) {
return []byte{0, 0, 9}, nil
}
// Use array for object and ecma array, to keep the original order.
type amf0Property struct {
key amf0UTF8
value amf0Any
}
// The object-like AMF0 structure, like object and ecma array and strict array.
type amf0ObjectBase struct {
properties []*amf0Property
lock sync.Mutex
}
func (v *amf0ObjectBase) Size() int {
v.lock.Lock()
defer v.lock.Unlock()
var size int
for _, p := range v.properties {
key, value := p.key, p.value
size += key.Size() + value.Size()
}
return size
}
func (v *amf0ObjectBase) Get(key string) amf0Any {
v.lock.Lock()
defer v.lock.Unlock()
for _, p := range v.properties {
if string(p.key) == key {
return p.value
}
}
return nil
}
func (v *amf0ObjectBase) Set(key string, value amf0Any) *amf0ObjectBase {
v.lock.Lock()
defer v.lock.Unlock()
prop := &amf0Property{key: amf0UTF8(key), value: value}
var ok bool
for i, p := range v.properties {
if string(p.key) == key {
v.properties[i] = prop
ok = true
}
}
if !ok {
v.properties = append(v.properties, prop)
}
return v
}
func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) {
// if no eof, elems specified by maxElems.
if !eof && maxElems < 0 {
return errors.Errorf("maxElems=%v without eof", maxElems)
}
// if eof, maxElems must be -1.
if eof && maxElems != -1 {
return errors.Errorf("maxElems=%v with eof", maxElems)
}
readOne := func() (amf0UTF8, amf0Any, error) {
var u amf0UTF8
if err = u.UnmarshalBinary(p); err != nil {
return "", nil, errors.WithMessage(err, "prop name")
}
p = p[u.Size():]
var a amf0Any
if a, err = Amf0Discovery(p); err != nil {
return "", nil, errors.WithMessage(err, fmt.Sprintf("discover prop %v", string(u)))
}
return u, a, nil
}
pushOne := func(u amf0UTF8, a amf0Any) error {
// For object property, consume the whole bytes.
if err = a.UnmarshalBinary(p); err != nil {
return errors.WithMessage(err, fmt.Sprintf("unmarshal prop %v", string(u)))
}
v.Set(string(u), a)
p = p[a.Size():]
return nil
}
for eof {
u, a, err := readOne()
if err != nil {
return errors.WithMessage(err, "read")
}
// For object EOF, we should only consume total 3bytes.
if u.Size() == 2 && a.amf0Marker() == amf0MarkerObjectEnd {
// 2 bytes is consumed by u(name), the a(eof) should only consume 1 byte.
p = p[1:]
return nil
}
if err := pushOne(u, a); err != nil {
return errors.WithMessage(err, "push")
}
}
for len(v.properties) < maxElems {
u, a, err := readOne()
if err != nil {
return errors.WithMessage(err, "read")
}
if err := pushOne(u, a); err != nil {
return errors.WithMessage(err, "push")
}
}
return
}
func (v *amf0ObjectBase) marshal(b amf0Buffer) (err error) {
v.lock.Lock()
defer v.lock.Unlock()
var pb []byte
for _, p := range v.properties {
key, value := p.key, p.value
if pb, err = key.MarshalBinary(); err != nil {
return errors.WithMessage(err, fmt.Sprintf("marshal %v", string(key)))
}
if _, err = b.Write(pb); err != nil {
return errors.Wrapf(err, "write %v", string(key))
}
if pb, err = value.MarshalBinary(); err != nil {
return errors.WithMessage(err, fmt.Sprintf("marshal value for %v", string(key)))
}
if _, err = b.Write(pb); err != nil {
return errors.Wrapf(err, "marshal value for %v", string(key))
}
}
return
}
// The AMF0 object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.5 Object Type
type amf0Object struct {
amf0ObjectBase
eof amf0ObjectEOF
}
func NewAmf0Object() *amf0Object {
v := &amf0Object{}
v.properties = []*amf0Property{}
return v
}
func (v *amf0Object) amf0Marker() amf0Marker {
return amf0MarkerObject
}
func (v *amf0Object) Size() int {
return int(1) + v.eof.Size() + v.amf0ObjectBase.Size()
}
func (v *amf0Object) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 1 {
return errors.Errorf("require 1 byte only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerObject {
return errors.Errorf("Amf0Object amf0Marker %v is illegal", m)
}
p = p[1:]
if err = v.unmarshal(p, true, -1); err != nil {
return errors.WithMessage(err, "unmarshal")
}
return
}
func (v *amf0Object) MarshalBinary() (data []byte, err error) {
b := createBuffer()
if err = b.WriteByte(byte(amf0MarkerObject)); err != nil {
return nil, errors.Wrap(err, "marshal")
}
if err = v.marshal(b); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
var pb []byte
if pb, err = v.eof.MarshalBinary(); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
if _, err = b.Write(pb); err != nil {
return nil, errors.Wrap(err, "marshal")
}
return b.Bytes(), nil
}
// The AMF0 ecma array, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.10 ECMA Array Type
type amf0EcmaArray struct {
amf0ObjectBase
count uint32
eof amf0ObjectEOF
}
func NewAmf0EcmaArray() *amf0EcmaArray {
v := &amf0EcmaArray{}
v.properties = []*amf0Property{}
return v
}
func (v *amf0EcmaArray) amf0Marker() amf0Marker {
return amf0MarkerEcmaArray
}
func (v *amf0EcmaArray) Size() int {
return int(1) + 4 + v.eof.Size() + v.amf0ObjectBase.Size()
}
func (v *amf0EcmaArray) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 5 {
return errors.Errorf("require 5 bytes only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerEcmaArray {
return errors.Errorf("EcmaArray amf0Marker %v is illegal", m)
}
v.count = binary.BigEndian.Uint32(p[1:])
p = p[5:]
if err = v.unmarshal(p, true, -1); err != nil {
return errors.WithMessage(err, "unmarshal")
}
return
}
func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) {
b := createBuffer()
if err = b.WriteByte(byte(amf0MarkerEcmaArray)); err != nil {
return nil, errors.Wrap(err, "marshal")
}
if err = binary.Write(b, binary.BigEndian, v.count); err != nil {
return nil, errors.Wrap(err, "marshal")
}
if err = v.marshal(b); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
var pb []byte
if pb, err = v.eof.MarshalBinary(); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
if _, err = b.Write(pb); err != nil {
return nil, errors.Wrap(err, "marshal")
}
return b.Bytes(), nil
}
// The AMF0 strict array, please read @doc amf0_spec_121207.pdf, @page 7, @section 2.12 Strict Array Type
type amf0StrictArray struct {
amf0ObjectBase
count uint32
}
func NewAmf0StrictArray() *amf0StrictArray {
v := &amf0StrictArray{}
v.properties = []*amf0Property{}
return v
}
func (v *amf0StrictArray) amf0Marker() amf0Marker {
return amf0MarkerStrictArray
}
func (v *amf0StrictArray) Size() int {
return int(1) + 4 + v.amf0ObjectBase.Size()
}
func (v *amf0StrictArray) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 5 {
return errors.Errorf("require 5 bytes only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerStrictArray {
return errors.Errorf("StrictArray amf0Marker %v is illegal", m)
}
v.count = binary.BigEndian.Uint32(p[1:])
p = p[5:]
if int(v.count) <= 0 {
return
}
if err = v.unmarshal(p, false, int(v.count)); err != nil {
return errors.WithMessage(err, "unmarshal")
}
return
}
func (v *amf0StrictArray) MarshalBinary() (data []byte, err error) {
b := createBuffer()
if err = b.WriteByte(byte(amf0MarkerStrictArray)); err != nil {
return nil, errors.Wrap(err, "marshal")
}
if err = binary.Write(b, binary.BigEndian, v.count); err != nil {
return nil, errors.Wrap(err, "marshal")
}
if err = v.marshal(b); err != nil {
return nil, errors.WithMessage(err, "marshal")
}
return b.Bytes(), nil
}
// The single amf0Marker object, for all AMF0 which only has the amf0Marker, like null and undefined.
type amf0SingleMarkerObject struct {
target amf0Marker
}
func newAmf0SingleMarkerObject(m amf0Marker) amf0SingleMarkerObject {
return amf0SingleMarkerObject{target: m}
}
func (v *amf0SingleMarkerObject) amf0Marker() amf0Marker {
return v.target
}
func (v *amf0SingleMarkerObject) Size() int {
return int(1)
}
func (v *amf0SingleMarkerObject) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 1 {
return errors.Errorf("require 1 byte only %v", len(p))
}
if m := amf0Marker(p[0]); m != v.target {
return errors.Errorf("%v amf0Marker %v is illegal", v.target, m)
}
return
}
func (v *amf0SingleMarkerObject) MarshalBinary() (data []byte, err error) {
return []byte{byte(v.target)}, nil
}
// The AMF0 null, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.7 null Type
type amf0Null struct {
amf0SingleMarkerObject
}
func NewAmf0Null() *amf0Null {
v := amf0Null{}
v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerNull)
return &v
}
// The AMF0 undefined, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.8 undefined Type
type amf0Undefined struct {
amf0SingleMarkerObject
}
func NewAmf0Undefined() amf0Any {
v := amf0Undefined{}
v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerUndefined)
return &v
}
// The AMF0 boolean, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.3 Boolean Type
type amf0Boolean bool
func NewAmf0Boolean(b bool) amf0Any {
v := amf0Boolean(b)
return &v
}
func (v *amf0Boolean) amf0Marker() amf0Marker {
return amf0MarkerBoolean
}
func (v *amf0Boolean) Size() int {
return int(2)
}
func (v *amf0Boolean) UnmarshalBinary(data []byte) (err error) {
var p []byte
if p = data; len(p) < 2 {
return errors.Errorf("require 2 bytes only %v", len(p))
}
if m := amf0Marker(p[0]); m != amf0MarkerBoolean {
return errors.Errorf("BOOL amf0Marker %v is illegal", m)
}
if p[1] == 0 {
*v = false
} else {
*v = true
}
return
}
func (v *amf0Boolean) MarshalBinary() (data []byte, err error) {
var b byte
if *v {
b = 1
}
return []byte{byte(amf0MarkerBoolean), b}, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,44 +0,0 @@
// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"os"
"os/signal"
"syscall"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
)
func installSignals(ctx context.Context, cancel context.CancelFunc) {
sc := make(chan os.Signal, 1)
signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
go func() {
for s := range sc {
logger.Df(ctx, "Got signal %v", s)
cancel()
}
}()
}
func installForceQuit(ctx context.Context) error {
var forceTimeout time.Duration
if t, err := time.ParseDuration(envForceQuitTimeout()); err != nil {
return errors.Wrapf(err, "parse force timeout %v", envForceQuitTimeout())
} else {
forceTimeout = t
}
go func() {
<-ctx.Done()
time.Sleep(forceTimeout)
logger.Wf(ctx, "Force to exit by timeout")
os.Exit(1)
}()
return nil
}

View File

@@ -1,553 +0,0 @@
// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"os"
"strconv"
"strings"
"time"
// Use v8 because we use Go 1.16+, while v9 requires Go 1.18+
"github.com/go-redis/redis/v8"
"srs-proxy/errors"
"srs-proxy/logger"
"srs-proxy/sync"
)
// If server heartbeat in this duration, it's alive.
const srsServerAliveDuration = 300 * time.Second
// If HLS streaming update in this duration, it's alive.
const srsHLSAliveDuration = 120 * time.Second
// If WebRTC streaming update in this duration, it's alive.
const srsRTCAliveDuration = 120 * time.Second
type SRSServer struct {
// The server IP.
IP string `json:"ip,omitempty"`
// The server device ID, configured by user.
DeviceID string `json:"device_id,omitempty"`
// The server id of SRS, store in file, may not change, mandatory.
ServerID string `json:"server_id,omitempty"`
// The service id of SRS, always change when restarted, mandatory.
ServiceID string `json:"service_id,omitempty"`
// The process id of SRS, always change when restarted, mandatory.
PID string `json:"pid,omitempty"`
// The RTMP listen endpoints.
RTMP []string `json:"rtmp,omitempty"`
// The HTTP Stream listen endpoints.
HTTP []string `json:"http,omitempty"`
// The HTTP API listen endpoints.
API []string `json:"api,omitempty"`
// The SRT server listen endpoints.
SRT []string `json:"srt,omitempty"`
// The RTC server listen endpoints.
RTC []string `json:"rtc,omitempty"`
// Last update time.
UpdatedAt time.Time `json:"update_at,omitempty"`
}
func (v *SRSServer) ID() string {
return fmt.Sprintf("%v-%v-%v", v.ServerID, v.ServiceID, v.PID)
}
func (v *SRSServer) String() string {
return fmt.Sprintf("%v", v)
}
func (v *SRSServer) Format(f fmt.State, c rune) {
switch c {
case 'v', 's':
if f.Flag('+') {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("pid=%v, server=%v, service=%v", v.PID, v.ServerID, v.ServiceID))
if v.DeviceID != "" {
sb.WriteString(fmt.Sprintf(", device=%v", v.DeviceID))
}
if len(v.RTMP) > 0 {
sb.WriteString(fmt.Sprintf(", rtmp=[%v]", strings.Join(v.RTMP, ",")))
}
if len(v.HTTP) > 0 {
sb.WriteString(fmt.Sprintf(", http=[%v]", strings.Join(v.HTTP, ",")))
}
if len(v.API) > 0 {
sb.WriteString(fmt.Sprintf(", api=[%v]", strings.Join(v.API, ",")))
}
if len(v.SRT) > 0 {
sb.WriteString(fmt.Sprintf(", srt=[%v]", strings.Join(v.SRT, ",")))
}
if len(v.RTC) > 0 {
sb.WriteString(fmt.Sprintf(", rtc=[%v]", strings.Join(v.RTC, ",")))
}
sb.WriteString(fmt.Sprintf(", update=%v", v.UpdatedAt.Format("2006-01-02 15:04:05.999")))
fmt.Fprintf(f, "SRS ip=%v, id=%v, %v", v.IP, v.ID(), sb.String())
} else {
fmt.Fprintf(f, "SRS ip=%v, id=%v", v.IP, v.ID())
}
default:
fmt.Fprintf(f, "%v, fmt=%%%c", v, c)
}
}
func NewSRSServer(opts ...func(*SRSServer)) *SRSServer {
v := &SRSServer{}
for _, opt := range opts {
opt(v)
}
return v
}
// NewDefaultSRSForDebugging initialize the default SRS media server, for debugging only.
func NewDefaultSRSForDebugging() (*SRSServer, error) {
if envDefaultBackendEnabled() != "on" {
return nil, nil
}
if envDefaultBackendIP() == "" {
return nil, fmt.Errorf("empty default backend ip")
}
if envDefaultBackendRTMP() == "" {
return nil, fmt.Errorf("empty default backend rtmp")
}
server := NewSRSServer(func(srs *SRSServer) {
srs.IP = envDefaultBackendIP()
srs.RTMP = []string{envDefaultBackendRTMP()}
srs.ServerID = fmt.Sprintf("default-%v", logger.GenerateContextID())
srs.ServiceID = logger.GenerateContextID()
srs.PID = fmt.Sprintf("%v", os.Getpid())
srs.UpdatedAt = time.Now()
})
if envDefaultBackendHttp() != "" {
server.HTTP = []string{envDefaultBackendHttp()}
}
if envDefaultBackendAPI() != "" {
server.API = []string{envDefaultBackendAPI()}
}
if envDefaultBackendRTC() != "" {
server.RTC = []string{envDefaultBackendRTC()}
}
if envDefaultBackendSRT() != "" {
server.SRT = []string{envDefaultBackendSRT()}
}
return server, nil
}
// SRSLoadBalancer is the interface to load balance the SRS servers.
type SRSLoadBalancer interface {
// Initialize the load balancer.
Initialize(ctx context.Context) error
// Update the backer server.
Update(ctx context.Context, server *SRSServer) error
// Pick a backend server for the specified stream URL.
Pick(ctx context.Context, streamURL string) (*SRSServer, error)
// Load or store the HLS streaming for the specified stream URL.
LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error)
// Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID.
LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error)
// Store the WebRTC streaming for the specified stream URL.
StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error
// Load the WebRTC streaming by ufrag, the ICE username.
LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error)
}
// srsLoadBalancer is the global SRS load balancer.
var srsLoadBalancer SRSLoadBalancer
// srsMemoryLoadBalancer stores state in memory.
type srsMemoryLoadBalancer struct {
// All available SRS servers, key is server ID.
servers sync.Map[string, *SRSServer]
// The picked server to servce client by specified stream URL, key is stream url.
picked sync.Map[string, *SRSServer]
// The HLS streaming, key is stream URL.
hlsStreamURL sync.Map[string, *HLSPlayStream]
// The HLS streaming, key is SPBHID.
hlsSPBHID sync.Map[string, *HLSPlayStream]
// The WebRTC streaming, key is stream URL.
rtcStreamURL sync.Map[string, *RTCConnection]
// The WebRTC streaming, key is ufrag.
rtcUfrag sync.Map[string, *RTCConnection]
}
func NewMemoryLoadBalancer() SRSLoadBalancer {
return &srsMemoryLoadBalancer{}
}
func (v *srsMemoryLoadBalancer) Initialize(ctx context.Context) error {
if server, err := NewDefaultSRSForDebugging(); err != nil {
return errors.Wrapf(err, "initialize default SRS")
} else if server != nil {
if err := v.Update(ctx, server); err != nil {
return errors.Wrapf(err, "update default SRS %+v", server)
}
// Keep alive.
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(30 * time.Second):
if err := v.Update(ctx, server); err != nil {
logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err)
}
}
}
}()
logger.Df(ctx, "MemoryLB: Initialize default SRS media server, %+v", server)
}
return nil
}
func (v *srsMemoryLoadBalancer) Update(ctx context.Context, server *SRSServer) error {
v.servers.Store(server.ID(), server)
return nil
}
func (v *srsMemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) {
// Always proxy to the same server for the same stream URL.
if server, ok := v.picked.Load(streamURL); ok {
return server, nil
}
// Gather all servers that were alive within the last few seconds.
var servers []*SRSServer
v.servers.Range(func(key string, server *SRSServer) bool {
if time.Since(server.UpdatedAt) < srsServerAliveDuration {
servers = append(servers, server)
}
return true
})
// If no servers available, use all possible servers.
if len(servers) == 0 {
v.servers.Range(func(key string, server *SRSServer) bool {
servers = append(servers, server)
return true
})
}
// No server found, failed.
if len(servers) == 0 {
return nil, fmt.Errorf("no server available for %v", streamURL)
}
// Pick a server randomly from servers.
server := servers[rand.Intn(len(servers))]
v.picked.Store(streamURL, server)
return server, nil
}
func (v *srsMemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) {
// Load the HLS streaming for the SPBHID, for TS files.
if actual, ok := v.hlsSPBHID.Load(spbhid); !ok {
return nil, errors.Errorf("no HLS streaming for SPBHID %v", spbhid)
} else {
return actual, nil
}
}
func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) {
// Update the HLS streaming for the stream URL, for M3u8.
actual, _ := v.hlsStreamURL.LoadOrStore(streamURL, value)
if actual == nil {
return nil, errors.Errorf("load or store HLS streaming for %v failed", streamURL)
}
// Update the HLS streaming for the SPBHID, for TS files.
v.hlsSPBHID.Store(value.SRSProxyBackendHLSID, actual)
return actual, nil
}
func (v *srsMemoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error {
// Update the WebRTC streaming for the stream URL.
v.rtcStreamURL.Store(streamURL, value)
// Update the WebRTC streaming for the ufrag.
v.rtcUfrag.Store(value.Ufrag, value)
return nil
}
func (v *srsMemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) {
if actual, ok := v.rtcUfrag.Load(ufrag); !ok {
return nil, errors.Errorf("no WebRTC streaming for ufrag %v", ufrag)
} else {
return actual, nil
}
}
type srsRedisLoadBalancer struct {
// The redis client sdk.
rdb *redis.Client
}
func NewRedisLoadBalancer() SRSLoadBalancer {
return &srsRedisLoadBalancer{}
}
func (v *srsRedisLoadBalancer) Initialize(ctx context.Context) error {
redisDatabase, err := strconv.Atoi(envRedisDB())
if err != nil {
return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", envRedisDB())
}
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%v:%v", envRedisHost(), envRedisPort()),
Password: envRedisPassword(),
DB: redisDatabase,
})
v.rdb = rdb
if err := rdb.Ping(ctx).Err(); err != nil {
return errors.Wrapf(err, "unable to connect to redis %v", rdb.String())
}
logger.Df(ctx, "RedisLB: connected to redis %v ok", rdb.String())
if server, err := NewDefaultSRSForDebugging(); err != nil {
return errors.Wrapf(err, "initialize default SRS")
} else if server != nil {
if err := v.Update(ctx, server); err != nil {
return errors.Wrapf(err, "update default SRS %+v", server)
}
// Keep alive.
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(30 * time.Second):
if err := v.Update(ctx, server); err != nil {
logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err)
}
}
}
}()
logger.Df(ctx, "RedisLB: Initialize default SRS media server, %+v", server)
}
return nil
}
func (v *srsRedisLoadBalancer) Update(ctx context.Context, server *SRSServer) error {
b, err := json.Marshal(server)
if err != nil {
return errors.Wrapf(err, "marshal server %+v", server)
}
key := v.redisKeyServer(server.ID())
if err = v.rdb.Set(ctx, key, b, srsServerAliveDuration).Err(); err != nil {
return errors.Wrapf(err, "set key=%v server %+v", key, server)
}
// Query all servers from redis, in json string.
var serverKeys []string
if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil {
if err := json.Unmarshal(b, &serverKeys); err != nil {
return errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b))
}
}
// Check each server expiration, if not exists in redis, remove from servers.
for i := len(serverKeys) - 1; i >= 0; i-- {
if _, err := v.rdb.Get(ctx, serverKeys[i]).Bytes(); err != nil {
serverKeys = append(serverKeys[:i], serverKeys[i+1:]...)
}
}
// Add server to servers if not exists.
var found bool
for _, serverKey := range serverKeys {
if serverKey == key {
found = true
break
}
}
if !found {
serverKeys = append(serverKeys, key)
}
// Update all servers to redis.
b, err = json.Marshal(serverKeys)
if err != nil {
return errors.Wrapf(err, "marshal servers %+v", serverKeys)
}
if err = v.rdb.Set(ctx, v.redisKeyServers(), b, 0).Err(); err != nil {
return errors.Wrapf(err, "set key=%v servers %+v", v.redisKeyServers(), serverKeys)
}
return nil
}
func (v *srsRedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) {
key := fmt.Sprintf("srs-proxy-url:%v", streamURL)
// Always proxy to the same server for the same stream URL.
if serverKey, err := v.rdb.Get(ctx, key).Result(); err == nil {
// If server not exists, ignore and pick another server for the stream URL.
if b, err := v.rdb.Get(ctx, serverKey).Bytes(); err == nil && len(b) > 0 {
var server SRSServer
if err := json.Unmarshal(b, &server); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v server %v", key, string(b))
}
// TODO: If server fail, we should migrate the streams to another server.
return &server, nil
}
}
// Query all servers from redis, in json string.
var serverKeys []string
if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil {
if err := json.Unmarshal(b, &serverKeys); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b))
}
}
// No server found, failed.
if len(serverKeys) == 0 {
return nil, fmt.Errorf("no server available for %v", streamURL)
}
// All server should be alive, if not, should have been removed by redis. So we only
// random pick one that is always available.
var serverKey string
var server SRSServer
for i := 0; i < 3; i++ {
tryServerKey := serverKeys[rand.Intn(len(serverKeys))]
b, err := v.rdb.Get(ctx, tryServerKey).Bytes()
if err == nil && len(b) > 0 {
if err := json.Unmarshal(b, &server); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v server %v", serverKey, string(b))
}
serverKey = tryServerKey
break
}
}
if serverKey == "" {
return nil, errors.Errorf("no server available in %v for %v", serverKeys, streamURL)
}
// Update the picked server for the stream URL.
if err := v.rdb.Set(ctx, key, []byte(serverKey), 0).Err(); err != nil {
return nil, errors.Wrapf(err, "set key=%v server %v", key, serverKey)
}
return &server, nil
}
func (v *srsRedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) {
key := v.redisKeySPBHID(spbhid)
b, err := v.rdb.Get(ctx, key).Bytes()
if err != nil {
return nil, errors.Wrapf(err, "get key=%v HLS", key)
}
var actual HLSPlayStream
if err := json.Unmarshal(b, &actual); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b))
}
return &actual, nil
}
func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) {
b, err := json.Marshal(value)
if err != nil {
return nil, errors.Wrapf(err, "marshal HLS %v", value)
}
key := v.redisKeyHLS(streamURL)
if err = v.rdb.Set(ctx, key, b, srsHLSAliveDuration).Err(); err != nil {
return nil, errors.Wrapf(err, "set key=%v HLS %v", key, value)
}
key2 := v.redisKeySPBHID(value.SRSProxyBackendHLSID)
if err := v.rdb.Set(ctx, key2, b, srsHLSAliveDuration).Err(); err != nil {
return nil, errors.Wrapf(err, "set key=%v HLS %v", key2, value)
}
// Query the HLS streaming from redis.
b2, err := v.rdb.Get(ctx, key).Bytes()
if err != nil {
return nil, errors.Wrapf(err, "get key=%v HLS", key)
}
var actual HLSPlayStream
if err := json.Unmarshal(b2, &actual); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b2))
}
return &actual, nil
}
func (v *srsRedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error {
b, err := json.Marshal(value)
if err != nil {
return errors.Wrapf(err, "marshal WebRTC %v", value)
}
key := v.redisKeyRTC(streamURL)
if err = v.rdb.Set(ctx, key, b, srsRTCAliveDuration).Err(); err != nil {
return errors.Wrapf(err, "set key=%v WebRTC %v", key, value)
}
key2 := v.redisKeyUfrag(value.Ufrag)
if err := v.rdb.Set(ctx, key2, b, srsRTCAliveDuration).Err(); err != nil {
return errors.Wrapf(err, "set key=%v WebRTC %v", key2, value)
}
return nil
}
func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) {
key := v.redisKeyUfrag(ufrag)
b, err := v.rdb.Get(ctx, key).Bytes()
if err != nil {
return nil, errors.Wrapf(err, "get key=%v WebRTC", key)
}
var actual RTCConnection
if err := json.Unmarshal(b, &actual); err != nil {
return nil, errors.Wrapf(err, "unmarshal key=%v WebRTC %v", key, string(b))
}
return &actual, nil
}
func (v *srsRedisLoadBalancer) redisKeyUfrag(ufrag string) string {
return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag)
}
func (v *srsRedisLoadBalancer) redisKeyRTC(streamURL string) string {
return fmt.Sprintf("srs-proxy-rtc:%v", streamURL)
}
func (v *srsRedisLoadBalancer) redisKeySPBHID(spbhid string) string {
return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid)
}
func (v *srsRedisLoadBalancer) redisKeyHLS(streamURL string) string {
return fmt.Sprintf("srs-proxy-hls:%v", streamURL)
}
func (v *srsRedisLoadBalancer) redisKeyServer(serverID string) string {
return fmt.Sprintf("srs-proxy-server:%v", serverID)
}
func (v *srsRedisLoadBalancer) redisKeyServers() string {
return fmt.Sprintf("srs-proxy-all-servers")
}

View File

@@ -1,574 +0,0 @@
// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"net"
"strings"
stdSync "sync"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
"srs-proxy/sync"
)
// srsSRTServer is the proxy for SRS server via SRT. It will figure out which backend server to
// proxy to. It only parses the SRT handshake messages, parses the stream id, and proxy to the
// backend server.
type srsSRTServer struct {
// The UDP listener for SRT server.
listener *net.UDPConn
// The SRT connections, identify by the socket ID.
sockets sync.Map[uint32, *SRTConnection]
// The system start time.
start time.Time
// The wait group for server.
wg stdSync.WaitGroup
}
func NewSRSSRTServer(opts ...func(*srsSRTServer)) *srsSRTServer {
v := &srsSRTServer{
start: time.Now(),
}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *srsSRTServer) Close() error {
if v.listener != nil {
v.listener.Close()
}
v.wg.Wait()
return nil
}
func (v *srsSRTServer) Run(ctx context.Context) error {
// Parse address to listen.
endpoint := envSRTServer()
if !strings.Contains(endpoint, ":") {
endpoint = ":" + endpoint
}
saddr, err := net.ResolveUDPAddr("udp", endpoint)
if err != nil {
return errors.Wrapf(err, "resolve udp addr %v", endpoint)
}
listener, err := net.ListenUDP("udp", saddr)
if err != nil {
return errors.Wrapf(err, "listen udp %v", saddr)
}
v.listener = listener
logger.Df(ctx, "SRT server listen at %v", saddr)
// Consume all messages from UDP media transport.
v.wg.Add(1)
go func() {
defer v.wg.Done()
for ctx.Err() == nil {
buf := make([]byte, 4096)
n, caddr, err := v.listener.ReadFromUDP(buf)
if err != nil {
// TODO: If SRT server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "read from udp failed, err=%+v", err)
continue
}
if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil {
logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err)
}
}
}()
return nil
}
func (v *srsSRTServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error {
socketID := srtParseSocketID(data)
var pkt *SRTHandshakePacket
if srtIsHandshake(data) {
pkt = &SRTHandshakePacket{}
if err := pkt.UnmarshalBinary(data); err != nil {
return err
}
if socketID == 0 {
socketID = pkt.SRTSocketID
}
}
conn, ok := v.sockets.LoadOrStore(socketID, NewSRTConnection(func(c *SRTConnection) {
c.ctx = logger.WithContext(ctx)
c.listenerUDP, c.socketID = v.listener, socketID
c.start = v.start
}))
ctx = conn.ctx
if !ok {
logger.Df(ctx, "Create new SRT connection skt=%v", socketID)
}
if newSocketID, err := conn.HandlePacket(pkt, addr, data); err != nil {
return errors.Wrapf(err, "handle packet")
} else if newSocketID != 0 && newSocketID != socketID {
// The connection may use a new socket ID.
// TODO: FIXME: Should cleanup the dead SRT connection.
v.sockets.Store(newSocketID, conn)
}
return nil
}
// SRTConnection is an SRT connection proxy, for both caller and listener. It represents an SRT
// connection, identify by the socket ID.
//
// It's similar to RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is in
// the client request. The SRTConnection is stateless, and no need to sync between proxy servers.
//
// Unlike the WebRTC connection, SRTConnection does not support address changes. This means the
// client should never switch to another network or port. If this occurs, the client may be served
// by a different proxy server and fail because the other proxy server cannot identify the client.
type SRTConnection struct {
// The stream context for SRT connection.
ctx context.Context
// The current socket ID.
socketID uint32
// The UDP connection proxy to backend.
backendUDP *net.UDPConn
// The listener UDP connection, used to send messages to client.
listenerUDP *net.UDPConn
// Listener start time.
start time.Time
// Handshake packets with client.
handshake0 *SRTHandshakePacket
handshake1 *SRTHandshakePacket
handshake2 *SRTHandshakePacket
handshake3 *SRTHandshakePacket
}
func NewSRTConnection(opts ...func(*SRTConnection)) *SRTConnection {
v := &SRTConnection{}
for _, opt := range opts {
opt(v)
}
return v
}
func (v *SRTConnection) HandlePacket(pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) (uint32, error) {
ctx := v.ctx
// If not handshake, try to proxy to backend directly.
if pkt == nil {
// Proxy client message to backend.
if v.backendUDP != nil {
if _, err := v.backendUDP.Write(data); err != nil {
return v.socketID, errors.Wrapf(err, "write to backend")
}
}
return v.socketID, nil
}
// Handle handshake messages.
if err := v.handleHandshake(ctx, pkt, addr, data); err != nil {
return v.socketID, errors.Wrapf(err, "handle handshake %v", pkt)
}
return v.socketID, nil
}
func (v *SRTConnection) handleHandshake(ctx context.Context, pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) error {
// Handle handshake 0 and 1 messages.
if pkt.SynCookie == 0 {
// Save handshake 0 packet.
v.handshake0 = pkt
logger.Df(ctx, "SRT Handshake 0: %v", v.handshake0)
// Response handshake 1.
v.handshake1 = &SRTHandshakePacket{
ControlFlag: pkt.ControlFlag,
ControlType: 0,
SubType: 0,
AdditionalInfo: 0,
Timestamp: uint32(time.Since(v.start).Microseconds()),
SocketID: pkt.SRTSocketID,
Version: 5,
EncryptionField: 0,
ExtensionField: 0x4A17,
InitSequence: pkt.InitSequence,
MTU: pkt.MTU,
FlowWindow: pkt.FlowWindow,
HandshakeType: 1,
SRTSocketID: pkt.SRTSocketID,
SynCookie: 0x418d5e4e,
PeerIP: net.ParseIP("127.0.0.1"),
}
logger.Df(ctx, "SRT Handshake 1: %v", v.handshake1)
if b, err := v.handshake1.MarshalBinary(); err != nil {
return errors.Wrapf(err, "marshal handshake 1")
} else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil {
return errors.Wrapf(err, "write handshake 1")
}
return nil
}
// Handle handshake 2 and 3 messages.
// Parse stream id from packet.
streamID, err := pkt.StreamID()
if err != nil {
return errors.Wrapf(err, "parse stream id")
}
// Save handshake packet.
v.handshake2 = pkt
logger.Df(ctx, "SRT Handshake 2: %v, sid=%v", v.handshake2, streamID)
// Start the UDP proxy to backend.
if err := v.connectBackend(ctx, streamID); err != nil {
return errors.Wrapf(err, "connect backend for %v", streamID)
}
// Proxy client message to backend.
if v.backendUDP == nil {
return errors.Errorf("no backend for %v", streamID)
}
// Proxy handshake 0 to backend server.
if b, err := v.handshake0.MarshalBinary(); err != nil {
return errors.Wrapf(err, "marshal handshake 0")
} else if _, err = v.backendUDP.Write(b); err != nil {
return errors.Wrapf(err, "write handshake 0")
}
logger.Df(ctx, "Proxy send handshake 0: %v", v.handshake0)
// Read handshake 1 from backend server.
b := make([]byte, 4096)
handshake1p := &SRTHandshakePacket{}
if nn, err := v.backendUDP.Read(b); err != nil {
return errors.Wrapf(err, "read handshake 1")
} else if err := handshake1p.UnmarshalBinary(b[:nn]); err != nil {
return errors.Wrapf(err, "unmarshal handshake 1")
}
logger.Df(ctx, "Proxy got handshake 1: %v", handshake1p)
// Proxy handshake 2 to backend server.
handshake2p := *v.handshake2
handshake2p.SynCookie = handshake1p.SynCookie
if b, err := handshake2p.MarshalBinary(); err != nil {
return errors.Wrapf(err, "marshal handshake 2")
} else if _, err = v.backendUDP.Write(b); err != nil {
return errors.Wrapf(err, "write handshake 2")
}
logger.Df(ctx, "Proxy send handshake 2: %v", handshake2p)
// Read handshake 3 from backend server.
handshake3p := &SRTHandshakePacket{}
if nn, err := v.backendUDP.Read(b); err != nil {
return errors.Wrapf(err, "read handshake 3")
} else if err := handshake3p.UnmarshalBinary(b[:nn]); err != nil {
return errors.Wrapf(err, "unmarshal handshake 3")
}
logger.Df(ctx, "Proxy got handshake 3: %v", handshake3p)
// Response handshake 3 to client.
v.handshake3 = &*handshake3p
v.handshake3.SynCookie = v.handshake1.SynCookie
v.socketID = handshake3p.SRTSocketID
logger.Df(ctx, "Handshake 3: %v", v.handshake3)
if b, err := v.handshake3.MarshalBinary(); err != nil {
return errors.Wrapf(err, "marshal handshake 3")
} else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil {
return errors.Wrapf(err, "write handshake 3")
}
// Start a goroutine to proxy message from backend to client.
// TODO: FIXME: Support close the connection when timeout or client disconnected.
go func() {
for ctx.Err() == nil {
nn, err := v.backendUDP.Read(b)
if err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Wf(ctx, "read from backend failed, err=%v", err)
return
}
if _, err = v.listenerUDP.WriteToUDP(b[:nn], addr); err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Wf(ctx, "write to client failed, err=%v", err)
return
}
}
}()
return nil
}
func (v *SRTConnection) connectBackend(ctx context.Context, streamID string) error {
if v.backendUDP != nil {
return nil
}
// Parse stream id to host and resource.
host, resource, err := parseSRTStreamID(streamID)
if err != nil {
return errors.Wrapf(err, "parse stream id %v", streamID)
}
if host == "" {
host = "localhost"
}
streamURL, err := buildStreamURL(fmt.Sprintf("srt://%v/%v", host, resource))
if err != nil {
return errors.Wrapf(err, "build stream url %v", streamID)
}
// Pick a backend SRS server to proxy the SRT stream.
backend, err := srsLoadBalancer.Pick(ctx, streamURL)
if err != nil {
return errors.Wrapf(err, "pick backend for %v", streamURL)
}
// Parse UDP port from backend.
if len(backend.SRT) == 0 {
return errors.Errorf("no udp server %v for %v", backend, streamURL)
}
_, _, udpPort, err := parseListenEndpoint(backend.SRT[0])
if err != nil {
return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.SRT[0], backend, streamURL)
}
// Connect to backend SRS server via UDP client.
// TODO: FIXME: Support close the connection when timeout or client disconnected.
backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)}
if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil {
return errors.Wrapf(err, "dial udp to %v of %v for %v", backendAddr, backend, streamURL)
} else {
v.backendUDP = backendUDP
}
return nil
}
// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2
// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2.1
type SRTHandshakePacket struct {
// F: 1 bit. Packet Type Flag. The control packet has this flag set to
// "1". The data packet has this flag set to "0".
ControlFlag uint8
// Control Type: 15 bits. Control Packet Type. The use of these bits
// is determined by the control packet type definition.
// Handshake control packets (Control Type = 0x0000) are used to
// exchange peer configurations, to agree on connection parameters, and
// to establish a connection.
ControlType uint16
// Subtype: 16 bits. This field specifies an additional subtype for
// specific packets.
SubType uint16
// Type-specific Information: 32 bits. The use of this field depends on
// the particular control packet type. Handshake packets do not use
// this field.
AdditionalInfo uint32
// Timestamp: 32 bits.
Timestamp uint32
// Destination Socket ID: 32 bits.
SocketID uint32
// Version: 32 bits. A base protocol version number. Currently used
// values are 4 and 5. Values greater than 5 are reserved for future
// use.
Version uint32
// Encryption Field: 16 bits. Block cipher family and key size. The
// values of this field are described in Table 2. The default value
// is AES-128.
// 0 | No Encryption Advertised
// 2 | AES-128
// 3 | AES-192
// 4 | AES-256
EncryptionField uint16
// Extension Field: 16 bits. This field is message specific extension
// related to Handshake Type field. The value MUST be set to 0
// except for the following cases. (1) If the handshake control
// packet is the INDUCTION message, this field is sent back by the
// Listener. (2) In the case of a CONCLUSION message, this field
// value should contain a combination of Extension Type values.
// 0x00000001 | HSREQ
// 0x00000002 | KMREQ
// 0x00000004 | CONFIG
// 0x4A17 if HandshakeType is INDUCTION, see https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-4.3.1.1
ExtensionField uint16
// Initial Packet Sequence Number: 32 bits. The sequence number of the
// very first data packet to be sent.
InitSequence uint32
// Maximum Transmission Unit Size: 32 bits. This value is typically set
// to 1500, which is the default Maximum Transmission Unit (MTU) size
// for Ethernet, but can be less.
MTU uint32
// Maximum Flow Window Size: 32 bits. The value of this field is the
// maximum number of data packets allowed to be "in flight" (i.e. the
// number of sent packets for which an ACK control packet has not yet
// been received).
FlowWindow uint32
// Handshake Type: 32 bits. This field indicates the handshake packet
// type.
// 0xFFFFFFFD | DONE
// 0xFFFFFFFE | AGREEMENT
// 0xFFFFFFFF | CONCLUSION
// 0x00000000 | WAVEHAND
// 0x00000001 | INDUCTION
HandshakeType uint32
// SRT Socket ID: 32 bits. This field holds the ID of the source SRT
// socket from which a handshake packet is issued.
SRTSocketID uint32
// SYN Cookie: 32 bits. Randomized value for processing a handshake.
// The value of this field is specified by the handshake message
// type.
SynCookie uint32
// Peer IP Address: 128 bits. IPv4 or IPv6 address of the packet's
// sender. The value consists of four 32-bit fields.
PeerIP net.IP
// Extensions.
// Extension Type: 16 bits. The value of this field is used to process
// an integrated handshake. Each extension can have a pair of
// request and response types.
// Extension Length: 16 bits. The length of the Extension Contents
// field in four-byte blocks.
// Extension Contents: variable length. The payload of the extension.
ExtraData []byte
}
func (v *SRTHandshakePacket) IsData() bool {
return v.ControlFlag == 0x00
}
func (v *SRTHandshakePacket) IsControl() bool {
return v.ControlFlag == 0x80
}
func (v *SRTHandshakePacket) IsHandshake() bool {
return v.IsControl() && v.ControlType == 0x00 && v.SubType == 0x00
}
func (v *SRTHandshakePacket) StreamID() (string, error) {
p := v.ExtraData
for {
if len(p) < 2 {
return "", errors.Errorf("Require 2 bytes, actual=%v, extra=%v", len(p), len(v.ExtraData))
}
extType := binary.BigEndian.Uint16(p)
extSize := binary.BigEndian.Uint16(p[2:])
p = p[4:]
if len(p) < int(extSize*4) {
return "", errors.Errorf("Require %v bytes, actual=%v, extra=%v", extSize*4, len(p), len(v.ExtraData))
}
// Ignore other packets except stream id.
if extType != 0x05 {
p = p[extSize*4:]
continue
}
// We must copy it, because we will decode the stream id.
data := append([]byte{}, p[:extSize*4]...)
// Reverse the stream id encoded in little-endian to big-endian.
for i := 0; i < len(data); i += 4 {
value := binary.LittleEndian.Uint32(data[i:])
binary.BigEndian.PutUint32(data[i:], value)
}
// Trim the trailing zero bytes.
data = bytes.TrimRight(data, "\x00")
return string(data), nil
}
}
func (v *SRTHandshakePacket) String() string {
return fmt.Sprintf("Control=%v, CType=%v, SType=%v, Timestamp=%v, SocketID=%v, Version=%v, Encrypt=%v, Extension=%v, InitSequence=%v, MTU=%v, FlowWnd=%v, HSType=%v, SRTSocketID=%v, Cookie=%v, Peer=%vB, Extra=%vB",
v.IsControl(), v.ControlType, v.SubType, v.Timestamp, v.SocketID, v.Version, v.EncryptionField, v.ExtensionField, v.InitSequence, v.MTU, v.FlowWindow, v.HandshakeType, v.SRTSocketID, v.SynCookie, len(v.PeerIP), len(v.ExtraData))
}
func (v *SRTHandshakePacket) UnmarshalBinary(b []byte) error {
if len(b) < 4 {
return errors.Errorf("Invalid packet length %v", len(b))
}
v.ControlFlag = b[0] & 0x80
v.ControlType = binary.BigEndian.Uint16(b[0:2]) & 0x7fff
v.SubType = binary.BigEndian.Uint16(b[2:4])
if len(b) < 64 {
return errors.Errorf("Invalid packet length %v", len(b))
}
v.AdditionalInfo = binary.BigEndian.Uint32(b[4:])
v.Timestamp = binary.BigEndian.Uint32(b[8:])
v.SocketID = binary.BigEndian.Uint32(b[12:])
v.Version = binary.BigEndian.Uint32(b[16:])
v.EncryptionField = binary.BigEndian.Uint16(b[20:])
v.ExtensionField = binary.BigEndian.Uint16(b[22:])
v.InitSequence = binary.BigEndian.Uint32(b[24:])
v.MTU = binary.BigEndian.Uint32(b[28:])
v.FlowWindow = binary.BigEndian.Uint32(b[32:])
v.HandshakeType = binary.BigEndian.Uint32(b[36:])
v.SRTSocketID = binary.BigEndian.Uint32(b[40:])
v.SynCookie = binary.BigEndian.Uint32(b[44:])
// Only support IPv4.
v.PeerIP = net.IPv4(b[51], b[50], b[49], b[48])
v.ExtraData = b[64:]
return nil
}
func (v *SRTHandshakePacket) MarshalBinary() ([]byte, error) {
b := make([]byte, 64+len(v.ExtraData))
binary.BigEndian.PutUint16(b, uint16(v.ControlFlag)<<8|v.ControlType)
binary.BigEndian.PutUint16(b[2:], v.SubType)
binary.BigEndian.PutUint32(b[4:], v.AdditionalInfo)
binary.BigEndian.PutUint32(b[8:], v.Timestamp)
binary.BigEndian.PutUint32(b[12:], v.SocketID)
binary.BigEndian.PutUint32(b[16:], v.Version)
binary.BigEndian.PutUint16(b[20:], v.EncryptionField)
binary.BigEndian.PutUint16(b[22:], v.ExtensionField)
binary.BigEndian.PutUint32(b[24:], v.InitSequence)
binary.BigEndian.PutUint32(b[28:], v.MTU)
binary.BigEndian.PutUint32(b[32:], v.FlowWindow)
binary.BigEndian.PutUint32(b[36:], v.HandshakeType)
binary.BigEndian.PutUint32(b[40:], v.SRTSocketID)
binary.BigEndian.PutUint32(b[44:], v.SynCookie)
// Only support IPv4.
ip := v.PeerIP.To4()
b[48] = ip[3]
b[49] = ip[2]
b[50] = ip[1]
b[51] = ip[0]
if len(v.ExtraData) > 0 {
copy(b[64:], v.ExtraData)
}
return b, nil
}

View File

@@ -1,45 +0,0 @@
// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package sync
import "sync"
type Map[K comparable, V any] struct {
m sync.Map
}
func (m *Map[K, V]) Delete(key K) {
m.m.Delete(key)
}
func (m *Map[K, V]) Load(key K) (value V, ok bool) {
v, ok := m.m.Load(key)
if !ok {
return value, ok
}
return v.(V), ok
}
func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
v, loaded := m.m.LoadAndDelete(key)
if !loaded {
return value, loaded
}
return v.(V), loaded
}
func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
a, loaded := m.m.LoadOrStore(key, value)
return a.(V), loaded
}
func (m *Map[K, V]) Range(f func(key K, value V) bool) {
m.m.Range(func(key, value any) bool {
return f(key.(K), value.(V))
})
}
func (m *Map[K, V]) Store(key K, value V) {
m.m.Store(key, value)
}

View File

@@ -1,276 +0,0 @@
// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package main
import (
"context"
"encoding/binary"
"encoding/json"
stdErr "errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"path"
"reflect"
"regexp"
"strconv"
"strings"
"syscall"
"time"
"srs-proxy/errors"
"srs-proxy/logger"
)
func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, data any) {
w.Header().Set("Server", fmt.Sprintf("%v/%v", Signature(), Version()))
b, err := json.Marshal(data)
if err != nil {
apiError(ctx, w, r, errors.Wrapf(err, "marshal %v %v", reflect.TypeOf(data), data))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(b)
}
func apiError(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) {
logger.Wf(ctx, "HTTP API error %+v", err)
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintln(w, fmt.Sprintf("%v", err))
}
func apiCORS(ctx context.Context, w http.ResponseWriter, r *http.Request) bool {
// Always support CORS. Note that browser may send origin header for m3u8, but no origin header
// for ts. So we always response CORS header.
if true {
// SRS does not need cookie or credentials, so we disable CORS credentials, and use * for CORS origin,
// headers, expose headers and methods.
w.Header().Set("Access-Control-Allow-Origin", "*")
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
w.Header().Set("Access-Control-Allow-Headers", "*")
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
w.Header().Set("Access-Control-Allow-Methods", "*")
}
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return true
}
return false
}
func parseGracefullyQuitTimeout() (time.Duration, error) {
if t, err := time.ParseDuration(envGraceQuitTimeout()); err != nil {
return 0, errors.Wrapf(err, "parse duration %v", envGraceQuitTimeout())
} else {
return t, nil
}
}
// ParseBody read the body from r, and unmarshal JSON to v.
func ParseBody(r io.ReadCloser, v interface{}) error {
b, err := ioutil.ReadAll(r)
if err != nil {
return errors.Wrapf(err, "read body")
}
defer r.Close()
if len(b) == 0 {
return nil
}
if err := json.Unmarshal(b, v); err != nil {
return errors.Wrapf(err, "json unmarshal %v", string(b))
}
return nil
}
// buildStreamURL build as vhost/app/stream for stream URL r.
func buildStreamURL(r string) (string, error) {
u, err := url.Parse(r)
if err != nil {
return "", errors.Wrapf(err, "parse url %v", r)
}
// If not domain or ip in hostname, it's __defaultVhost__.
defaultVhost := !strings.Contains(u.Hostname(), ".")
// If hostname is actually an IP address, it's __defaultVhost__.
if ip := net.ParseIP(u.Hostname()); ip.To4() != nil {
defaultVhost = true
}
if defaultVhost {
return fmt.Sprintf("__defaultVhost__%v", u.Path), nil
}
// Ignore port, only use hostname as vhost.
return fmt.Sprintf("%v%v", u.Hostname(), u.Path), nil
}
// isPeerClosedError indicates whether peer object closed the connection.
func isPeerClosedError(err error) bool {
causeErr := errors.Cause(err)
if stdErr.Is(causeErr, io.EOF) {
return true
}
if stdErr.Is(causeErr, syscall.EPIPE) {
return true
}
if netErr, ok := causeErr.(*net.OpError); ok {
if sysErr, ok := netErr.Err.(*os.SyscallError); ok {
if stdErr.Is(sysErr.Err, syscall.ECONNRESET) {
return true
}
}
}
return false
}
// convertURLToStreamURL convert the URL in HTTP request to special URLs. The unifiedURL is the URL
// in unified, foramt as scheme://vhost/app/stream without extensions. While the fullURL is the unifiedURL
// with extension.
func convertURLToStreamURL(r *http.Request) (unifiedURL, fullURL string) {
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
hostname := "__defaultVhost__"
if strings.Contains(r.Host, ":") {
if v, _, err := net.SplitHostPort(r.Host); err == nil {
hostname = v
}
}
var appStream, streamExt string
// Parse app/stream from query string.
q := r.URL.Query()
if app := q.Get("app"); app != "" {
appStream = "/" + app
}
if stream := q.Get("stream"); stream != "" {
appStream = fmt.Sprintf("%v/%v", appStream, stream)
}
// Parse app/stream from path.
if appStream == "" {
streamExt = path.Ext(r.URL.Path)
appStream = strings.TrimSuffix(r.URL.Path, streamExt)
}
unifiedURL = fmt.Sprintf("%v://%v%v", scheme, hostname, appStream)
fullURL = fmt.Sprintf("%v%v", unifiedURL, streamExt)
return
}
// rtcIsSTUN returns true if data of UDP payload is a STUN packet.
func rtcIsSTUN(data []byte) bool {
return len(data) > 0 && (data[0] == 0 || data[0] == 1)
}
// rtcIsRTPOrRTCP returns true if data of UDP payload is a RTP or RTCP packet.
func rtcIsRTPOrRTCP(data []byte) bool {
return len(data) >= 12 && (data[0]&0xC0) == 0x80
}
// srtIsHandshake returns true if data of UDP payload is a SRT handshake packet.
func srtIsHandshake(data []byte) bool {
return len(data) >= 4 && binary.BigEndian.Uint32(data) == 0x80000000
}
// srtParseSocketID parse the socket id from the SRT packet.
func srtParseSocketID(data []byte) uint32 {
if len(data) >= 16 {
return binary.BigEndian.Uint32(data[12:])
}
return 0
}
// parseIceUfragPwd parse the ice-ufrag and ice-pwd from the SDP.
func parseIceUfragPwd(sdp string) (ufrag, pwd string, err error) {
if true {
ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`)
ufragMatch := ufragRe.FindStringSubmatch(sdp)
if len(ufragMatch) <= 1 {
return "", "", errors.Errorf("no ice-ufrag in sdp %v", sdp)
}
ufrag = ufragMatch[1]
}
if true {
pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`)
pwdMatch := pwdRe.FindStringSubmatch(sdp)
if len(pwdMatch) <= 1 {
return "", "", errors.Errorf("no ice-pwd in sdp %v", sdp)
}
pwd = pwdMatch[1]
}
return ufrag, pwd, nil
}
// parseSRTStreamID parse the SRT stream id to host(optional) and resource(required).
// See https://ossrs.io/lts/en-us/docs/v7/doc/srt#srt-url
func parseSRTStreamID(sid string) (host, resource string, err error) {
if true {
hostRe := regexp.MustCompile(`h=([^,]+)`)
hostMatch := hostRe.FindStringSubmatch(sid)
if len(hostMatch) > 1 {
host = hostMatch[1]
}
}
if true {
resourceRe := regexp.MustCompile(`r=([^,]+)`)
resourceMatch := resourceRe.FindStringSubmatch(sid)
if len(resourceMatch) <= 1 {
return "", "", errors.Errorf("no resource in sid %v", sid)
}
resource = resourceMatch[1]
}
return host, resource, nil
}
// parseListenEndpoint parse the listen endpoint as:
// port The tcp listen port, like 1935.
// protocol://ip:port The listen endpoint, like tcp://:1935 or tcp://0.0.0.0:1935
func parseListenEndpoint(ep string) (protocol string, ip net.IP, port uint16, err error) {
// If no colon in ep, it's port in string.
if !strings.Contains(ep, ":") {
if p, err := strconv.Atoi(ep); err != nil {
return "", nil, 0, errors.Wrapf(err, "parse port %v", ep)
} else {
return "tcp", nil, uint16(p), nil
}
}
// Must be protocol://ip:port schema.
parts := strings.Split(ep, ":")
if len(parts) != 3 {
return "", nil, 0, errors.Errorf("invalid endpoint %v", ep)
}
if p, err := strconv.Atoi(parts[2]); err != nil {
return "", nil, 0, errors.Wrapf(err, "parse port %v", parts[2])
} else {
return parts[0], net.ParseIP(parts[1]), uint16(p), nil
}
}

View File

@@ -1,27 +0,0 @@
// Copyright (c) 2025 Winlin
//
// SPDX-License-Identifier: MIT
package main
import "fmt"
func VersionMajor() int {
return 1
}
// VersionMinor specifies the typical version of SRS we adapt to.
func VersionMinor() int {
return 5
}
func VersionRevision() int {
return 0
}
func Version() string {
return fmt.Sprintf("%v.%v.%v", VersionMajor(), VersionMinor(), VersionRevision())
}
func Signature() string {
return "SRSProxy"
}