Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Programming Language :: Python',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3.5',
'Framework :: Twisted'],
'install_requires': ['six'],
}

from setuptools import setup
Expand Down
6 changes: 3 additions & 3 deletions txredis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@
@brief Twisted compatible version of redis.py
"""
# for backwards compatibility
from client import *
from exceptions import *
from protocol import *
from .client import *
from .exceptions import *
from .protocol import *
32 changes: 20 additions & 12 deletions txredis/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
"""
@file client.py
"""
from __future__ import unicode_literals
import itertools

try:
from itertools import izip
except ImportError: # python 3.x
izip = zip

import six
from six.moves import range
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory

Expand Down Expand Up @@ -74,7 +82,7 @@ def post_process(values):
res = {}
if not values:
return res
for i in xrange(0, len(values) - 1, 2):
for i in range(0, len(values) - 1, 2):
res[values[i]] = values[i + 1]
return res
return self.getResponse().addCallback(post_process)
Expand Down Expand Up @@ -131,7 +139,7 @@ def msetnx(self, mapping):
unchanged.
"""

self._send('msetnx', *list(itertools.chain(*mapping.iteritems())))
self._send('msetnx', *list(itertools.chain(*six.iteritems(mapping))))
return self.getResponse()

def mset(self, mapping, preserve=False):
Expand All @@ -142,7 +150,7 @@ def mset(self, mapping, preserve=False):
command = 'MSETNX'
else:
command = 'MSET'
self._send(command, *list(itertools.chain(*mapping.iteritems())))
self._send(command, *list(itertools.chain(*six.iteritems(mapping))))
return self.getResponse()

