diff --git a/docs/train_eval.md b/docs/train_eval.md index 3fd34d9..b5cdf47 100644 --- a/docs/train_eval.md +++ b/docs/train_eval.md @@ -35,7 +35,8 @@ python $NAVSIM_DEVKIT_ROOT/navsim/planning/script/run_training.py \ trainer.params.max_epochs=10 \ cache_path="${NAVSIM_EXP_ROOT}/training_cache/" \ use_cache_without_dataset=True \ - force_cache_computation=False + force_cache_computation=False \ + +df_verssion=v2 # Training mode selector python $NAVSIM_DEVKIT_ROOT/navsim/planning/script/run_training.py \ @@ -47,10 +48,25 @@ python $NAVSIM_DEVKIT_ROOT/navsim/planning/script/run_training.py \ trainer.params.max_epochs=20 \ cache_path="${NAVSIM_EXP_ROOT}/training_cache/" \ use_cache_without_dataset=True \ - force_cache_computation=False + force_cache_computation=False \ + +df_verssion=v2 ``` +if you want train DiffusionDrive v1, use this command +``` +python $NAVSIM_DEVKIT_ROOT/navsim/planning/script/run_training.py \ + agent=diffusiondrive_agent \ + experiment_name=training_diffusiondrive_agent \ + train_test_split=mini \ + split=mini\ + trainer.params.max_epochs=100 \ + cache_path="${NAVSIM_EXP_ROOT}/training_cache/" \ + use_cache_without_dataset=True \ + force_cache_computation=False \ + +df_version=v1 + +``` ## 4. Evaluation Use the following command to evaluate the trained model rapidly (**several times faster than the official evaluation script**): diff --git a/navsim/planning/script/run_training.py b/navsim/planning/script/run_training.py index a32c592..3ccb82a 100644 --- a/navsim/planning/script/run_training.py +++ b/navsim/planning/script/run_training.py @@ -30,14 +30,20 @@ def build_datasets(cfg: DictConfig, agent: AbstractAgent) -> Tuple[Dataset, Data train_scene_filter: SceneFilter = instantiate(cfg.train_test_split.scene_filter) if train_scene_filter.log_names is not None: train_scene_filter.log_names = [ - log_name for log_name in train_scene_filter.log_names if log_name in cfg.train_logs + log_name + for log_name in train_scene_filter.log_names + if log_name in cfg.train_logs ] else: train_scene_filter.log_names = cfg.train_logs val_scene_filter: SceneFilter = instantiate(cfg.train_test_split.scene_filter) if val_scene_filter.log_names is not None: - val_scene_filter.log_names = [log_name for log_name in val_scene_filter.log_names if log_name in cfg.val_logs] + val_scene_filter.log_names = [ + log_name + for log_name in val_scene_filter.log_names + if log_name in cfg.val_logs + ] else: val_scene_filter.log_names = cfg.val_logs @@ -99,19 +105,21 @@ def main(cfg: DictConfig) -> None: if cfg.use_cache_without_dataset: logger.info("Using cached data without building SceneLoader") - assert ( - not cfg.force_cache_computation - ), "force_cache_computation must be False when using cached data without building SceneLoader" - assert ( - cfg.cache_path is not None - ), "cache_path must be provided when using cached data without building SceneLoader" + assert not cfg.force_cache_computation, ( + "force_cache_computation must be False when using cached data without building SceneLoader" + ) + assert cfg.cache_path is not None, ( + "cache_path must be provided when using cached data without building SceneLoader" + ) train_data = CacheOnlyDataset( + df_version=cfg.df_version, cache_path=cfg.cache_path, feature_builders=agent.get_feature_builders(), target_builders=agent.get_target_builders(), log_names=cfg.train_logs, ) val_data = CacheOnlyDataset( + df_version=cfg.df_version, cache_path=cfg.cache_path, feature_builders=agent.get_feature_builders(), target_builders=agent.get_target_builders(), diff --git a/navsim/planning/training/dataset.py b/navsim/planning/training/dataset.py index 76aa49f..5a78115 100644 --- a/navsim/planning/training/dataset.py +++ b/navsim/planning/training/dataset.py @@ -9,7 +9,10 @@ from tqdm import tqdm from navsim.common.dataloader import SceneLoader -from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder +from navsim.planning.training.abstract_feature_target_builder import ( + AbstractFeatureBuilder, + AbstractTargetBuilder, +) from navsim.common.dataloader import SceneLoader, MetricCacheLoader logger = logging.getLogger(__name__) @@ -22,7 +25,9 @@ def load_feature_target_from_pickle(path: Path) -> Dict[str, torch.Tensor]: return data_dict -def dump_feature_target_to_pickle(path: Path, data_dict: Dict[str, torch.Tensor]) -> None: +def dump_feature_target_to_pickle( + path: Path, data_dict: Dict[str, torch.Tensor] +) -> None: """Helper function to save feature/target to pickle.""" # Use compresslevel = 1 to compress the size but also has fast write and read. with gzip.open(path, "wb", compresslevel=1) as f: @@ -34,6 +39,7 @@ class CacheOnlyDataset(torch.utils.data.Dataset): def __init__( self, + df_version: str, cache_path: str, feature_builders: List[AbstractFeatureBuilder], target_builders: List[AbstractTargetBuilder], @@ -49,13 +55,18 @@ def __init__( super().__init__() assert Path(cache_path).is_dir(), f"Cache path {cache_path} does not exist!" self._cache_path = Path(cache_path) + self.df_version = df_version if log_names is not None: - self.log_names = [Path(log_name) for log_name in log_names if (self._cache_path / log_name).is_dir()] + self.log_names = [ + Path(log_name) + for log_name in log_names + if (self._cache_path / log_name).is_dir() + ] else: self.log_names = [log_name for log_name in self._cache_path.iterdir()] - if 'metric' in cache_path: + if "metric" in cache_path: self.metric_cache_loader = MetricCacheLoader(Path(cache_path)) self._feature_builders = feature_builders @@ -74,13 +85,20 @@ def __len__(self) -> int: """ return len(self.tokens) - def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + def __getitem__(self, index: int): """ Loads and returns pair of feature and target dict from data. :param idx: index of sample to load. :return: tuple of feature and target dictionary """ - return self._load_scene_with_token(idx) + if self.df_version == "v1": + return self._getitem_v1(index) + elif self.df_version in "v2": + return self._getitem_v2(index) + else: + raise ValueError( + f"Unexpected df_version: '{self.df_version}'. Supported df_versions are v1 or v2" + ) @staticmethod def _load_valid_caches( @@ -112,20 +130,49 @@ def _load_valid_caches( return valid_cache_paths - def _load_scene_with_token(self, idx) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + def _getitem_v1( + self, index + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + token: str = self.tokens[index] + token_path = self._valid_cache_paths[token] + + features: Dict[str, torch.Tensor] = {} + for builder in self._feature_builders: + data_dict_path = token_path / (builder.get_unique_name() + ".gz") + data_dict = load_feature_target_from_pickle(data_dict_path) + features.update(data_dict) + + targets: Dict[str, torch.Tensor] = {} + for builder in self._target_builders: + data_dict_path = token_path / (builder.get_unique_name() + ".gz") + data_dict = load_feature_target_from_pickle(data_dict_path) + targets.update(data_dict) + + return (features, targets) + + def _getitem_v2( + self, idx + ) -> Tuple[ + Dict[str, torch.Tensor], + Dict[str, torch.Tensor], + str, + str, + ]: """ Helper method to load sample tensors given token :param token: unique string identifier of sample :return: tuple of feature and target dictionaries """ - token = self.tokens[idx] + token: str = self.tokens[idx] token_path = self._valid_cache_paths[token] - if 'training_cache' in str(token_path): - pdm_token_path = str(token_path).replace("training_cache", "train_pdm_cache") - pdm_token_path_parts = pdm_token_path.split('/') - pdm_token_path_parts.insert(-1, 'unknown') - pdm_token_path = '/'.join(pdm_token_path_parts) + "/metric_cache.pkl" - else: + if "training_cache" in str(token_path): + pdm_token_path = str(token_path).replace( + "training_cache", "train_pdm_cache" + ) + pdm_token_path_parts = pdm_token_path.split("/") + pdm_token_path_parts.insert(-1, "unknown") + pdm_token_path = "/".join(pdm_token_path_parts) + "/metric_cache.pkl" + else: pdm_token_path = token_path features: Dict[str, torch.Tensor] = {} @@ -187,7 +234,9 @@ def _load_valid_caches( for token_path in log_path.iterdir(): found_caches: List[bool] = [] for builder in feature_builders + target_builders: - data_dict_path = token_path / (builder.get_unique_name() + ".gz") + data_dict_path = token_path / ( + builder.get_unique_name() + ".gz" + ) found_caches.append(data_dict_path.is_file()) if all(found_caches): valid_cache_paths[token_path.name] = token_path @@ -219,7 +268,9 @@ def _cache_scene_with_token(self, token: str) -> None: self._valid_cache_paths[token] = token_path - def _load_scene_with_token(self, token: str) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + def _load_scene_with_token( + self, token: str + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: """ Helper function to load feature / targets from cache. :param token: unique identifier of scene to load @@ -252,7 +303,9 @@ def cache_dataset(self) -> None: if self._force_cache_computation: tokens_to_cache = self._scene_loader.tokens else: - tokens_to_cache = set(self._scene_loader.tokens) - set(self._valid_cache_paths.keys()) + tokens_to_cache = set(self._scene_loader.tokens) - set( + self._valid_cache_paths.keys() + ) tokens_to_cache = list(tokens_to_cache) logger.info( f""" @@ -271,7 +324,9 @@ def __len__(self) -> None: """ return len(self._scene_loader) - def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + def __getitem__( + self, idx: int + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: """ Get features or targets either from cache or computed on-the-fly. :param idx: index of sample to load. @@ -283,13 +338,15 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torc targets: Dict[str, torch.Tensor] = {} if self._cache_path is not None: - assert ( - token in self._valid_cache_paths.keys() - ), f"The token {token} has not been cached yet, please call cache_dataset first!" + assert token in self._valid_cache_paths.keys(), ( + f"The token {token} has not been cached yet, please call cache_dataset first!" + ) features, targets = self._load_scene_with_token(token) else: - scene = self._scene_loader.get_scene_from_token(self._scene_loader.tokens[idx]) + scene = self._scene_loader.get_scene_from_token( + self._scene_loader.tokens[idx] + ) agent_input = scene.get_agent_input() for builder in self._feature_builders: features.update(builder.compute_features(agent_input)) @@ -298,6 +355,7 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torc return (features, targets) + class CacheOnlyDatasetTest(torch.utils.data.Dataset): def __init__( self, @@ -310,30 +368,48 @@ def __init__( super().__init__() self._feature_cache_path = Path(feature_cache_path) self._metric_cache_path = Path(metric_cache_path) - assert self._feature_cache_path.is_dir(), f"Feature cache path {feature_cache_path} does not exist!" - assert self._metric_cache_path.is_dir(), f"Metric cache path {metric_cache_path} does not exist!" + assert self._feature_cache_path.is_dir(), ( + f"Feature cache path {feature_cache_path} does not exist!" + ) + assert self._metric_cache_path.is_dir(), ( + f"Metric cache path {metric_cache_path} does not exist!" + ) if log_names is not None: - self.log_names = [Path(log_name) for log_name in log_names if (self._feature_cache_path / log_name).is_dir()] + self.log_names = [ + Path(log_name) + for log_name in log_names + if (self._feature_cache_path / log_name).is_dir() + ] else: - self.log_names = [log_name for log_name in self._feature_cache_path.iterdir() if log_name.is_dir()] + self.log_names = [ + log_name + for log_name in self._feature_cache_path.iterdir() + if log_name.is_dir() + ] self._feature_builders = feature_builders self._target_builders = target_builders self._valid_cache_paths: Dict[str, Path] = self._load_valid_caches() self.tokens = list(self._valid_cache_paths.keys()) if not self.tokens: - logger.error(f"CacheOnlyDataset: No valid cache found in {feature_cache_path}. Please ensure it was generated correctly.") + logger.error( + f"CacheOnlyDataset: No valid cache found in {feature_cache_path}. Please ensure it was generated correctly." + ) else: - logger.info(f"CacheOnlyDataset: Found {len(self.tokens)} cached scenarios for evaluation.") + logger.info( + f"CacheOnlyDataset: Found {len(self.tokens)} cached scenarios for evaluation." + ) def __len__(self) -> int: return len(self.tokens) - def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Path, str]: + def __getitem__( + self, idx: int + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Path, str]: token = self.tokens[idx] token_path = self._valid_cache_paths[token] - + features: Dict[str, torch.Tensor] = {} for builder in self._feature_builders: data_dict_path = token_path / (builder.get_unique_name() + ".gz") @@ -345,10 +421,11 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torc data_dict_path = token_path / (builder.get_unique_name() + ".gz") if data_dict_path.is_file(): targets.update(load_feature_target_from_pickle(data_dict_path)) - log_name = token_path.parent.name - metric_cache_path = self._metric_cache_path / log_name / "unknown" / token / "metric_cache.pkl" + metric_cache_path = ( + self._metric_cache_path / log_name / "unknown" / token / "metric_cache.pkl" + ) return (features, targets, metric_cache_path, token) def _load_valid_caches(self) -> Dict[str, Path]: @@ -357,13 +434,15 @@ def _load_valid_caches(self) -> Dict[str, Path]: for log_name in tqdm(self.log_names, desc="Checking Cached Logs"): log_path = self._feature_cache_path / log_name for token_path in log_path.iterdir(): - if not token_path.is_dir(): continue - + if not token_path.is_dir(): + continue + found_caches: List[bool] = [] for builder in self._feature_builders + self._target_builders: data_dict_path = token_path / (builder.get_unique_name() + ".gz") found_caches.append(data_dict_path.is_file()) - + if all(found_caches): valid_cache_paths[token_path.name] = token_path - return valid_cache_paths \ No newline at end of file + return valid_cache_paths +