Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions docs/train_eval.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ python $NAVSIM_DEVKIT_ROOT/navsim/planning/script/run_training.py \
trainer.params.max_epochs=10 \
cache_path="${NAVSIM_EXP_ROOT}/training_cache/" \
use_cache_without_dataset=True \
force_cache_computation=False
force_cache_computation=False \
+df_verssion=v2

# Training mode selector
python $NAVSIM_DEVKIT_ROOT/navsim/planning/script/run_training.py \
Expand All @@ -47,10 +48,25 @@ python $NAVSIM_DEVKIT_ROOT/navsim/planning/script/run_training.py \
trainer.params.max_epochs=20 \
cache_path="${NAVSIM_EXP_ROOT}/training_cache/" \
use_cache_without_dataset=True \
force_cache_computation=False
force_cache_computation=False \
+df_verssion=v2

```

if you want train DiffusionDrive v1, use this command
```
python $NAVSIM_DEVKIT_ROOT/navsim/planning/script/run_training.py \
agent=diffusiondrive_agent \
experiment_name=training_diffusiondrive_agent \
train_test_split=mini \
split=mini\
trainer.params.max_epochs=100 \
cache_path="${NAVSIM_EXP_ROOT}/training_cache/" \
use_cache_without_dataset=True \
force_cache_computation=False \
+df_version=v1

