From 9c740746d49c5546ad7c936d9ef66b998ae01137 Mon Sep 17 00:00:00 2001 From: ParticularlyPythonicBS Date: Wed, 21 Jan 2026 14:59:15 -0500 Subject: [PATCH 1/4] feat: font size slider for visualizer --- temoa/utilities/graph_utils.py | 21 +++-- .../network_vis_templates/graph_script.js | 78 +++++++++++++++---- .../network_vis_templates/graph_styles.css | 3 + .../network_vis_templates/graph_template.html | 7 ++ temoa/utilities/visualizer.py | 6 +- 5 files changed, 94 insertions(+), 21 deletions(-) diff --git a/temoa/utilities/graph_utils.py b/temoa/utilities/graph_utils.py index a1541f0a..7c417a76 100644 --- a/temoa/utilities/graph_utils.py +++ b/temoa/utilities/graph_utils.py @@ -175,10 +175,15 @@ def calculate_initial_positions( return positions # Arrange sector "anchors" in a large circle - layout_radius = 2000 # The radius of the main circle for sectors - jitter_radius = 1000 # How far nodes can be from their sector anchor - sector_anchors = {} + # Scale radius based on the number of sectors and nodes to handle small models better num_sectors = len(sectors_to_place) + num_nodes = len(nodes_to_place) + + # Base radius + incremental scaling + layout_radius = max(800, min(2000, 400 + 200 * num_sectors + 2 * num_nodes)) + jitter_radius = layout_radius // 2 + + sector_anchors = {} for i, sector in enumerate(sectors_to_place): angle = (i / num_sectors) * 2 * math.pi @@ -226,10 +231,14 @@ def calculate_tech_graph_positions( return {} # 2. Arrange sector "anchors" in a large circle - layout_radius = 2500 # Use a large radius to ensure initial separation - jitter_radius = 600 # Controls the size of the initial clusters - sector_anchors = {} + # Scale radius based on the number of sectors and edges num_sectors = len(sectors_to_place) + num_edges = sum(1 for _ in all_edges) + + layout_radius = max(1000, min(2500, 500 + 300 * num_sectors + num_edges)) + jitter_radius = layout_radius // 4 + + sector_anchors = {} for i, sector in enumerate(sectors_to_place): angle = (i / num_sectors) * 2 * math.pi diff --git a/temoa/utilities/network_vis_templates/graph_script.js b/temoa/utilities/network_vis_templates/graph_script.js index 78e77b91..d7f392aa 100644 --- a/temoa/utilities/network_vis_templates/graph_script.js +++ b/temoa/utilities/network_vis_templates/graph_script.js @@ -28,11 +28,25 @@ document.addEventListener('DOMContentLoaded', function () { primary_view_name: primaryViewName, secondary_view_name: secondaryViewName, } = data; - const optionsObject = (typeof optionsRaw === 'string') ? JSON.parse(optionsRaw) : optionsRaw; - // --- State --- + + const optionsObject = + typeof optionsRaw === "string" ? JSON.parse(optionsRaw) : optionsRaw; + + window.__graph = { + data, + allNodesPrimary, + allEdgesPrimary, + allNodesSecondary, + allEdgesSecondary, + optionsObject, + }; + // --- Visual State --- let currentView = 'primary'; let primaryViewPositions = null; let secondaryViewPositions = null; + let visualState = { + fontSize: 14 + }; // --- DOM Elements --- const configWrapper = document.getElementById('config-panel-wrapper'); @@ -40,6 +54,7 @@ document.addEventListener('DOMContentLoaded', function () { const configToggleButton = document.querySelector('.config-toggle-btn'); const advancedControlsToggle = document.getElementById('advanced-controls-toggle'); const visConfigContainer = document.getElementById('vis-config-container'); + const fontSizeSlider = document.getElementById('font-size-slider'); const searchInput = document.getElementById('search-input'); const resetButton = document.getElementById('reset-view-btn'); const sectorTogglesContainer = document.getElementById('sector-toggles'); @@ -61,9 +76,33 @@ document.addEventListener('DOMContentLoaded', function () { }); } + // --- Visual Settings Sliders --- + function updateVisualSettings() { + if (fontSizeSlider) visualState.fontSize = parseInt(fontSizeSlider.value, 10); + + // Use setOptions for global font size - works for edges with smooth enabled + // Note: Don't set per-edge font as it breaks rendering with smooth edges + network.setOptions({ + nodes: { font: { size: visualState.fontSize } }, + edges: { font: { size: visualState.fontSize, align: 'top' } } + }); + + // Also update nodes individually since they have per-node font from addWithCurrentFontSize + const nodeUpdates = nodes.get().map(n => ({ + id: n.id, + font: { ...(n.font ?? {}), size: visualState.fontSize } + })); + nodes.update(nodeUpdates); + + network.redraw(); + } + + if (fontSizeSlider) fontSizeSlider.addEventListener('input', updateVisualSettings); + + // --- Vis.js Network Initialization --- - const nodes = new vis.DataSet(allNodesPrimary); - const edges = new vis.DataSet(allEdgesPrimary); + const nodes = new vis.DataSet(); + const edges = new vis.DataSet(); const network = new vis.Network(graphContainer, { nodes, edges }, optionsObject); // --- Core Functions --- @@ -84,13 +123,13 @@ document.addEventListener('DOMContentLoaded', function () { nodes.clear(); edges.clear(); if (currentView === 'primary') { - nodes.add(allNodesSecondary); edges.add(allEdgesSecondary); + addWithCurrentFontSize(allNodesSecondary, allEdgesSecondary); currentView = 'secondary'; viewToggleButton.textContent = `Switch to ${primaryViewName}`; viewToggleButton.setAttribute('aria-pressed', 'true'); applyPositions(secondaryViewPositions); } else { - nodes.add(allNodesPrimary); edges.add(allEdgesPrimary); + addWithCurrentFontSize(allNodesPrimary, allEdgesPrimary); currentView = 'primary'; viewToggleButton.textContent = `Switch to ${secondaryViewName}`; viewToggleButton.setAttribute('aria-pressed', 'false'); @@ -134,8 +173,8 @@ document.addEventListener('DOMContentLoaded', function () { const visibleNodeIds = new Set(visibleNodes.map(n => n.id)); visibleEdges = activeEdgesData.filter(edge => visibleNodeIds.has(edge.from) && visibleNodeIds.has(edge.to)); } - nodes.clear(); edges.clear(); - nodes.add(visibleNodes); edges.add(visibleEdges); + + addWithCurrentFontSize(visibleNodes, visibleEdges); applyPositions(currentPositions); } @@ -205,6 +244,20 @@ document.addEventListener('DOMContentLoaded', function () { }); } + function addWithCurrentFontSize(newNodes, newEdges) { + nodes.clear(); + edges.clear(); + nodes.add( + newNodes.map(n => ({ + ...n, + font: { ...(n.font ?? {}), size: visualState.fontSize }, + })), + ); + // Don't set per-edge font - let network.setOptions() handle it + // vis.js ignores global font options when edges have per-item font set + edges.add(newEdges); + } + function resetView() { searchInput.value = ""; primaryViewPositions = null; @@ -213,8 +266,7 @@ document.addEventListener('DOMContentLoaded', function () { switchView(); // This will switch back to primary and apply null positions } else { // If already on primary, just reload the original data - nodes.clear(); edges.clear(); - nodes.add(allNodesPrimary); edges.add(allEdgesPrimary); + addWithCurrentFontSize(allNodesPrimary, allEdgesPrimary); applyPositions(primaryViewPositions); // Apply null to reset network.fit(); } @@ -233,9 +285,7 @@ document.addEventListener('DOMContentLoaded', function () { }); const filteredNodes = activeNodes.filter(node => nodesToShow.has(node.id)); const filteredEdges = activeEdges.filter(edge => nodesToShow.has(edge.from) && nodesToShow.has(edge.to)); - nodes.clear(); edges.clear(); - nodes.add(filteredNodes); - edges.add(filteredEdges); + addWithCurrentFontSize(filteredNodes, filteredEdges); network.fit(); } @@ -257,4 +307,6 @@ document.addEventListener('DOMContentLoaded', function () { createStyleLegend(); createSectorLegend(); createSectorToggles(); + // Initial data load with consistent font handling + addWithCurrentFontSize(allNodesPrimary, allEdgesPrimary); }); diff --git a/temoa/utilities/network_vis_templates/graph_styles.css b/temoa/utilities/network_vis_templates/graph_styles.css index 0402c090..41e36997 100644 --- a/temoa/utilities/network_vis_templates/graph_styles.css +++ b/temoa/utilities/network_vis_templates/graph_styles.css @@ -24,6 +24,9 @@ body, html { .legend-item { display: flex; align-items: center; margin-bottom: 6px; } .legend-color-swatch { width: 18px; height: 18px; margin-right: 8px; flex-shrink: 0; border: 1px solid #ccc; background-color: #f0f0f0; box-sizing: border-box; } .legend-label { font-size: 13px; } +.control-group { display: flex; align-items: center; gap: 15px; margin-bottom: 8px; } +.control-group label { min-width: 120px; font-size: 13px; font-weight: 500; } +.control-group input[type=range] { flex-grow: 1; max-width: 250px; } #advanced-controls-toggle { font-size: 12px; color: #007bff; cursor: pointer; text-decoration: none; margin-top: 15px; display: block; } .view-toggle-panel { padding: 8px 15px; background-color: #343a40; color: white; display: flex; justify-content: center; align-items: center; } .view-toggle-panel button { font-size: 14px; font-weight: 600; padding: 8px 16px; border-radius: 5px; border: 1px solid #6c757d; background-color: #495057; color: white; cursor: pointer; } diff --git a/temoa/utilities/network_vis_templates/graph_template.html b/temoa/utilities/network_vis_templates/graph_template.html index 8ec0a1f0..7ddb6caf 100644 --- a/temoa/utilities/network_vis_templates/graph_template.html +++ b/temoa/utilities/network_vis_templates/graph_template.html @@ -17,6 +17,13 @@

Configuration & Legend

aria-controls="config-container-content">
+
+

Visual Settings

+
+ + +
+

Style Legend

diff --git a/temoa/utilities/visualizer.py b/temoa/utilities/visualizer.py index 9dc04369..50885e7f 100644 --- a/temoa/utilities/visualizer.py +++ b/temoa/utilities/visualizer.py @@ -180,7 +180,8 @@ def make_nx_graph( if any(info['attrs'].get('dashes', False) for info in techs_info): combined_attrs['dashes'] = True - combined_attrs['value'] = sum(info['attrs'].get('value', 1) for info in techs_info) + # Use 'width' for thickness, 'value' breaks font rendering with smooth edges + combined_attrs['width'] = 2 + len(techs_info) # Base width + 1 per tech multi_edge_key = f'{ic}-{oc}-{uuid.uuid4().hex[:8]}' dg.add_edge(ic, oc, key=multi_edge_key, **combined_attrs) @@ -280,6 +281,7 @@ def nx_to_vis( 'width': 2, 'smooth': {'type': 'continuous', 'roundness': 0.5}, 'arrows': {'to': {'enabled': False, 'scaleFactor': 1}}, + 'font': {'align': 'top', 'size': 14}, }, 'physics': { 'enabled': False, @@ -304,7 +306,7 @@ def nx_to_vis( 'navigationButtons': False, 'keyboard': {'enabled': True, 'bindToWindow': False}, }, - 'layout': {'randomSeed': None, 'improvedLayout': True}, + 'layout': {'improvedLayout': True}, 'configure': { 'enabled': True, 'showButton': False, # We have our own header, so hide the default floating button From 2e63ad0c10ffaf6744f1c009c87d76ced33a6666 Mon Sep 17 00:00:00 2001 From: ParticularlyPythonicBS Date: Wed, 21 Jan 2026 15:00:01 -0500 Subject: [PATCH 2/4] tests: adding tests for special item styling in visualizer --- tests/test_commodity_visualizer.py | 90 ++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 tests/test_commodity_visualizer.py diff --git a/tests/test_commodity_visualizer.py b/tests/test_commodity_visualizer.py new file mode 100644 index 00000000..d8c9c2fc --- /dev/null +++ b/tests/test_commodity_visualizer.py @@ -0,0 +1,90 @@ +from unittest.mock import MagicMock + +from temoa.model_checking.commodity_graph import generate_commodity_graph +from temoa.model_checking.network_model_data import EdgeTuple, NetworkModelData +from temoa.types.core_types import Period, Region, Sector, Technology, Vintage + + +def test_special_items_styling(): + """ + Test that demand orphans, other orphans, and driven techs + are correctly styled in the commodity graph. + """ + region = Region('test_region') + period = Period(2025) + + # Mock NetworkModelData + network_data = MagicMock(spec=NetworkModelData) + network_data.physical_commodities = {'comm_inter'} + network_data.source_commodities = {(region, period): {'comm_source'}} + network_data.demand_commodities = {(region, period): {'comm_demand'}} + network_data.available_techs = {(region, period): {}} + network_data.tech_data = {} + + # Define some special items + demand_orphans = [ + EdgeTuple( + region, + 'comm_inter', + Technology('tech_demand_orphan'), + Vintage(2020), + 'comm_demand', + sector=Sector('S1'), + ) + ] + other_orphans = [ + EdgeTuple( + region, + 'comm_source', + Technology('tech_other_orphan'), + Vintage(2020), + 'comm_inter', + sector=Sector('S2'), + ) + ] + driven_techs = [ + EdgeTuple( + region, + 'comm_source', + Technology('tech_driven'), + Vintage(2020), + 'comm_demand', + sector=Sector('S3'), + ) + ] + + # Generate the graph + dg, _sector_colors = generate_commodity_graph( + region, + period, + network_data, + demand_orphans=demand_orphans, + other_orphans=other_orphans, + driven_techs=driven_techs, + ) + + # 1. Check Node Styling + assert dg.nodes['comm_demand']['color']['border'] == '#d62728' + assert dg.nodes['comm_demand']['borderWidth'] == 4 + assert 'Connected to Demand Orphan' in dg.nodes['comm_demand']['title'] + + assert dg.nodes['comm_inter']['color']['border'] == '#d62728' + assert dg.nodes['comm_inter']['borderWidth'] == 4 + + assert dg.nodes['comm_source']['color']['border'] == '#ff7f0e' + assert dg.nodes['comm_source']['borderWidth'] == 4 + + # 2. Check Edge Styling + edges = list(dg.edges(data=True)) + + edge_do = next(e for e in edges if (e[0] == 'comm_inter' and e[1] == 'comm_demand')) + assert edge_do[2]['dashes'] is True + assert edge_do[2]['color'] == '#d62728' + + edge_oo = next(e for e in edges if (e[0] == 'comm_source' and e[1] == 'comm_inter')) + assert edge_oo[2]['dashes'] is True + assert edge_oo[2]['color'] == '#ff7f0e' + + edge_dt = next(e for e in edges if (e[0] == 'comm_source' and e[1] == 'comm_demand')) + assert edge_dt[2]['dashes'] is True + assert edge_dt[2]['color'] == '#1f77b4' From 2b7148b2bd8b80107c9dfb7ba7aa90253eca71de Mon Sep 17 00:00:00 2001 From: ParticularlyPythonicBS Date: Wed, 21 Jan 2026 15:00:55 -0500 Subject: [PATCH 3/4] docs: updating visualizer embed in docs to have font size slider --- .../source/default/static/Network_Graph_utopia_1990.html | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/source/default/static/Network_Graph_utopia_1990.html b/docs/source/default/static/Network_Graph_utopia_1990.html index 24303675..84b1755e 100644 --- a/docs/source/default/static/Network_Graph_utopia_1990.html +++ b/docs/source/default/static/Network_Graph_utopia_1990.html @@ -17,6 +17,13 @@

Configuration & Legend

aria-controls="config-container-content">
+
+

Visual Settings

+
+ + +
+

Style Legend

@@ -47,7 +54,7 @@

Sector Legend

From 64f0278a1c7890d8042cb40d491a28c29f70e399 Mon Sep 17 00:00:00 2001 From: ParticularlyPythonicBS Date: Wed, 21 Jan 2026 15:21:19 -0500 Subject: [PATCH 4/4] PR feedback fixes --- .../static/Network_Graph_utopia_1990.html | 2 +- temoa/utilities/graph_utils.py | 30 ++++-- .../network_vis_templates/graph_script.js | 92 ++++++++++++------- .../network_vis_templates/graph_template.html | 2 +- tests/test_commodity_visualizer.py | 51 +++++----- 5 files changed, 107 insertions(+), 70 deletions(-) diff --git a/docs/source/default/static/Network_Graph_utopia_1990.html b/docs/source/default/static/Network_Graph_utopia_1990.html index 84b1755e..f4089ae4 100644 --- a/docs/source/default/static/Network_Graph_utopia_1990.html +++ b/docs/source/default/static/Network_Graph_utopia_1990.html @@ -21,7 +21,7 @@

Configuration & Legend

Visual Settings

- +
diff --git a/temoa/utilities/graph_utils.py b/temoa/utilities/graph_utils.py index 7c417a76..4ee5ba5f 100644 --- a/temoa/utilities/graph_utils.py +++ b/temoa/utilities/graph_utils.py @@ -223,19 +223,28 @@ def calculate_tech_graph_positions( """ positions = {} + # Materialize the iterable to avoid consumption issues + all_edges_list = list(all_edges) + # 1. Identify all unique sectors present in the technology list - sectors_to_place = sorted({tech.sector for tech in all_edges if tech.sector}) + sectors_to_place = sorted({edge.sector for edge in all_edges_list if edge.sector}) if not sectors_to_place: # If no sectors, just return empty positions and let physics handle it return {} # 2. Arrange sector "anchors" in a large circle - # Scale radius based on the number of sectors and edges + # Scale radius based on the number of sectors and unique technologies + unique_techs_to_place = sorted( + {edge.tech for edge in all_edges_list if edge.tech}, key=lambda t: str(t) + ) num_sectors = len(sectors_to_place) - num_edges = sum(1 for _ in all_edges) + num_nodes = len(unique_techs_to_place) - layout_radius = max(1000, min(2500, 500 + 300 * num_sectors + num_edges)) + if not unique_techs_to_place: + return {} + + layout_radius = max(1000, min(2500, 500 + 300 * num_sectors + 5 * num_nodes)) jitter_radius = layout_radius // 4 sector_anchors = {} @@ -246,9 +255,12 @@ def calculate_tech_graph_positions( cy = layout_radius * math.sin(angle) sector_anchors[sector] = (cx, cy) - # 3. Place each technology node near its sector's anchor point with jitter - for edge_tuple in all_edges: - primary_sector = edge_tuple.sector + # 3. Place each unique technology node near its sector's anchor point with jitter + # Create a mapping of tech to its primary sector from the edges + tech_to_sector = {edge.tech: edge.sector for edge in all_edges_list if edge.tech} + + for tech in unique_techs_to_place: + primary_sector = tech_to_sector.get(tech) if not primary_sector or primary_sector not in sector_anchors: # Place nodes without a defined sector at the center cx, cy = 0, 0 @@ -256,13 +268,13 @@ def calculate_tech_graph_positions( cx, cy = sector_anchors[primary_sector] # Apply deterministic "jitter" to prevent stacking (stable per-tech) - seed = uuid.uuid5(uuid.NAMESPACE_DNS, str(edge_tuple.tech)).int + seed = uuid.uuid5(uuid.NAMESPACE_DNS, str(tech)).int rng = random.Random(seed) rand_angle = rng.uniform(0, 2 * math.pi) rand_radius = rng.uniform(0, jitter_radius) x = cx + rand_radius * math.cos(rand_angle) y = cy + rand_radius * math.sin(rand_angle) - positions[edge_tuple.tech] = {'x': x, 'y': y} + positions[tech] = {'x': x, 'y': y} return positions diff --git a/temoa/utilities/network_vis_templates/graph_script.js b/temoa/utilities/network_vis_templates/graph_script.js index d7f392aa..f0bf39a9 100644 --- a/temoa/utilities/network_vis_templates/graph_script.js +++ b/temoa/utilities/network_vis_templates/graph_script.js @@ -29,40 +29,58 @@ document.addEventListener('DOMContentLoaded', function () { secondary_view_name: secondaryViewName, } = data; - const optionsObject = - typeof optionsRaw === "string" ? JSON.parse(optionsRaw) : optionsRaw; - - window.__graph = { - data, - allNodesPrimary, - allEdgesPrimary, - allNodesSecondary, - allEdgesSecondary, - optionsObject, - }; - // --- Visual State --- - let currentView = 'primary'; - let primaryViewPositions = null; - let secondaryViewPositions = null; - let visualState = { - fontSize: 14 - }; + let optionsObject = {}; + if (typeof optionsRaw === "string") { + try { + optionsObject = JSON.parse(optionsRaw); + } catch (e) { + console.error('Failed to parse graph options JSON:', e); + optionsObject = {}; + } + } else { + optionsObject = optionsRaw || {}; + } + // Expose for debugging only — enable in production. + const isDebug = (typeof window !== 'undefined' && window.DEBUG_GRAPH) || + (typeof URLSearchParams !== 'undefined' && new URLSearchParams(window.location.search).has('debugGraph')); + if (isDebug) { + window.__graph = { + data, + allNodesPrimary, + allEdgesPrimary, + allNodesSecondary, + allEdgesSecondary, + optionsObject, + }; + } // --- DOM Elements --- + const fontSizeSlider = document.getElementById('font-size-slider'); const configWrapper = document.getElementById('config-panel-wrapper'); const configHeader = document.querySelector('.config-panel-header'); const configToggleButton = document.querySelector('.config-toggle-btn'); const advancedControlsToggle = document.getElementById('advanced-controls-toggle'); const visConfigContainer = document.getElementById('vis-config-container'); - const fontSizeSlider = document.getElementById('font-size-slider'); const searchInput = document.getElementById('search-input'); + + // --- Visual State --- + let currentView = 'primary'; + let primaryViewPositions = null; + let secondaryViewPositions = null; + let visualState = { + fontSize: (optionsObject?.nodes?.font?.size) || 14 + }; + + if (fontSizeSlider) { + fontSizeSlider.value = String(visualState.fontSize); + } const resetButton = document.getElementById('reset-view-btn'); const sectorTogglesContainer = document.getElementById('sector-toggles'); const viewToggleButton = document.getElementById('view-toggle-btn'); const graphContainer = document.getElementById('mynetwork'); // --- Config Panel Toggle --- - if (optionsObject.configure && optionsObject.configure.enabled) { + if (optionsObject?.configure?.enabled) { optionsObject.configure.container = visConfigContainer; configHeader.addEventListener('click', () => { const isCollapsed = configWrapper.classList.toggle('collapsed'); @@ -77,24 +95,32 @@ document.addEventListener('DOMContentLoaded', function () { } // --- Visual Settings Sliders --- + let pendingRaf = null; function updateVisualSettings() { if (fontSizeSlider) visualState.fontSize = parseInt(fontSizeSlider.value, 10); - // Use setOptions for global font size - works for edges with smooth enabled - // Note: Don't set per-edge font as it breaks rendering with smooth edges - network.setOptions({ - nodes: { font: { size: visualState.fontSize } }, - edges: { font: { size: visualState.fontSize, align: 'top' } } - }); + if (pendingRaf) return; - // Also update nodes individually since they have per-node font from addWithCurrentFontSize - const nodeUpdates = nodes.get().map(n => ({ - id: n.id, - font: { ...(n.font ?? {}), size: visualState.fontSize } - })); - nodes.update(nodeUpdates); + pendingRaf = requestAnimationFrame(() => { + pendingRaf = null; - network.redraw(); + // Use setOptions for global font size - works for edges with smooth enabled + // Note: Don't set per-edge font as it breaks rendering with smooth edges + network.setOptions({ + nodes: { font: { size: visualState.fontSize } }, + edges: { font: { size: visualState.fontSize, align: 'top' } } + }); + + // Also update nodes individually since they have per-node font from addWithCurrentFontSize + // Note: Per-node font properties must be overwritten because they would otherwise take precedence over the global setting + const nodeUpdates = nodes.get().map(n => ({ + id: n.id, + font: { ...(n.font ?? {}), size: visualState.fontSize } + })); + nodes.update(nodeUpdates); + + network.redraw(); + }); } if (fontSizeSlider) fontSizeSlider.addEventListener('input', updateVisualSettings); diff --git a/temoa/utilities/network_vis_templates/graph_template.html b/temoa/utilities/network_vis_templates/graph_template.html index 7ddb6caf..e54218ed 100644 --- a/temoa/utilities/network_vis_templates/graph_template.html +++ b/temoa/utilities/network_vis_templates/graph_template.html @@ -21,7 +21,7 @@

Configuration & Legend

Visual Settings

- +
diff --git a/tests/test_commodity_visualizer.py b/tests/test_commodity_visualizer.py index d8c9c2fc..69e2efe5 100644 --- a/tests/test_commodity_visualizer.py +++ b/tests/test_commodity_visualizer.py @@ -1,11 +1,9 @@ -from unittest.mock import MagicMock - from temoa.model_checking.commodity_graph import generate_commodity_graph from temoa.model_checking.network_model_data import EdgeTuple, NetworkModelData -from temoa.types.core_types import Period, Region, Sector, Technology, Vintage +from temoa.types.core_types import Commodity, Period, Region, Sector, Technology, Vintage -def test_special_items_styling(): +def test_special_items_styling() -> None: """ Test that demand orphans, other orphans, and driven techs are correctly styled in the commodity graph. @@ -13,42 +11,40 @@ def test_special_items_styling(): region = Region('test_region') period = Period(2025) - # Mock NetworkModelData - network_data = MagicMock(spec=NetworkModelData) - network_data.physical_commodities = {'comm_inter'} - network_data.source_commodities = {(region, period): {'comm_source'}} - network_data.demand_commodities = {(region, period): {'comm_demand'}} - network_data.available_techs = {(region, period): {}} - network_data.tech_data = {} + # Concrete NetworkModelData + network_data = NetworkModelData() + network_data.physical_commodities = {Commodity('comm_inter')} + network_data.source_commodities[(region, period)] = {Commodity('comm_source')} + network_data.demand_commodities[(region, period)] = {Commodity('comm_demand')} # Define some special items demand_orphans = [ EdgeTuple( region, - 'comm_inter', + Commodity('comm_inter'), Technology('tech_demand_orphan'), Vintage(2020), - 'comm_demand', + Commodity('comm_demand'), sector=Sector('S1'), ) ] other_orphans = [ EdgeTuple( region, - 'comm_source', + Commodity('comm_source'), Technology('tech_other_orphan'), Vintage(2020), - 'comm_inter', + Commodity('comm_inter'), sector=Sector('S2'), ) ] driven_techs = [ EdgeTuple( region, - 'comm_source', + Commodity('comm_source'), Technology('tech_driven'), Vintage(2020), - 'comm_demand', + Commodity('comm_demand'), sector=Sector('S3'), ) ] @@ -77,14 +73,17 @@ def test_special_items_styling(): # 2. Check Edge Styling edges = list(dg.edges(data=True)) - edge_do = next(e for e in edges if (e[0] == 'comm_inter' and e[1] == 'comm_demand')) - assert edge_do[2]['dashes'] is True - assert edge_do[2]['color'] == '#d62728' + edge_do = next((e for e in edges if (e[0] == 'comm_inter' and e[1] == 'comm_demand')), None) + assert edge_do is not None, 'Edge (comm_inter -> comm_demand) not found' + assert edge_do[2].get('dashes') is True + assert edge_do[2].get('color') == '#d62728' - edge_oo = next(e for e in edges if (e[0] == 'comm_source' and e[1] == 'comm_inter')) - assert edge_oo[2]['dashes'] is True - assert edge_oo[2]['color'] == '#ff7f0e' + edge_oo = next((e for e in edges if (e[0] == 'comm_source' and e[1] == 'comm_inter')), None) + assert edge_oo is not None, 'Edge (comm_source -> comm_inter) not found' + assert edge_oo[2].get('dashes') is True + assert edge_oo[2].get('color') == '#ff7f0e' - edge_dt = next(e for e in edges if (e[0] == 'comm_source' and e[1] == 'comm_demand')) - assert edge_dt[2]['dashes'] is True - assert edge_dt[2]['color'] == '#1f77b4' + edge_dt = next((e for e in edges if (e[0] == 'comm_source' and e[1] == 'comm_demand')), None) + assert edge_dt is not None, 'Edge (comm_source -> comm_demand) not found' + assert edge_dt[2].get('dashes') is True + assert edge_dt[2].get('color') == '#1f77b4'