Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cpp/examples/abm_minimal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ int main()
model.parameters.get<mio::abm::TimeExposedToNoSymptoms>() = mio::ParameterDistributionLogNormal(4., 1.);

// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
model.parameters.get<mio::abm::AgeGroupGotoSchool>() = false;
model.parameters.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
// Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 and 35-59)
model.parameters.get<mio::abm::AgeGroupGotoWork>().set_multiple({age_group_15_to_34, age_group_35_to_59}, true);
Expand Down
43 changes: 43 additions & 0 deletions cpp/memilio/io/history.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#define MIO_IO_HISTORY_H

#include "memilio/utils/metaprogramming.h"
#include <memory>
#include <vector>
#include <tuple>

Expand Down Expand Up @@ -108,6 +109,19 @@ class History
return m_data;
}

/**
* @brief Access the data of the given Logger.
* This function only works with Writers that stores its records in a tuple, like DataWriterToMemory.
* @return A read-only reference to the Logger's records.
*/
template <class Logger>
requires(is_type_in_list_v<Logger, Loggers...> &&
std::tuple_size_v<typename WriteWrapper::Data> == sizeof...(Loggers))
const auto& get_log() const
{
return std::get<index_of_type_v<Logger, Loggers...>>(m_data);
}

private:
typename WriteWrapper::Data m_data;
std::tuple<Loggers...> m_loggers;
Expand All @@ -127,6 +141,35 @@ class History
}
};

namespace details
{

template <class T>
class AbstractHistory
{
public:
template <template <class...> class Writer, class... Loggers>
AbstractHistory(History<Writer, Loggers...>& history)
: m_history(static_cast<void*>(&history), [](void*) {})
, m_log([](const std::shared_ptr<void>& h, const T& t) {
using H = History<Writer, Loggers...>;
static_cast<H*>(h.get())->log(t);
})
{
}

void log(const T& t)
{
m_log(m_history, t);
}

private:
std::shared_ptr<void> m_history; ///< A non-owning pointer to the History.
void (*m_log)(const std::shared_ptr<void>&, const T&); ///< Function pointer that stores the
};

} // namespace details

template <class... Loggers>
using HistoryWithMemoryWriter = History<DataWriterToMemory, Loggers...>;

Expand Down
8 changes: 7 additions & 1 deletion cpp/memilio/utils/abstract_parameter_distribution.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,16 @@
#include "parameter_distributions.h"
#include <memory>
#include <string>
#include <concepts>

namespace mio
{

template <class T>
concept HasSampleFunction = requires(T t) {
{ t.get_sample(std::declval<RandomNumberGenerator&>()) } -> std::convertible_to<ScalarType>;
};
Comment on lines +35 to +38
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HasSampleFunction uses std::declval but this header doesn’t include <utility> directly. It currently works only via transitive includes (e.g. through memilio/io/io.h), which is brittle. Add #include <utility> here to make the dependency explicit.

Copilot uses AI. Check for mistakes.

/**
* @brief This class represents an arbitrary ParameterDistribution.
* @see mio::ParameterDistribution
Expand All @@ -44,7 +50,7 @@ class AbstractParameterDistribution
* The implementation handed to the constructor should have get_sample function
* overloaded with mio::RandomNumberGenerator and mio::abm::PersonalRandomNumberGenerator as input arguments
*/
template <class Impl>
template <HasSampleFunction Impl>
AbstractParameterDistribution(Impl&& dist)
: m_dist(std::make_shared<Impl>(std::move(dist)))
, sample_impl1([](void* d, RandomNumberGenerator& rng) {
Expand Down
3 changes: 2 additions & 1 deletion cpp/models/abm/mobility_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ enum class ActivityType : uint32_t
PrivateMatters,
OtherActivity,
Home,
UnknownActivity
UnknownActivity,
Count //last!!
};

} // namespace abm
Expand Down
15 changes: 15 additions & 0 deletions cpp/models/abm/simulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "abm/model.h"
#include "abm/time.h"
#include "memilio/io/history.h"
#include <vector>

namespace mio
{
Expand Down Expand Up @@ -77,6 +78,20 @@ class Simulation
}
}