def append(self, key, value):
Expand Down Expand Up @@ -963,7 +971,7 @@ def sort(self, key, by=None, get=None, start=None, num=None, desc=False,
stmt.extend(['LIMIT', start, num])
if get is None:
pass
elif isinstance(get, basestring):
elif isinstance(get, six.string_types):
stmt.extend(['GET', get])
elif isinstance(get, list) or isinstance(get, tuple):
for g in get:
Expand Down Expand Up @@ -1011,7 +1019,7 @@ def hmset(self, key, in_dict):
at key. This command overwrites any existing fields in the hash. If key
does not exist, a new key holding a hash is created.
"""
fields = list(itertools.chain(*in_dict.iteritems()))
fields = list(itertools.chain(*six.iteritems(in_dict)))
self._send('HMSET', key, *fields)
return self.getResponse()

Expand Down Expand Up @@ -1040,17 +1048,17 @@ def hget(self, key, field):
"""
Returns the value associated with field in the hash stored at key.
"""
if isinstance(field, basestring):
if isinstance(field, six.string_types):
self._send('HGET', key, field)
else:
self._send('HMGET', *([key] + field))

def post_process(values):
if not values:
return values
if isinstance(field, basestring):
if isinstance(field, six.string_types):
return {field: values}
return dict(itertools.izip(field, values))
return dict(izip(field, values))

return self.getResponse().addCallback(post_process)
hmget = hget
Expand All @@ -1059,7 +1067,7 @@ def hget_value(self, key, field):
"""
Get the value of a hash field
"""
assert isinstance(field, basestring)
assert isinstance(field, six.string_types)
self._send('HGET', key, field)
return self.getResponse()

Expand Down Expand Up @@ -1171,7 +1179,7 @@ def zadd(self, key, *item_tuples, **kwargs):
as (value, score) for backwards compatibility reasons.
"""
if not kwargs and len(item_tuples) == 2 and \
isinstance(item_tuples[0], basestring):
isinstance(item_tuples[0], six.string_types):
self._send('ZADD', key, item_tuples[1], item_tuples[0])
elif not kwargs:
self._send('ZADD', key, *item_tuples)
Expand Down Expand Up @@ -1222,9 +1230,9 @@ def _zopstore(self, op, dstkey, keys, aggregate=None):
args = [op, dstkey, len(keys)]
# add in key names, and optionally weights
if isinstance(keys, dict):
args.extend(list(keys.iterkeys()))
args.extend(list(six.iterkeys(keys)))
args.append('WEIGHTS')
args.extend(list(keys.itervalues()))
args.extend(list(six.itervalues(keys)))
else:
args.extend(keys)
if aggregate:
Expand Down
2 changes: 2 additions & 0 deletions txredis/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
@file exceptions.py
"""
from __future__ import unicode_literals


class RedisError(Exception):
pass
Expand Down
61 changes: 37 additions & 24 deletions txredis/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,36 @@ def main():
Command doc strings taken from the CommandReference wiki page.

"""
from __future__ import unicode_literals

from collections import deque

from six import PY3
from twisted.internet import defer, protocol
from twisted.protocols import policies

from txredis import exceptions

if PY3:
unicode = str


class RedisBase(protocol.Protocol, policies.TimeoutMixin, object):
"""The main Redis client."""

ERROR = "-"
SINGLE_LINE = "+"
INTEGER = ":"
BULK = "$"
MULTI_BULK = "*"
ERROR = b"-"
SINGLE_LINE = b"+"
INTEGER = b":"
BULK = b"$"
MULTI_BULK = b"*"

def __init__(self, db=None, password=None, charset='utf8',
errors='strict'):
self.charset = charset
self.db = db if db is not None else 0
self.password = password
self.errors = errors
self._buffer = ''
self._buffer = b''
self._bulk_length = None
self._disconnected = False
# Format of _multi_bulk_stack elements is:
Expand Down Expand Up @@ -90,16 +96,16 @@ def dataReceived(self, data):
continue

# wait until we have a line
if '\r\n' not in self._buffer:
if b'\r\n' not in self._buffer:
return

# grab a line
line, self._buffer = self._buffer.split('\r\n', 1)
line, self._buffer = self._buffer.split(b'\r\n', 1)
if len(line) == 0:
continue

# first byte indicates reply type
reply_type = line[0]
reply_type = line[0:1]
reply_data = line[1:]

# Error message (-)
Expand Down Expand Up @@ -141,6 +147,8 @@ def dataReceived(self, data):
self._multi_bulk_stack.append([multi_bulk_length, []])
if multi_bulk_length == 0:
self.multiBulkDataReceived()
else:
raise exceptions.InvalidData("Unexpected reply_type: %r", reply_type)

def failRequests(self, reason):
while self._request_queue:
Expand Down Expand Up @@ -186,14 +194,14 @@ def timeoutConnection(self):

def errorReceived(self, data):
"""Error response received."""
if data[:4] == 'ERR ':
reply = exceptions.ResponseError(data[4:])
elif data[:9] == 'NOSCRIPT ':
reply = exceptions.NoScript(data[9:])
elif data[:8] == 'NOTBUSY ':
reply = exceptions.NotBusy(data[8:])
if data[:4] == b'ERR ':
reply = exceptions.ResponseError(data[4:].decode(self.charset))
elif data[:9] == b'NOSCRIPT ':
reply = exceptions.NoScript(data[9:].decode(self.charset))
elif data[:8] == b'NOTBUSY ':
reply = exceptions.NotBusy(data[8:].decode(self.charset))
else:
reply = exceptions.ResponseError(data)
reply = exceptions.ResponseError(data.decode(self.charset))

if self._request_queue:
# properly errback this reply
Expand All @@ -204,11 +212,11 @@ def errorReceived(self, data):

def singleLineReceived(self, data):
"""Single line response received."""
if data == 'none':
if data == b'none':
# should this happen here in the client?
reply = None
else:
reply = data
reply = data.decode(self.charset)

self.responseReceived(reply)

Expand All @@ -235,6 +243,8 @@ def integerReceived(self, data):
def bulkDataReceived(self, data):
"""Bulk data response received."""
self._bulk_length = None
if isinstance(data, bytes):
data = data.decode(self.charset)
self.responseReceived(data)

def multiBulkDataReceived(self):
Expand Down Expand Up @@ -278,16 +288,19 @@ def getResponse(self):

def _encode(self, s):
"""Encode a value for sending to the server."""
if isinstance(s, str):
return s
if not isinstance(s, (unicode, str, bytes)):
s = str(s)
# we made "unicode" an alias for "str" on Python 3 at the head of the file
if isinstance(s, unicode):
try:
return s.encode(self.charset, self.errors)
except UnicodeEncodeError, e:
except UnicodeEncodeError as e:
raise exceptions.InvalidData(
"Error encoding unicode value '%s': %s" % (
s.encode(self.charset, 'replace'), e))
return str(s)
if isinstance(s, (str, bytes)):
return s
raise exceptions.InvalidData("Unexpected data: %r" % s)

def _send(self, *args):
"""Encode and send a request
Expand All @@ -298,8 +311,8 @@ def _send(self, *args):
cmds = []
for i in args:
v = self._encode(i)
cmds.append('$%s\r\n%s\r\n' % (len(v), v))
cmd = '*%s\r\n' % len(args) + ''.join(cmds)
cmds.append(b'$%d\r\n%s\r\n' % (len(v), v))
cmd = b'*%d\r\n' % len(args) + b''.join(cmds)
self.transport.write(cmd)

def send(self, command, *args):
Expand Down
2 changes: 2 additions & 0 deletions txredis/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

This module provides the basic needs to run txRedis unit tests.
"""
from __future__ import unicode_literals

from twisted.internet import protocol
from twisted.internet import reactor
from twisted.trial import unittest
Expand Down
Loading