Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graphql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, opts ...WebSocket
header: http.Header{},
errChan: make(chan error),
endpoint: endpoint,
subscriptions: subscriptionMap{map_: make(map[string]subscription)},
subscriptions: subscriptionMap{map_: make(map[string]*subscription)},
}

for _, opt := range opts {
Expand Down
73 changes: 57 additions & 16 deletions graphql/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,60 @@ package graphql

import (
"fmt"
"reflect"
"sync"
)

// map of subscription ID to subscription
type subscriptionMap struct {
map_ map[string]subscription
map_ map[string]*subscription
sync.RWMutex
}

type subscription struct {
interfaceChan interface{}
forwardDataFunc ForwardDataFunction
id string
hasBeenUnsubscribed bool
// interfaceChan is passed in by the user when creating a subscription but
// closed by webSocketClient when the subscription is unsubscribed, i.e.
// ownership of interfaceChan is passed from the user to the client.
//
// The subscription is unsubscribed either explicitly by the user or when
// a message of webSocketTypeComplete is received. On unsubscribe,
// the _hasBeenUnsubscribed flag is set to true. listenWebSocket then
// closes interfaceChan on the next receive loop.
//
// The listenWebSocket client method handles both sending on the channel
// and closing of the channel, so is no possibility of races between send
// and close.
interfaceChan interface{}

forwardDataFunc ForwardDataFunction
id string

// Hold when accessing _hasBeenUnsubscribed
hasBeenUnsubscribedMu sync.Mutex
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the copylocks (below) is a real problem -- you have to make the mutex a pointer and then hae a constructor for this struct.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the code to use *subscription in the subscription map, which should work as well.

_hasBeenUnsubscribed bool
}

func (s *subscription) unsubscribe() {
s.hasBeenUnsubscribedMu.Lock()
defer s.hasBeenUnsubscribedMu.Unlock()

s._hasBeenUnsubscribed = true
}

func (s *subscription) hasBeenUnsubscribed() bool {
s.hasBeenUnsubscribedMu.Lock()
defer s.hasBeenUnsubscribedMu.Unlock()

return s._hasBeenUnsubscribed
}

func (s *subscriptionMap) Create(subscriptionID string, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) {
s.Lock()
defer s.Unlock()
s.map_[subscriptionID] = subscription{
id: subscriptionID,
interfaceChan: interfaceChan,
forwardDataFunc: forwardDataFunc,
hasBeenUnsubscribed: false,
s.map_[subscriptionID] = &subscription{
id: subscriptionID,
interfaceChan: interfaceChan,
forwardDataFunc: forwardDataFunc,
_hasBeenUnsubscribed: false,
}
}

Expand All @@ -37,16 +66,28 @@ func (s *subscriptionMap) Unsubscribe(subscriptionID string) error {
if !success {
return fmt.Errorf("tried to unsubscribe from unknown subscription with ID '%s'", subscriptionID)
}
hasBeenUnsubscribed := unsub.hasBeenUnsubscribed
unsub.hasBeenUnsubscribed = true
unsub.unsubscribe()
s.map_[subscriptionID] = unsub

if !hasBeenUnsubscribed {
reflect.ValueOf(s.map_[subscriptionID].interfaceChan).Close()
}
return nil
}

func (s *subscriptionMap) forEachSubscription(fn func(sub *subscription)) {
s.Lock()
defer s.Unlock()

for id := range s.map_ {
fn(s.map_[id])
}
}

func (s *subscriptionMap) GetSubscription(subscriptionID string) (*subscription, bool) {
s.Lock()
defer s.Unlock()
sub, ok := s.map_[subscriptionID]
return sub, ok
}

func (s *subscriptionMap) GetAllIDs() (subscriptionIDs []string) {
s.RLock()
defer s.RUnlock()
Expand Down
22 changes: 11 additions & 11 deletions graphql/subscription_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ func Test_subscriptionMap_Unsubscribe(t *testing.T) {
{
name: "unsubscribe existing subscription",
sm: subscriptionMap{
map_: map[string]subscription{
map_: map[string]*subscription{
"sub1": {
id: "sub1",
interfaceChan: make(chan struct{}),
forwardDataFunc: nil,
hasBeenUnsubscribed: false,
id: "sub1",
interfaceChan: make(chan struct{}),
forwardDataFunc: nil,
_hasBeenUnsubscribed: false,
},
},
},
Expand All @@ -32,20 +32,20 @@ func Test_subscriptionMap_Unsubscribe(t *testing.T) {
{
name: "unsubscribe non-existent subscription",
sm: subscriptionMap{
map_: map[string]subscription{},
map_: map[string]*subscription{},
},
args: args{subscriptionID: "doesnotexist"},
wantErr: true,
},
{
name: "unsubscribe already unsubscribed subscription",
sm: subscriptionMap{
map_: map[string]subscription{
map_: map[string]*subscription{
"sub2": {
id: "sub2",
interfaceChan: nil,
forwardDataFunc: nil,
hasBeenUnsubscribed: true,
id: "sub2",
interfaceChan: nil,
forwardDataFunc: nil,
_hasBeenUnsubscribed: true,
},
},
},
Expand Down
34 changes: 22 additions & 12 deletions graphql/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@ func (w *webSocketClient) handleErr(err error) {

func (w *webSocketClient) listenWebSocket() {
for {
// The listenWebSocket goroutine "owns" interfaceChan. Both sending
// (in forwardWebSocketData below) and closure (here) happen in this
// goroutine, so there is no possibility of races between send and close.
//
// interfaceChan's are closed at the top of listenWebSocket to
// guarantee the channels are closed even if listenWebSocket will exit.
w.subscriptions.forEachSubscription(func(sub *subscription) {
if sub.hasBeenUnsubscribed() && sub.interfaceChan != nil {
reflect.ValueOf(sub.interfaceChan).Close()
sub.interfaceChan = nil
}
})
if w.isClosing {
return
}
Expand All @@ -139,22 +151,20 @@ func (w *webSocketClient) forwardWebSocketData(message []byte) error {
if wsMsg.ID == "" { // e.g. keep-alive messages
return nil
}
w.subscriptions.Lock()
defer w.subscriptions.Unlock()
sub, success := w.subscriptions.map_[wsMsg.ID]
if !success {
return fmt.Errorf("received message for unknown subscription ID '%s'", wsMsg.ID)
if wsMsg.Type == webSocketTypeComplete {
return w.subscriptions.Unsubscribe(wsMsg.ID)
}
if sub.hasBeenUnsubscribed {
return nil

sub, ok := w.subscriptions.GetSubscription(wsMsg.ID)
if !ok {
return fmt.Errorf("received message for unknown subscription ID '%s'", wsMsg.ID)
}
if wsMsg.Type == webSocketTypeComplete {
sub.hasBeenUnsubscribed = true
w.subscriptions.map_[wsMsg.ID] = sub
reflect.ValueOf(sub.interfaceChan).Close()
// Note: there's no data race between hasBeenUnsubscribed and the closed
// state of interfaceChan because interfaceChan is only closed by the
// caller of this function.
if sub.hasBeenUnsubscribed() {
return nil
}

return sub.forwardDataFunc(sub.interfaceChan, wsMsg.Payload)
}

Expand Down
6 changes: 3 additions & 3 deletions graphql/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ func forgeTestWebSocketClient(hasBeenUnsubscribed bool) *webSocketClient {
return &webSocketClient{
subscriptions: subscriptionMap{
RWMutex: sync.RWMutex{},
map_: map[string]subscription{
map_: map[string]*subscription{
testSubscriptionID: {
hasBeenUnsubscribed: hasBeenUnsubscribed,
interfaceChan: make(chan any),
_hasBeenUnsubscribed: hasBeenUnsubscribed,
interfaceChan: make(chan any),
forwardDataFunc: func(interfaceChan any, jsonRawMsg json.RawMessage) error {
return nil
},
Expand Down
Loading