Skip to content

Commit 5e8cb3d

Browse files
committed
fix: improve relay robustness — errgroup, ping handling, PAKE for local relay
- Replace sync.WaitGroup + panic with errgroup.Group (crash prevention) - Add receiveSkippingPing helper for relay keepalive pings - PAKE key exchange for encrypted IP discovery in local relay - Fix race condition: send error to errchan on connection failure - Gracefully refuse local relay on unexpected data - Add golang.org/x/sync v0.10.0 dependency
1 parent d225c85 commit 5e8cb3d

3 files changed

Lines changed: 133 additions & 24 deletions

File tree

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ require (
1919
github.com/stretchr/testify v1.11.1
2020
golang.org/x/crypto v0.51.0
2121
golang.org/x/net v0.54.0
22+
golang.org/x/sync v0.10.0
2223
golang.org/x/sys v0.44.0
2324
golang.org/x/term v0.43.0
2425
golang.org/x/time v0.15.0

go.sum

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
9292
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
9393
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
9494
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
95+
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
9596
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
9697
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
9798
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

src/croc/croc.go

Lines changed: 131 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import (
2222
"sync"
2323
"time"
2424

25+
"golang.org/x/sync/errgroup"
26+
2527
"github.com/denisbrodbeck/machineid"
2628
ignore "github.com/sabhiram/go-gitignore"
2729
log "github.com/schollz/logger"
@@ -626,6 +628,19 @@ func (c *Client) broadcastOnLocalNetwork(useipv6 bool) {
626628
}
627629
}
628630

631+
// transferOverLocalRelay connects the sender to its own local relay and waits
632+
// for the receiver to join. The receiver may initiate a PAKE key exchange
633+
// (pake1/pake2) followed by an encrypted ipRequest to discover the sender's
634+
// local IPs. Although the relay is local, the PAKE exchange is still valuable:
635+
//
636+
// - The receiver may not have discovered the sender via multicast (e.g. due
637+
// to firewalls, different subnets, or IPv4/IPv6 mismatches), so it falls
638+
// back to requesting IPs through the relay pipe.
639+
// - The shared secret used in PAKE ensures that only the legitimate receiver
640+
// can decrypt the IP list, preventing leakage to unauthorized peers that
641+
// might guess the room name on the local relay.
642+
// - Keeping the same protocol flow as the public relay goroutine (see Send)
643+
// ensures consistent behaviour regardless of which relay path is used.
629644
func (c *Client) transferOverLocalRelay(errchan chan<- error) {
630645
time.Sleep(500 * time.Millisecond)
631646
log.Debug("establishing connection")
@@ -636,21 +651,92 @@ func (c *Client) transferOverLocalRelay(errchan chan<- error) {
636651
err = fmt.Errorf("could not connect to 127.0.0.1:%s: %w", c.Options.RelayPorts[0], err)
637652
log.Debug(err)
638653
// not really an error because it will try to connect over the actual relay
654+
// but if not send to errchan second read from it stuck
655+
errchan <- err
639656
return
640657
}
641658
log.Debugf("local connection established: %+v", conn)
659+
var kB []byte
660+
B, _ := pake.InitCurve([]byte(c.Options.SharedSecret[5:]), 1, c.Options.Curve)
642661
for {
643662
if err := c.ctxErr(); err != nil {
644663
errchan <- err
645664
return
646665
}
647-
data, _ := conn.Receive()
648-
if bytes.Equal(data, handshakeRequest) {
666+
var dataMessage SimpleMessage
667+
data, connErr := conn.Receive()
668+
if connErr != nil {
669+
log.Tracef("[%+v] had error: %s", conn, connErr.Error())
670+
}
671+
json.Unmarshal(data, &dataMessage)
672+
// if kB not null, then use it to decrypt
673+
if kB != nil {
674+
var decryptErr error
675+
var dataDecrypt []byte
676+
dataDecrypt, decryptErr = crypt.Decrypt(data, kB)
677+
if decryptErr != nil {
678+
log.Tracef("error decrypting: %v: '%s'", decryptErr, data)
679+
if strings.Contains(decryptErr.Error(), "message authentication failed") {
680+
errchan <- decryptErr
681+
return
682+
}
683+
} else {
684+
data = dataDecrypt
685+
log.Tracef("decrypted: %s", data)
686+
}
687+
}
688+
if bytes.Equal(data, ipRequest) {
689+
log.Tracef("got ipRequest")
690+
// recipient wants to try to connect to local ips
691+
var ips []string
692+
if !c.Options.DisableLocal {
693+
ips, err = utils.GetLocalIPs()
694+
if err != nil {
695+
log.Tracef("error getting local ips: %v", err)
696+
}
697+
ips = append([]string{c.Options.RelayPorts[0]}, ips...)
698+
}
699+
log.Tracef("sending ips: %+v", ips)
700+
bips, errIps := json.Marshal(ips)
701+
if errIps != nil {
702+
log.Tracef("error marshalling ips: %v", errIps)
703+
}
704+
bips, errIps = crypt.Encrypt(bips, kB)
705+
if errIps != nil {
706+
log.Tracef("error encrypting ips: %v", errIps)
707+
}
708+
if err = conn.Send(bips); err != nil {
709+
log.Errorf("error sending: %v", err)
710+
}
711+
} else if dataMessage.Kind == "pake1" {
712+
log.Trace("got pake1")
713+
var pakeError error
714+
pakeError = B.Update(dataMessage.Bytes)
715+
if pakeError == nil {
716+
kB, pakeError = B.SessionKey()
717+
if pakeError == nil {
718+
log.Tracef("dataMessage kB: %x", kB)
719+
dataMessage.Bytes = B.Bytes()
720+
dataMessage.Kind = "pake2"
721+
data, _ = json.Marshal(dataMessage)
722+
if pakeError = conn.Send(data); pakeError != nil {
723+
log.Errorf("dataMessage error sending: %v", pakeError)
724+
}
725+
}
726+
}
727+
} else if bytes.Equal(data, handshakeRequest) {
728+
log.Trace("got handshake")
649729
break
650730
} else if bytes.Equal(data, []byte{1}) {
651731
log.Trace("got ping")
652732
} else {
653-
log.Debugf("instead of handshake got: %s", data)
733+
log.Tracef("[%+v] got weird bytes: %+v", conn, data)
734+
// throttle the reading
735+
if connErr == nil {
736+
connErr = fmt.Errorf("gracefully refusing using the local relay")
737+
}
738+
errchan <- connErr
739+
return
654740
}
655741
}
656742
c.conn[0] = conn
@@ -1044,7 +1130,7 @@ func (c *Client) Receive() (err error) {
10441130
log.Errorf("dataMessage send error: %v", err)
10451131
return
10461132
}
1047-
data, err = c.conn[0].Receive()
1133+
data, err = c.receiveSkippingPing(0)
10481134
if err != nil {
10491135
return
10501136
}
@@ -1073,7 +1159,7 @@ func (c *Client) Receive() (err error) {
10731159
if err = c.conn[0].Send(data); err != nil {
10741160
log.Errorf("ips send error: %v", err)
10751161
}
1076-
data, err = c.conn[0].Receive()
1162+
data, err = c.receiveSkippingPing(0)
10771163
if err != nil {
10781164
return
10791165
}
@@ -1159,6 +1245,24 @@ func (c *Client) Receive() (err error) {
11591245
return
11601246
}
11611247

1248+
// receiveSkippingPing receives data from the specified connection,
1249+
// silently discarding relay keepalive ping messages ([]byte{1}).
1250+
// The relay sends these pings periodically while waiting for the second peer
1251+
// to connect, so all Receive calls must skip them to avoid protocol errors.
1252+
func (c *Client) receiveSkippingPing(connIndex int) (data []byte, err error) {
1253+
for {
1254+
data, err = c.conn[connIndex].Receive()
1255+
if err != nil {
1256+
return
1257+
}
1258+
if bytes.Equal(data, []byte{1}) {
1259+
log.Trace("got ping")
1260+
continue
1261+
}
1262+
return
1263+
}
1264+
}
1265+
11621266
func (c *Client) transfer() (err error) {
11631267
// connect to the server
11641268

@@ -1187,7 +1291,7 @@ func (c *Client) transfer() (err error) {
11871291
}
11881292
var data []byte
11891293
var done bool
1190-
data, err = c.conn[0].Receive()
1294+
data, err = c.receiveSkippingPing(0)
11911295
if err != nil {
11921296
log.Debugf("got error receiving: %v", err)
11931297
if !c.Step1ChannelSecured {
@@ -1499,39 +1603,45 @@ func (c *Client) processMessagePake(m message.Message) (err error) {
14991603
log.Debugf("generated key = %+x with salt %x", c.Key, salt)
15001604

15011605
// connects to the other ports of the server for transfer
1502-
var wg sync.WaitGroup
1503-
wg.Add(len(c.Options.RelayPorts))
1606+
var g errgroup.Group
15041607
for i := 0; i < len(c.Options.RelayPorts); i++ {
15051608
log.Debugf("port: [%s]", c.Options.RelayPorts[i])
1506-
go func(j int) {
1507-
defer wg.Done()
1609+
j := i
1610+
g.Go(func() error {
15081611
var host string
15091612
if c.Options.RelayAddress == "127.0.0.1" {
15101613
host = c.Options.RelayAddress
15111614
} else {
1512-
host, _, err = net.SplitHostPort(c.Options.RelayAddress)
1513-
if err != nil {
1514-
log.Errorf("bad relay address %s", c.Options.RelayAddress)
1515-
return
1615+
var splitErr error
1616+
host, _, splitErr = net.SplitHostPort(c.Options.RelayAddress)
1617+
if splitErr != nil {
1618+
return fmt.Errorf("bad relay address %s: %w", c.Options.RelayAddress, splitErr)
15161619
}
15171620
}
15181621
server := net.JoinHostPort(host, c.Options.RelayPorts[j])
15191622
log.Debugf("connecting to %s", server)
1520-
c.conn[j+1], _, _, err = tcp.ConnectToTCPServer(
1623+
var connErr error
1624+
c.conn[j+1], _, _, connErr = tcp.ConnectToTCPServer(
15211625
server,
15221626
c.Options.RelayPassword,
15231627
fmt.Sprintf("%s-%d", c.Options.RoomName, j),
15241628
)
1525-
if err != nil {
1526-
panic(err)
1629+
if connErr != nil {
1630+
return fmt.Errorf("connect to port %s: %w", c.Options.RelayPorts[j], connErr)
15271631
}
15281632
log.Debugf("connected to %s", server)
15291633
if !c.Options.IsSender {
15301634
go c.receiveData(j)
15311635
}
1532-
}(i)
1636+
return nil
1637+
})
1638+
}
1639+
if err = g.Wait(); err != nil {
1640+
if c.stop.gui {
1641+
c.stop.Cancel()
1642+
}
1643+
return err
15331644
}
1534-
wg.Wait()
15351645
if !c.Options.IsSender {
15361646
log.Debug("sending external IP")
15371647
err = message.Send(c.conn[0], c.Key, message.Message{
@@ -1662,6 +1772,7 @@ func (c *Client) processMessage(payload []byte) (done bool, err error) {
16621772

16631773
func (c *Client) updateIfSenderChannelSecured() (err error) {
16641774
if c.Options.IsSender && c.Step1ChannelSecured && !c.Step2FileInfoTransferred {
1775+
16651776
var b []byte
16661777
machID, _ := machineid.ID()
16671778
b, err = json.Marshal(SenderInfo{
@@ -2094,14 +2205,10 @@ func (c *Client) receiveData(i int) {
20942205
}()
20952206
log.Tracef("%d receiving data", i)
20962207
for {
2097-
data, err := c.conn[i+1].Receive()
2208+
data, err := c.receiveSkippingPing(i + 1)
20982209
if err != nil {
20992210
break
21002211
}
2101-
if bytes.Equal(data, []byte{1}) {
2102-
log.Trace("got ping")
2103-
continue
2104-
}
21052212

21062213
data, err = crypt.Decrypt(data, c.Key)
21072214
if err != nil {

0 commit comments

Comments
 (0)