```

## 4. Evaluation
Use the following command to evaluate the trained model rapidly (**several times faster than the official evaluation script**):
Expand Down
24 changes: 16 additions & 8 deletions navsim/planning/script/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,20 @@ def build_datasets(cfg: DictConfig, agent: AbstractAgent) -> Tuple[Dataset, Data
train_scene_filter: SceneFilter = instantiate(cfg.train_test_split.scene_filter)
if train_scene_filter.log_names is not None:
train_scene_filter.log_names = [
log_name for log_name in train_scene_filter.log_names if log_name in cfg.train_logs
log_name
for log_name in train_scene_filter.log_names
if log_name in cfg.train_logs
]
else:
train_scene_filter.log_names = cfg.train_logs

val_scene_filter: SceneFilter = instantiate(cfg.train_test_split.scene_filter)
if val_scene_filter.log_names is not None:
val_scene_filter.log_names = [log_name for log_name in val_scene_filter.log_names if log_name in cfg.val_logs]
val_scene_filter.log_names = [
log_name
for log_name in val_scene_filter.log_names
if log_name in cfg.val_logs
]
else:
val_scene_filter.log_names = cfg.val_logs

Expand Down Expand Up @@ -99,19 +105,21 @@ def main(cfg: DictConfig) -> None:

if cfg.use_cache_without_dataset:
logger.info("Using cached data without building SceneLoader")
assert (
not cfg.force_cache_computation
), "force_cache_computation must be False when using cached data without building SceneLoader"
assert (
cfg.cache_path is not None
), "cache_path must be provided when using cached data without building SceneLoader"
assert not cfg.force_cache_computation, (
"force_cache_computation must be False when using cached data without building SceneLoader"
)
assert cfg.cache_path is not None, (
"cache_path must be provided when using cached data without building SceneLoader"
)
train_data = CacheOnlyDataset(
df_version=cfg.df_version,
cache_path=cfg.cache_path,
feature_builders=agent.get_feature_builders(),
target_builders=agent.get_target_builders(),
log_names=cfg.train_logs,
)
val_data = CacheOnlyDataset(
df_version=cfg.df_version,
cache_path=cfg.cache_path,
feature_builders=agent.get_feature_builders(),
target_builders=agent.get_target_builders(),
Expand Down
151 changes: 115 additions & 36 deletions navsim/planning/training/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from tqdm import tqdm

from navsim.common.dataloader import SceneLoader
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
from navsim.planning.training.abstract_feature_target_builder import (
AbstractFeatureBuilder,
AbstractTargetBuilder,
)
from navsim.common.dataloader import SceneLoader, MetricCacheLoader

logger = logging.getLogger(__name__)
Expand All @@ -22,7 +25,9 @@ def load_feature_target_from_pickle(path: Path) -> Dict[str, torch.Tensor]:
return data_dict


def dump_feature_target_to_pickle(path: Path, data_dict: Dict[str, torch.Tensor]) -> None:
def dump_feature_target_to_pickle(
path: Path, data_dict: Dict[str, torch.Tensor]
) -> None:
"""Helper function to save feature/target to pickle."""
# Use compresslevel = 1 to compress the size but also has fast write and read.
with gzip.open(path, "wb", compresslevel=1) as f:
Expand All @@ -34,6 +39,7 @@ class CacheOnlyDataset(torch.utils.data.Dataset):

def __init__(
self,
df_version: str,
cache_path: str,
feature_builders: List[AbstractFeatureBuilder],
target_builders: List[AbstractTargetBuilder],
Expand All @@ -49,13 +55,18 @@ def __init__(
super().__init__()
assert Path(cache_path).is_dir(), f"Cache path {cache_path} does not exist!"
self._cache_path = Path(cache_path)
self.df_version = df_version

if log_names is not None:
self.log_names = [Path(log_name) for log_name in log_names if (self._cache_path / log_name).is_dir()]
self.log_names = [
Path(log_name)
for log_name in log_names
if (self._cache_path / log_name).is_dir()
]
else:
self.log_names = [log_name for log_name in self._cache_path.iterdir()]

if 'metric' in cache_path:
if "metric" in cache_path:
self.metric_cache_loader = MetricCacheLoader(Path(cache_path))

self._feature_builders = feature_builders
Expand All @@ -74,13 +85,20 @@ def __len__(self) -> int:
"""
return len(self.tokens)

def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
def __getitem__(self, index: int):
"""
Loads and returns pair of feature and target dict from data.
:param idx: index of sample to load.
:return: tuple of feature and target dictionary
"""
return self._load_scene_with_token(idx)
if self.df_version == "v1":
return self._getitem_v1(index)
elif self.df_version in "v2":
return self._getitem_v2(index)
else:
raise ValueError(
f"Unexpected df_version: '{self.df_version}'. Supported df_versions are v1 or v2"
)

@staticmethod
def _load_valid_caches(
Expand Down Expand Up @@ -112,20 +130,49 @@ def _load_valid_caches(

return valid_cache_paths

def _load_scene_with_token(self, idx) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
def _getitem_v1(
self, index
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
token: str = self.tokens[index]
token_path = self._valid_cache_paths[token]

features: Dict[str, torch.Tensor] = {}
for builder in self._feature_builders:
data_dict_path = token_path / (builder.get_unique_name() + ".gz")
data_dict = load_feature_target_from_pickle(data_dict_path)
features.update(data_dict)

targets: Dict[str, torch.Tensor] = {}
for builder in self._target_builders:
data_dict_path = token_path / (builder.get_unique_name() + ".gz")
data_dict = load_feature_target_from_pickle(data_dict_path)
targets.update(data_dict)

return (features, targets)

def _getitem_v2(
self, idx
) -> Tuple[
Dict[str, torch.Tensor],
Dict[str, torch.Tensor],
str,
str,
]:
"""
Helper method to load sample tensors given token
:param token: unique string identifier of sample
:return: tuple of feature and target dictionaries
"""
token = self.tokens[idx]
token: str = self.tokens[idx]
token_path = self._valid_cache_paths[token]
if 'training_cache' in str(token_path):
pdm_token_path = str(token_path).replace("training_cache", "train_pdm_cache")
pdm_token_path_parts = pdm_token_path.split('/')
pdm_token_path_parts.insert(-1, 'unknown')
pdm_token_path = '/'.join(pdm_token_path_parts) + "/metric_cache.pkl"
else:
if "training_cache" in str(token_path):
pdm_token_path = str(token_path).replace(
"training_cache", "train_pdm_cache"
)
pdm_token_path_parts = pdm_token_path.split("/")
pdm_token_path_parts.insert(-1, "unknown")
pdm_token_path = "/".join(pdm_token_path_parts) + "/metric_cache.pkl"
else:
pdm_token_path = token_path

features: Dict[str, torch.Tensor] = {}
Expand Down Expand Up @@ -187,7 +234,9 @@ def _load_valid_caches(
for token_path in log_path.iterdir():
found_caches: List[bool] = []
for builder in feature_builders + target_builders:
data_dict_path = token_path / (builder.get_unique_name() + ".gz")
data_dict_path = token_path / (
builder.get_unique_name() + ".gz"
)
found_caches.append(data_dict_path.is_file())
if all(found_caches):
valid_cache_paths[token_path.name] = token_path
Expand Down Expand Up @@ -219,7 +268,9 @@ def _cache_scene_with_token(self, token: str) -> None:

self._valid_cache_paths[token] = token_path

def _load_scene_with_token(self, token: str) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
def _load_scene_with_token(
self, token: str
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""
Helper function to load feature / targets from cache.
:param token: unique identifier of scene to load
Expand Down Expand Up @@ -252,7 +303,9 @@ def cache_dataset(self) -> None:
if self._force_cache_computation:
tokens_to_cache = self._scene_loader.tokens
else:
tokens_to_cache = set(self._scene_loader.tokens) - set(self._valid_cache_paths.keys())
tokens_to_cache = set(self._scene_loader.tokens) - set(
self._valid_cache_paths.keys()
)
tokens_to_cache = list(tokens_to_cache)
logger.info(
f"""
Expand All @@ -271,7 +324,9 @@ def __len__(self) -> None:
"""
return len(self._scene_loader)

def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
def __getitem__(
self, idx: int
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""
Get features or targets either from cache or computed on-the-fly.
:param idx: index of sample to load.
Expand All @@ -283,13 +338,15 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torc
targets: Dict[str, torch.Tensor] = {}

if self._cache_path is not None:
assert (
token in self._valid_cache_paths.keys()
), f"The token {token} has not been cached yet, please call cache_dataset first!"
assert token in self._valid_cache_paths.keys(), (
f"The token {token} has not been cached yet, please call cache_dataset first!"
)

features, targets = self._load_scene_with_token(token)
else:
scene = self._scene_loader.get_scene_from_token(self._scene_loader.tokens[idx])
scene = self._scene_loader.get_scene_from_token(
self._scene_loader.tokens[idx]
)
agent_input = scene.get_agent_input()
for builder in self._feature_builders:
features.update(builder.compute_features(agent_input))
Expand All @@ -298,6 +355,7 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torc

return (features, targets)


class CacheOnlyDatasetTest(torch.utils.data.Dataset):
def __init__(
self,
Expand All @@ -310,30 +368,48 @@ def __init__(
super().__init__()
self._feature_cache_path = Path(feature_cache_path)
self._metric_cache_path = Path(metric_cache_path)
assert self._feature_cache_path.is_dir(), f"Feature cache path {feature_cache_path} does not exist!"
assert self._metric_cache_path.is_dir(), f"Metric cache path {metric_cache_path} does not exist!"
assert self._feature_cache_path.is_dir(), (
f"Feature cache path {feature_cache_path} does not exist!"
)
assert self._metric_cache_path.is_dir(), (
f"Metric cache path {metric_cache_path} does not exist!"
)

if log_names is not None:
self.log_names = [Path(log_name) for log_name in log_names if (self._feature_cache_path / log_name).is_dir()]
self.log_names = [
Path(log_name)
for log_name in log_names
if (self._feature_cache_path / log_name).is_dir()
]
else:
self.log_names = [log_name for log_name in self._feature_cache_path.iterdir() if log_name.is_dir()]
self.log_names = [
log_name
for log_name in self._feature_cache_path.iterdir()
if log_name.is_dir()
]

self._feature_builders = feature_builders
self._target_builders = target_builders
self._valid_cache_paths: Dict[str, Path] = self._load_valid_caches()
self.tokens = list(self._valid_cache_paths.keys())
if not self.tokens:
logger.error(f"CacheOnlyDataset: No valid cache found in {feature_cache_path}. Please ensure it was generated correctly.")
logger.error(
f"CacheOnlyDataset: No valid cache found in {feature_cache_path}. Please ensure it was generated correctly."
)
else:
logger.info(f"CacheOnlyDataset: Found {len(self.tokens)} cached scenarios for evaluation.")
logger.info(
f"CacheOnlyDataset: Found {len(self.tokens)} cached scenarios for evaluation."
)

def __len__(self) -> int:
return len(self.tokens)

def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Path, str]:
def __getitem__(
self, idx: int
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Path, str]:
token = self.tokens[idx]
token_path = self._valid_cache_paths[token]

features: Dict[str, torch.Tensor] = {}
for builder in self._feature_builders:
data_dict_path = token_path / (builder.get_unique_name() + ".gz")
Expand All @@ -345,10 +421,11 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torc
data_dict_path = token_path / (builder.get_unique_name() + ".gz")
if data_dict_path.is_file():
targets.update(load_feature_target_from_pickle(data_dict_path))


log_name = token_path.parent.name
metric_cache_path = self._metric_cache_path / log_name / "unknown" / token / "metric_cache.pkl"
metric_cache_path = (
self._metric_cache_path / log_name / "unknown" / token / "metric_cache.pkl"
)
return (features, targets, metric_cache_path, token)

def _load_valid_caches(self) -> Dict[str, Path]:
Expand All @@ -357,13 +434,15 @@ def _load_valid_caches(self) -> Dict[str, Path]:
for log_name in tqdm(self.log_names, desc="Checking Cached Logs"):
log_path = self._feature_cache_path / log_name
for token_path in log_path.iterdir():
if not token_path.is_dir(): continue

if not token_path.is_dir():
continue

found_caches: List[bool] = []
for builder in self._feature_builders + self._target_builders:
data_dict_path = token_path / (builder.get_unique_name() + ".gz")
found_caches.append(data_dict_path.is_file())

if all(found_caches):
valid_cache_paths[token_path.name] = token_path
return valid_cache_paths
return valid_cache_paths