Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ lint:
internal/lint/golangci-lint run ./... --fix

check: lint
go test -cover ./...
go test -race -cover ./...
go mod tidy

.PHONY: example
2 changes: 1 addition & 1 deletion graphql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, opts ...WebSocket
header: http.Header{},
errChan: make(chan error),
endpoint: endpoint,
subscriptions: subscriptionMap{map_: make(map[string]subscription)},
subscriptions: subscriptionMap{map_: make(map[string]*subscription)},
}

for _, opt := range opts {
Expand Down
73 changes: 57 additions & 16 deletions graphql/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,60 @@ package graphql

import (
"fmt"
"reflect"
"sync"
)

// map of subscription ID to subscription
type subscriptionMap struct {
map_ map[string]subscription
map_ map[string]*subscription
sync.RWMutex
}

type subscription struct {
interfaceChan interface{}
forwardDataFunc ForwardDataFunction
id string
hasBeenUnsubscribed bool
// interfaceChan is passed in by the user when creating a subscription but
// closed by webSocketClient when the subscription is unsubscribed, i.e.
// ownership of interfaceChan is passed from the user to the client.
//
// The subscription is unsubscribed either explicitly by the user or when
// a message of webSocketTypeComplete is received. On unsubscribe,
// the _hasBeenUnsubscribed flag is set to true. listenWebSocket then
// closes interfaceChan on the next receive loop.
//
// The listenWebSocket client method handles both sending on the channel
// and closing of the channel, so is no possibility of races between send
// and close.
interfaceChan interface{}

forwardDataFunc ForwardDataFunction
id string

// Hold when accessing _hasBeenUnsubscribed
hasBeenUnsubscribedMu sync.Mutex
_hasBeenUnsubscribed bool
}

func (s *subscription) unsubscribe() {
s.hasBeenUnsubscribedMu.Lock()
defer s.hasBeenUnsubscribedMu.Unlock()

s._hasBeenUnsubscribed = true
}

func (s *subscription) hasBeenUnsubscribed() bool {
s.hasBeenUnsubscribedMu.Lock()
defer s.hasBeenUnsubscribedMu.Unlock()

return s._hasBeenUnsubscribed
}

func (s *subscriptionMap) Create(subscriptionID string, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) {
s.Lock()
defer s.Unlock()
s.map_[subscriptionID] = subscription{
id: subscriptionID,
interfaceChan: interfaceChan,
forwardDataFunc: forwardDataFunc,
hasBeenUnsubscribed: false,
s.map_[subscriptionID] = &subscription{
id: subscriptionID,
interfaceChan: interfaceChan,
forwardDataFunc: forwardDataFunc,
_hasBeenUnsubscribed: false,
}
}

Expand All @@ -37,16 +66,28 @@ func (s *subscriptionMap) Unsubscribe(subscriptionID string) error {
if !success {
return fmt.Errorf("tried to unsubscribe from unknown subscription with ID '%s'", subscriptionID)
}
hasBeenUnsubscribed := unsub.hasBeenUnsubscribed
unsub.hasBeenUnsubscribed = true
unsub.unsubscribe()
s.map_[subscriptionID] = unsub

if !hasBeenUnsubscribed {
reflect.ValueOf(s.map_[subscriptionID].interfaceChan).Close()
}
return nil
}

func (s *subscriptionMap) forEachSubscription(fn func(sub *subscription)) {
s.Lock()
defer s.Unlock()

for id := range s.map_ {
fn(s.map_[id])
}
}

func (s *subscriptionMap) GetSubscription(subscriptionID string) (*subscription, bool) {
s.Lock()
defer s.Unlock()
sub, ok := s.map_[subscriptionID]
return sub, ok
}

func (s *subscriptionMap) GetAllIDs() (subscriptionIDs []string) {
s.RLock()
defer s.RUnlock()
Expand Down
22 changes: 11 additions & 11 deletions graphql/subscription_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ func Test_subscriptionMap_Unsubscribe(t *testing.T) {
{
name: "unsubscribe existing subscription",
sm: subscriptionMap{
map_: map[string]subscription{
map_: map[string]*subscription{
"sub1": {
id: "sub1",
interfaceChan: make(chan struct{}),
forwardDataFunc: nil,
hasBeenUnsubscribed: false,
id: "sub1",
interfaceChan: make(chan struct{}),
forwardDataFunc: nil,
_hasBeenUnsubscribed: false,
},
},
},
Expand All @@ -32,20 +32,20 @@ func Test_subscriptionMap_Unsubscribe(t *testing.T) {
{
name: "unsubscribe non-existent subscription",
sm: subscriptionMap{
map_: map[string]subscription{},
map_: map[string]*subscription{},
},
args: args{subscriptionID: "doesnotexist"},
wantErr: true,
},
{
name: "unsubscribe already unsubscribed subscription",
sm: subscriptionMap{
map_: map[string]subscription{
map_: map[string]*subscription{
"sub2": {
id: "sub2",
interfaceChan: nil,
forwardDataFunc: nil,
hasBeenUnsubscribed: true,
id: "sub2",
interfaceChan: nil,
forwardDataFunc: nil,
_hasBeenUnsubscribed: true,
},
},
},
Expand Down
78 changes: 44 additions & 34 deletions graphql/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,19 @@ const (
)

type webSocketClient struct {
Dialer Dialer
header http.Header
endpoint string
conn WSConn
connParams map[string]interface{}
Dialer Dialer
conn WSConn
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Struct fields were reordered due to the following lint error:

graphql/websocket.go:46:22: fieldalignment: struct with 104 pointer bytes could be 80 (govet)

header http.Header
connParams map[string]interface{}
// Closed when exiting the receive loop in listenWebSocket
errChan chan error
endpoint string
subscriptions subscriptionMap
isClosing bool
sync.Mutex

// Hold when accessing `exitListenWebSocket`
exitListenWebSocketMu sync.Mutex
// Set to indicate the listenWebSocket should exit
exitListenWebSocket bool
}

type webSocketInitMessage struct {
Expand Down Expand Up @@ -104,27 +108,34 @@ func (w *webSocketClient) waitForConnAck() error {
return nil
}

func (w *webSocketClient) handleErr(err error) {
w.Lock()
defer w.Unlock()
if !w.isClosing {
w.errChan <- err
}
}

func (w *webSocketClient) listenWebSocket() {
for {
if w.isClosing {
// The listenWebSocket goroutine "owns" interfaceChan. Both sending
// (in forwardWebSocketData below) and closure (here) happen in this
// goroutine, so there is no possibility of races between send and close.
//
// interfaceChan's are closed at the top of listenWebSocket to
// guarantee the channels are closed even if listenWebSocket will exit.
w.subscriptions.forEachSubscription(func(sub *subscription) {
if sub.hasBeenUnsubscribed() && sub.interfaceChan != nil {
reflect.ValueOf(sub.interfaceChan).Close()
sub.interfaceChan = nil
}
})
w.exitListenWebSocketMu.Lock()
if w.exitListenWebSocket {
close(w.errChan)
return
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to unlock here?

}
w.exitListenWebSocketMu.Unlock()
_, message, err := w.conn.ReadMessage()
if err != nil {
w.handleErr(err)
w.errChan <- err
return
}
err = w.forwardWebSocketData(message)
if err != nil {
w.handleErr(err)
w.errChan <- err
return
}
}
Expand All @@ -139,22 +150,20 @@ func (w *webSocketClient) forwardWebSocketData(message []byte) error {
if wsMsg.ID == "" { // e.g. keep-alive messages
return nil
}
w.subscriptions.Lock()
defer w.subscriptions.Unlock()
sub, success := w.subscriptions.map_[wsMsg.ID]
if !success {
return fmt.Errorf("received message for unknown subscription ID '%s'", wsMsg.ID)
if wsMsg.Type == webSocketTypeComplete {
return w.subscriptions.Unsubscribe(wsMsg.ID)
}
if sub.hasBeenUnsubscribed {
return nil

sub, ok := w.subscriptions.GetSubscription(wsMsg.ID)
if !ok {
return fmt.Errorf("received message for unknown subscription ID '%s'", wsMsg.ID)
}
if wsMsg.Type == webSocketTypeComplete {
sub.hasBeenUnsubscribed = true
w.subscriptions.map_[wsMsg.ID] = sub
reflect.ValueOf(sub.interfaceChan).Close()
// Note: there's no data race between hasBeenUnsubscribed and the closed
// state of interfaceChan because interfaceChan is only closed by the
// caller of this function.
if sub.hasBeenUnsubscribed() {
return nil
}

return sub.forwardDataFunc(sub.interfaceChan, wsMsg.Payload)
}

Expand Down Expand Up @@ -206,10 +215,11 @@ func (w *webSocketClient) Close() error {
if err != nil {
return fmt.Errorf("failed to send closure message: %w", err)
}
w.Lock()
defer w.Unlock()
w.isClosing = true
close(w.errChan)

w.exitListenWebSocketMu.Lock()
w.exitListenWebSocket = true
w.exitListenWebSocketMu.Unlock()

return w.conn.Close()
}

Expand Down
6 changes: 3 additions & 3 deletions graphql/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ func forgeTestWebSocketClient(hasBeenUnsubscribed bool) *webSocketClient {
return &webSocketClient{
subscriptions: subscriptionMap{
RWMutex: sync.RWMutex{},
map_: map[string]subscription{
map_: map[string]*subscription{
testSubscriptionID: {
hasBeenUnsubscribed: hasBeenUnsubscribed,
interfaceChan: make(chan any),
_hasBeenUnsubscribed: hasBeenUnsubscribed,
interfaceChan: make(chan any),
forwardDataFunc: func(interfaceChan any, jsonRawMsg json.RawMessage) error {
return nil
},
Expand Down
Loading