From 96299ecffb69ef2d3ed1a99e9d998387d912a155 Mon Sep 17 00:00:00 2001 From: Alex Vandiver Date: Fri, 24 Apr 2026 12:58:13 +0000 Subject: [PATCH 1/3] Document CAS-plus-replication hazard on ReplicatingClient ReplicatingClient's CAS-touching methods have always had a silent correctness problem when used against more than one replica: each server maintains its own CAS counter, so a CAS value cannot match on more than one replica. any(returns) then reports success as long as one server accepted the write, but the other replicas silently rejected it, leaving them divergent. The get-side methods do not warn about this -- `gets()`, `get(get_cas=True)`, and `get_multi(get_cas=True)` return a CAS from whichever replica happened to respond first, even though that value cannot be safely passed to cas() on a multi-replica client. Add warnings on the class docstring and on each affected method, so callers have some hope of noticing the hazard. For backwards compatibility, the behavior itself is left unchanged. --- bmemcached/client/replicating.py | 50 ++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/bmemcached/client/replicating.py b/bmemcached/client/replicating.py index 2cd7ff7..7d99ca4 100644 --- a/bmemcached/client/replicating.py +++ b/bmemcached/client/replicating.py @@ -6,6 +6,25 @@ class ReplicatingClient(ClientMixin): This is intended to be a client class which implement standard cache interface that common libs do... It replicates values over servers and get a response from the first one it can. + + .. warning:: + CAS operations are fundamentally incompatible with multi-server + replication. Each server maintains its own independent CAS counter, + so a CAS value obtained from one replica will not match any other + replica. As a consequence: + + * :meth:`cas` against more than one replica causes at most one + server to accept the write; the rest silently reject it, leaving + the replicas divergent. The same hazard applies to + :meth:`set_multi` mappings that use ``(key, cas)`` tuple keys. + * :meth:`gets`, :meth:`get` with ``get_cas=True``, and + :meth:`get_multi` with ``get_cas=True`` return a CAS from + whichever replica happens to respond first. That value cannot + be safely passed back to :meth:`cas` on a multi-replica client, + for the reason above. + + If you need CAS semantics, configure this client with exactly one + server (or use :class:`DistributedClient`). """ def _set_retry_delay(self, value): @@ -31,6 +50,12 @@ def get(self, key, default=None, get_cas=False): """ Get a key from server. + .. warning:: + When called with ``get_cas=True`` against more than one replica, + the returned CAS is from whichever replica responded first and + cannot be safely passed to :meth:`cas` on this client. See the + class-level note on CAS and replication. + :param key: Key's name :type key: six.string_types :param default: In case memcached does not find a key, return a default value @@ -59,6 +84,12 @@ def gets(self, key): This method is for API compatibility with other implementations. + .. warning:: + Against more than one replica, the returned CAS is from + whichever replica responded first and cannot be safely passed + to :meth:`cas` on this client. See the class-level note on + CAS and replication. + :param key: Key's name :type key: six.string_types :return: Returns (key data, value), or (None, None) if the value is not in cache. @@ -74,6 +105,13 @@ def get_multi(self, keys, get_cas=False): """ Get multiple keys from server. + .. warning:: + When called with ``get_cas=True`` against more than one replica, + each key's returned CAS is from whichever replica returned that + key first; none of those values can be safely passed to + :meth:`cas` on this client. See the class-level note on CAS + and replication. + :param keys: A list of keys to from server. :type keys: list :param get_cas: If get_cas is true, each value is (data, cas), with each result's CAS value. @@ -122,6 +160,12 @@ def cas(self, key, value, cas, time=0, compress_level=-1): """ Set a value for a key on server if its CAS value matches cas. + .. warning:: + See the class-level note on CAS and replication. Each replica has + its own CAS counter, so a single CAS value cannot match on more + than one server. Calling this against multiple replicas will + silently diverge them -- at most one replica accepts the write. + :param key: Key's name :type key: six.string_types :param value: A value to be stored on server. @@ -147,6 +191,12 @@ def set_multi(self, mappings, time=0, compress_level=-1): """ Set multiple keys with it's values on server. + .. warning:: + If any key is given as a ``(key, cas)`` tuple, the same CAS-plus- + replication hazard documented on :meth:`cas` applies: the CAS + value can match at most one replica, so those entries will + silently diverge across servers. + :param mappings: A dict with keys/values :type mappings: dict :param time: Time in seconds that your key will expire. From 52c948b60dc21c47d79a0cfea99faa9f48b067ce Mon Sep 17 00:00:00 2001 From: Alex Vandiver Date: Fri, 24 Apr 2026 03:03:32 +0000 Subject: [PATCH 2/3] Return CAS from single-key mutators via `get_cas` kwarg add(), set(), replace(), and cas() all produce an item with a new CAS value on success, and the memcached binary protocol already returns it in the response header -- the client was simply discarding it. Callers who want to chain a CAS-guarded update after a write had to follow up with a separate gets() round-trip, which is both slower and racy (another writer could slip in between):. Add an optional `get_cas=False` kwarg matching the existing convention on get()/get_multi(). When True, these methods now return a tuple of `(success, cas)` instead of a plain bool; `cas` is the new CAS on success, or None on failure. For ReplicatingClient, the returned CAS comes from the first replica that reported success, matching how get/gets already handle replicas. CAS is inherently per-server, so only that server's CAS is meaningful for subsequent CAS operations against the same server. --- bmemcached/client/distributed.py | 48 ++++++++++----- bmemcached/client/mixin.py | 8 +-- bmemcached/client/replicating.py | 92 +++++++++++++++++++++++----- bmemcached/protocol.py | 77 +++++++++++++++-------- test/test_simple_functions.py | 101 +++++++++++++++++++++++++++++++ 5 files changed, 266 insertions(+), 60 deletions(-) diff --git a/bmemcached/client/distributed.py b/bmemcached/client/distributed.py index 80cb242..3e31164 100644 --- a/bmemcached/client/distributed.py +++ b/bmemcached/client/distributed.py @@ -39,7 +39,7 @@ def delete_multi(self, keys): servers[server_key].append(key) return all([server.delete_multi(keys_) for server, keys_ in servers.items()]) - def set(self, key, value, time=0, compress_level=-1): + def set(self, key, value, time=0, compress_level=-1, get_cas=False): """ Set a value for a key on server. @@ -53,11 +53,15 @@ def set(self, key, value, time=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True in case of success and False in case of failure - :rtype: bool + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. + :type get_cas: bool + :return: True in case of success and False in case of failure, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple """ server = self._get_server(key) - return server.set(key, value, time, compress_level) + return server.set(key, value, time, compress_level, get_cas=get_cas) def set_multi(self, mappings, time=0, compress_level=-1): """ @@ -86,7 +90,7 @@ def set_multi(self, mappings, time=0, compress_level=-1): return list(returns) - def add(self, key, value, time=0, compress_level=-1): + def add(self, key, value, time=0, compress_level=-1, get_cas=False): """ Add a key/value to server ony if it does not exist. @@ -100,13 +104,17 @@ def add(self, key, value, time=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True if key is added False if key already exists - :rtype: bool + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. + :type get_cas: bool + :return: True if key is added False if key already exists, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple """ server = self._get_server(key) - return server.add(key, value, time, compress_level) + return server.add(key, value, time, compress_level, get_cas=get_cas) - def replace(self, key, value, time=0, compress_level=-1): + def replace(self, key, value, time=0, compress_level=-1, get_cas=False): """ Replace a key/value to server ony if it does exist. @@ -120,11 +128,15 @@ def replace(self, key, value, time=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True if key is replace False if key does not exists - :rtype: bool + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. + :type get_cas: bool + :return: True if key is replace False if key does not exists, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple """ server = self._get_server(key) - return server.replace(key, value, time, compress_level) + return server.replace(key, value, time, compress_level, get_cas=get_cas) def get(self, key, default=None, get_cas=False): """ @@ -182,7 +194,7 @@ def gets(self, key): server = self._get_server(key) return server.get(key) - def cas(self, key, value, cas, time=0, compress_level=-1): + def cas(self, key, value, cas, time=0, compress_level=-1, get_cas=False): """ Set a value for a key on server if its CAS value matches cas. @@ -198,11 +210,15 @@ def cas(self, key, value, cas, time=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True in case of success and False in case of failure - :rtype: bool + :param get_cas: If true, return (success, new_cas) where new_cas is + the item's new CAS after the operation, or None on failure. + :type get_cas: bool + :return: True in case of success and False in case of failure, or a + (success, new_cas) tuple if get_cas=True. + :rtype: bool or tuple """ server = self._get_server(key) - return server.cas(key, value, cas, time, compress_level) + return server.cas(key, value, cas, time, compress_level, get_cas=get_cas) def incr(self, key, value, default=0, time=1000000): """ diff --git a/bmemcached/client/mixin.py b/bmemcached/client/mixin.py index 2d9ba26..80401af 100644 --- a/bmemcached/client/mixin.py +++ b/bmemcached/client/mixin.py @@ -132,19 +132,19 @@ def gets(self, key): def get_multi(self, keys, get_cas=False): raise NotImplementedError() - def set(self, key, value, time=0, compress_level=-1): + def set(self, key, value, time=0, compress_level=-1, get_cas=False): raise NotImplementedError() - def cas(self, key, value, cas, time=0, compress_level=-1): + def cas(self, key, value, cas, time=0, compress_level=-1, get_cas=False): raise NotImplementedError() def set_multi(self, mappings, time=0, compress_level=-1): raise NotImplementedError() - def add(self, key, value, time=0, compress_level=-1): + def add(self, key, value, time=0, compress_level=-1, get_cas=False): raise NotImplementedError() - def replace(self, key, value, time=0, compress_level=-1): + def replace(self, key, value, time=0, compress_level=-1, get_cas=False): raise NotImplementedError() def delete(self, key, cas=0): # type: (six.string_types, int) -> bool diff --git a/bmemcached/client/replicating.py b/bmemcached/client/replicating.py index 7d99ca4..706c4db 100644 --- a/bmemcached/client/replicating.py +++ b/bmemcached/client/replicating.py @@ -133,7 +133,7 @@ def get_multi(self, keys, get_cas=False): break return d - def set(self, key, value, time=0, compress_level=-1): + def set(self, key, value, time=0, compress_level=-1, get_cas=False): """ Set a value for a key on server. @@ -147,16 +147,31 @@ def set(self, key, value, time=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True in case of success and False in case of failure - :rtype: bool + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. Only supported when + the client is configured with a single server; see the class + docstring for why CAS and multi-server replication don't mix. + :type get_cas: bool + :return: True in case of success and False in case of failure, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple + :raises NotImplementedError: if get_cas=True and more than one + server is configured. """ + if get_cas: + if len(self._servers) > 1: + raise NotImplementedError( + "get_cas=True is not supported on ReplicatingClient with " + "more than one server." + ) + return self._servers[0].set(key, value, time, compress_level=compress_level, get_cas=True) + returns = [] for server in self.servers: returns.append(server.set(key, value, time, compress_level=compress_level)) - return any(returns) - def cas(self, key, value, cas, time=0, compress_level=-1): + def cas(self, key, value, cas, time=0, compress_level=-1, get_cas=False): """ Set a value for a key on server if its CAS value matches cas. @@ -178,13 +193,28 @@ def cas(self, key, value, cas, time=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True in case of success and False in case of failure - :rtype: bool + :param get_cas: If true, return (success, new_cas) where new_cas is + the item's new CAS after the operation, or None on failure. Only + supported when the client is configured with a single server; + see the class docstring. + :type get_cas: bool + :return: True in case of success and False in case of failure, or a + (success, new_cas) tuple if get_cas=True. + :rtype: bool or tuple + :raises NotImplementedError: if get_cas=True and more than one + server is configured. """ + if get_cas: + if len(self._servers) > 1: + raise NotImplementedError( + "get_cas=True is not supported on ReplicatingClient with " + "more than one server." + ) + return self._servers[0].cas(key, value, cas, time, compress_level=compress_level, get_cas=True) + returns = [] for server in self.servers: returns.append(server.cas(key, value, cas, time, compress_level=compress_level)) - return any(returns) def set_multi(self, mappings, time=0, compress_level=-1): @@ -215,7 +245,7 @@ def set_multi(self, mappings, time=0, compress_level=-1): return list(returns) - def add(self, key, value, time=0, compress_level=-1): + def add(self, key, value, time=0, compress_level=-1, get_cas=False): """ Add a key/value to server ony if it does not exist. @@ -229,16 +259,31 @@ def add(self, key, value, time=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True if key is added False if key already exists - :rtype: bool + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. Only supported when + the client is configured with a single server; see the class + docstring. + :type get_cas: bool + :return: True if key is added False if key already exists, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple + :raises NotImplementedError: if get_cas=True and more than one + server is configured. """ + if get_cas: + if len(self._servers) > 1: + raise NotImplementedError( + "get_cas=True is not supported on ReplicatingClient with " + "more than one server." + ) + return self._servers[0].add(key, value, time, compress_level=compress_level, get_cas=True) + returns = [] for server in self.servers: returns.append(server.add(key, value, time, compress_level=compress_level)) - return any(returns) - def replace(self, key, value, time=0, compress_level=-1): + def replace(self, key, value, time=0, compress_level=-1, get_cas=False): """ Replace a key/value to server ony if it does exist. @@ -252,13 +297,28 @@ def replace(self, key, value, time=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True if key is replace False if key does not exists - :rtype: bool + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. Only supported when + the client is configured with a single server; see the class + docstring. + :type get_cas: bool + :return: True if key is replace False if key does not exists, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple + :raises NotImplementedError: if get_cas=True and more than one + server is configured. """ + if get_cas: + if len(self._servers) > 1: + raise NotImplementedError( + "get_cas=True is not supported on ReplicatingClient with " + "more than one server." + ) + return self._servers[0].replace(key, value, time, compress_level=compress_level, get_cas=True) + returns = [] for server in self.servers: returns.append(server.replace(key, value, time, compress_level=compress_level)) - return any(returns) def delete(self, key, cas=0): diff --git a/bmemcached/protocol.py b/bmemcached/protocol.py index 06d7cb3..29a213a 100644 --- a/bmemcached/protocol.py +++ b/bmemcached/protocol.py @@ -573,8 +573,9 @@ def _set_add_replace(self, command, key, value, time, cas=0, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True in case of success and False in case of failure - :rtype: bool + :return: A (success, cas) tuple. success is True on success and False + on failure; cas is the new CAS value on success and None otherwise. + :rtype: tuple """ time = time if time >= 0 else self.MAXIMUM_EXPIRE_TIME logger.debug('Setting/adding/replacing key %s.', key) @@ -596,16 +597,16 @@ def _set_add_replace(self, command, key, value, time, cas=0, compress_level=-1): if status != self.STATUS['success']: if status == self.STATUS['key_exists']: - return False + return False, None elif status == self.STATUS['key_not_found']: - return False + return False, None elif status == self.STATUS['server_disconnected']: - return False + return False, None raise MemcachedException('Code: %d Message: %s' % (status, extra_content), status) - return True + return True, cas - def set(self, key, value, time, compress_level=-1): + def set(self, key, value, time, compress_level=-1, get_cas=False): """ Set a value for a key on server. @@ -619,12 +620,19 @@ def set(self, key, value, time, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True in case of success and False in case of failure - :rtype: bool + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. + :type get_cas: bool + :return: True in case of success and False in case of failure, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple """ - return self._set_add_replace('set', key, value, time, compress_level=compress_level) + success, cas = self._set_add_replace('set', key, value, time, compress_level=compress_level) + if get_cas: + return success, cas + return success - def cas(self, key, value, cas, time, compress_level=-1): + def cas(self, key, value, cas, time, compress_level=-1, get_cas=False): """ Add a key/value to server ony if it does not exist. @@ -638,8 +646,12 @@ def cas(self, key, value, cas, time, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True if key is added False if key already exists and has a different CAS - :rtype: bool + :param get_cas: If true, return (success, new_cas) where new_cas is + the item's new CAS after the operation, or None on failure. + :type get_cas: bool + :return: True if key is added False if key already exists and has a + different CAS, or a (success, new_cas) tuple if get_cas=True. + :rtype: bool or tuple """ # The protocol CAS value 0 means "no cas". Calling cas() with that value is # probably unintentional. Don't allow it, since it would overwrite the value @@ -649,11 +661,14 @@ def cas(self, key, value, cas, time, compress_level=-1): # If we get a cas of None, interpret that as "compare against nonexistant and set", # which is simply Add. if cas is None: - return self._set_add_replace('add', key, value, time, compress_level=compress_level) + success, new_cas = self._set_add_replace('add', key, value, time, compress_level=compress_level) else: - return self._set_add_replace('set', key, value, time, cas=cas, compress_level=compress_level) + success, new_cas = self._set_add_replace('set', key, value, time, cas=cas, compress_level=compress_level) + if get_cas: + return success, new_cas + return success - def add(self, key, value, time, compress_level=-1): + def add(self, key, value, time, compress_level=-1, get_cas=False): """ Add a key/value to server ony if it does not exist. @@ -667,12 +682,19 @@ def add(self, key, value, time, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True if key is added False if key already exists - :rtype: bool + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. + :type get_cas: bool + :return: True if key is added False if key already exists, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple """ - return self._set_add_replace('add', key, value, time, compress_level=compress_level) + success, cas = self._set_add_replace('add', key, value, time, compress_level=compress_level) + if get_cas: + return success, cas + return success - def replace(self, key, value, time, compress_level=-1): + def replace(self, key, value, time, compress_level=-1, get_cas=False): """ Replace a key/value to server ony if it does exist. @@ -686,10 +708,17 @@ def replace(self, key, value, time, compress_level=-1): 0 = no compression, 1 = fastest, 9 = slowest but best, -1 = default compression level. :type compress_level: int - :return: True if key is replace False if key does not exists - :rtype: bool - """ - return self._set_add_replace('replace', key, value, time, compress_level=compress_level) + :param get_cas: If true, return (success, cas) where cas is the new + CAS value on success and None on failure. + :type get_cas: bool + :return: True if key is replace False if key does not exists, or a + (success, cas) tuple if get_cas=True. + :rtype: bool or tuple + """ + success, cas = self._set_add_replace('replace', key, value, time, compress_level=compress_level) + if get_cas: + return success, cas + return success def set_multi(self, mappings, time=100, compress_level=-1): """ diff --git a/test/test_simple_functions.py b/test/test_simple_functions.py index f69ce26..f97554a 100644 --- a/test/test_simple_functions.py +++ b/test/test_simple_functions.py @@ -196,6 +196,83 @@ def testAddFail(self): self.client.add('test_key', 'value') self.assertFalse(self.client.add('test_key', 'test')) + def testAddCas(self): + success, cas = self.client.add('test_key', 'value', get_cas=True) + self.assertTrue(success) + self.assertTrue(cas is not None) + + # The CAS returned by add() must equal the CAS later returned by gets(). + _, gets_cas = self.client.gets('test_key') + self.assertEqual(cas, gets_cas) + + # A second add of the same key fails; cas is None. + success2, cas2 = self.client.add('test_key', 'value2', get_cas=True) + self.assertFalse(success2) + self.assertTrue(cas2 is None) + + # The CAS returned from add() can be used directly in cas() without + # a separate gets() round-trip. + self.assertTrue(self.client.cas('test_key', 'value3', cas)) + self.assertEqual('value3', self.client.get('test_key')) + + # Backward compatibility: with no get_cas kwarg, add() still returns a plain bool. + result = self.client.add('test_key2', 'value') + self.assertEqual(True, result) + + def testSetCas(self): + # set() with get_cas=True returns (True, cas) and cas matches gets(). + success, cas = self.client.set('test_key', 'v1', get_cas=True) + self.assertTrue(success) + self.assertTrue(cas is not None) + _, gets_cas = self.client.gets('test_key') + self.assertEqual(cas, gets_cas) + + # The returned CAS is usable directly in cas() without a gets() round-trip. + self.assertTrue(self.client.cas('test_key', 'v2', cas)) + self.assertEqual('v2', self.client.get('test_key')) + + # Backward compatibility: no get_cas kwarg still returns a plain bool. + self.assertEqual(True, self.client.set('test_key2', 'v')) + + def testReplaceCas(self): + # Replace on a nonexistent key fails; cas is None. + success, cas = self.client.replace('test_key', 'v', get_cas=True) + self.assertFalse(success) + self.assertTrue(cas is None) + + # Replace on an existing key succeeds and returns the new CAS. + self.client.set('test_key', 'original') + success, cas = self.client.replace('test_key', 'new', get_cas=True) + self.assertTrue(success) + self.assertTrue(cas is not None) + _, gets_cas = self.client.gets('test_key') + self.assertEqual(cas, gets_cas) + + # Backward compatibility: no get_cas kwarg still returns a plain bool. + self.assertEqual(True, self.client.replace('test_key', 'x')) + + def testCasCas(self): + # cas() with get_cas=True, invoked as add (cas=None): returns new CAS. + success, cas = self.client.cas('test_key', 'v1', None, get_cas=True) + self.assertTrue(success) + self.assertTrue(cas is not None) + + # Chain a second CAS using the returned value directly (no gets()). + success2, cas2 = self.client.cas('test_key', 'v2', cas, get_cas=True) + self.assertTrue(success2) + self.assertTrue(cas2 is not None) + self.assertNotEqual(cas, cas2) + self.assertEqual('v2', self.client.get('test_key')) + + # A stale CAS fails; the returned new_cas is None. + success3, cas3 = self.client.cas('test_key', 'v3', cas, get_cas=True) + self.assertFalse(success3) + self.assertTrue(cas3 is None) + self.assertEqual('v2', self.client.get('test_key')) + + # Backward compatibility: no get_cas kwarg still returns a plain bool. + self.assertEqual(True, self.client.cas('test_key', 'v4', cas2)) + def testReplacePass(self): self.client.add('test_key', 'value') self.assertTrue(self.client.replace('test_key', 'value2')) @@ -237,6 +314,30 @@ def testReconnect(self): self.client.disconnect_all() self.assertEqual('test', self.client.get('test_key')) + def testGetCasMultiReplicaRaises(self): + # A ReplicatingClient with >1 server can't safely return a CAS, because + # each replica has its own CAS counter. Confirm every new get_cas path + # raises loudly instead of silently diverging replicas. + client = bmemcached.Client( + ['/tmp/memcached.sock', '{}:11211'.format(os.environ['MEMCACHED_HOST'])], + 'user', 'password', + ) + try: + with self.assertRaises(NotImplementedError): + client.add('test_key', 'v', get_cas=True) + with self.assertRaises(NotImplementedError): + client.set('test_key', 'v', get_cas=True) + with self.assertRaises(NotImplementedError): + client.replace('test_key', 'v', get_cas=True) + with self.assertRaises(NotImplementedError): + client.cas('test_key', 'v', None, get_cas=True) + + # get_cas=False (default) still works fine on multi-replica. + self.assertTrue(client.set('test_key', 'v')) + finally: + client.delete('test_key') + client.disconnect_all() + class TimeoutMemcachedTests(unittest.TestCase): def setUp(self): From 2ec16f4e8aee128efacac58688e8eca2e04785b6 Mon Sep 17 00:00:00 2001 From: Alex Vandiver Date: Fri, 24 Apr 2026 03:04:52 +0000 Subject: [PATCH 3/3] Add set_multi_cas for per-key CAS return from batched writes set_multi's current return shape is a list of failed keys, which can't carry a per-key CAS value. It also uses the quiet setq/addq opcodes, which intentionally suppress successful responses -- so even if the shape allowed it, the wire protocol wouldn't return a CAS per key. Add a separate `set_multi_cas` method that uses the non-quiet set/add opcodes (one response per key) and returns `{str_key: int | None}` for every input key -- int on success, None on failure. The existing `{(key, cas): value}` input syntax from set_multi is preserved; the result dict is keyed by the string key regardless of which form was passed. For ReplicatingClient, the returned CAS per key is the first non-None CAS from any replica, matching the single-key helpers. --- bmemcached/client/distributed.py | 30 +++++++++++++++ bmemcached/client/mixin.py | 3 ++ bmemcached/client/replicating.py | 32 ++++++++++++++++ bmemcached/protocol.py | 66 ++++++++++++++++++++++++++++++++ test/test_simple_functions.py | 48 +++++++++++++++++++++++ 5 files changed, 179 insertions(+) diff --git a/bmemcached/client/distributed.py b/bmemcached/client/distributed.py index 3e31164..1f7e52b 100644 --- a/bmemcached/client/distributed.py +++ b/bmemcached/client/distributed.py @@ -90,6 +90,36 @@ def set_multi(self, mappings, time=0, compress_level=-1): return list(returns) + def set_multi_cas(self, mappings, time=0, compress_level=-1): + """ + Set multiple keys with their values on server, returning the new CAS + value for each successfully stored key. + + :param mappings: A dict with keys/values. Keys may be (key, cas) + tuples as in set_multi. + :type mappings: dict + :param time: Time in seconds that your key will expire. + :type time: int + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int + :return: A dict keyed by the string key of every input mapping. The + value is the new CAS int on success or None on failure. + :rtype: dict + """ + if not mappings: + return {} + result = {} + server_mappings = defaultdict(dict) + for key, value in mappings.items(): + str_key = key[0] if isinstance(key, tuple) else key + server_key = self._get_server(str_key) + server_mappings[server_key][key] = value + for server, m in server_mappings.items(): + result.update(server.set_multi_cas(m, time, compress_level)) + return result + def add(self, key, value, time=0, compress_level=-1, get_cas=False): """ Add a key/value to server ony if it does not exist. diff --git a/bmemcached/client/mixin.py b/bmemcached/client/mixin.py index 80401af..26dd01b 100644 --- a/bmemcached/client/mixin.py +++ b/bmemcached/client/mixin.py @@ -141,6 +141,9 @@ def cas(self, key, value, cas, time=0, compress_level=-1, get_cas=False): def set_multi(self, mappings, time=0, compress_level=-1): raise NotImplementedError() + def set_multi_cas(self, mappings, time=0, compress_level=-1): + raise NotImplementedError() + def add(self, key, value, time=0, compress_level=-1, get_cas=False): raise NotImplementedError() diff --git a/bmemcached/client/replicating.py b/bmemcached/client/replicating.py index 706c4db..c82ae31 100644 --- a/bmemcached/client/replicating.py +++ b/bmemcached/client/replicating.py @@ -245,6 +245,38 @@ def set_multi(self, mappings, time=0, compress_level=-1): return list(returns) + def set_multi_cas(self, mappings, time=0, compress_level=-1): + """ + Set multiple keys with their values on the server, returning the new + CAS value for each successfully stored key. + + Only supported when the client is configured with a single server; + see the class docstring for why CAS and multi-server replication + don't mix. + + :param mappings: A dict with keys/values. Keys may be (key, cas) + tuples as in set_multi. + :type mappings: dict + :param time: Time in seconds that your key will expire. + :type time: int + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int + :return: A dict keyed by the string key of every input mapping. The + value is the new CAS int on success or None on failure. + :rtype: dict + :raises NotImplementedError: if more than one server is configured. + """ + if len(self._servers) > 1: + raise NotImplementedError( + "set_multi_cas is not supported on ReplicatingClient with " + "more than one server." + ) + if not mappings: + return {} + return self._servers[0].set_multi_cas(mappings, time, compress_level=compress_level) + def add(self, key, value, time=0, compress_level=-1, get_cas=False): """ Add a key/value to server ony if it does not exist. diff --git a/bmemcached/protocol.py b/bmemcached/protocol.py index 29a213a..b1132fb 100644 --- a/bmemcached/protocol.py +++ b/bmemcached/protocol.py @@ -789,6 +789,72 @@ def set_multi(self, mappings, time=100, compress_level=-1): return failed + def set_multi_cas(self, mappings, time=100, compress_level=-1): + """ + Set multiple keys with their values on server and return the new CAS + value for each successfully stored key. + + If a key is a (key, cas) tuple, insert as if cas(key, value, cas) had + been called. A cas of 0 means add-if-not-exists. + + Unlike set_multi, this uses the non-quiet set/add opcodes so that the + server responds to every request; this costs one response per key but + is what allows per-key CAS values to be returned. + + :param mappings: A dict with keys/values + :type mappings: dict + :param time: Time in seconds that your key will expire. + :type time: int + :param compress_level: How much to compress. + 0 = no compression, 1 = fastest, 9 = slowest but best, + -1 = default compression level. + :type compress_level: int + :return: A dict keyed by the string key of every input mapping. The + value is the new CAS int on success or None on failure. + :rtype: dict + """ + mappings = list(mappings.items()) + msg = bytearray() + result = {} + + for opaque, (key, value) in enumerate(mappings): + if isinstance(key, tuple): + str_key, cas = key + else: + str_key, cas = key, None + result[str_key] = None + + if cas == 0: + command = 'add' + else: + command = 'set' + + keybytes = str_to_bytes(str_key) + flags, value = self.serialize(value, compress_level=compress_level) + msg += struct.pack(self.HEADER_STRUCT + + self.COMMANDS[command]['struct'] % (len(keybytes), len(value)), + self.MAGIC['request'], + self.COMMANDS[command]['command'], + len(keybytes), + 8, 0, 0, len(keybytes) + len(value) + 8, opaque, cas or 0, + flags, time, keybytes, value) + + self._send(msg) + + # Non-quiet set/add return exactly one response per request, so we can + # read a fixed count rather than relying on a trailing noop sentinel. + for _ in range(len(mappings)): + (magic, opcode, keylen, extlen, datatype, status, bodylen, opaque, + cas, extra_content) = self._get_response() + if status == self.STATUS['server_disconnected']: + return result + if status == self.STATUS['success']: + key, value = mappings[opaque] + str_key = key[0] if isinstance(key, tuple) else key + result[str_key] = cas + + return result + def _incr_decr(self, command, key, value, default, time): """ Function which increments and decrements. diff --git a/test/test_simple_functions.py b/test/test_simple_functions.py index f97554a..e6e625c 100644 --- a/test/test_simple_functions.py +++ b/test/test_simple_functions.py @@ -27,6 +27,7 @@ def tearDown(self): def reset(self): self.client.delete('test_key') self.client.delete('test_key2') + self.client.delete('fresh_key') def testSet(self): self.assertTrue(self.client.set('test_key', 'test')) @@ -120,6 +121,51 @@ def testMultiCas(self): }), []) self.assertEqual(self.client.get('test_key'), 'value4') + def testSetMultiCas(self): + # All-success plain keys: every input gets a non-None CAS, and each + # returned CAS matches what gets() reports afterwards. + result = self.client.set_multi_cas({ + 'test_key': 'value1', + 'test_key2': 'value2', + }) + self.assertEqual(set(result.keys()), {'test_key', 'test_key2'}) + self.assertTrue(result['test_key'] is not None) + self.assertTrue(result['test_key2'] is not None) + _, cas1 = self.client.gets('test_key') + _, cas2 = self.client.gets('test_key2') + self.assertEqual(result['test_key'], cas1) + self.assertEqual(result['test_key2'], cas2) + + # CAS failure: add-if-not-exists when the key already exists returns + # None for that key; unrelated keys still succeed. + result = self.client.set_multi_cas({ + ('test_key', 0): 'shouldnt_store', + 'fresh_key': 'fresh', + }) + self.assertTrue(result['test_key'] is None) + self.assertTrue(result['fresh_key'] is not None) + self.assertEqual(self.client.get('test_key'), 'value1') + self.client.delete('fresh_key') + + # Stale-CAS failure: capture cas, mutate out of band, then set_multi_cas + # with the stale cas must fail and leave the out-of-band value intact. + _, stale_cas = self.client.gets('test_key') + self.client.set('test_key', 'other') + result = self.client.set_multi_cas({ + ('test_key', stale_cas): 'should_fail', + }) + self.assertTrue(result['test_key'] is None) + self.assertEqual(self.client.get('test_key'), 'other') + + # Returned CAS is usable directly in cas() without a gets() round-trip. + self.client.delete('test_key') + result = self.client.set_multi_cas({'test_key': 'v'}) + self.assertTrue(self.client.cas('test_key', 'v2', result['test_key'])) + self.assertEqual(self.client.get('test_key'), 'v2') + + def testSetMultiCasEmpty(self): + self.assertEqual(self.client.set_multi_cas({}), {}) + def testGetMultiCas(self): self.client.set('test_key', 'value1') self.client.set('test_key2', 'value2') @@ -331,6 +377,8 @@ def testGetCasMultiReplicaRaises(self): client.replace('test_key', 'v', get_cas=True) with self.assertRaises(NotImplementedError): client.cas('test_key', 'v', None, get_cas=True) + with self.assertRaises(NotImplementedError): + client.set_multi_cas({'test_key': 'v'}) # get_cas=False (default) still works fine on multi-replica. self.assertTrue(client.set('test_key', 'v'))