Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 115 additions & 54 deletions pycode/memilio-epidata/memilio/epidata/getContactData.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
Prem et al., 2017 (DOI: https://doi.org/10.1371/journal.pcbi.1005697).
The module can download the supporting ZIP from
https://doi.org/10.1371/journal.pcbi.1005697.s002 (contains the
``MUestimates_all_locations_1.xlsx`` workbook) or read a defined local
workbook path. By default, downloads are done in memory and no
files are written.
``MUestimates_all_locations_1.xlsx`` and
``MUestimates_all_locations_2.xlsx`` workbooks) or read a defined local
workbook path. By default, downloads are done in memory and no files are
written.
"""

import io
Expand All @@ -43,7 +44,10 @@
"https://journals.plos.org/ploscompbiol/article/file"
"?id=10.1371/journal.pcbi.1005697.s002&type=supplementary"
)
CONTACT_WORKBOOK_NAME = "MUestimates_all_locations_1.xlsx"
CONTACT_WORKBOOK_NAMES = (
"MUestimates_all_locations_1.xlsx",
"MUestimates_all_locations_2.xlsx",
)

AGE_GROUP_LABELS = [
"0-4",
Expand Down Expand Up @@ -84,59 +88,64 @@ def _normalize_country_name(country: str):
return "".join(ch for ch in country.casefold() if ch.isalnum())


def _download_contact_workbook(
url: str = CONTACT_ZIP_URL, target_filename: str = CONTACT_WORKBOOK_NAME):
def _download_contact_workbooks():
"""
Download the ZIP from the url and return the workbook.
Download the ZIP and return all contact workbooks.

:param url: URL to download the ZIP from.
:param target_filename: Name of the workbook file within the ZIP.
:returns: Content of the workbook.
:returns: List of workbook contents.
"""
response = requests.get(url, timeout=30)
response = requests.get(CONTACT_ZIP_URL, timeout=30)
response.raise_for_status()

workbooks = []
with zipfile.ZipFile(io.BytesIO(response.content)) as zf:
candidates = [name for name in zf.namelist()
if name.endswith(target_filename)]
if not candidates:
raise FileNotFoundError(
f"'{target_filename}' not found in downloaded workbook.")
with zf.open(candidates[0]) as f:
return f.read()
for target_filename in CONTACT_WORKBOOK_NAMES:
candidates = [name for name in zf.namelist()
if name.endswith(target_filename)]
if not candidates:
raise FileNotFoundError(
f"'{target_filename}' not found in downloaded workbook.")
with zf.open(candidates[0]) as f:
workbooks.append(f.read())

return workbooks

def _load_workbook_bytes(
contact_path: Optional[str],
url: str = CONTACT_ZIP_URL,
target_filename: str = CONTACT_WORKBOOK_NAME):

def _load_workbooks_bytes(
contact_path: Optional[str]):
"""
Return workbook either from a user path or by downloading the ZIP.
Return one explicit workbook or, by default, all downloaded workbooks.

:param contact_path: Optional local path to the workbook.
:param url: Url to download the ZIP from if no path is provided.
:param target_filename: Name of the workbook file within the ZIP.
:returns: Content of the workbook.
:param contact_path: Optional local path to a single workbook.
:returns: List of workbook contents.
"""
if contact_path:
if not os.path.exists(contact_path):
raise FileNotFoundError(
f"Contact matrix file not found at {contact_path}")
with open(contact_path, "rb") as f:
return f.read()
return _download_contact_workbook(url=url, target_filename=target_filename)
return [f.read()]
return _download_contact_workbooks()


def list_available_contact_countries(
contact_path: Optional[str] = None):
"""
List all country names available in the contact matrix workbook.
List all country names available in the contact matrix workbooks.

:param contact_path: Optional local path to the workbook.
:returns: List of all country names.
"""
xls_bytes = _load_workbook_bytes(contact_path)
xls = pd.ExcelFile(io.BytesIO(xls_bytes))
return xls.sheet_names
countries = []
seen = set()
for xls_bytes in _load_workbooks_bytes(contact_path):
xls = pd.ExcelFile(io.BytesIO(xls_bytes))
for sheet_name in xls.sheet_names:
key = _normalize_country_name(sheet_name)
if key not in seen:
countries.append(sheet_name)
seen.add(key)
return countries


def _select_sheet_name(country: str, sheet_names: Iterable[str]):
Expand All @@ -157,6 +166,58 @@ def _select_sheet_name(country: str, sheet_names: Iterable[str]):
return lookup[key]


def _read_contact_sheet(xls: pd.ExcelFile, sheet_name: str, country: str):
Comment thread
HenrZu marked this conversation as resolved.
"""
Read a contact sheet and extract its numeric 16x16 matrix.

Some source workbooks contain an explicit header row, others contain only
the matrix. Reading without headers and selecting the numeric block handles
both formats.

