diff --git a/src/agentlab/agents/agent_args.py b/src/agentlab/agents/agent_args.py index 70f370cb..b2cd0eb6 100644 --- a/src/agentlab/agents/agent_args.py +++ b/src/agentlab/agents/agent_args.py @@ -3,6 +3,16 @@ class AgentArgs(AbstractAgentArgs): + """Base class for agent arguments for instantiating an agent. + + Define agent arguments as dataclass variables of this class. For example: + + class MyAgentArgs(AgentArgs): + my_arg: str = "default_value" + my_other_arg: int = 42 + + Note: for working properly with AgentXRay, the arguments need to be serializable and hasable. + """ def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode: bool): """Optional method to set benchmark specific flags. diff --git a/src/agentlab/experiments/multi_server.py b/src/agentlab/experiments/multi_server.py new file mode 100644 index 00000000..4825258a --- /dev/null +++ b/src/agentlab/experiments/multi_server.py @@ -0,0 +1,90 @@ +from copy import deepcopy +from dataclasses import dataclass +import os +import sys +from browsergym.webarena.instance import WebArenaInstance + + +class BaseServer: + """Base class for server instances. + + Behaves like an identity function for running in parallel on servers that don't need multiple + instances. + """ + + def init(self): + pass + + +@dataclass +class WebArenaInstanceVars(BaseServer): + base_url: str + shopping: str + shopping_admin: str + reddit: str + gitlab: str + wikipedia: str + map: str + homepage: str + full_reset: str + module_name: str = "webarena" + prefix: str = "WA_" + + def make_env_vars(self): + """Return a dictionary of environment variables""" + return { + f"{self.prefix}SHOPPING": f"{self.base_url}:{self.shopping}", + f"{self.prefix}SHOPPING_ADMIN": f"{self.base_url}:{self.shopping_admin}", + f"{self.prefix}REDDIT": f"{self.base_url}:{self.reddit}", + f"{self.prefix}GITLAB": f"{self.base_url}:{self.gitlab}", + f"{self.prefix}WIKIPEDIA": f"{self.base_url}:{self.wikipedia}", + f"{self.prefix}MAP": f"{self.base_url}:{self.map}", + f"{self.prefix}HOMEPAGE": f"{self.base_url}:{self.homepage}", + f"{self.prefix}FULL_RESET": f"{self.base_url}:{self.full_reset}", + } + + def init(self): + # necessary for webarena to re-import the env vars + unimport_modules(self.module_name) + for key, value in self.make_env_vars().items(): + os.environ[key] = value + + # this is just a dynamic check to see that the env vars are set correctly + bgym_instance = WebArenaInstance() + base_url, _ = _split_url(bgym_instance.urls["reddit"]) + assert base_url == self.base_url, f"Expected {self.base_url}, got {base_url}" + + @staticmethod + def from_env_vars(prefix="WA_", module_name="webarena"): + kwargs = {"module_name": module_name} + base_urls = set() + for key, url in os.environ.items(): + if key.startswith(prefix): + base_url, url_tail = _split_url(url) + base_urls.add(base_url) + kwargs[key[len(prefix) :].lower()] = url_tail + + if len(base_urls) > 1: + raise ValueError("Multiple base urls found in environment variables") + + kwargs["base_url"] = base_urls.pop() + return WebArenaInstanceVars(**kwargs) + + def clone(self): + """Return a deep copy of the instance""" + return deepcopy(self) + + +def unimport_modules(base_name): + """un-import any module starting with base_name""" + for module in sys.modules.copy(): + if module.startswith(base_name): + del sys.modules[module] + + +def _split_url(url: str): + """Extract the base url and the port/page from a url""" + parts = url.split(":") + base_url = ":".join(parts[0:2]) + url_tail = ":".join(parts[2:]) + return base_url, url_tail diff --git a/src/agentlab/experiments/study.py b/src/agentlab/experiments/study.py index 8a65b3a2..f49ccd0e 100644 --- a/src/agentlab/experiments/study.py +++ b/src/agentlab/experiments/study.py @@ -1,5 +1,6 @@ import gzip import logging +import os import pickle import uuid from abc import ABC, abstractmethod @@ -16,6 +17,8 @@ from agentlab.experiments import reproducibility_util as repro from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies from agentlab.experiments.launch_exp import find_incomplete, non_dummy_count, run_experiments +from agentlab.experiments.multi_server import BaseServer, WebArenaInstanceVars +from multiprocessing import Pool, Manager, Queue logger = logging.getLogger(__name__) @@ -27,6 +30,7 @@ def make_study( suffix="", comment=None, ignore_dependencies=False, + parallel_servers=None, ): """Run a list of agents on a benchmark. @@ -57,10 +61,18 @@ def make_study( 3x compare to sequential executionz. To accelerate execution, you can ignore dependencies and run in full parallel. This leads to a decrease in performance of about 1%-2%, and could be more. Note: ignore_dependencies on VisualWebArena doesn't work. + parallel_servers: list[WebArenaInstanceVars] + The number of parallel servers to use `if "webarena" in benchmark.name`. Use this to + dispatch agent_args on a pool of servers in parallel. If len(agent_args) > + len(parallel_servers), the servers will be reused for next evaluation (with a reset) as + soon as it is done. Returns: - Study object or SequentialStudies object if the benchmark requires manual reset after each - evaluation such as WebArena and VisualWebArena. + Study | SequentialStudies | ParallelStudies object. + SequentialStudies: if the benchmark requires manual reset after each evaluation such as + WebArena and VisualWebArena. + ParallelStudies: if the benchmark has multiple servers to run in parallel. + Study: otherwise. """ if not isinstance(agent_args, (list, tuple)): @@ -69,7 +81,7 @@ def make_study( if isinstance(benchmark, str): benchmark = bgym.DEFAULT_BENCHMARKS[benchmark.lower()]() - if "webarena" in benchmark.name and len(agent_args) > 1: + if len(agent_args) > 1 and ("webarena" in benchmark.name or parallel_servers is not None): logger.warning( "*WebArena* requires manual reset after each evaluation. Running through SequentialStudies." ) @@ -85,8 +97,10 @@ def make_study( ignore_dependencies=ignore_dependencies, ) ) - - return SequentialStudies(studies) + if parallel_servers is not None: + return ParallelStudies(studies, parallel_servers=parallel_servers) + else: + return SequentialStudies(studies) else: return Study( agent_args, @@ -164,7 +178,7 @@ class Study(AbstractStudy): A suffix to add to the study name. This can be useful to keep track of your experiments. By default the study name contains agent name, benchmark name and date. uuid: str - A unique identifier for the study. + A unique identifier for the study. Will be generated automatically. reproducibility_info: dict Information about the study that may affect the reproducibility of the experiment. e.g.: versions of BrowserGym, benchmark, AgentLab... @@ -178,12 +192,12 @@ class Study(AbstractStudy): information. Leave any extra information that can explain why results could be different than expected. ignore_dependencies: bool - If True, ignore the dependencies of the tasks in the benchmark. *Use with caution.* So + If True, ignore the dependencies of the tasks in the benchmark. *Use with caution*. So far, only WebArena and VisualWebArena have dependencies between tasks to minimize the influence of solving one task before another one. This dependency graph allows experiments to run in parallel while respecting task dependencies. However, it still can't run more than 4 and, in practice it's speeding up evaluation by a factor of only - 3x compare to sequential executionz. To accelerate execution, you can ignore + 3x compare to sequential execution. To accelerate execution, you can ignore dependencies and run in full parallel. This leads to a decrease in performance of about 1%-2%, and could be more. Note: ignore_dependencies on VisualWebArena doesn't work. avg_step_timeout: int @@ -455,13 +469,15 @@ def run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_ study.make_dir(exp_root=self.dir) self.save() - - for study in self.studies: - study.run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch) + self._run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch) _, summary_df, _ = self.get_results() logger.info("\n" + str(summary_df)) logger.info(f"SequentialStudies {self.name} finished.") + def _run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_relaunch=3): + for study in self.studies: + study.run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch) + def override_max_steps(self, max_steps): for study in self.studies: study.override_max_steps(max_steps) @@ -471,6 +487,57 @@ def append_to_journal(self, strict_reproducibility=True): study.append_to_journal(strict_reproducibility=strict_reproducibility) +def _init_worker(server_queue: Queue): + """Run once at the initialization of the worker in the multiprocessing.Pool. + + This is typically used to initialize different environment variables of the WebArena server for + multiple instances in parallel. + + Args: + server_queue: Queue + A queue of object implementing BaseServer to initialize (or anything with a init + method). + """ + server_instance = server_queue.get() # type: "WebArenaInstanceVars" + logger.warning(f"Initializing server instance {server_instance} from process {os.getpid()}") + server_instance.init() + + +def _run_study(study: Study, n_jobs, parallel_backend, strict_reproducibility, n_relaunch): + """Wrapper to run a study remotely.""" + study.run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch) + + +@dataclass +class ParallelStudies(SequentialStudies): + + parallel_servers: list[BaseServer] | int = None + + def _run( + self, + n_jobs=1, + parallel_backend="ray", + strict_reproducibility=False, + n_relaunch=3, + ): + parallel_servers = self.parallel_servers + if isinstance(parallel_servers, int): + parallel_servers = [BaseServer() for _ in range(parallel_servers)] + + server_queue = Manager().Queue() + for server in parallel_servers: + server_queue.put(server) + + with Pool(len(parallel_servers), initializer=_init_worker, initargs=(server_queue,)) as p: + p.starmap( + _run_study, + [ + (study, n_jobs, parallel_backend, strict_reproducibility, n_relaunch) + for study in self.studies + ], + ) + + def get_most_recent_study( root_dir: Path = None, date_format: str = "%Y-%m-%d_%H-%M-%S", contains=None ): diff --git a/tests/experiments/test_multi_server.py b/tests/experiments/test_multi_server.py new file mode 100644 index 00000000..8254c2a3 --- /dev/null +++ b/tests/experiments/test_multi_server.py @@ -0,0 +1,37 @@ +from agentlab.experiments.multi_server import WebArenaInstanceVars +from browsergym.webarena.instance import WebArenaInstance + + +def test_webarena_multiserver(): + + instance_1 = WebArenaInstanceVars( + base_url="http://webarena1.eastus.cloudapp.azure.com", + shopping="8082/", + shopping_admin="8083/admin", + reddit="8080", + gitlab="9001", + wikipedia="8081/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing", + map="443", + homepage="80", + full_reset="7565", + module_name="webarena", + prefix="WA_", + ) + + instance_1.init() + + bgym_instance = WebArenaInstance() + base_url_1 = bgym_instance.urls["reddit"].rsplit(":", 1)[0] + assert base_url_1 == instance_1.base_url + + instance_2 = instance_1.clone() + instance_2.base_url = "http://webarena2.eastus.cloudapp.azure.com" + instance_2.init() + + bgym_instance = WebArenaInstance() + base_url_2 = bgym_instance.urls["reddit"].rsplit(":", 1)[0] + assert base_url_2 == instance_2.base_url + + +if __name__ == "__main__": + test_webarena_multiserver() diff --git a/tests/experiments/test_ray.py b/tests/experiments/test_ray.py index 9af5959a..a509742f 100644 --- a/tests/experiments/test_ray.py +++ b/tests/experiments/test_ray.py @@ -31,7 +31,7 @@ def test_execute_task_graph(): # Verify that parallel tasks (task2 and task3) started within a short time of each other parallel_start_diff = abs(exp_args_list[1].start_time - exp_args_list[2].start_time) print(f"parallel_start_diff: {parallel_start_diff}") - assert parallel_start_diff < 1.5 # Allow for a small delay + assert parallel_start_diff < 2 # Allow for a small delay # Ensure that the entire task graph took the expected amount of time total_time = exp_args_list[-1].end_time - exp_args_list[0].start_time diff --git a/tests/experiments/test_study.py b/tests/experiments/test_study.py new file mode 100644 index 00000000..0bc24161 --- /dev/null +++ b/tests/experiments/test_study.py @@ -0,0 +1,58 @@ +import pytest +from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_4o +from agentlab.agents.generic_agent.generic_agent import GenericAgentArgs +from agentlab.llm.chat_api import CheatMiniWoBLLMArgs +from agentlab.experiments.study import ParallelStudies, make_study, Study +from agentlab.experiments.multi_server import WebArenaInstanceVars + + +def _make_agent_args_list(): + # CheatMiniWoB agents won't succeed on WebArena, this is just for testing parallelization + agent_args_list = [] + for i in range(2): + agent_args = GenericAgentArgs( + chat_model_args=CheatMiniWoBLLMArgs(), + flags=FLAGS_GPT_4o, + ) + + agent_args.agent_name = agent_args.agent_name + f"_{i}" + agent_args_list.append(agent_args) + return agent_args_list + + +@pytest.mark.skip(reason="This test requires WebArena instances to be running") +def manual_test_launch_parallel_study_webarena(): + agent_args_list = _make_agent_args_list() + + server_instance_1 = WebArenaInstanceVars.from_env_vars() + server_instance_2 = server_instance_1.clone() + server_instance_2.base_url = "http://webarena-slow.eastus.cloudapp.azure.com" + parallel_servers = [server_instance_1, server_instance_2] + + for server in parallel_servers: + print(server) + + study = make_study( + agent_args_list, benchmark="webarena_tiny", parallel_servers=parallel_servers + ) + assert isinstance(study, ParallelStudies) + + study.run(n_jobs=4, parallel_backend="ray", n_relaunch=1) + + +def test_launch_parallel_study(): + agent_args_list = _make_agent_args_list() + + study = make_study(agent_args_list, benchmark="miniwob_tiny_test", parallel_servers=2) + assert isinstance(study, ParallelStudies) + + study.run(n_jobs=4, parallel_backend="ray", n_relaunch=1) + _, summary_df, _ = study.get_results() + assert len(summary_df) == 2 + for n_completed in summary_df["n_completed"]: + assert n_completed == "4/4" + + +if __name__ == "__main__": + # test_launch_parallel_study() + manual_test_launch_parallel_study_webarena()