void advance(TimePoint tmax, std::vector<details::AbstractHistory<Simulation>> histories)
{
//log initial system state
for (auto& history : histories) {
history.log(*this);
}
while (m_t < tmax) {
evolve_model(tmax);
for (auto& history : histories) {
history.log(*this);
}
}
}

/**
* @brief Get the current time of the Simulation.
*/
Expand Down
189 changes: 189 additions & 0 deletions pycode/examples/simulation/abm_minimal_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
#############################################################################
# Copyright (C) 2020-2026 MEmilio
#
# Authors: Carlotta Gerstein
#
# Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#############################################################################

from memilio.simulation import AgeGroup
import memilio.simulation.abm as abm
import memilio.simulation as mio

import numpy as np
import random

num_age_groups = 4

model = abm.Model(num_age_groups)

# Set parameters

for age_group in range(num_age_groups):
model.parameters.TimeExposedToNoSymptoms[abm.VirusVariant.Wildtype, AgeGroup(age_group)] = mio.AbstractParameterLogNormalDistribution(mio.ParameterDistributionLogNormal(
4., 1.))

model.parameters.AgeGroupGotoSchool[AgeGroup(1)] = True
model.parameters.AgeGroupGotoWork[AgeGroup(2)] = True
model.parameters.AgeGroupGotoWork[AgeGroup(3)] = True

for age in range(num_age_groups):
model.parameters.InfectionProtectionFactor[abm.ProtectionType.GenericVaccine, AgeGroup(
age), abm.VirusVariant.Wildtype] = mio.TimeSeriesFunctor(
[[0, 0.0], [14, 0.67], [180, 0.4]])

model.parameters.SeverityProtectionFactor[abm.ProtectionType.GenericVaccine, AgeGroup(
age), abm.VirusVariant.Wildtype] = mio.TimeSeriesFunctor(
[[0, 0.0], [14, 0.85], [180, 0.7]])
Comment on lines +42 to +49
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is not in the .cpp file. We should probably stay as close as possible to the .cpp example. If we want to add this, then we should also add it in the .cpp example.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the cpp version this is divided into the abm_minimal and the abm_vaccination example. Agree that we should probably stay close, so either merge them in the cpp version or split it here. Do you have a favorite solution?


model.parameters.check_constraints()

# Set populations

n_households = 10

child = abm.HouseholdMember(num_age_groups)
child.age_weights[AgeGroup(0)] = 1.
child.age_weights[AgeGroup(1)] = 1.

parent = abm.HouseholdMember(num_age_groups)
parent.age_weights[AgeGroup(2)] = 1.
parent.age_weights[AgeGroup(3)] = 1.

twoPersonHousehold_group = abm.HouseholdGroup()
twoPersonHousehold_full = abm.Household()
twoPersonHousehold_full.add_members(child, 1)
twoPersonHousehold_full.add_members(parent, 1)
twoPersonHousehold_group.add_households(twoPersonHousehold_full, n_households)
abm.add_household_group_to_model(model, twoPersonHousehold_group)

threePersonHousehold_group = abm.HouseholdGroup()
threePersonHousehold_full = abm.Household()
threePersonHousehold_full.add_members(child, 1)
threePersonHousehold_full.add_members(parent, 2)
threePersonHousehold_group.add_households(
threePersonHousehold_full, n_households)
abm.add_household_group_to_model(model, threePersonHousehold_group)

# Set locations

event = model.add_location(abm.LocationType.SocialEvent)
model.get_location(event).infection_parameters.MaximumContacts = 5

hospital = model.add_location(abm.LocationType.Hospital)
model.get_location(hospital).infection_parameters.MaximumContacts = 5
icu = model.add_location(abm.LocationType.ICU)
model.get_location(icu).infection_parameters.MaximumContacts = 5

shop = model.add_location(abm.LocationType.BasicsShop)
model.get_location(shop).infection_parameters.MaximumContacts = 20

