Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func main() {

handler := simple.New(logger)

var interceptor session.Interceptor
var beforeHandler, afterHandler session.Interceptor

// Loading .env file to environment
err := godotenv.Load()
Expand All @@ -62,7 +62,7 @@ func main() {
}

// mGate server for MQTT without TLS
mqttProxy := mqtt.New(mqttConfig, handler, interceptor, logger)
mqttProxy := mqtt.New(mqttConfig, handler, beforeHandler, afterHandler, logger)
g.Go(func() error {
return mqttProxy.Listen(ctx)
})
Expand All @@ -74,7 +74,7 @@ func main() {
}

// mGate server for MQTT with TLS
mqttTLSProxy := mqtt.New(mqttTLSConfig, handler, interceptor, logger)
mqttTLSProxy := mqtt.New(mqttTLSConfig, handler, beforeHandler, afterHandler, logger)
g.Go(func() error {
return mqttTLSProxy.Listen(ctx)
})
Expand All @@ -86,7 +86,7 @@ func main() {
}

// mGate server for MQTT with mTLS
mqttMTlsProxy := mqtt.New(mqttMTLSConfig, handler, interceptor, logger)
mqttMTlsProxy := mqtt.New(mqttMTLSConfig, handler, beforeHandler, afterHandler, logger)
g.Go(func() error {
return mqttMTlsProxy.Listen(ctx)
})
Expand All @@ -98,7 +98,7 @@ func main() {
}

// mGate server for MQTT over Websocket without TLS
wsProxy := websocket.New(wsConfig, handler, interceptor, logger)
wsProxy := websocket.New(wsConfig, handler, beforeHandler, afterHandler, logger)
g.Go(func() error {
return wsProxy.Listen(ctx)
})
Expand All @@ -110,7 +110,7 @@ func main() {
}

// mGate server for MQTT over Websocket with TLS
wsTLSProxy := websocket.New(wsTLSConfig, handler, interceptor, logger)
wsTLSProxy := websocket.New(wsTLSConfig, handler, beforeHandler, afterHandler, logger)
g.Go(func() error {
return wsTLSProxy.Listen(ctx)
})
Expand All @@ -122,7 +122,7 @@ func main() {
}

// mGate server for MQTT over Websocket with mTLS
wsMTLSProxy := websocket.New(wsMTLSConfig, handler, interceptor, logger)
wsMTLSProxy := websocket.New(wsMTLSConfig, handler, beforeHandler, afterHandler, logger)
g.Go(func() error {
return wsMTLSProxy.Listen(ctx)
})
Expand Down
3 changes: 1 addition & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module github.com/absmach/mgate

go 1.22.7
toolchain go1.24.1
go 1.23.0

require (
github.com/caarlos0/env/v11 v11.3.1
Expand Down
24 changes: 13 additions & 11 deletions pkg/mqtt/mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,22 @@ import (

// Proxy is main MQTT proxy struct.
type Proxy struct {
config mgate.Config
handler session.Handler
interceptor session.Interceptor
logger *slog.Logger
dialer net.Dialer
config mgate.Config
handler session.Handler
beforeHandler session.Interceptor
afterHandler session.Interceptor
logger *slog.Logger
dialer net.Dialer
}

// New returns a new MQTT Proxy instance.
func New(config mgate.Config, handler session.Handler, interceptor session.Interceptor, logger *slog.Logger) *Proxy {
func New(config mgate.Config, handler session.Handler, beforeHandler, afterHandler session.Interceptor, logger *slog.Logger) *Proxy {
return &Proxy{
config: config,
handler: handler,
logger: logger,
interceptor: interceptor,
config: config,
handler: handler,
logger: logger,
beforeHandler: beforeHandler,
afterHandler: afterHandler,
}
}

Expand Down Expand Up @@ -68,7 +70,7 @@ func (p Proxy) handle(ctx context.Context, inbound net.Conn) {
return
}

if err = session.Stream(ctx, inbound, outbound, p.handler, p.interceptor, clientCert); err != io.EOF {
if err = session.Stream(ctx, inbound, outbound, p.handler, p.beforeHandler, p.afterHandler, clientCert); err != io.EOF {
p.logger.Warn(err.Error())
}
}
Expand Down
22 changes: 12 additions & 10 deletions pkg/mqtt/websocket/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,21 @@ import (

// Proxy represents WS Proxy.
type Proxy struct {
config mgate.Config
handler session.Handler
interceptor session.Interceptor
logger *slog.Logger
config mgate.Config
handler session.Handler
beforeHandler session.Interceptor
afterHandler session.Interceptor
logger *slog.Logger
}

// New - creates new WS proxy.
func New(config mgate.Config, handler session.Handler, interceptor session.Interceptor, logger *slog.Logger) *Proxy {
func New(config mgate.Config, handler session.Handler, beforeHandler, afterHandler session.Interceptor, logger *slog.Logger) *Proxy {
return &Proxy{
config: config,
handler: handler,
interceptor: interceptor,
logger: logger,
config: config,
handler: handler,
beforeHandler: beforeHandler,
afterHandler: afterHandler,
logger: logger,
}
}

Expand Down Expand Up @@ -92,7 +94,7 @@ func (p Proxy) pass(in *websocket.Conn) {
return
}

err = session.Stream(ctx, inboundConn, outboundConn, p.handler, p.interceptor, clientCert)
err = session.Stream(ctx, inboundConn, outboundConn, p.handler, p.beforeHandler, p.afterHandler, clientCert)
errc <- err
p.logger.Warn("Broken connection for client", slog.Any("error", err))
}
Expand Down
20 changes: 14 additions & 6 deletions pkg/session/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ var (
)

// Stream starts proxy between client and broker.
func Stream(ctx context.Context, in, out net.Conn, h Handler, ic Interceptor, cert x509.Certificate) error {
func Stream(ctx context.Context, in, out net.Conn, h Handler, preIc, postIc Interceptor, cert x509.Certificate) error {
s := Session{
Cert: cert,
}
ctx = NewContext(ctx, &s)
errs := make(chan error, 2)

go stream(ctx, Up, in, out, h, ic, errs)
go stream(ctx, Down, out, in, h, ic, errs)
go stream(ctx, Up, in, out, h, preIc, postIc, errs)
go stream(ctx, Down, out, in, h, preIc, postIc, errs)

// Handle whichever error happens first.
// The other routine won't be blocked when writing
Expand All @@ -49,7 +49,7 @@ func Stream(ctx context.Context, in, out net.Conn, h Handler, ic Interceptor, ce
return errors.Join(err, disconnectErr)
}

func stream(ctx context.Context, dir Direction, r, w net.Conn, h Handler, ic Interceptor, errs chan error) {
func stream(ctx context.Context, dir Direction, r, w net.Conn, h Handler, preIc, postIc Interceptor, errs chan error) {
for {
// Read from one connection.
pkt, err := packets.ReadPacket(r)
Expand All @@ -58,6 +58,14 @@ func stream(ctx context.Context, dir Direction, r, w net.Conn, h Handler, ic Int
return
}

if preIc != nil {
pkt, err = preIc.Intercept(ctx, pkt, dir)
if err != nil {
errs <- wrap(ctx, err, dir)
return
}
}

switch dir {
case Up:
if err = authorize(ctx, pkt, h); err != nil {
Expand All @@ -81,8 +89,8 @@ func stream(ctx context.Context, dir Direction, r, w net.Conn, h Handler, ic Int
}
}

if ic != nil {
pkt, err = ic.Intercept(ctx, pkt, dir)
if postIc != nil {
pkt, err = postIc.Intercept(ctx, pkt, dir)
if err != nil {
errs <- wrap(ctx, err, dir)
return
Expand Down