diff --git a/core/engine/src/decision_graph/cleaner.rs b/core/engine/src/decision_graph/cleaner.rs index af1671a4..549c0820 100644 --- a/core/engine/src/decision_graph/cleaner.rs +++ b/core/engine/src/decision_graph/cleaner.rs @@ -3,7 +3,7 @@ use std::ops::Deref; use std::rc::Rc; use zen_types::variable::Variable; -pub(crate) const ZEN_RESERVED_PROPERTIES: &[&str] = &["$nodes"]; +pub(crate) const ZEN_RESERVED_PROPERTIES: &[&str] = &["$nodes", "$params"]; pub(crate) struct VariableCleaner { visited: HashSet, diff --git a/core/engine/src/decision_graph/graph.rs b/core/engine/src/decision_graph/graph.rs index 23f351a7..11edd29f 100644 --- a/core/engine/src/decision_graph/graph.rs +++ b/core/engine/src/decision_graph/graph.rs @@ -136,7 +136,7 @@ impl DecisionGraph { return Err(Box::new(EvaluationError::DepthLimitExceeded)); } - let mut walker = GraphWalker::new(&self.graph); + let mut walker = GraphWalker::new(&self.graph, self.config.content.params.clone()); let mut tracer = NodeTracer::new(self.config.trace); while let Some(nid) = walker.next(&mut self.graph, tracer.trace_callback()) { diff --git a/core/engine/src/decision_graph/walker.rs b/core/engine/src/decision_graph/walker.rs index 9d8d88f0..9fc49a3b 100644 --- a/core/engine/src/decision_graph/walker.rs +++ b/core/engine/src/decision_graph/walker.rs @@ -1,3 +1,8 @@ +use crate::config::ZEN_CONFIG; +use crate::model::{ + DecisionEdge, DecisionNode, DecisionNodeKind, SwitchStatement, SwitchStatementHitPolicy, +}; +use crate::DecisionGraphTrace; use ahash::HashMap; use fixedbitset::FixedBitSet; use petgraph::data::DataMap; @@ -5,17 +10,12 @@ use petgraph::matrix_graph::Zero; use petgraph::prelude::{EdgeIndex, NodeIndex, StableDiGraph}; use petgraph::visit::{EdgeRef, IntoNodeIdentifiers, VisitMap, Visitable}; use petgraph::{Incoming, Outgoing}; +use serde_json::Value; use std::ops::Deref; use std::rc::Rc; use std::sync::atomic::Ordering; use std::sync::Arc; use std::time::Instant; - -use crate::config::ZEN_CONFIG; -use crate::model::{ - DecisionEdge, DecisionNode, DecisionNodeKind, SwitchStatement, SwitchStatementHitPolicy, -}; -use crate::DecisionGraphTrace; use zen_expression::variable::{ToVariable, Variable}; use zen_expression::Isolate; @@ -32,6 +32,7 @@ pub(crate) struct GraphWalker { ordered: FixedBitSet, to_visit: Vec, visited_switch_nodes: Vec, + params: Option, nodes_in_context: bool, } @@ -39,9 +40,10 @@ pub(crate) struct GraphWalker { const ITER_MAX: usize = 1_000; impl GraphWalker { - pub fn new(graph: &StableDiDecisionGraph) -> Self { + pub fn new(graph: &StableDiDecisionGraph, params: Option>) -> Self { let mut walker = Self::empty(graph); walker.initialize_input_nodes(graph); + walker.params = params.map(|p| Variable::from(p.deref())); walker } @@ -61,6 +63,7 @@ impl GraphWalker { node_data: Default::default(), visited_switch_nodes: Default::default(), iter: 0, + params: None, nodes_in_context: ZEN_CONFIG.nodes_in_context.load(Ordering::Relaxed), } @@ -114,13 +117,17 @@ impl GraphWalker { ) -> (Variable, Variable) { let value = self.merge_node_data(g.neighbors_directed(node_id, Incoming)); - if self.nodes_in_context && with_nodes { - if let Some(object_ref) = value.as_object() { - let mut new_object = object_ref.borrow().clone(); + if let Some(object_ref) = value.as_object() { + let mut new_object = object_ref.borrow().clone(); + if self.nodes_in_context && with_nodes { new_object.insert(Rc::from("$nodes"), self.get_all_node_data()); + } - return (Variable::from_object(new_object), value); + if let Some(params) = &self.params { + new_object.insert(Rc::from("$params"), params.clone()); } + + return (Variable::from_object(new_object), value); } (value.depth_clone(1), value) @@ -184,6 +191,7 @@ impl GraphWalker { if let Some(on_trace) = &mut on_trace { let output = input_trace.depth_clone(1); output.dot_remove("$nodes"); + output.dot_remove("$params"); on_trace(DecisionGraphTrace { id: decision_node.id.clone(), diff --git a/core/engine/src/model/decision_content.rs b/core/engine/src/model/decision_content.rs index a8d4769d..03db8d9b 100644 --- a/core/engine/src/model/decision_content.rs +++ b/core/engine/src/model/decision_content.rs @@ -1,5 +1,6 @@ use ahash::{HashMap, HashMapExt}; use serde::{Deserialize, Serialize}; +use serde_json::Value; use std::sync::Arc; use zen_expression::compiler::Opcode; use zen_expression::{ExpressionKind, Isolate}; @@ -15,6 +16,8 @@ pub struct CompilationKey { pub struct DecisionContent { pub nodes: Vec>, pub edges: Vec>, + #[serde(default)] + pub params: Option>, #[serde(skip)] pub compiled_cache: Option>>>, diff --git a/core/engine/src/nodes/transform_attributes.rs b/core/engine/src/nodes/transform_attributes.rs index c4b8933a..ecf1fd40 100644 --- a/core/engine/src/nodes/transform_attributes.rs +++ b/core/engine/src/nodes/transform_attributes.rs @@ -28,6 +28,7 @@ impl TransformAttributesExecution for TransformAttributes { .node_context_message(&ctx, "Failed to evaluate expression")?; let nodes = ctx.input.dot("$nodes").unwrap_or(Variable::Null); + let params = ctx.input.dot("$params").unwrap_or(Variable::Null); match &calculated_input { Variable::Array(arr) => { let arr = arr.borrow(); @@ -36,6 +37,7 @@ impl TransformAttributesExecution for TransformAttributes { .map(|v| { let new_v = v.depth_clone(1); new_v.dot_insert("$nodes", nodes.clone()); + new_v.dot_insert("$params", params.clone()); new_v }) .collect(); @@ -45,6 +47,7 @@ impl TransformAttributesExecution for TransformAttributes { _ => { let new_input = calculated_input.depth_clone(1); new_input.dot_insert("$nodes", nodes); + new_input.dot_insert("$params", params); new_input } } @@ -62,6 +65,7 @@ impl TransformAttributesExecution for TransformAttributes { } response.output.dot_remove("$nodes"); + response.output.dot_remove("$params"); response.output } TransformExecutionMode::Loop => { @@ -91,6 +95,7 @@ impl TransformAttributesExecution for TransformAttributes { } response.output.dot_remove("$nodes"); + response.output.dot_remove("$params"); output_array.push(response.output); } diff --git a/core/engine/tests/engine.rs b/core/engine/tests/engine.rs index 77bd6f28..47c874ad 100644 --- a/core/engine/tests/engine.rs +++ b/core/engine/tests/engine.rs @@ -191,6 +191,7 @@ async fn engine_function_imports() { edges: function_content.edges, nodes: new_nodes, compiled_cache: None, + params: None, }; let decision = DecisionEngine::default().create_decision(function_content.into()); let response = decision.evaluate(json!({}).into()).await.unwrap();