diff --git a/bmemcached/client/distributed.py b/bmemcached/client/distributed.py index 80cb242..1f7e52b 100644 --- a/bmemcached/client/distributed.py +++ b/bmemcached/client/distributed.py @@ -39,7 +39,7 @@ def delete_multi(self, keys): servers[server_key].append(key) return all([server.delete_multi(keys_) for server, keys_ in servers.items()]) - def set(self, key, value, time=0, compress_level=-1): + def set(self, key, value, time=0, compress_level=-1, get_cas=False): """ Set a value for a key on server. @@ -53,11 +53,15 @@ def set(self, key, value, time=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True in case of success and False in case of failure - :rtype: bool + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. + :type get_cas: bool + :return: True in case of success and False in case of failure, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple """ server = self._get_server(key) - return server.set(key, value, time, compress_level) + return server.set(key, value, time, compress_level, get_cas=get_cas) def set_multi(self, mappings, time=0, compress_level=-1): """ @@ -86,7 +90,37 @@ def set_multi(self, mappings, time=0, compress_level=-1): return list(returns) - def add(self, key, value, time=0, compress_level=-1): + def set_multi_cas(self, mappings, time=0, compress_level=-1): + """ + Set multiple keys with their values on server, returning the new CAS + value for each successfully stored key. + + :param mappings: A dict with keys/values. Keys may be (key, cas) + tuples as in set_multi. + :type mappings: dict + :param time: Time in seconds that your key will expire. + :type time: int + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int + :return: A dict keyed by the string key of every input mapping. The + value is the new CAS int on success or None on failure. + :rtype: dict + """ + if not mappings: + return {} + result = {} + server_mappings = defaultdict(dict) + for key, value in mappings.items(): + str_key = key[0] if isinstance(key, tuple) else key + server_key = self._get_server(str_key) + server_mappings[server_key][key] = value + for server, m in server_mappings.items(): + result.update(server.set_multi_cas(m, time, compress_level)) + return result + + def add(self, key, value, time=0, compress_level=-1, get_cas=False): """ Add a key/value to server ony if it does not exist. @@ -100,13 +134,17 @@ def add(self, key, value, time=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True if key is added False if key already exists - :rtype: bool + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. + :type get_cas: bool + :return: True if key is added False if key already exists, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple """ server = self._get_server(key) - return server.add(key, value, time, compress_level) + return server.add(key, value, time, compress_level, get_cas=get_cas) - def replace(self, key, value, time=0, compress_level=-1): + def replace(self, key, value, time=0, compress_level=-1, get_cas=False): """ Replace a key/value to server ony if it does exist. @@ -120,11 +158,15 @@ def replace(self, key, value, time=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True if key is replace False if key does not exists - :rtype: bool + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. + :type get_cas: bool + :return: True if key is replace False if key does not exists, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple """ server = self._get_server(key) - return server.replace(key, value, time, compress_level) + return server.replace(key, value, time, compress_level, get_cas=get_cas) def get(self, key, default=None, get_cas=False): """ @@ -182,7 +224,7 @@ def gets(self, key): server = self._get_server(key) return server.get(key) - def cas(self, key, value, cas, time=0, compress_level=-1): + def cas(self, key, value, cas, time=0, compress_level=-1, get_cas=False): """ Set a value for a key on server if its CAS value matches cas. @@ -198,11 +240,15 @@ def cas(self, key, value, cas, time=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True in case of success and False in case of failure - :rtype: bool + :param get_cas: If true, return (success, new_cas) where new_cas is + the item's new CAS after the operation, or None on failure. + :type get_cas: bool + :return: True in case of success and False in case of failure, or a + (success, new_cas) tuple if get_cas=True. + :rtype: bool or tuple """ server = self._get_server(key) - return server.cas(key, value, cas, time, compress_level) + return server.cas(key, value, cas, time, compress_level, get_cas=get_cas) def incr(self, key, value, default=0, time=1000000): """ diff --git a/bmemcached/client/mixin.py b/bmemcached/client/mixin.py index 2d9ba26..26dd01b 100644 --- a/bmemcached/client/mixin.py +++ b/bmemcached/client/mixin.py @@ -132,19 +132,22 @@ def gets(self, key): def get_multi(self, keys, get_cas=False): raise NotImplementedError() - def set(self, key, value, time=0, compress_level=-1): + def set(self, key, value, time=0, compress_level=-1, get_cas=False): raise NotImplementedError() - def cas(self, key, value, cas, time=0, compress_level=-1): + def cas(self, key, value, cas, time=0, compress_level=-1, get_cas=False): raise NotImplementedError() def set_multi(self, mappings, time=0, compress_level=-1): raise NotImplementedError() - def add(self, key, value, time=0, compress_level=-1): + def set_multi_cas(self, mappings, time=0, compress_level=-1): raise NotImplementedError() - def replace(self, key, value, time=0, compress_level=-1): + def add(self, key, value, time=0, compress_level=-1, get_cas=False): + raise NotImplementedError() + + def replace(self, key, value, time=0, compress_level=-1, get_cas=False): raise NotImplementedError() def delete(self, key, cas=0): # type: (six.string_types, int) -> bool diff --git a/bmemcached/client/replicating.py b/bmemcached/client/replicating.py index 2cd7ff7..c82ae31 100644 --- a/bmemcached/client/replicating.py +++ b/bmemcached/client/replicating.py @@ -6,6 +6,25 @@ class ReplicatingClient(ClientMixin): This is intended to be a client class which implement standard cache interface that common libs do... It replicates values over servers and get a response from the first one it can. + + .. warning:: + CAS operations are fundamentally incompatible with multi-server + replication. Each server maintains its own independent CAS counter, + so a CAS value obtained from one replica will not match any other + replica. As a consequence: + + * :meth:`cas` against more than one replica causes at most one + server to accept the write; the rest silently reject it, leaving + the replicas divergent. The same hazard applies to + :meth:`set_multi` mappings that use ``(key, cas)`` tuple keys. + * :meth:`gets`, :meth:`get` with ``get_cas=True``, and + :meth:`get_multi` with ``get_cas=True`` return a CAS from + whichever replica happens to respond first. That value cannot + be safely passed back to :meth:`cas` on a multi-replica client, + for the reason above. + + If you need CAS semantics, configure this client with exactly one + server (or use :class:`DistributedClient`). """ def _set_retry_delay(self, value): @@ -31,6 +50,12 @@ def get(self, key, default=None, get_cas=False): """ Get a key from server. + .. warning:: + When called with ``get_cas=True`` against more than one replica, + the returned CAS is from whichever replica responded first and + cannot be safely passed to :meth:`cas` on this client. See the + class-level note on CAS and replication. + :param key: Key's name :type key: six.string_types :param default: In case memcached does not find a key, return a default value @@ -59,6 +84,12 @@ def gets(self, key): This method is for API compatibility with other implementations. + .. warning:: + Against more than one replica, the returned CAS is from + whichever replica responded first and cannot be safely passed + to :meth:`cas` on this client. See the class-level note on + CAS and replication. + :param key: Key's name :type key: six.string_types :return: Returns (key data, value), or (None, None) if the value is not in cache. @@ -74,6 +105,13 @@ def get_multi(self, keys, get_cas=False): """ Get multiple keys from server. + .. warning:: + When called with ``get_cas=True`` against more than one replica, + each key's returned CAS is from whichever replica returned that + key first; none of those values can be safely passed to + :meth:`cas` on this client. See the class-level note on CAS + and replication. + :param keys: A list of keys to from server. :type keys: list :param get_cas: If get_cas is true, each value is (data, cas), with each result's CAS value. @@ -95,7 +133,7 @@ def get_multi(self, keys, get_cas=False): break return d - def set(self, key, value, time=0, compress_level=-1): + def set(self, key, value, time=0, compress_level=-1, get_cas=False): """ Set a value for a key on server. @@ -109,19 +147,40 @@ def set(self, key, value, time=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True in case of success and False in case of failure - :rtype: bool + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. Only supported when + the client is configured with a single server; see the class + docstring for why CAS and multi-server replication don't mix. + :type get_cas: bool + :return: True in case of success and False in case of failure, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple + :raises NotImplementedError: if get_cas=True and more than one + server is configured. """ + if get_cas: + if len(self._servers) > 1: + raise NotImplementedError( + "get_cas=True is not supported on ReplicatingClient with " + "more than one server." + ) + return self._servers[0].set(key, value, time, compress_level=compress_level, get_cas=True) + returns = [] for server in self.servers: returns.append(server.set(key, value, time, compress_level=compress_level)) - return any(returns) - def cas(self, key, value, cas, time=0, compress_level=-1): + def cas(self, key, value, cas, time=0, compress_level=-1, get_cas=False): """ Set a value for a key on server if its CAS value matches cas. + .. warning:: + See the class-level note on CAS and replication. Each replica has + its own CAS counter, so a single CAS value cannot match on more + than one server. Calling this against multiple replicas will + silently diverge them -- at most one replica accepts the write. + :param key: Key's name :type key: six.string_types :param value: A value to be stored on server. @@ -134,19 +193,40 @@ def cas(self, key, value, cas, time=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True in case of success and False in case of failure - :rtype: bool + :param get_cas: If true, return (success, new_cas) where new_cas is + the item's new CAS after the operation, or None on failure. Only + supported when the client is configured with a single server; + see the class docstring. + :type get_cas: bool + :return: True in case of success and False in case of failure, or a + (success, new_cas) tuple if get_cas=True. + :rtype: bool or tuple + :raises NotImplementedError: if get_cas=True and more than one + server is configured. """ + if get_cas: + if len(self._servers) > 1: + raise NotImplementedError( + "get_cas=True is not supported on ReplicatingClient with " + "more than one server." + ) + return self._servers[0].cas(key, value, cas, time, compress_level=compress_level, get_cas=True) + returns = [] for server in self.servers: returns.append(server.cas(key, value, cas, time, compress_level=compress_level)) - return any(returns) def set_multi(self, mappings, time=0, compress_level=-1): """ Set multiple keys with it's values on server. + .. warning:: + If any key is given as a ``(key, cas)`` tuple, the same CAS-plus- + replication hazard documented on :meth:`cas` applies: the CAS + value can match at most one replica, so those entries will + silently diverge across servers. + :param mappings: A dict with keys/values :type mappings: dict :param time: Time in seconds that your key will expire. @@ -165,7 +245,39 @@ def set_multi(self, mappings, time=0, compress_level=-1): return list(returns) - def add(self, key, value, time=0, compress_level=-1): + def set_multi_cas(self, mappings, time=0, compress_level=-1): + """ + Set multiple keys with their values on the server, returning the new + CAS value for each successfully stored key. + + Only supported when the client is configured with a single server; + see the class docstring for why CAS and multi-server replication + don't mix. + + :param mappings: A dict with keys/values. Keys may be (key, cas) + tuples as in set_multi. + :type mappings: dict + :param time: Time in seconds that your key will expire. + :type time: int + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int + :return: A dict keyed by the string key of every input mapping. The + value is the new CAS int on success or None on failure. + :rtype: dict + :raises NotImplementedError: if more than one server is configured. + """ + if len(self._servers) > 1: + raise NotImplementedError( + "set_multi_cas is not supported on ReplicatingClient with " + "more than one server." + ) + if not mappings: + return {} + return self._servers[0].set_multi_cas(mappings, time, compress_level=compress_level) + + def add(self, key, value, time=0, compress_level=-1, get_cas=False): """ Add a key/value to server ony if it does not exist. @@ -179,16 +291,31 @@ def add(self, key, value, time=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True if key is added False if key already exists - :rtype: bool + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. Only supported when + the client is configured with a single server; see the class + docstring. + :type get_cas: bool + :return: True if key is added False if key already exists, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple + :raises NotImplementedError: if get_cas=True and more than one + server is configured. """ + if get_cas: + if len(self._servers) > 1: + raise NotImplementedError( + "get_cas=True is not supported on ReplicatingClient with " + "more than one server." + ) + return self._servers[0].add(key, value, time, compress_level=compress_level, get_cas=True) + returns = [] for server in self.servers: returns.append(server.add(key, value, time, compress_level=compress_level)) - return any(returns) - def replace(self, key, value, time=0, compress_level=-1): + def replace(self, key, value, time=0, compress_level=-1, get_cas=False): """ Replace a key/value to server ony if it does exist. @@ -202,13 +329,28 @@ def replace(self, key, value, time=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True if key is replace False if key does not exists - :rtype: bool + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. Only supported when + the client is configured with a single server; see the class + docstring. + :type get_cas: bool + :return: True if key is replace False if key does not exists, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple + :raises NotImplementedError: if get_cas=True and more than one + server is configured. """ + if get_cas: + if len(self._servers) > 1: + raise NotImplementedError( + "get_cas=True is not supported on ReplicatingClient with " + "more than one server." + ) + return self._servers[0].replace(key, value, time, compress_level=compress_level, get_cas=True) + returns = [] for server in self.servers: returns.append(server.replace(key, value, time, compress_level=compress_level)) - return any(returns) def delete(self, key, cas=0): diff --git a/bmemcached/protocol.py b/bmemcached/protocol.py index 06d7cb3..b1132fb 100644 --- a/bmemcached/protocol.py +++ b/bmemcached/protocol.py @@ -573,8 +573,9 @@ def _set_add_replace(self, command, key, value, time, cas=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True in case of success and False in case of failure - :rtype: bool + :return: A (success, cas) tuple. success is True on success and False + on failure; cas is the new CAS value on success and None otherwise. + :rtype: tuple """ time = time if time >= 0 else self.MAXIMUM_EXPIRE_TIME logger.debug('Setting/adding/replacing key %s.', key) @@ -596,16 +597,16 @@ def _set_add_replace(self, command, key, value, time, cas=0, compress_level=-1): if status != self.STATUS['success']: if status == self.STATUS['key_exists']: - return False + return False, None elif status == self.STATUS['key_not_found']: - return False + return False, None elif status == self.STATUS['server_disconnected']: - return False + return False, None raise MemcachedException('Code: %d Message: %s' % (status, extra_content), status) - return True + return True, cas - def set(self, key, value, time, compress_level=-1): + def set(self, key, value, time, compress_level=-1, get_cas=False): """ Set a value for a key on server. @@ -619,12 +620,19 @@ def set(self, key, value, time, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True in case of success and False in case of failure - :rtype: bool + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. + :type get_cas: bool + :return: True in case of success and False in case of failure, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple """ - return self._set_add_replace('set', key, value, time, compress_level=compress_level) + success, cas = self._set_add_replace('set', key, value, time, compress_level=compress_level) + if get_cas: + return success, cas + return success - def cas(self, key, value, cas, time, compress_level=-1): + def cas(self, key, value, cas, time, compress_level=-1, get_cas=False): """ Add a key/value to server ony if it does not exist. @@ -638,8 +646,12 @@ def cas(self, key, value, cas, time, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True if key is added False if key already exists and has a different CAS - :rtype: bool + :param get_cas: If true, return (success, new_cas) where new_cas is + the item's new CAS after the operation, or None on failure. + :type get_cas: bool + :return: True if key is added False if key already exists and has a + different CAS, or a (success, new_cas) tuple if get_cas=True. + :rtype: bool or tuple """ # The protocol CAS value 0 means "no cas". Calling cas() with that value is # probably unintentional. Don't allow it, since it would overwrite the value @@ -649,11 +661,14 @@ def cas(self, key, value, cas, time, compress_level=-1): # If we get a cas of None, interpret that as "compare against nonexistant and set", # which is simply Add. if cas is None: - return self._set_add_replace('add', key, value, time, compress_level=compress_level) + success, new_cas = self._set_add_replace('add', key, value, time, compress_level=compress_level) else: - return self._set_add_replace('set', key, value, time, cas=cas, compress_level=compress_level) + success, new_cas = self._set_add_replace('set', key, value, time, cas=cas, compress_level=compress_level) + if get_cas: + return success, new_cas + return success - def add(self, key, value, time, compress_level=-1): + def add(self, key, value, time, compress_level=-1, get_cas=False): """ Add a key/value to server ony if it does not exist. @@ -667,12 +682,19 @@ def add(self, key, value, time, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True if key is added False if key already exists - :rtype: bool + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. + :type get_cas: bool + :return: True if key is added False if key already exists, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple """ - return self._set_add_replace('add', key, value, time, compress_level=compress_level) + success, cas = self._set_add_replace('add', key, value, time, compress_level=compress_level) + if get_cas: + return success, cas + return success - def replace(self, key, value, time, compress_level=-1): + def replace(self, key, value, time, compress_level=-1, get_cas=False): """ Replace a key/value to server ony if it does exist. @@ -686,10 +708,17 @@ def replace(self, key, value, time, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True if key is replace False if key does not exists - :rtype: bool + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. + :type get_cas: bool + :return: True if key is replace False if key does not exists, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple """ - return self._set_add_replace('replace', key, value, time, compress_level=compress_level) + success, cas = self._set_add_replace('replace', key, value, time, compress_level=compress_level) + if get_cas: + return success, cas + return success def set_multi(self, mappings, time=100, compress_level=-1): """ @@ -760,6 +789,72 @@ def set_multi(self, mappings, time=100, compress_level=-1): return failed + def set_multi_cas(self, mappings, time=100, compress_level=-1): + """ + Set multiple keys with their values on server and return the new CAS + value for each successfully stored key. + + If a key is a (key, cas) tuple, insert as if cas(key, value, cas) had + been called. A cas of 0 means add-if-not-exists. + + Unlike set_multi, this uses the non-quiet set/add opcodes so that the + server responds to every request; this costs one response per key but + is what allows per-key CAS values to be returned. + + :param mappings: A dict with keys/values + :type mappings: dict + :param time: Time in seconds that your key will expire. + :type time: int + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int + :return: A dict keyed by the string key of every input mapping. The + value is the new CAS int on success or None on failure. + :rtype: dict + """ + mappings = list(mappings.items()) + msg = bytearray() + result = {} + + for opaque, (key, value) in enumerate(mappings): + if isinstance(key, tuple): + str_key, cas = key + else: + str_key, cas = key, None + result[str_key] = None + + if cas == 0: + command = 'add' + else: + command = 'set' + + keybytes = str_to_bytes(str_key) + flags, value = self.serialize(value, compress_level=compress_level) + msg += struct.pack(self.HEADER_STRUCT + + self.COMMANDS[command]['struct'] % (len(keybytes), len(value)), + self.MAGIC['request'], + self.COMMANDS[command]['command'], + len(keybytes), + 8, 0, 0, len(keybytes) + len(value) + 8, opaque, cas or 0, + flags, time, keybytes, value) + + self._send(msg) + + # Non-quiet set/add return exactly one response per request, so we can + # read a fixed count rather than relying on a trailing noop sentinel. + for _ in range(len(mappings)): + (magic, opcode, keylen, extlen, datatype, status, bodylen, opaque, + cas, extra_content) = self._get_response() + if status == self.STATUS['server_disconnected']: + return result + if status == self.STATUS['success']: + key, value = mappings[opaque] + str_key = key[0] if isinstance(key, tuple) else key + result[str_key] = cas + + return result + def _incr_decr(self, command, key, value, default, time): """ Function which increments and decrements. diff --git a/test/test_simple_functions.py b/test/test_simple_functions.py index f69ce26..e6e625c 100644 --- a/test/test_simple_functions.py +++ b/test/test_simple_functions.py @@ -27,6 +27,7 @@ def tearDown(self): def reset(self): self.client.delete('test_key') self.client.delete('test_key2') + self.client.delete('fresh_key') def testSet(self): self.assertTrue(self.client.set('test_key', 'test')) @@ -120,6 +121,51 @@ def testMultiCas(self): }), []) self.assertEqual(self.client.get('test_key'), 'value4') + def testSetMultiCas(self): + # All-success plain keys: every input gets a non-None CAS, and each + # returned CAS matches what gets() reports afterwards. + result = self.client.set_multi_cas({ + 'test_key': 'value1', + 'test_key2': 'value2', + }) + self.assertEqual(set(result.keys()), {'test_key', 'test_key2'}) + self.assertTrue(result['test_key'] is not None) + self.assertTrue(result['test_key2'] is not None) + _, cas1 = self.client.gets('test_key') + _, cas2 = self.client.gets('test_key2') + self.assertEqual(result['test_key'], cas1) + self.assertEqual(result['test_key2'], cas2) + + # CAS failure: add-if-not-exists when the key already exists returns + # None for that key; unrelated keys still succeed. + result = self.client.set_multi_cas({ + ('test_key', 0): 'shouldnt_store', + 'fresh_key': 'fresh', + }) + self.assertTrue(result['test_key'] is None) + self.assertTrue(result['fresh_key'] is not None) + self.assertEqual(self.client.get('test_key'), 'value1') + self.client.delete('fresh_key') + + # Stale-CAS failure: capture cas, mutate out of band, then set_multi_cas + # with the stale cas must fail and leave the out-of-band value intact. + _, stale_cas = self.client.gets('test_key') + self.client.set('test_key', 'other') + result = self.client.set_multi_cas({ + ('test_key', stale_cas): 'should_fail', + }) + self.assertTrue(result['test_key'] is None) + self.assertEqual(self.client.get('test_key'), 'other') + + # Returned CAS is usable directly in cas() without a gets() round-trip. + self.client.delete('test_key') + result = self.client.set_multi_cas({'test_key': 'v'}) + self.assertTrue(self.client.cas('test_key', 'v2', result['test_key'])) + self.assertEqual(self.client.get('test_key'), 'v2') + + def testSetMultiCasEmpty(self): + self.assertEqual(self.client.set_multi_cas({}), {}) + def testGetMultiCas(self): self.client.set('test_key', 'value1') self.client.set('test_key2', 'value2') @@ -196,6 +242,83 @@ def testAddFail(self): self.client.add('test_key', 'value') self.assertFalse(self.client.add('test_key', 'test')) + def testAddCas(self): + success, cas = self.client.add('test_key', 'value', get_cas=True) + self.assertTrue(success) + self.assertTrue(cas is not None) + + # The CAS returned by add() must equal the CAS later returned by gets(). + _, gets_cas = self.client.gets('test_key') + self.assertEqual(cas, gets_cas) + + # A second add of the same key fails; cas is None. + success2, cas2 = self.client.add('test_key', 'value2', get_cas=True) + self.assertFalse(success2) + self.assertTrue(cas2 is None) + + # The CAS returned from add() can be used directly in cas() without + # a separate gets() round-trip. + self.assertTrue(self.client.cas('test_key', 'value3', cas)) + self.assertEqual('value3', self.client.get('test_key')) + + # Backward compatibility: with no get_cas kwarg, add() still returns a plain bool. + result = self.client.add('test_key2', 'value') + self.assertEqual(True, result) + + def testSetCas(self): + # set() with get_cas=True returns (True, cas) and cas matches gets(). + success, cas = self.client.set('test_key', 'v1', get_cas=True) + self.assertTrue(success) + self.assertTrue(cas is not None) + _, gets_cas = self.client.gets('test_key') + self.assertEqual(cas, gets_cas) + + # The returned CAS is usable directly in cas() without a gets() round-trip. + self.assertTrue(self.client.cas('test_key', 'v2', cas)) + self.assertEqual('v2', self.client.get('test_key')) + + # Backward compatibility: no get_cas kwarg still returns a plain bool. + self.assertEqual(True, self.client.set('test_key2', 'v')) + + def testReplaceCas(self): + # Replace on a nonexistent key fails; cas is None. + success, cas = self.client.replace('test_key', 'v', get_cas=True) + self.assertFalse(success) + self.assertTrue(cas is None) + + # Replace on an existing key succeeds and returns the new CAS. + self.client.set('test_key', 'original') + success, cas = self.client.replace('test_key', 'new', get_cas=True) + self.assertTrue(success) + self.assertTrue(cas is not None) + _, gets_cas = self.client.gets('test_key') + self.assertEqual(cas, gets_cas) + + # Backward compatibility: no get_cas kwarg still returns a plain bool. + self.assertEqual(True, self.client.replace('test_key', 'x')) + + def testCasCas(self): + # cas() with get_cas=True, invoked as add (cas=None): returns new CAS. + success, cas = self.client.cas('test_key', 'v1', None, get_cas=True) + self.assertTrue(success) + self.assertTrue(cas is not None) + + # Chain a second CAS using the returned value directly (no gets()). + success2, cas2 = self.client.cas('test_key', 'v2', cas, get_cas=True) + self.assertTrue(success2) + self.assertTrue(cas2 is not None) + self.assertNotEqual(cas, cas2) + self.assertEqual('v2', self.client.get('test_key')) + + # A stale CAS fails; the returned new_cas is None. + success3, cas3 = self.client.cas('test_key', 'v3', cas, get_cas=True) + self.assertFalse(success3) + self.assertTrue(cas3 is None) + self.assertEqual('v2', self.client.get('test_key')) + + # Backward compatibility: no get_cas kwarg still returns a plain bool. + self.assertEqual(True, self.client.cas('test_key', 'v4', cas2)) + def testReplacePass(self): self.client.add('test_key', 'value') self.assertTrue(self.client.replace('test_key', 'value2')) @@ -237,6 +360,32 @@ def testReconnect(self): self.client.disconnect_all() self.assertEqual('test', self.client.get('test_key')) + def testGetCasMultiReplicaRaises(self): + # A ReplicatingClient with >1 server can't safely return a CAS, because + # each replica has its own CAS counter. Confirm every new get_cas path + # raises loudly instead of silently diverging replicas. + client = bmemcached.Client( + ['/tmp/memcached.sock', '{}:11211'.format(os.environ['MEMCACHED_HOST'])], + 'user', 'password', + ) + try: + with self.assertRaises(NotImplementedError): + client.add('test_key', 'v', get_cas=True) + with self.assertRaises(NotImplementedError): + client.set('test_key', 'v', get_cas=True) + with self.assertRaises(NotImplementedError): + client.replace('test_key', 'v', get_cas=True) + with self.assertRaises(NotImplementedError): + client.cas('test_key', 'v', None, get_cas=True) + with self.assertRaises(NotImplementedError): + client.set_multi_cas({'test_key': 'v'}) + + # get_cas=False (default) still works fine on multi-replica. + self.assertTrue(client.set('test_key', 'v')) + finally: + client.delete('test_key') + client.disconnect_all() + class TimeoutMemcachedTests(unittest.TestCase): def setUp(self):