school = model.add_location(abm.LocationType.School)
model.get_location(school).infection_parameters.MaximumContacts = 20

work = model.add_location(abm.LocationType.Work)
model.get_location(work).infection_parameters.MaximumContacts = 20

model.parameters.AerosolTransmissionRates[abm.VirusVariant.Wildtype] = 10

contacts = np.zeros((num_age_groups, num_age_groups))
contacts[2, 3] = 10

model.get_location(
work).infection_parameters.ContactRates.baseline = contacts

# Testing Schemes

validity_period = abm.days(1)
probability = 0.5
start_date = abm.TimePoint(0)
end_date = abm.TimePoint(0) + abm.days(10)
test_type = abm.TestType.Antigen
test_parameters = model.parameters.TestData[test_type]

testing_criteria_work = abm.TestingCriteria()
testing_scheme_work = abm.TestingScheme(
testing_criteria_work, validity_period, start_date, end_date, test_parameters, probability)

model.testing_strategy.add_scheme(
abm.LocationType.Work, testing_scheme_work)

# Seed infections

infection_distribution = [0.5, 0.3, 0.05, 0.05, 0.05, 0.05, 0.0, 0.0]
for person in model.persons:
infection_state = abm.InfectionState(
mio.DiscreteDistribution.get_instance()(model.rng, infection_distribution))
prng = abm.PersonalRandomNumberGenerator(model.rng, person)
if infection_state != abm.InfectionState.Susceptible:
person.add_new_infection(mio.abm.Infection(
prng, abm.VirusVariant.Wildtype, person.age, model.parameters, start_date, infection_state))

# Assign locations

for person in model.persons:
person_id = person.id

model.assign_location(person_id, event)
model.assign_location(person_id, shop)

model.assign_location(person_id, hospital)
model.assign_location(person_id, icu)

if person.age == AgeGroup(1):
model.assign_location(person_id, school)

if person.age == AgeGroup(2) or person.age == AgeGroup(3):
model.assign_location(person_id, work)

# Vaccinations

vacc_rate = 0.7
vaccination_priority = [AgeGroup(3), AgeGroup(2), AgeGroup(1)]
vaccination_time = start_date - abm.days(20)

persons_by_age = [[] for _ in range(num_age_groups)]
for idx, person in enumerate(model.persons):
persons_by_age[person.age.get()].append(idx)

for age in vaccination_priority:
indices = persons_by_age[age.get()]

random.shuffle(indices)

n_to_vaccinate = int(np.round(vacc_rate * len(indices)))
for i in range(n_to_vaccinate):
person = model.persons[indices[i]]
if person.get_infection_state(vaccination_time) == abm.InfectionState.Susceptible:
person.add_new_vaccination(
abm.ProtectionType.GenericVaccine, vaccination_time)

# Simulate

t_lockdown = start_date + abm.days(10)
abm.close_social_events(t_lockdown, 0.9, model.parameters)

t0 = start_date
tmax = t0 + abm.days(10)
sim = abm.Simulation(t0, model)


history = abm.TimeSeriesWriterLogInfectionStateHistory(
mio.TimeSeries(len(abm.InfectionState.values())))
history2 = abm.DataWriterLogDataForMobilityHistory()

sim.advance(tmax, history, history2)

history.get_log()[0].print_table()
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (C) 2020-2026 MEmilio
*
* Authors: Carlotta Gerstein
*
* Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PYMIO_GEOLOCATION_H
#define PYMIO_GEOLOCATION_H

#include "memilio/math/integrator.h"
#include "memilio/geography/geolocation.h"
#include "pybind_util.h"

#include "pybind11/pybind11.h"
#include <pybind11/eigen.h>

namespace py = pybind11;

namespace pymio
{

void bind_geolocation(pybind11::module_& m, std::string const& name)
{
bind_class<mio::geo::GeographicalLocation, EnablePickling::Never>(m, name.c_str()).def(py::init<double, double>());
}
} // namespace pymio

#endif //PYMIO_GEOLOCATION_H
Loading
Loading