:param xls: Opened Excel workbook containing contact matrix sheets.
:param sheet_name: Name of the sheet to read.
:param country: Country name used for error messages.
:returns: DataFrame containing the extracted 16x16 contact matrix.
"""
df = pd.read_excel(
xls,
sheet_name=sheet_name,
engine="openpyxl",
header=None)

return _extract_contact_matrix(df, country)


def _extract_contact_matrix(df: pd.DataFrame, country: str):
Comment thread
HenrZu marked this conversation as resolved.
"""
Extract the 16x16 numeric contact matrix from a raw Excel sheet.

:param df: Raw sheet data read from the workbook without headers.
:param country: Country name used for error messages.
:returns: DataFrame with age-group labels as index and columns.
"""
matrix_size = len(AGE_GROUP_LABELS)
numeric = df.apply(pd.to_numeric, errors="coerce")
max_row_start = numeric.shape[0] - matrix_size
max_col_start = numeric.shape[1] - matrix_size

for row_start in range(max_row_start, -1, -1):
for col_start in range(max_col_start, -1, -1):
matrix = numeric.iloc[
row_start:row_start + matrix_size,
col_start:col_start + matrix_size]
if matrix.shape == (matrix_size, matrix_size):
if not matrix.isnull().any().any():
matrix = matrix.copy()
matrix.columns = AGE_GROUP_LABELS
matrix.index = AGE_GROUP_LABELS
return matrix

raise ValueError(
f"Contact matrix for '{country}' does not contain a numeric "
f"{matrix_size}x{matrix_size} block. Raw shape: {df.shape}")


def load_contact_matrix(
country: str,
contact_path: Optional[str] = None,
Expand All @@ -165,31 +226,31 @@ def load_contact_matrix(
"""
Load the all-locations contact matrix for the given country. If
``contact_path`` is not provided, the function downloads the
``MUestimates_all_locations_1.xlsx`` workbook from Prem et al., 2017.
``MUestimates_all_locations_*.xlsx`` workbooks from Prem et al., 2017.

:param country: Country name as listed in the workbook (case-insensitive).
:param contact_path: Optional path to ``MUestimates_all_locations_1.xlsx``.
:param contact_path: Optional path to one ``MUestimates_all_locations``
workbook.
:param reduce_to_rki_groups: If True, aggregate to the six RKI age groups
(0-4, 5-14, 15-34, 35-59, 60-79, 80+ years). Default True.
:param population: An iterable of 16 float values representing the population
size for each original age group. Required if reduce_to_rki_groups is True.
:param population: An iterable of 16 float values representing the
population size for each original age group. Required if
reduce_to_rki_groups is True.
:returns: DataFrame indexed by age group with floats.
"""
xls_bytes = _load_workbook_bytes(contact_path)
xls = pd.ExcelFile(io.BytesIO(xls_bytes))
sheet_names = xls.sheet_names
sheet = _select_sheet_name(country, sheet_names)
df = pd.read_excel(xls, sheet_name=sheet, engine="openpyxl")

# Ensure numeric values and trim potential trailing rows/cols.
matrix = df.apply(pd.to_numeric, errors="coerce")
matrix = matrix.iloc[:len(AGE_GROUP_LABELS), :len(AGE_GROUP_LABELS)]
matrix.columns = AGE_GROUP_LABELS[:matrix.shape[1]]
matrix.index = AGE_GROUP_LABELS[:matrix.shape[0]]

if matrix.isnull().any().any():
raise ValueError(
f"Contact matrix for '{country}' contains non-numeric entries.")
all_sheet_names = []
for xls_bytes in _load_workbooks_bytes(contact_path):
xls = pd.ExcelFile(io.BytesIO(xls_bytes))
sheet_names = xls.sheet_names
all_sheet_names.extend(sheet_names)
try:
sheet = _select_sheet_name(country, sheet_names)
except ValueError:
continue
matrix = _read_contact_sheet(xls, sheet, country)
break
else:
_select_sheet_name(country, all_sheet_names)

if matrix.shape[0] != matrix.shape[1]:
raise ValueError(
Expand All @@ -210,8 +271,8 @@ def load_contact_matrix(
def _aggregate_to_rki_age_groups(
matrix: pd.DataFrame, population: Iterable[float]):
"""
Aggregate an age-structured 16x16 contact matrix to the 6-group RKI scheme using
population-weighted averages.
Aggregate an age-structured 16x16 contact matrix to the 6-group RKI
scheme using population-weighted averages.
Assumes the original columns/rows follow AGE_GROUP_LABELS order.
Note: The source only provides data up to 70-74 and a 75+ group.
We map 60-74 to the 60-79 RKI group and 75+ to the 80-99 RKI group.
Expand Down
Loading
Loading