From b4189d70083a611d308ba227a1d7fe6ed85c9c26 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 4 Mar 2026 10:02:03 +0800 Subject: [PATCH 1/7] update experience and buffer --- pyproject.toml | 1 + tests/common/experience_test.py | 312 +++++---------------- trinity/buffer/reader/queue_reader.py | 64 ++++- trinity/buffer/storage/queue.py | 179 +++++++++++- trinity/buffer/writer/queue_writer.py | 33 ++- trinity/common/config.py | 17 ++ trinity/common/experience.py | 374 +++++++++++++------------- 7 files changed, 525 insertions(+), 455 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6cbdf19b3db..ca48f4b9254 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "openai", "jsonlines", "sortedcontainers", + "pyzmq>=25.1.0", "word2number", "matplotlib", "transformers>=4.51.0", diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index 55eabca7212..883a4110fbb 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*- """Test cases for Storage modules.""" import os +import pickle import unittest import torch from trinity.buffer.schema.sql_schema import ExperienceModel -from trinity.common.experience import EID, CustomField, Experience, Experiences +from trinity.common.experience import EID, CustomField, Experience db_url = os.path.join(os.path.dirname(__file__), "tmp", "test.db") dataset_path = os.path.join(os.path.dirname(__file__), "data") @@ -79,7 +80,66 @@ def test_serialize_deserialize(self): self.assertTrue(torch.equal(exp.tokens, exp2.tokens)) self.assertEqual(exp.reward, exp2.reward) self.assertEqual(exp.prompt_length, exp2.prompt_length) - self.assertEqual(exp.experience_type, exp2.experience_type) + + def test_serialize_many_deserialize_many(self): + exp1 = Experience( + eid=EID(batch=1, task=1, run=1, step=1), + tokens=torch.tensor([1, 2, 3], dtype=torch.int32), + logprobs=torch.tensor([0.1, 0.2], dtype=torch.float32), + reward=1.0, + prompt_length=1, + info={"source": "a"}, + metrics={"m": 1.0}, + multi_modal_inputs={"image": torch.randn(2, 3)}, + custom_fields=[ + CustomField( + source_field="foo", + destination_field="bar", + data_type=torch.float32, + ) + ], + ) + exp2 = Experience( + eid=EID(batch=1, task=1, run=2, step=1), + tokens=torch.tensor([4, 5, 6, 7], dtype=torch.int32), + reward=2.0, + prompt_length=2, + info={"source": "b"}, + metrics={"m": 2.0}, + ) + + data = Experience.serialize_many([exp1, exp2]) + restored = Experience.deserialize_many(data) + + self.assertEqual(len(restored), 2) + self.assertTrue(torch.equal(restored[0].tokens, exp1.tokens)) + self.assertTrue(torch.equal(restored[0].logprobs, exp1.logprobs)) + self.assertEqual(restored[0].reward, exp1.reward) + self.assertEqual(restored[0].info, exp1.info) + self.assertEqual(restored[0].metrics, exp1.metrics) + self.assertIsNotNone(restored[0].multi_modal_inputs) + self.assertIn("image", restored[0].multi_modal_inputs) + self.assertTrue(torch.equal(restored[0].multi_modal_inputs["image"], exp1.multi_modal_inputs["image"])) + self.assertEqual(len(restored[0].custom_fields), 1) + self.assertEqual(restored[0].custom_fields[0].destination_field, "bar") + + self.assertTrue(torch.equal(restored[1].tokens, exp2.tokens)) + self.assertEqual(restored[1].reward, exp2.reward) + self.assertEqual(restored[1].prompt_length, exp2.prompt_length) + + def test_deserialize_legacy_pickle_payload(self): + exp = Experience(tokens=torch.tensor([1, 2, 3]), reward=1.23, prompt_length=1) + legacy_data = pickle.dumps(exp) + restored = Experience.deserialize(legacy_data) + self.assertTrue(torch.equal(restored.tokens, exp.tokens)) + self.assertEqual(restored.reward, exp.reward) + + def test_deserialize_single_rejects_batch_payload(self): + exp1 = Experience(tokens=torch.tensor([1, 2, 3]), prompt_length=1) + exp2 = Experience(tokens=torch.tensor([4, 5, 6]), prompt_length=1) + payload = Experience.serialize_many([exp1, exp2]) + with self.assertRaises(ValueError): + Experience.deserialize(payload) def test_to_dict(self): tokens = torch.tensor([1, 2, 3]) @@ -94,77 +154,6 @@ def test_to_dict(self): self.assertEqual(d["response_text"], "yo") self.assertEqual(d["reward"], 2.5) - def test_gather(self): - # test empty gathering - batch = Experiences.gather_experiences([]) - self.assertEqual(batch.tokens.numel(), 0) - self.assertEqual(batch.rewards.numel(), 0) - self.assertEqual(batch.eids, []) - - # test single experience gathering - exp = Experience(tokens=torch.tensor([1, 2, 3]), reward=1.0, prompt_length=1) - batch = Experiences.gather_experiences([exp]) - self.assertEqual(batch.batch_size, 1) - self.assertTrue( - torch.equal(batch.tokens[0], torch.tensor([0, 1, 2, 3], dtype=torch.int64)[-3:]) - ) - self.assertEqual(batch.prompt_length, 1) - self.assertEqual(batch.rewards[0], 1.0) - - # test multiple experiences gathering - exps = [ - Experience(tokens=torch.tensor([1, 2]), reward=0.1, prompt_length=1), - Experience(tokens=torch.tensor([3, 4, 5]), reward=0.2, prompt_length=2), - ] - batch = Experiences.gather_experiences(exps) - self.assertEqual(batch.batch_size, 2) - self.assertEqual(batch.prompt_length, 2) - self.assertEqual(batch.tokens.shape[1], 3) - self.assertEqual(batch.rewards[0], 0.1) - self.assertEqual(batch.rewards[1], 0.2) - - def test_gather_with_token_level_reward(self): - # test empty gathering - batch = Experiences.gather_experiences([]) - self.assertEqual(batch.tokens.numel(), 0) - self.assertEqual(batch.rewards.numel(), 0) - self.assertEqual(batch.token_level_rewards.numel(), 0) - self.assertEqual(batch.eids, []) - - # test single experience gathering - exp = Experience( - tokens=torch.tensor([1, 2, 3]), - token_level_reward=torch.tensor([0, 1.0]), - prompt_length=1, - ) - batch = Experiences.gather_experiences([exp]) - self.assertEqual(batch.batch_size, 1) - self.assertTrue( - torch.equal(batch.tokens[0], torch.tensor([0, 1, 2, 3], dtype=torch.int64)[-3:]) - ) - self.assertEqual(batch.prompt_length, 1) - self.assertIsNone(batch.rewards) - self.assertTrue(torch.equal(batch.token_level_rewards[0], torch.tensor([0, 1.0]))) - - # test multiple experiences gathering - exps = [ - Experience( - tokens=torch.tensor([1, 2]), token_level_reward=torch.tensor([0.1]), prompt_length=1 - ), - Experience( - tokens=torch.tensor([3, 4, 5]), - token_level_reward=torch.tensor([0.2]), - prompt_length=2, - ), - ] - batch = Experiences.gather_experiences(exps) - self.assertEqual(batch.batch_size, 2) - self.assertEqual(batch.prompt_length, 2) - self.assertEqual(batch.tokens.shape[1], 3) - self.assertIsNone(batch.rewards) - self.assertTrue(torch.equal(batch.token_level_rewards[0], torch.tensor([0.1]))) - self.assertTrue(torch.equal(batch.token_level_rewards[1], torch.tensor([0.2]))) - def test_action_mask_and_logprobs_type(self): exp = Experience(tokens=[1, 2, 3], logprobs=[0.1, 0.2, 0.3], prompt_length=1) self.assertIsInstance(exp.tokens, torch.Tensor) @@ -264,183 +253,6 @@ def test_experience_model_experience_conversion(self): self.assertTrue(torch.equal(new_experience.logprobs, logprobs)) self.assertTrue(torch.equal(new_experience.action_mask, experience.action_mask)) - def test_batch_conversion(self): - exps = [ - Experience( - tokens=torch.tensor([1, 2]), - prompt_length=1, - reward=float(0.1), - logprobs=torch.tensor([0.1]), - advantages=torch.tensor([0.1]), - returns=torch.tensor([0.4]), - ), - Experience( - tokens=torch.tensor([1, 2, 3]), - prompt_length=2, - reward=float(0.2), - logprobs=torch.tensor([0.1]), - advantages=torch.tensor([0.3]), - returns=torch.tensor([0.2]), - ), - ] - batch = Experiences.gather_experiences(exps) - self.assertEqual(batch.batch_size, 2) - self.assertEqual(batch.prompt_length, 2) - prompt_length = batch.prompt_length - for i in range(batch.batch_size): - self.assertEqual(batch.rewards[i], exps[i].reward) - self.assertTrue( - torch.all( - batch.tokens[i][ - prompt_length - - exps[i].prompt_length : prompt_length - - exps[i].prompt_length - + exps[i].tokens.size(0) - ] - == exps[i].tokens - ) - ) - self.assertTrue( - torch.all( - batch.logprobs[i][: exps[i].tokens.size(0) - exps[i].prompt_length] - == exps[i].logprobs - ) - ) - self.assertTrue( - torch.all( - batch.action_masks[i][: exps[i].tokens.size(0) - exps[i].prompt_length] - == exps[i].action_mask - ) - ) - self.assertTrue( - torch.all( - batch.advantages[i][: exps[i].tokens.size(0) - exps[i].prompt_length] - == exps[i].advantages - ) - ) - self.assertTrue( - torch.all( - batch.returns[i][: exps[i].tokens.size(0) - exps[i].prompt_length] - == exps[i].returns - ) - ) - - def test_multiturn_experience_batch_converstion(self): - exps = [ - Experience( - tokens=torch.tensor([1, 2, 3, 4, 5, 6]), - reward=float(0.3), - logprobs=torch.tensor([0, 0.1, 0.2, 0.3]), - prompt_length=2, - action_mask=torch.tensor([1, 0, 1, 1]), - advantages=torch.tensor([0.1, 0, 0.2, 0.3]), - returns=torch.tensor([0.5, 0, 0.7, 0.8]), - ), - Experience( - tokens=torch.tensor([1, 2, 3, 4]), - reward=float(0.4), - logprobs=torch.tensor([0, 0.1]), - prompt_length=2, - action_mask=torch.tensor([1, 1]), - advantages=torch.tensor([0.2, 0.3]), - returns=torch.tensor([0.6, 0.9]), - ), - ] - batch = Experiences.gather_experiences(exps) - self.assertEqual(batch.batch_size, 2) - self.assertEqual(batch.prompt_length, 2) - prompt_length = batch.prompt_length - for i in range(batch.batch_size): - self.assertEqual(batch.rewards[i], exps[i].reward) - self.assertTrue( - torch.all( - batch.tokens[i][ - prompt_length - - exps[i].prompt_length : prompt_length - - exps[i].prompt_length - + exps[i].tokens.size(0) - ] - == exps[i].tokens - ) - ) - self.assertTrue( - torch.all( - batch.logprobs[i][: exps[i].tokens.size(0) - exps[i].prompt_length] - == exps[i].logprobs - ) - ) - self.assertTrue( - torch.all( - batch.action_masks[i][: exps[i].tokens.size(0) - exps[i].prompt_length] - == exps[i].action_mask - ) - ) - self.assertTrue( - torch.all( - batch.advantages[i][: exps[i].tokens.size(0) - exps[i].prompt_length] - == exps[i].advantages - ) - ) - self.assertTrue( - torch.all( - batch.returns[i][: exps[i].tokens.size(0) - exps[i].prompt_length] - == exps[i].returns - ) - ) - - def test_dpo_experience_batch_conversion(self): - exps = [ - Experience( - tokens=torch.tensor([1, 2]), - chosen=torch.tensor([3, 4]), - rejected=torch.tensor([5, 6]), - ), - Experience( - tokens=torch.tensor([7, 8, 9]), - chosen=torch.tensor([10, 11]), - rejected=torch.tensor([12, 13]), - ), - ] - batch = Experiences.gather_experiences(exps) - self.assertEqual(batch.batch_size, 4) - self.assertEqual(batch.prompt_length, 3) - prompt_length = batch.prompt_length - for i in range(batch.batch_size): - j = i // 2 - self.assertTrue( - torch.all( - batch.tokens[i][ - prompt_length - - exps[j].prompt_length : prompt_length - - exps[j].prompt_length - + exps[j].tokens.size(0) - ] - == exps[j].tokens - ) - ) - - def test_gather_experiences_with_custom_fields(self): - # test multiple experiences gathering - exps = [ - Experience( - tokens=torch.tensor([1, 2]), reward=0.1, prompt_length=1, info={"a": 1.0, "b": 3} - ), - Experience( - tokens=torch.tensor([3, 4, 5]), reward=0.2, prompt_length=2, info={"a": 2, "c": 4} - ), - ] - batch = Experiences.gather_experiences( - exps, custom_fields=[CustomField("a", "a", torch.float32)] - ) - self.assertEqual(batch.batch_size, 2) - self.assertEqual(batch.prompt_length, 2) - self.assertEqual(batch.tokens.shape[1], 3) - self.assertEqual(batch.rewards[0], 0.1) - self.assertEqual(batch.rewards[1], 0.2) - self.assertIn("a", batch.custom_fields) - self.assertEqual(batch.a[0], 1.0) - self.assertEqual(batch.a[1], 2.0) - if __name__ == "__main__": unittest.main() diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index 6f8e7c03343..39c6b30a096 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -1,6 +1,9 @@ """Reader of the Queue buffer.""" +import asyncio + from typing import Dict, List, Optional +from typing import Any, cast import ray @@ -8,6 +11,7 @@ from trinity.buffer.storage.queue import QueueStorage from trinity.common.config import StorageConfig from trinity.common.constants import StorageType +from trinity.common.experience import Experience class QueueReader(BufferReader): @@ -18,21 +22,64 @@ def __init__(self, config: StorageConfig): self.timeout = config.max_read_timeout self.read_batch_size = config.batch_size self.queue = QueueStorage.get_wrapper(config) + self._use_zmq = False + self._zmq_socket = None - def read(self, batch_size: Optional[int] = None, **kwargs) -> List: - try: - batch_size = self.read_batch_size if batch_size is None else batch_size - exps = ray.get(self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs)) + zmq_config = config.zmq + if zmq_config is not None and zmq_config.enable: + endpoints = cast(Dict[str, Any], ray.get(self.queue.get_zmq_endpoints.remote())) + if endpoints.get("enabled", False): + import zmq.asyncio + + self._zmq_socket = zmq.asyncio.Context.instance().socket(zmq.REQ) + self._zmq_socket.setsockopt(zmq.SNDHWM, zmq_config.sndhwm) + self._zmq_socket.setsockopt(zmq.RCVHWM, zmq_config.rcvhwm) + self._zmq_socket.setsockopt(zmq.LINGER, zmq_config.linger_ms) + self._zmq_socket.setsockopt(zmq.SNDTIMEO, int(config.max_read_timeout * 1000)) + self._zmq_socket.setsockopt(zmq.RCVTIMEO, int(config.max_read_timeout * 1000)) + self._zmq_socket.connect(endpoints["reader_endpoint"]) + self._use_zmq = True + + async def _read_via_zmq(self, batch_size: int, **kwargs) -> List: + assert self._zmq_socket is not None + min_model_version = int(kwargs.get("min_model_version", 0)) + request = { + "cmd": "get_batch", + "batch_size": batch_size, + "timeout": float(self.timeout), + "min_model_version": min_model_version, + } + await self._zmq_socket.send_json(request) + status, payload = await self._zmq_socket.recv_multipart() + status_text = status.decode("utf-8") + + if status_text == "ok": + exps = Experience.deserialize_many(payload) if len(exps) != batch_size: raise TimeoutError( f"Read incomplete batch ({len(exps)}/{batch_size}), please check your workflow." ) - except StopAsyncIteration: + return exps + + if status_text == "eos": raise StopIteration() - return exps + + if status_text == "error": + raise RuntimeError(payload.decode("utf-8")) + + raise RuntimeError(f"Unknown queue reader response status: {status_text}") + + def read(self, batch_size: Optional[int] = None, **kwargs) -> List: + raise NotImplementedError("This function is deprecated, please use read_async instead.") async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List: batch_size = self.read_batch_size if batch_size is None else batch_size + if self._use_zmq: + try: + return await self._read_via_zmq(batch_size, **kwargs) + except StopIteration as e: + raise StopAsyncIteration() from e + exps = await self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs) if len(exps) != batch_size: raise TimeoutError( @@ -47,3 +94,8 @@ def state_dict(self) -> Dict: def load_state_dict(self, state_dict): # Queue Not supporting state dict yet return None + + def __del__(self): + if self._zmq_socket is not None: + self._zmq_socket.close(0) + self._zmq_socket = None diff --git a/trinity/buffer/storage/queue.py b/trinity/buffer/storage/queue.py index a0da0438954..bd6e43ae173 100644 --- a/trinity/buffer/storage/queue.py +++ b/trinity/buffer/storage/queue.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections import deque from copy import deepcopy -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np import ray @@ -127,16 +127,17 @@ def stopped(self) -> bool: def get_queue(cls, config: StorageConfig) -> "QueueBuffer": """Get a queue instance based on the storage configuration.""" logger = get_logger(__name__) - if config.replay_buffer.enable: + replay_buffer = config.replay_buffer + if replay_buffer is not None and replay_buffer.enable: capacity = config.capacity logger.info( - f"Using AsyncPriorityQueue with capacity {capacity}, reuse_cooldown_time {config.replay_buffer.reuse_cooldown_time}." + f"Using AsyncPriorityQueue with capacity {capacity}, reuse_cooldown_time {replay_buffer.reuse_cooldown_time}." ) return AsyncPriorityQueue( capacity=capacity, - reuse_cooldown_time=config.replay_buffer.reuse_cooldown_time, - priority_fn=config.replay_buffer.priority_fn, - priority_fn_args=config.replay_buffer.priority_fn_args, + reuse_cooldown_time=replay_buffer.reuse_cooldown_time, + priority_fn=replay_buffer.priority_fn, + priority_fn_args=replay_buffer.priority_fn_args, ) else: return AsyncQueue(capacity=config.capacity) @@ -171,10 +172,11 @@ async def get(self): async def close(self) -> None: """Close the queue.""" self._closed = True - for getter in self._getters: + getters = getattr(self, "_getters", []) + for getter in getters: if not getter.done(): getter.set_exception(StopAsyncIteration()) - self._getters.clear() + getters.clear() def stopped(self) -> bool: """Check if there is no more data to read.""" @@ -351,7 +353,165 @@ def __init__(self, config: StorageConfig) -> None: self.exp_pool = deque() # A pool to store experiences self.closed = False + self.zmq_config = config.zmq + self._zmq_enabled = bool(self.zmq_config and self.zmq_config.enable) + self._zmq_context = None + self._zmq_pull_socket = None + self._zmq_rep_socket = None + self._zmq_server_task = None + self._zmq_server_lock = asyncio.Lock() + self._zmq_endpoints: Dict[str, str] = {} + + if self._zmq_enabled: + self.logger.warning("QueueStorage ZeroMQ data transport is enabled.") + + async def _ensure_zmq_server(self) -> None: + if not self._zmq_enabled: + return + zmq_config = self.zmq_config + if zmq_config is None: + return + async with self._zmq_server_lock: + if self._zmq_server_task is not None and not self._zmq_server_task.done(): + return + + try: + import zmq + import zmq.asyncio + except ImportError as exc: + raise RuntimeError( + "ZeroMQ transport is enabled, but dependency `pyzmq` is not installed." + ) from exc + + self._zmq_context = zmq.asyncio.Context.instance() + self._zmq_pull_socket = self._zmq_context.socket(zmq.PULL) + self._zmq_rep_socket = self._zmq_context.socket(zmq.REP) + + self._zmq_pull_socket.setsockopt(zmq.RCVHWM, zmq_config.rcvhwm) + self._zmq_pull_socket.setsockopt(zmq.LINGER, zmq_config.linger_ms) + self._zmq_rep_socket.setsockopt(zmq.SNDHWM, zmq_config.sndhwm) + self._zmq_rep_socket.setsockopt(zmq.RCVHWM, zmq_config.rcvhwm) + self._zmq_rep_socket.setsockopt(zmq.LINGER, zmq_config.linger_ms) + + bind_host = zmq_config.bind_host + writer_port = zmq_config.writer_port + reader_port = zmq_config.reader_port + + if writer_port > 0: + self._zmq_pull_socket.bind(f"tcp://{bind_host}:{writer_port}") + else: + writer_port = self._zmq_pull_socket.bind_to_random_port(f"tcp://{bind_host}") + + if reader_port > 0: + self._zmq_rep_socket.bind(f"tcp://{bind_host}:{reader_port}") + else: + reader_port = self._zmq_rep_socket.bind_to_random_port(f"tcp://{bind_host}") + + connect_host = zmq_config.connect_host or ray.util.get_node_ip_address() + self._zmq_endpoints = { + "writer_endpoint": f"tcp://{connect_host}:{writer_port}", + "reader_endpoint": f"tcp://{connect_host}:{reader_port}", + } + + self._zmq_server_task = asyncio.create_task(self._zmq_server_loop()) + self.logger.warning( + "ZeroMQ server started for queue %s, writer=%s, reader=%s", + self.config.name, + self._zmq_endpoints["writer_endpoint"], + self._zmq_endpoints["reader_endpoint"], + ) + + async def _stop_zmq_server(self) -> None: + task = self._zmq_server_task + if task is not None: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception as e: + self.logger.warning("Error when stopping ZeroMQ server: %s", e) + self._zmq_server_task = None + + if self._zmq_pull_socket is not None: + self._zmq_pull_socket.close(0) + self._zmq_pull_socket = None + if self._zmq_rep_socket is not None: + self._zmq_rep_socket.close(0) + self._zmq_rep_socket = None + self._zmq_endpoints = {} + + async def _zmq_server_loop(self) -> None: + import zmq + import zmq.asyncio + + if self._zmq_pull_socket is None or self._zmq_rep_socket is None: + return + + poller = zmq.asyncio.Poller() + poller.register(self._zmq_pull_socket, zmq.POLLIN) + poller.register(self._zmq_rep_socket, zmq.POLLIN) + + try: + while True: + events = dict(await poller.poll(timeout=1000)) + + if self._zmq_pull_socket in events: + payload = await self._zmq_pull_socket.recv() + exps = Experience.deserialize_many(payload) + await self.put_batch(exps) + + if self._zmq_rep_socket in events: + request = await self._zmq_rep_socket.recv_json() + command = request.get("cmd", "get_batch") + + if command == "ping": + await self._zmq_rep_socket.send_multipart([b"ok", b"pong"]) + continue + + if command != "get_batch": + await self._zmq_rep_socket.send_multipart( + [b"error", f"Unknown command: {command}".encode("utf-8")] + ) + continue + + batch_size = int(request.get("batch_size", self.config.batch_size or 1)) + timeout_sec = float(request.get("timeout", self.config.max_read_timeout)) + min_model_version = int(request.get("min_model_version", 0)) + try: + exps = await self.get_batch( + batch_size=batch_size, + timeout=timeout_sec, + min_model_version=min_model_version, + ) + payload = Experience.serialize_many(exps) + await self._zmq_rep_socket.send_multipart([b"ok", payload]) + except StopAsyncIteration: + await self._zmq_rep_socket.send_multipart([b"eos", b""]) + except Exception as e: + await self._zmq_rep_socket.send_multipart([b"error", str(e).encode("utf-8")]) + except asyncio.CancelledError: + return + except Exception as e: + self.logger.exception("ZeroMQ server loop crashed: %s", e) + + async def get_zmq_endpoints(self) -> Dict[str, Any]: + if not self._zmq_enabled: + return {"enabled": False} + zmq_config = self.zmq_config + if zmq_config is None: + return {"enabled": False} + + await self._ensure_zmq_server() + return { + "enabled": True, + "writer_endpoint": self._zmq_endpoints["writer_endpoint"], + "reader_endpoint": self._zmq_endpoints["reader_endpoint"], + } + async def acquire(self) -> int: + if self._zmq_enabled: + await self._ensure_zmq_server() self.ref_count += 1 return self.ref_count @@ -359,6 +519,9 @@ async def release(self) -> int: """Release the queue.""" self.ref_count -= 1 if self.ref_count <= 0: + self.closed = True + if self._zmq_enabled: + await self._stop_zmq_server() await self.queue.close() if self.writer is not None: await self.writer.release() diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py index 951f5445b7d..185abb09b38 100644 --- a/trinity/buffer/writer/queue_writer.py +++ b/trinity/buffer/writer/queue_writer.py @@ -1,5 +1,6 @@ """Writer of the Queue buffer.""" -from typing import List +import asyncio +from typing import Any, Dict, List, cast import ray @@ -7,6 +8,7 @@ from trinity.buffer.storage.queue import QueueStorage from trinity.common.config import StorageConfig from trinity.common.constants import StorageType +from trinity.common.experience import Experience class QueueWriter(BufferWriter): @@ -15,15 +17,44 @@ class QueueWriter(BufferWriter): def __init__(self, config: StorageConfig): assert config.storage_type == StorageType.QUEUE.value self.queue = QueueStorage.get_wrapper(config) + self._use_zmq = False + self._zmq_socket = None + + zmq_config = config.zmq + if zmq_config is not None and zmq_config.enable: + endpoints = cast(Dict[str, Any], ray.get(self.queue.get_zmq_endpoints.remote())) + if endpoints.get("enabled", False): + import zmq + + self._zmq_socket = zmq.Context.instance().socket(zmq.PUSH) + self._zmq_socket.setsockopt(zmq.SNDHWM, zmq_config.sndhwm) + self._zmq_socket.setsockopt(zmq.LINGER, zmq_config.linger_ms) + self._zmq_socket.connect(endpoints["writer_endpoint"]) + self._use_zmq = True def write(self, data: List) -> None: + if self._use_zmq: + assert self._zmq_socket is not None + payload = Experience.serialize_many(data) + self._zmq_socket.send(payload) + return ray.get(self.queue.put_batch.remote(data)) async def write_async(self, data): + if self._use_zmq: + return await asyncio.to_thread(self.write, data) return await self.queue.put_batch.remote(data) async def acquire(self) -> int: return await self.queue.acquire.remote() async def release(self) -> int: + if self._zmq_socket is not None: + self._zmq_socket.close(0) + self._zmq_socket = None return await self.queue.release.remote() + + def __del__(self): + if self._zmq_socket is not None: + self._zmq_socket.close(0) + self._zmq_socket = None diff --git a/trinity/common/config.py b/trinity/common/config.py index 20a159fb103..432aec1a678 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -138,6 +138,20 @@ class ReplayBufferConfig: priority_fn_args: Dict = field(default_factory=lambda: {"decay": 2.0}) +@dataclass +class ZMQConfig: + """Config for optional ZeroMQ data transport in queue storage.""" + + enable: bool = False + bind_host: str = "0.0.0.0" + connect_host: Optional[str] = None + writer_port: int = 0 + reader_port: int = 0 + sndhwm: int = 10000 + rcvhwm: int = 10000 + linger_ms: int = 0 + + @dataclass class OverRolloutConfig: """Config for over-rollout in explorer.""" @@ -179,6 +193,7 @@ class StorageConfig: capacity: int = 10000 max_read_timeout: float = 1800 replay_buffer: Optional[ReplayBufferConfig] = field(default_factory=ReplayBufferConfig) + zmq: Optional[ZMQConfig] = field(default_factory=ZMQConfig) # used for StorageType.SQL max_retry_times: int = 3 @@ -299,6 +314,7 @@ class ExperienceBufferConfig: capacity: int = 10000 max_read_timeout: float = 1800 replay_buffer: Optional[ReplayBufferConfig] = field(default_factory=ReplayBufferConfig) + zmq: Optional[ZMQConfig] = field(default_factory=ZMQConfig) # used for StorageType.SQL max_retry_times: int = 3 @@ -333,6 +349,7 @@ def to_storage_config(self) -> StorageConfig: capacity=self.capacity, max_read_timeout=self.max_read_timeout, replay_buffer=self.replay_buffer, + zmq=self.zmq, max_retry_times=self.max_retry_times, max_retry_interval=self.max_retry_interval, split=self.split, diff --git a/trinity/common/experience.py b/trinity/common/experience.py index d9d519252e7..9e9359df4be 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -3,11 +3,14 @@ from __future__ import annotations import pickle +import struct import uuid from dataclasses import asdict, dataclass, field, fields -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union import torch +from safetensors.torch import load as st_load +from safetensors.torch import save as st_save from torch import Tensor if TYPE_CHECKING: @@ -98,6 +101,34 @@ class CustomField: @dataclass class Experience: + _SER_MAGIC = b"TRXP" + _SER_VERSION = 1 + _TENSOR_FIELDS = ( + "tokens", + "logprobs", + "token_level_reward", + "advantages", + "returns", + "action_mask", + "chosen", + "rejected", + "teacher_logprobs", + ) + _META_FIELDS = ( + "eid", + "reward", + "truncate_status", + "info", + "metrics", + "prompt_length", + "response_text", + "prompt_text", + "messages", + "tools", + "chosen_messages", + "rejected_messages", + ) + eid: EID = field(default_factory=EID) # Unique identifier for the experience tokens: Optional[Tensor] = None # [seq_length] prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks @@ -259,11 +290,161 @@ def __init__( # noqa: C901 def serialize(self) -> bytes: """Serialize the experience to bytes.""" - return pickle.dumps(self) + return self.serialize_many([self]) @classmethod def deserialize(cls, data: bytes) -> Experience: - return pickle.loads(data) + experiences = cls.deserialize_many(data) + if len(experiences) != 1: + raise ValueError( + f"Expected a single Experience payload, got batch size {len(experiences)}. " + "Use Experience.deserialize_many for batched payloads." + ) + return experiences[0] + + @staticmethod + def _serialize_custom_fields(custom_fields: Optional[List[CustomField]]) -> list[dict]: + if not custom_fields: + return [] + return [ + { + "source_field": field.source_field, + "destination_field": field.destination_field, + "data_type": str(field.data_type), + } + for field in custom_fields + ] + + @staticmethod + def _deserialize_custom_fields(serialized_fields: Optional[List[dict]]) -> List[CustomField]: + if not serialized_fields: + return [] + custom_fields = [] + for field_dict in serialized_fields: + dtype_name = field_dict["data_type"].replace("torch.", "") + dtype = getattr(torch, dtype_name) + custom_fields.append( + CustomField( + source_field=field_dict["source_field"], + destination_field=field_dict["destination_field"], + data_type=dtype, + ) + ) + return custom_fields + + @classmethod + def serialize_many(cls, experiences: List[Experience]) -> bytes: + """Serialize a list of experiences into a compact bytes payload. + + Tensor fields are packed with safetensors while non-tensor fields are packed + as metadata via pickle. + """ + metadata = {"version": cls._SER_VERSION, "num_items": len(experiences), "items": []} + tensor_data = {} + + for index, exp in enumerate(experiences): + item_meta = {} + for field_name in cls._META_FIELDS: + value = getattr(exp, field_name) + if field_name == "eid" and value is not None: + item_meta[field_name] = value.to_dict() if isinstance(value, EID) else value + else: + item_meta[field_name] = value + + item_meta["custom_fields"] = cls._serialize_custom_fields(exp.custom_fields) + + for field_name in cls._TENSOR_FIELDS: + value = getattr(exp, field_name) + if value is None: + continue + tensor_data[f"{index}:{field_name}"] = value.detach().cpu().contiguous() + + if exp.multi_modal_inputs is None: + item_meta["multi_modal_input_keys"] = [] + else: + mm_keys = list(exp.multi_modal_inputs.keys()) + item_meta["multi_modal_input_keys"] = mm_keys + for key in mm_keys: + value = exp.multi_modal_inputs[key] + tensor_data[f"{index}:multi_modal_inputs:{key}"] = value.detach().cpu().contiguous() + + metadata["items"].append(item_meta) + + metadata_bytes = pickle.dumps(metadata, protocol=pickle.HIGHEST_PROTOCOL) + tensor_bytes = st_save(tensor_data) + header = ( + cls._SER_MAGIC + + struct.pack(" List[Experience]: + """Deserialize bytes into a list of experiences. + + Supports both new batched payloads and legacy single-experience pickle payloads. + """ + if not data.startswith(cls._SER_MAGIC): + legacy = pickle.loads(data) + if isinstance(legacy, list): + return legacy + return [legacy] + + offset = len(cls._SER_MAGIC) + version = struct.unpack(" dict: """Convert the experience to a dictionary.""" @@ -293,106 +474,6 @@ def to_dict(self) -> dict: res["truncate_status"] = self.truncate_status return res - @classmethod - def gather( - cls, - experiences: List[Experience], - pad_token_id: int = 0, - custom_fields: Optional[List[CustomField]] = None, - ) -> Experiences: - if len(experiences) == 0: - return empty_experiences(custom_fields) - exp_type = experiences[0].experience_type - if exp_type == "dpo": - experiences = split_dpo_experience_to_single_turn(experiences) - max_prompt_length = max([exp.prompt_length for exp in experiences]) # type: ignore [type-var] - max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) # type: ignore [arg-type] - eids = [exp.eid for exp in experiences] - - # Gather tokens - tokens = gather_token_ids(experiences, max_prompt_length, max_response_length, pad_token_id) - - # Gather rewards - if experiences[0].reward is not None: - rewards = torch.tensor([exp.reward for exp in experiences], dtype=torch.float) - else: - rewards = None - - # Gather token level rewards - if all(exp.token_level_reward is not None for exp in experiences): - token_level_rewards = gather_response_attrs( - experiences, "token_level_reward", max_response_length - ) - else: - token_level_rewards = None - - # gather action_masks - action_masks = gather_action_masks(experiences, max_response_length) - - # gather attention_masks - attention_masks = gather_attention_masks( - experiences, max_prompt_length, max_response_length - ) - - # gather logprobs - if all(exp.logprobs is not None for exp in experiences): - logprobs = gather_response_attrs(experiences, "logprobs", max_response_length) - else: - logprobs = None - - # gather advantages - if all(exp.advantages is not None for exp in experiences): - advantages = gather_response_attrs(experiences, "advantages", max_response_length) - else: - advantages = None - - # gather returns - if all(exp.returns is not None for exp in experiences): - returns = gather_response_attrs(experiences, "returns", max_response_length) - else: - returns = None - - # gather multi_modal_inputs - if all(exp.multi_modal_inputs is not None for exp in experiences): - multi_modal_inputs = gather_multi_modal_inputs(experiences) - else: - multi_modal_inputs = None - - # gather teacher_logprobs - if all(exp.teacher_logprobs is not None for exp in experiences): - teacher_logprobs = gather_response_attrs( - experiences, "teacher_logprobs", max_response_length - ) - else: - teacher_logprobs = None - - exps = Experiences( - eids=eids, - tokens=tokens, - rewards=rewards, - token_level_rewards=token_level_rewards, - advantages=advantages, - returns=returns, - attention_masks=attention_masks, - action_masks=action_masks, - prompt_length=max_prompt_length, - logprobs=logprobs, - multi_modal_inputs=multi_modal_inputs, - teacher_logprobs=teacher_logprobs, - ) - if custom_fields is not None: - for custom_field in custom_fields: - exps.custom_fields.append(custom_field.destination_field) - setattr( - exps, - custom_field.destination_field, - torch.tensor( - [exp.info[custom_field.source_field] for exp in experiences], - dtype=custom_field.data_type, - ), - ) - return exps - def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[Experience]: single_turn_experiences = [] @@ -434,93 +515,6 @@ def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[E return single_turn_experiences -@dataclass -class Experiences: - """A container for a batch of experiences, for high performance communication usage. - - Example: - - >>> |<- prompt_length ->| | - >>> tokens: ('P' represents prompt, 'O' represents output) - >>> exp1: |........PPPPPPPPPPP|OOOOOOOOOO.....| - >>> exp2: |......PPPPPPPPPPPPP|OOOOOOO........| - >>> - >>> attention_masks: ('.' represents False and '1' represents True) - >>> exp1: |........11111111111|1111111111.....| - >>> exp2: |......1111111111111|1111111........| - """ - - eids: List[EID] # Experience IDs of each experience in the batch - tokens: Tensor # [batch_size, seq_length] - - # At least one of `rewards` or `token_level_rewards` must be provided (not None). - # If both are provided, `token_level_rewards` will be used and `rewards` will be ignored. - rewards: Tensor # [batch_size] - token_level_rewards: Tensor # [batch_size, response_length] - - advantages: Optional[Tensor] # [batch_size, response_length] - returns: Optional[Tensor] # [batch_size, response_length] - attention_masks: Tensor # [batch_size, sequence_length] - action_masks: Optional[Tensor] # [batch_size, response_length] - prompt_length: int - logprobs: Optional[Tensor] # [batch_size, response_length] - multi_modal_inputs: Optional[Any] - custom_fields: List[str] = field( - default_factory=list - ) # Custom fields to include in the gathered experiences - teacher_logprobs: Optional[Tensor] = None # [batch_size, response_length] - - @property - def batch_size(self) -> int: - """Get the batch size.""" - return self.tokens.size(0) - - @classmethod - def gather_experiences( - cls, - experiences: list[Experience], - pad_token_id: int = 0, - custom_fields: Optional[List[CustomField]] = None, - ) -> Experiences: - """Gather a batch of experiences from a list of experiences. - - This method will automatically pad the `tokens` and `logprobs` of input experiences to the same length. - - Args: - experiences (list[Experience]): A list of experiences to gather. - pad_token_id (int): The token ID to use for padding. Default is 0. - custom_fields (Optional[List[CustomField]]): Custom fields to include in the gathered experiences. - """ - if len(experiences) == 0: - return empty_experiences(custom_fields) - return experiences[0].__class__.gather( - experiences, pad_token_id=pad_token_id, custom_fields=custom_fields - ) - - -def empty_experiences(custom_fields: Optional[List[CustomField]]) -> Experiences: - exps = Experiences( - tokens=torch.empty(0, dtype=torch.int32), - rewards=torch.empty(0, dtype=torch.float32), - token_level_rewards=torch.empty(0, dtype=torch.float32), - advantages=torch.empty(0, dtype=torch.float32), - returns=torch.empty(0, dtype=torch.float32), - attention_masks=torch.empty(0, dtype=torch.bool), - action_masks=torch.empty(0, dtype=torch.bool), - logprobs=torch.empty(0, dtype=torch.float32), - prompt_length=torch.empty(0, dtype=torch.int32), - eids=[], - multi_modal_inputs=torch.empty(0, dtype=torch.float32), - ) - if custom_fields is not None: - for custom_field in custom_fields: - exps.custom_fields.append(custom_field.destination_field) - setattr( - exps, custom_field.destination_field, torch.empty(0, dtype=custom_field.data_type) - ) - return exps - - def gather_token_ids( experiences, max_prompt_length: int, max_response_length: int, pad_token_id: int ) -> Tensor: From a2a014039a7a4f2607cc5a71f3d9a6b5883b5551 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 4 Mar 2026 17:46:11 +0800 Subject: [PATCH 2/7] queue use serialized exp --- tests/buffer/queue_test.py | 10 +- trinity/buffer/reader/queue_reader.py | 71 ++-------- trinity/buffer/storage/queue.py | 190 ++------------------------ trinity/buffer/writer/queue_writer.py | 41 +----- trinity/common/config.py | 17 --- 5 files changed, 36 insertions(+), 293 deletions(-) diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 695a5ce619c..cc7d2dbe41c 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -169,14 +169,14 @@ async def test_queue_buffer_capacity(self): config = config.to_storage_config() writer = QueueWriter(config) reader = QueueReader(config) - writer.write([{"content": "hello"}]) - writer.write([{"content": "hi"}]) - writer.write([{"content": "hello"}]) - writer.write([{"content": "hi"}]) + writer.write([Experience(tokens=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": 0, "use_count": 0})]) + writer.write([Experience(tokens=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": 1, "use_count": 0})]) + writer.write([Experience(tokens=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": 2, "use_count": 0})]) + writer.write([Experience(tokens=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": 3, "use_count": 0})]) # should be blocked def write_blocking_call(): - writer.write([{"content": "blocked"}]) + writer.write([Experience(tokens=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": 4, "use_count": 0})]) thread = threading.Thread(target=write_blocking_call) thread.start() diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index 39c6b30a096..f18d57dcb15 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -1,9 +1,6 @@ """Reader of the Queue buffer.""" -import asyncio - from typing import Dict, List, Optional -from typing import Any, cast import ray @@ -22,65 +19,24 @@ def __init__(self, config: StorageConfig): self.timeout = config.max_read_timeout self.read_batch_size = config.batch_size self.queue = QueueStorage.get_wrapper(config) - self._use_zmq = False - self._zmq_socket = None - - zmq_config = config.zmq - if zmq_config is not None and zmq_config.enable: - endpoints = cast(Dict[str, Any], ray.get(self.queue.get_zmq_endpoints.remote())) - if endpoints.get("enabled", False): - import zmq.asyncio - - self._zmq_socket = zmq.asyncio.Context.instance().socket(zmq.REQ) - self._zmq_socket.setsockopt(zmq.SNDHWM, zmq_config.sndhwm) - self._zmq_socket.setsockopt(zmq.RCVHWM, zmq_config.rcvhwm) - self._zmq_socket.setsockopt(zmq.LINGER, zmq_config.linger_ms) - self._zmq_socket.setsockopt(zmq.SNDTIMEO, int(config.max_read_timeout * 1000)) - self._zmq_socket.setsockopt(zmq.RCVTIMEO, int(config.max_read_timeout * 1000)) - self._zmq_socket.connect(endpoints["reader_endpoint"]) - self._use_zmq = True - async def _read_via_zmq(self, batch_size: int, **kwargs) -> List: - assert self._zmq_socket is not None - min_model_version = int(kwargs.get("min_model_version", 0)) - request = { - "cmd": "get_batch", - "batch_size": batch_size, - "timeout": float(self.timeout), - "min_model_version": min_model_version, - } - await self._zmq_socket.send_json(request) - status, payload = await self._zmq_socket.recv_multipart() - status_text = status.decode("utf-8") - - if status_text == "ok": - exps = Experience.deserialize_many(payload) + def read(self, batch_size: Optional[int] = None, **kwargs) -> List[Experience]: + try: + batch_size = self.read_batch_size if batch_size is None else batch_size + exp_bytes = ray.get(self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs)) + exps = Experience.deserialize_many(exp_bytes) if len(exps) != batch_size: raise TimeoutError( f"Read incomplete batch ({len(exps)}/{batch_size}), please check your workflow." ) - return exps - - if status_text == "eos": + except StopAsyncIteration: raise StopIteration() + return exps - if status_text == "error": - raise RuntimeError(payload.decode("utf-8")) - - raise RuntimeError(f"Unknown queue reader response status: {status_text}") - - def read(self, batch_size: Optional[int] = None, **kwargs) -> List: - raise NotImplementedError("This function is deprecated, please use read_async instead.") - - async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List: + async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List[Experience]: batch_size = self.read_batch_size if batch_size is None else batch_size - if self._use_zmq: - try: - return await self._read_via_zmq(batch_size, **kwargs) - except StopIteration as e: - raise StopAsyncIteration() from e - - exps = await self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs) + exp_bytes = await self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs) + exps = Experience.deserialize_many(exp_bytes) if len(exps) != batch_size: raise TimeoutError( f"Read incomplete batch ({len(exps)}/{batch_size}), please check your workflow." @@ -93,9 +49,4 @@ def state_dict(self) -> Dict: def load_state_dict(self, state_dict): # Queue Not supporting state dict yet - return None - - def __del__(self): - if self._zmq_socket is not None: - self._zmq_socket.close(0) - self._zmq_socket = None + return None \ No newline at end of file diff --git a/trinity/buffer/storage/queue.py b/trinity/buffer/storage/queue.py index bd6e43ae173..9f6a628e700 100644 --- a/trinity/buffer/storage/queue.py +++ b/trinity/buffer/storage/queue.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections import deque from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np import ray @@ -127,17 +127,16 @@ def stopped(self) -> bool: def get_queue(cls, config: StorageConfig) -> "QueueBuffer": """Get a queue instance based on the storage configuration.""" logger = get_logger(__name__) - replay_buffer = config.replay_buffer - if replay_buffer is not None and replay_buffer.enable: + if config.replay_buffer.enable: capacity = config.capacity logger.info( - f"Using AsyncPriorityQueue with capacity {capacity}, reuse_cooldown_time {replay_buffer.reuse_cooldown_time}." + f"Using AsyncPriorityQueue with capacity {capacity}, reuse_cooldown_time {config.replay_buffer.reuse_cooldown_time}." ) return AsyncPriorityQueue( capacity=capacity, - reuse_cooldown_time=replay_buffer.reuse_cooldown_time, - priority_fn=replay_buffer.priority_fn, - priority_fn_args=replay_buffer.priority_fn_args, + reuse_cooldown_time=config.replay_buffer.reuse_cooldown_time, + priority_fn=config.replay_buffer.priority_fn, + priority_fn_args=config.replay_buffer.priority_fn_args, ) else: return AsyncQueue(capacity=config.capacity) @@ -172,11 +171,10 @@ async def get(self): async def close(self) -> None: """Close the queue.""" self._closed = True - getters = getattr(self, "_getters", []) - for getter in getters: + for getter in self._getters: if not getter.done(): getter.set_exception(StopAsyncIteration()) - getters.clear() + self._getters.clear() def stopped(self) -> bool: """Check if there is no more data to read.""" @@ -353,165 +351,7 @@ def __init__(self, config: StorageConfig) -> None: self.exp_pool = deque() # A pool to store experiences self.closed = False - self.zmq_config = config.zmq - self._zmq_enabled = bool(self.zmq_config and self.zmq_config.enable) - self._zmq_context = None - self._zmq_pull_socket = None - self._zmq_rep_socket = None - self._zmq_server_task = None - self._zmq_server_lock = asyncio.Lock() - self._zmq_endpoints: Dict[str, str] = {} - - if self._zmq_enabled: - self.logger.warning("QueueStorage ZeroMQ data transport is enabled.") - - async def _ensure_zmq_server(self) -> None: - if not self._zmq_enabled: - return - zmq_config = self.zmq_config - if zmq_config is None: - return - async with self._zmq_server_lock: - if self._zmq_server_task is not None and not self._zmq_server_task.done(): - return - - try: - import zmq - import zmq.asyncio - except ImportError as exc: - raise RuntimeError( - "ZeroMQ transport is enabled, but dependency `pyzmq` is not installed." - ) from exc - - self._zmq_context = zmq.asyncio.Context.instance() - self._zmq_pull_socket = self._zmq_context.socket(zmq.PULL) - self._zmq_rep_socket = self._zmq_context.socket(zmq.REP) - - self._zmq_pull_socket.setsockopt(zmq.RCVHWM, zmq_config.rcvhwm) - self._zmq_pull_socket.setsockopt(zmq.LINGER, zmq_config.linger_ms) - self._zmq_rep_socket.setsockopt(zmq.SNDHWM, zmq_config.sndhwm) - self._zmq_rep_socket.setsockopt(zmq.RCVHWM, zmq_config.rcvhwm) - self._zmq_rep_socket.setsockopt(zmq.LINGER, zmq_config.linger_ms) - - bind_host = zmq_config.bind_host - writer_port = zmq_config.writer_port - reader_port = zmq_config.reader_port - - if writer_port > 0: - self._zmq_pull_socket.bind(f"tcp://{bind_host}:{writer_port}") - else: - writer_port = self._zmq_pull_socket.bind_to_random_port(f"tcp://{bind_host}") - - if reader_port > 0: - self._zmq_rep_socket.bind(f"tcp://{bind_host}:{reader_port}") - else: - reader_port = self._zmq_rep_socket.bind_to_random_port(f"tcp://{bind_host}") - - connect_host = zmq_config.connect_host or ray.util.get_node_ip_address() - self._zmq_endpoints = { - "writer_endpoint": f"tcp://{connect_host}:{writer_port}", - "reader_endpoint": f"tcp://{connect_host}:{reader_port}", - } - - self._zmq_server_task = asyncio.create_task(self._zmq_server_loop()) - self.logger.warning( - "ZeroMQ server started for queue %s, writer=%s, reader=%s", - self.config.name, - self._zmq_endpoints["writer_endpoint"], - self._zmq_endpoints["reader_endpoint"], - ) - - async def _stop_zmq_server(self) -> None: - task = self._zmq_server_task - if task is not None: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - except Exception as e: - self.logger.warning("Error when stopping ZeroMQ server: %s", e) - self._zmq_server_task = None - - if self._zmq_pull_socket is not None: - self._zmq_pull_socket.close(0) - self._zmq_pull_socket = None - if self._zmq_rep_socket is not None: - self._zmq_rep_socket.close(0) - self._zmq_rep_socket = None - self._zmq_endpoints = {} - - async def _zmq_server_loop(self) -> None: - import zmq - import zmq.asyncio - - if self._zmq_pull_socket is None or self._zmq_rep_socket is None: - return - - poller = zmq.asyncio.Poller() - poller.register(self._zmq_pull_socket, zmq.POLLIN) - poller.register(self._zmq_rep_socket, zmq.POLLIN) - - try: - while True: - events = dict(await poller.poll(timeout=1000)) - - if self._zmq_pull_socket in events: - payload = await self._zmq_pull_socket.recv() - exps = Experience.deserialize_many(payload) - await self.put_batch(exps) - - if self._zmq_rep_socket in events: - request = await self._zmq_rep_socket.recv_json() - command = request.get("cmd", "get_batch") - - if command == "ping": - await self._zmq_rep_socket.send_multipart([b"ok", b"pong"]) - continue - - if command != "get_batch": - await self._zmq_rep_socket.send_multipart( - [b"error", f"Unknown command: {command}".encode("utf-8")] - ) - continue - - batch_size = int(request.get("batch_size", self.config.batch_size or 1)) - timeout_sec = float(request.get("timeout", self.config.max_read_timeout)) - min_model_version = int(request.get("min_model_version", 0)) - try: - exps = await self.get_batch( - batch_size=batch_size, - timeout=timeout_sec, - min_model_version=min_model_version, - ) - payload = Experience.serialize_many(exps) - await self._zmq_rep_socket.send_multipart([b"ok", payload]) - except StopAsyncIteration: - await self._zmq_rep_socket.send_multipart([b"eos", b""]) - except Exception as e: - await self._zmq_rep_socket.send_multipart([b"error", str(e).encode("utf-8")]) - except asyncio.CancelledError: - return - except Exception as e: - self.logger.exception("ZeroMQ server loop crashed: %s", e) - - async def get_zmq_endpoints(self) -> Dict[str, Any]: - if not self._zmq_enabled: - return {"enabled": False} - zmq_config = self.zmq_config - if zmq_config is None: - return {"enabled": False} - - await self._ensure_zmq_server() - return { - "enabled": True, - "writer_endpoint": self._zmq_endpoints["writer_endpoint"], - "reader_endpoint": self._zmq_endpoints["reader_endpoint"], - } - async def acquire(self) -> int: - if self._zmq_enabled: - await self._ensure_zmq_server() self.ref_count += 1 return self.ref_count @@ -519,9 +359,6 @@ async def release(self) -> int: """Release the queue.""" self.ref_count -= 1 if self.ref_count <= 0: - self.closed = True - if self._zmq_enabled: - await self._stop_zmq_server() await self.queue.close() if self.writer is not None: await self.writer.release() @@ -531,13 +368,14 @@ def length(self) -> int: """The length of the queue.""" return self.queue.qsize() - async def put_batch(self, exp_list: List) -> None: + async def put_batch(self, exp_bytes: bytes) -> None: """Put batch of experience.""" + exp_list = Experience.deserialize_many(exp_bytes) await self.queue.put(exp_list) if self.writer is not None: self.writer.write(exp_list) - async def get_batch(self, batch_size: int, timeout: float, min_model_version: int = 0) -> List: + async def get_batch(self, batch_size: int, timeout: float, min_model_version: int = 0) -> bytes: """Get batch of experience.""" await self.queue.set_min_model_version(min_model_version) start_time = time.time() @@ -566,8 +404,8 @@ async def get_batch(self, batch_size: int, timeout: float, min_model_version: in ) batch = list(self.exp_pool) self.exp_pool.clear() - return batch - return result + return Experience.serialize_many(batch) + return Experience.serialize_many(result) @classmethod def get_wrapper(cls, config: StorageConfig): @@ -580,4 +418,4 @@ def get_wrapper(cls, config: StorageConfig): get_if_exists=True, ) .remote(config) - ) + ) \ No newline at end of file diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py index 185abb09b38..3e9411a89bc 100644 --- a/trinity/buffer/writer/queue_writer.py +++ b/trinity/buffer/writer/queue_writer.py @@ -17,44 +17,15 @@ class QueueWriter(BufferWriter): def __init__(self, config: StorageConfig): assert config.storage_type == StorageType.QUEUE.value self.queue = QueueStorage.get_wrapper(config) - self._use_zmq = False - self._zmq_socket = None - - zmq_config = config.zmq - if zmq_config is not None and zmq_config.enable: - endpoints = cast(Dict[str, Any], ray.get(self.queue.get_zmq_endpoints.remote())) - if endpoints.get("enabled", False): - import zmq - - self._zmq_socket = zmq.Context.instance().socket(zmq.PUSH) - self._zmq_socket.setsockopt(zmq.SNDHWM, zmq_config.sndhwm) - self._zmq_socket.setsockopt(zmq.LINGER, zmq_config.linger_ms) - self._zmq_socket.connect(endpoints["writer_endpoint"]) - self._use_zmq = True - - def write(self, data: List) -> None: - if self._use_zmq: - assert self._zmq_socket is not None - payload = Experience.serialize_many(data) - self._zmq_socket.send(payload) - return - ray.get(self.queue.put_batch.remote(data)) - - async def write_async(self, data): - if self._use_zmq: - return await asyncio.to_thread(self.write, data) - return await self.queue.put_batch.remote(data) + + def write(self, data: List[Experience]) -> None: + ray.get(self.queue.put_batch.remote(Experience.serialize_many(data))) + + async def write_async(self, data: List[Experience]) -> None: + return await self.queue.put_batch.remote(Experience.serialize_many(data)) async def acquire(self) -> int: return await self.queue.acquire.remote() async def release(self) -> int: - if self._zmq_socket is not None: - self._zmq_socket.close(0) - self._zmq_socket = None return await self.queue.release.remote() - - def __del__(self): - if self._zmq_socket is not None: - self._zmq_socket.close(0) - self._zmq_socket = None diff --git a/trinity/common/config.py b/trinity/common/config.py index 432aec1a678..20a159fb103 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -138,20 +138,6 @@ class ReplayBufferConfig: priority_fn_args: Dict = field(default_factory=lambda: {"decay": 2.0}) -@dataclass -class ZMQConfig: - """Config for optional ZeroMQ data transport in queue storage.""" - - enable: bool = False - bind_host: str = "0.0.0.0" - connect_host: Optional[str] = None - writer_port: int = 0 - reader_port: int = 0 - sndhwm: int = 10000 - rcvhwm: int = 10000 - linger_ms: int = 0 - - @dataclass class OverRolloutConfig: """Config for over-rollout in explorer.""" @@ -193,7 +179,6 @@ class StorageConfig: capacity: int = 10000 max_read_timeout: float = 1800 replay_buffer: Optional[ReplayBufferConfig] = field(default_factory=ReplayBufferConfig) - zmq: Optional[ZMQConfig] = field(default_factory=ZMQConfig) # used for StorageType.SQL max_retry_times: int = 3 @@ -314,7 +299,6 @@ class ExperienceBufferConfig: capacity: int = 10000 max_read_timeout: float = 1800 replay_buffer: Optional[ReplayBufferConfig] = field(default_factory=ReplayBufferConfig) - zmq: Optional[ZMQConfig] = field(default_factory=ZMQConfig) # used for StorageType.SQL max_retry_times: int = 3 @@ -349,7 +333,6 @@ def to_storage_config(self) -> StorageConfig: capacity=self.capacity, max_read_timeout=self.max_read_timeout, replay_buffer=self.replay_buffer, - zmq=self.zmq, max_retry_times=self.max_retry_times, max_retry_interval=self.max_retry_interval, split=self.split, From 6351da79a6d1f0ef5dd41d0c8ae2c13ca3633f33 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 4 Mar 2026 17:51:19 +0800 Subject: [PATCH 3/7] update experience passing --- tests/buffer/queue_test.py | 50 +++++++++++++++++-- tests/common/experience_test.py | 4 +- .../buffer/pipelines/experience_pipeline.py | 7 +-- trinity/buffer/reader/queue_reader.py | 6 ++- trinity/buffer/storage/queue.py | 2 +- trinity/buffer/writer/queue_writer.py | 3 +- trinity/common/experience.py | 4 +- trinity/explorer/explorer.py | 5 +- 8 files changed, 65 insertions(+), 16 deletions(-) diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index cc7d2dbe41c..1c4ca493a64 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -169,14 +169,54 @@ async def test_queue_buffer_capacity(self): config = config.to_storage_config() writer = QueueWriter(config) reader = QueueReader(config) - writer.write([Experience(tokens=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": 0, "use_count": 0})]) - writer.write([Experience(tokens=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": 1, "use_count": 0})]) - writer.write([Experience(tokens=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": 2, "use_count": 0})]) - writer.write([Experience(tokens=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": 3, "use_count": 0})]) + writer.write( + [ + Experience( + tokens=torch.tensor([1, 2, 3]), + prompt_length=2, + info={"model_version": 0, "use_count": 0}, + ) + ] + ) + writer.write( + [ + Experience( + tokens=torch.tensor([1, 2, 3]), + prompt_length=2, + info={"model_version": 1, "use_count": 0}, + ) + ] + ) + writer.write( + [ + Experience( + tokens=torch.tensor([1, 2, 3]), + prompt_length=2, + info={"model_version": 2, "use_count": 0}, + ) + ] + ) + writer.write( + [ + Experience( + tokens=torch.tensor([1, 2, 3]), + prompt_length=2, + info={"model_version": 3, "use_count": 0}, + ) + ] + ) # should be blocked def write_blocking_call(): - writer.write([Experience(tokens=torch.tensor([1, 2, 3]), prompt_length=2, info={"model_version": 4, "use_count": 0})]) + writer.write( + [ + Experience( + tokens=torch.tensor([1, 2, 3]), + prompt_length=2, + info={"model_version": 4, "use_count": 0}, + ) + ] + ) thread = threading.Thread(target=write_blocking_call) thread.start() diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index 883a4110fbb..9d8417fba74 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -119,7 +119,9 @@ def test_serialize_many_deserialize_many(self): self.assertEqual(restored[0].metrics, exp1.metrics) self.assertIsNotNone(restored[0].multi_modal_inputs) self.assertIn("image", restored[0].multi_modal_inputs) - self.assertTrue(torch.equal(restored[0].multi_modal_inputs["image"], exp1.multi_modal_inputs["image"])) + self.assertTrue( + torch.equal(restored[0].multi_modal_inputs["image"], exp1.multi_modal_inputs["image"]) + ) self.assertEqual(len(restored[0].custom_fields), 1) self.assertEqual(restored[0].custom_fields[0].destination_field, "bar") diff --git a/trinity/buffer/pipelines/experience_pipeline.py b/trinity/buffer/pipelines/experience_pipeline.py index d0edf774dcb..8052d8ccd1e 100644 --- a/trinity/buffer/pipelines/experience_pipeline.py +++ b/trinity/buffer/pipelines/experience_pipeline.py @@ -1,7 +1,7 @@ import asyncio import time import traceback -from typing import Dict, List, Optional +from typing import Dict, Optional from trinity.buffer.buffer import BufferWriter, get_buffer_reader, get_buffer_writer from trinity.buffer.operators.experience_operator import ( @@ -128,16 +128,17 @@ async def prepare(self) -> None: self.logger.error(f"Failed to create experience operators: {traceback.format_exc()}") raise e - async def process(self, exps: List[Experience]) -> Dict: + async def process(self, exp_bytes: bytes) -> Dict: """Process a batch of experiences. Args: - exps (List[Experience]): List of experiences to process. These experiences are typically generated by an explorer in one step. + exp_bytes (bytes): Serialized experiences to process. These experiences are typically generated by an explorer in one step. Returns: Dict: A dictionary containing metrics collected during the processing of experiences. """ st = time.time() + exps = Experience.deserialize_many(exp_bytes) if self.input_store is not None: await self.input_store.write_async(exps) diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index f18d57dcb15..d474a2819cc 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -23,7 +23,9 @@ def __init__(self, config: StorageConfig): def read(self, batch_size: Optional[int] = None, **kwargs) -> List[Experience]: try: batch_size = self.read_batch_size if batch_size is None else batch_size - exp_bytes = ray.get(self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs)) + exp_bytes = ray.get( + self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs) + ) exps = Experience.deserialize_many(exp_bytes) if len(exps) != batch_size: raise TimeoutError( @@ -49,4 +51,4 @@ def state_dict(self) -> Dict: def load_state_dict(self, state_dict): # Queue Not supporting state dict yet - return None \ No newline at end of file + return None diff --git a/trinity/buffer/storage/queue.py b/trinity/buffer/storage/queue.py index 9f6a628e700..7f178f78787 100644 --- a/trinity/buffer/storage/queue.py +++ b/trinity/buffer/storage/queue.py @@ -418,4 +418,4 @@ def get_wrapper(cls, config: StorageConfig): get_if_exists=True, ) .remote(config) - ) \ No newline at end of file + ) diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py index 3e9411a89bc..e00c7559150 100644 --- a/trinity/buffer/writer/queue_writer.py +++ b/trinity/buffer/writer/queue_writer.py @@ -1,6 +1,5 @@ """Writer of the Queue buffer.""" -import asyncio -from typing import Any, Dict, List, cast +from typing import List import ray diff --git a/trinity/common/experience.py b/trinity/common/experience.py index 9e9359df4be..948354236ad 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -366,7 +366,9 @@ def serialize_many(cls, experiences: List[Experience]) -> bytes: item_meta["multi_modal_input_keys"] = mm_keys for key in mm_keys: value = exp.multi_modal_inputs[key] - tensor_data[f"{index}:multi_modal_inputs:{key}"] = value.detach().cpu().contiguous() + tensor_data[f"{index}:multi_modal_inputs:{key}"] = ( + value.detach().cpu().contiguous() + ) metadata["items"].append(item_meta) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 236b05f7a01..833a43cced2 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -24,6 +24,7 @@ SyncMethod, SyncStyle, ) +from trinity.common.experience import Experience from trinity.common.models import create_explorer_models from trinity.explorer.scheduler import Scheduler from trinity.manager.state_manager import StateManager @@ -393,7 +394,9 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None: batch_id=step, min_num=self.min_wait_num ) if self.experience_pipeline is not None: - pipeline_metrics = await self.experience_pipeline.process.remote(exps) + pipeline_metrics = await self.experience_pipeline.process.remote( + Experience.serialize_many(exps) + ) self.taskset.feedback(pipeline_metrics) metric.update(pipeline_metrics) if statuses: From 695df36a903d4b18ccc28ed48eadb2f5147fb647 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 4 Mar 2026 17:57:52 +0800 Subject: [PATCH 4/7] fix tests --- tests/buffer/experience_pipeline_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/buffer/experience_pipeline_test.py b/tests/buffer/experience_pipeline_test.py index 0605dc0d73f..cc13db5508e 100644 --- a/tests/buffer/experience_pipeline_test.py +++ b/tests/buffer/experience_pipeline_test.py @@ -71,7 +71,7 @@ async def test_experience_pipeline(self): task_num = 8 repeat_times = 4 experiences = get_experiences(task_num=task_num, repeat_times=repeat_times) - metrics = await pipeline.process.remote(experiences) + metrics = await pipeline.process.remote(Experience.serialize_many(experiences)) self.assertEqual( metrics["experience_pipeline/experience_count"], task_num * (repeat_times - 1) ) # first experience of each task will be filtered out by the reward filter @@ -116,7 +116,7 @@ async def test_pass_rate_calculation(self) -> None: "taskset_id": 0, "index": exp.eid.task, } - metrics = await pipeline.process.remote(experiences) + metrics = await pipeline.process.remote(Experience.serialize_many(experiences)) self.assertIn(SELECTOR_METRIC, metrics) selector_metrics = metrics[SELECTOR_METRIC] self.assertEqual(len(selector_metrics), 1) From a195a57e3f6d4b9b350178bf3a4f0c6e9918485b Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 4 Mar 2026 18:00:01 +0800 Subject: [PATCH 5/7] update pyproject --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ca48f4b9254..6cbdf19b3db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,6 @@ dependencies = [ "openai", "jsonlines", "sortedcontainers", - "pyzmq>=25.1.0", "word2number", "matplotlib", "transformers>=4.51.0", From 5574bca707dad7a548f16c9d185f24ae2c8c57b8 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 4 Mar 2026 18:02:18 +0800 Subject: [PATCH 6/7] fix tests --- tests/service/data_juicer_test.py | 2 +- trinity/explorer/proxy/service.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/service/data_juicer_test.py b/tests/service/data_juicer_test.py index 2a966b5546a..2dd6e58909d 100644 --- a/tests/service/data_juicer_test.py +++ b/tests/service/data_juicer_test.py @@ -206,7 +206,7 @@ async def test_data_juicer_operators(self): info={"model_version": 0}, ), ] - metrics = await pipeline.process.remote(exps) + metrics = await pipeline.process.remote(Experience.serialize_many(exps)) self.assertIsInstance(metrics, dict) reader = get_buffer_reader(config.buffer.trainer_input.experience_buffer) filtered_exps = reader.read(batch_size=2) diff --git a/trinity/explorer/proxy/service.py b/trinity/explorer/proxy/service.py index 214f1a02251..2411432d10e 100644 --- a/trinity/explorer/proxy/service.py +++ b/trinity/explorer/proxy/service.py @@ -169,7 +169,9 @@ async def submit_experiences(self) -> None: async with self.commit_lock: experiences = list(self.ready_experiences) self.ready_experiences.clear() - metrics = await self.explorer.experience_pipeline.process.remote(experiences) + metrics = await self.explorer.experience_pipeline.process.remote( + Experience.serialize_many(experiences) + ) metrics.update(self.collect_metrics()) self.explorer.explore_step_num += 1 self.explorer.monitor.log(metrics, self.explorer.explore_step_num) From d4bff5e1a0ba30bb35b69285fda738d1a7cde904 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 4 Mar 2026 21:09:41 +0800 Subject: [PATCH 7/7] fix pre-commit --- tests/common/experience_test.py | 31 ++++++++++++++++++++++++ tests/manager/synchronizer_test.py | 2 +- trinity/common/experience.py | 5 +++- trinity/explorer/proxy/recorder.py | 38 +++++++++++++++++++++++++++++- 4 files changed, 73 insertions(+), 3 deletions(-) diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index 9d8417fba74..41608ca0dad 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -129,6 +129,37 @@ def test_serialize_many_deserialize_many(self): self.assertEqual(restored[1].reward, exp2.reward) self.assertEqual(restored[1].prompt_length, exp2.prompt_length) + def test_serialize_many_with_shared_multimodal_tensor(self): + shared_pixel_values = torch.randn(2, 3) + shared_image_grid_thw = torch.tensor([1, 2, 3], dtype=torch.int64) + exps = [] + + for i in range(4): + exps.append( + Experience( + eid=EID(batch=1, task=1, run=i, step=1), + tokens=torch.tensor([1, 2, 3], dtype=torch.int32), + prompt_length=1, + multi_modal_inputs={ + "pixel_values": shared_pixel_values, + "image_grid_thw": shared_image_grid_thw, + }, + ) + ) + + data = Experience.serialize_many(exps) + restored = Experience.deserialize_many(data) + + self.assertEqual(len(restored), 4) + for exp in restored: + self.assertIsNotNone(exp.multi_modal_inputs) + self.assertTrue( + torch.equal(exp.multi_modal_inputs["pixel_values"], shared_pixel_values) + ) + self.assertTrue( + torch.equal(exp.multi_modal_inputs["image_grid_thw"], shared_image_grid_thw) + ) + def test_deserialize_legacy_pickle_payload(self): exp = Experience(tokens=torch.tensor([1, 2, 3]), reward=1.23, prompt_length=1) legacy_data = pickle.dumps(exp) diff --git a/tests/manager/synchronizer_test.py b/tests/manager/synchronizer_test.py index 3c5588a37f4..dd9d4c8c4c8 100644 --- a/tests/manager/synchronizer_test.py +++ b/tests/manager/synchronizer_test.py @@ -71,7 +71,7 @@ async def new_finish_explore_step(self: Explorer, step: int, model_version: int) ) for _ in range(self.config.buffer.train_batch_size) ] - await self.experience_pipeline.process.remote(dummy_exps) + await self.experience_pipeline.process.remote(Experience.serialize_many(dummy_exps)) self.monitor.log(metric, step=step) Explorer.explore_step = new_explore_step diff --git a/trinity/common/experience.py b/trinity/common/experience.py index 948354236ad..f8bb1e2722e 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -367,7 +367,10 @@ def serialize_many(cls, experiences: List[Experience]) -> bytes: for key in mm_keys: value = exp.multi_modal_inputs[key] tensor_data[f"{index}:multi_modal_inputs:{key}"] = ( - value.detach().cpu().contiguous() + value.detach() + .cpu() + .contiguous() + .clone() # clone to avoid shared memory issues ) metadata["items"].append(item_meta) diff --git a/trinity/explorer/proxy/recorder.py b/trinity/explorer/proxy/recorder.py index d5eaded27e7..c70ab07b82f 100644 --- a/trinity/explorer/proxy/recorder.py +++ b/trinity/explorer/proxy/recorder.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Dict, List, Set from sqlalchemy.orm import sessionmaker @@ -73,3 +73,39 @@ def update_reward( # The session commit is handled by the `retry_session` context manager. updated_experiences = [record.to_experience() for record in records] return updated_experiences + + +class MemoryHistoryRecorder: + """ + In-memory version of HistoryRecorder for high-performance reward update and history recording. + All data is stored in memory, and can be flushed to persistent storage as needed. + """ + + def __init__(self): + self.logger = get_logger() + # msg_id -> Experience + self._exp_map: Dict[str, Experience] = {} + # Set of msg_id that are not consumed + self._unconsumed: Set[str] = set() + + def record_history(self, experiences: List[Experience]) -> None: + """Save experiences in memory.""" + for exp in experiences: + self._exp_map[exp.eid.suffix] = exp + if getattr(exp, "consumed", 0) == 0: + self._unconsumed.add(exp.eid.suffix) + + def update_reward( + self, reward: float, msg_ids: list, run_id: int, task_id: str + ) -> List[Experience]: + """Update reward for given response IDs and return the updated experiences.""" + updated = [] + for msg_id in msg_ids: + if msg_id in self._unconsumed and msg_id in self._exp_map: + exp = self._exp_map.pop(msg_id) + exp.reward = reward + exp.eid.run = run_id + exp.eid.task = task_id + self._unconsumed.remove(msg_id) + updated.append(exp) + return updated