Skip to content

Custom expression function, automatically bound to js context#374

Open
Cassielxd wants to merge 4 commits intogorules:masterfrom
Cassielxd:dev
Open

Custom expression function, automatically bound to js context#374
Cassielxd wants to merge 4 commits intogorules:masterfrom
Cassielxd:dev

Conversation

@Cassielxd
Copy link
Copy Markdown

Custom expression function, automatically bound to js context

@stefan-gorules
Copy link
Copy Markdown
Contributor

Hi, could you give more context into what this PR does?

@Cassielxd
Copy link
Copy Markdown
Author

My scenario is to use it in the pricing software of the construction industry. I need to customize business-related functions and use them in the expression engine. In the current business, if all metadata are passed in as requests, a lot of serialization is required. Therefore, a separate state mechanism is designed, which is only used in custom functions and the state will be released after execution.

@Cassielxd
Copy link
Copy Markdown
Author

use zen_expression::{Isolate, Variable};
use zen_expression::functions::mf_function::{
MfFunctionHelper, MfFunctionRegistry,
};
use zen_expression::variable::VariableType;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

// 1. 定义一个简单的、我们自己的状态
#[derive(Debug)]
struct MyState {
call_count: Mutex,
}

impl MyState {
fn new() -> Self {
Self { call_count: Mutex::new(0) }
}

fn increment(&self) -> u32 {
    let mut count = self.call_count.lock().unwrap();
    *count += 1;
    *count
}

}

fn main() -> anyhow::Result<()> {
println!("=== 自定义函数与泛型State集成演示 ===\n");

// === 第一部分: 演示使用我们自定义的 `MyState` ===
println!("--- 场景1: 使用自定义的 MyState ---");
let my_state = Arc::new(MyState::new());

// 2. 为 `MyState` 创建一个 Helper
let my_helper = MfFunctionHelper::<MyState>::new();

// 3. 注册一个可以访问 `MyState` 的函数
println!("注册函数: getMyStateCallCount()");
my_helper
    .register_function(
        "getMyStateCallCount".to_string(),
        vec![],
        VariableType::Number,
        Box::new(|_args, state_opt: Option<&MyState>| {
            if let Some(state) = state_opt {
                // `state` 的类型是 &MyState
                let count = state.increment();
                Ok(Variable::Number(count.into()))
            } else {
                Ok(Variable::Number((-1i32).into()))
            }
        }),
    )
    .map_err(|e| anyhow::anyhow!(e))?;

// 4. 创建 Isolate 并使用 `MyState` 执行表达式
let mut isolate = Isolate::new();
println!("使用 `MyState` 执行 'getMyStateCallCount()'");
let result1 = isolate
    .run_standard_with_state("getMyStateCallCount()", my_state.clone())?;
println!("  第一次调用结果: {}", result1);
let result2 = isolate
    .run_standard_with_state("getMyStateCallCount()", my_state.clone())?;
println!("  第二次调用结果: {}", result2);

// === 第三部分: 验证两种函数可以共存 ===
println!("\n--- 场景3: 验证两种状态的函数可以共存 ---");
println!("再次调用 `getMyStateCallCount` (应为3)");
let result4 = isolate
    .run_standard_with_state("getMyStateCallCount()", my_state.clone())?;
println!("  结果: {}", result4);

// 显示所有已注册的自定义函数
println!("\n=== 已注册的自定义函数 ===");
let functions = MfFunctionRegistry::list_functions();
for func in functions {
    println!("- {}", func);
}

// 清理
println!("\n清理所有自定义函数...");
MfFunctionRegistry::clear();

println!("演示完成!");
Ok(())

}

@Cassielxd
Copy link
Copy Markdown
Author

Registration is complete and is bound to the js context

use std::future::Future;
use std::pin::Pin;
use crate::handler::function::error::{FunctionResult, ResultExt};
use crate::handler::function::listener::{RuntimeEvent, RuntimeListener};
use crate::handler::function::module::export_default;
use crate::handler::function::serde::JsValue;
use rquickjs::module::{Declarations, Exports, ModuleDef};
use rquickjs::prelude::{Async, Func};
use rquickjs::{CatchResultExt, Ctx};
use zen_expression::functions::arguments::Arguments;
use zen_expression::functions::mf_function::MfFunctionRegistry;

pub struct ModuforgeListener {
// 目前为空结构体,后续可以添加配置或状态字段
}

impl RuntimeListener for ModuforgeListener {

fn on_event<'js>(
    &self,
    ctx: Ctx<'js>,
    event: RuntimeEvent,
) -> Pin<Box<dyn Future<Output = FunctionResult> + 'js>> {
    Box::pin(async move {
        // 只在运行时启动事件时执行函数注册
        if event != RuntimeEvent::Startup {
            return Ok(());
        };

        // 设置全局函数及变量
        // 创建或获取 md 命名空间对象
        let md_namespace = if ctx.globals().contains_key("md")? {
            // 如果 md 已存在,获取它
            ctx.globals().get("md")?
        } else {
            // 如果 md 不存在,创建一个新的空对象
            let md_obj = rquickjs::Object::new(ctx.clone())?;
            ctx.globals().set("md", md_obj.clone())?;
            md_obj
        };

        // 从自定义函数注册表中获取所有函数名称
        let functions_keys = MfFunctionRegistry::list_functions();

        // 遍历每个注册的函数
        for function_key in functions_keys {
            // 根据函数名获取函数定义
            let function_definition =
                MfFunctionRegistry::get_definition(&function_key);

            if let Some(function_definition) = function_definition {
                // 将Rust函数包装为JavaScript异步函数并注册到md命名空间下

                let function_definition = function_definition.clone();
                let parameters = function_definition.required_parameters();
                match parameters {
                    0 => {
                        md_namespace
                            .set(
                                function_key, // 函数名作为md对象的属性名
                                Func::from(Async(move |ctx: Ctx<'js>| {
                                    // 克隆函数定义以避免生命周期问题
                                    let function_definition =
                                        function_definition.clone();

                                    async move {
                                        // 调用Rust函数,传入JavaScript参数
                                        let response = function_definition
                                            .call(Arguments(&[]))
                                            .or_throw(&ctx)?;

                                        // 将Rust函数的返回值序列化为JSON,再转换为JavaScript值
                                        let k =
                                            serde_json::to_value(response)
                                                .or_throw(&ctx)?
                                                .into();

                                        return rquickjs::Result::Ok(
                                            JsValue(k),
                                        );
                                    }
                                })),
                            )
                            .catch(&ctx)?; // 捕获并处理可能的JavaScript异常
                    },
                    1 => {
                        md_namespace
                        .set(
                            function_key, // 函数名作为md对象的属性名
                            Func::from(Async(
                                move |ctx: Ctx<'js>, context: JsValue| {
                                    // 克隆函数定义以避免生命周期问题
                                    let function_definition =
                                        function_definition.clone();
                                    async move {
                                        // 调用Rust函数,传入JavaScript参数
                                        let response = function_definition
                                            .call(Arguments(&[context.0]))
                                            .or_throw(&ctx)?;
                                        // 将Rust函数的返回值序列化为JSON,再转换为JavaScript值
                                        let k = serde_json::to_value(response)
                                            .or_throw(&ctx)?
                                            .into();
                                        return rquickjs::Result::Ok(JsValue(
                                            k,
                                        ));
                                    }
                                },
                            )),
                        )
                        .catch(&ctx)?; // 捕获并处理可能的JavaScript异常
                    },
                    2 => {
                        md_namespace
                        .set(
                            function_key, // 函数名作为md对象的属性名
                            Func::from(Async(
                                move |ctx: Ctx<'js>, context: JsValue,context2: JsValue| {
                                    // 克隆函数定义以避免生命周期问题
                                    let function_definition =
                                        function_definition.clone();
                                    async move {
                                        // 调用Rust函数,传入JavaScript参数
                                        let response = function_definition
                                            .call(Arguments(&[context.0,context2.0]))
                                            .or_throw(&ctx)?;
                                        // 将Rust函数的返回值序列化为JSON,再转换为JavaScript值
                                        let k = serde_json::to_value(response)
                                            .or_throw(&ctx)?
                                            .into();
                                        return rquickjs::Result::Ok(JsValue(
                                            k,
                                        ));
                                    }
                                },
                            )),
                        )
                        .catch(&ctx)?; // 捕获并处理可能的JavaScript异常
                    },
                    3 => {
                        md_namespace
                        .set(
                            function_key, // 函数名作为md对象的属性名
                            Func::from(Async(
                                move |ctx: Ctx<'js>, context: JsValue,context2: JsValue,context3: JsValue| {
                                    // 克隆函数定义以避免生命周期问题
                                    let function_definition =
                                        function_definition.clone();
                                    async move {
                                        // 调用Rust函数,传入JavaScript参数
                                        let response = function_definition
                                            .call(Arguments(&[context.0,context2.0,context3.0]))
                                            .or_throw(&ctx)?;
                                        // 将Rust函数的返回值序列化为JSON,再转换为JavaScript值
                                        let k: zen_expression::Variable = serde_json::to_value(response)
                                            .or_throw(&ctx)?
                                            .into();
                                        return rquickjs::Result::Ok(JsValue(
                                            k,
                                        ));
                                    }
                                },
                            )),
                        )
                        .catch(&ctx)?; // 捕获并处理可能的JavaScript异常
                    },
                    _ => {
                        md_namespace
                        .set(
                            function_key, // 函数名作为md对象的属性名
                            Func::from(Async(
                                move |ctx: Ctx<'js>, context: Vec<JsValue>| {
                                    // 克隆函数定义以避免生命周期问题
                                    let function_definition =
                                        function_definition.clone();
                                    async move {
                                        // 调用Rust函数,传入JavaScript参数
                                        let response = function_definition
                                            .call(Arguments(&context.iter().map(|arg| arg.0.clone()).collect::<Vec<_>>()))
                                            .or_throw(&ctx)?;
                                        // 将Rust函数的返回值序列化为JSON,再转换为JavaScript值
                                        let k = serde_json::to_value(response)
                                            .or_throw(&ctx)?
                                            .into();
                                        return rquickjs::Result::Ok(JsValue(
                                            k,
                                        ));
                                    }
                                },
                            )),
                        )
                        .catch(&ctx)?; // 捕获并处理可能的JavaScript异常
                    },
                }
            }
        }

        Ok(()) // 成功完成函数注册
    })
}

}

pub struct ModuforgeModule;

impl ModuleDef for ModuforgeModule {
fn declare<'js>(decl: &Declarations<'js>) -> rquickjs::Result<()> {
// 声明所有可用的函数
for function_key in MfFunctionRegistry::list_functions() {
decl.declare(function_key.as_str())?;
}
decl.declare("default")?;
Ok(())
}

fn evaluate<'js>(
    ctx: &Ctx<'js>,
    exports: &Exports<'js>,
) -> rquickjs::Result<()> {
    export_default(ctx, exports, |default| {
        // 为每个函数创建对应的异步函数
        for function_key in MfFunctionRegistry::list_functions() {
            if let Some(function_definition) =
                MfFunctionRegistry::get_definition(&function_key)
            {
                let function_definition = function_definition.clone();
                let parameters = function_definition.required_parameters();
                match parameters {
                    0 => {
                        default.set(
                            &function_key,
                            Func::from(Async(move |ctx: Ctx<'js>| {
                                let function_definition =
                                    function_definition.clone();
                                async move {
                                    let response = function_definition
                                        .call(Arguments(&[]))
                                        .or_throw(&ctx)?;

                                    let result =
                                        serde_json::to_value(response)
                                            .or_throw(&ctx)?
                                            .into();

                                    Ok::<JsValue, rquickjs::Error>(JsValue(
                                        result,
                                    ))
                                }
                            })),
                        )?;
                    },
                    1 => {
                        //只有一个参数
                        default.set(
                            &function_key,
                            Func::from(Async(
                                move |ctx: Ctx<'js>, args: JsValue| {
                                    let function_definition =
                                        function_definition.clone();
                                    async move {
                                        let response = function_definition
                                            .call(Arguments(&[args.0]))
                                            .or_throw(&ctx)?;

                                        let result =
                                            serde_json::to_value(response)
                                                .or_throw(&ctx)?
                                                .into();

                                        Ok::<JsValue, rquickjs::Error>(
                                            JsValue(result),
                                        )
                                    }
                                },
                            )),
                        )?;
                    },
                    2 => {
                        //有两个参数
                        default.set(
                            &function_key,
                            Func::from(Async(
                                move |ctx: Ctx<'js>, args: JsValue,args2: JsValue| {
                                    let function_definition =
                                        function_definition.clone();
                                    async move {
                                        let response = function_definition
                                            .call(Arguments(&[args.0,args2.0]))
                                            .or_throw(&ctx)?;

                                        let result =
                                            serde_json::to_value(response)
                                                .or_throw(&ctx)?
                                                .into();

                                        Ok::<JsValue, rquickjs::Error>(
                                            JsValue(result),
                                        )
                                    }
                                },
                            )),
                        )?;
                    },
                    3 => {
                        //有三个参数
                        default.set(
                            &function_key,
                            Func::from(Async(
                                move |ctx: Ctx<'js>, args: JsValue,args2: JsValue,args3: JsValue| {
                                    let function_definition =
                                        function_definition.clone();
                                    async move {
                                        let response = function_definition
                                            .call(Arguments(&[args.0,args2.0,args3.0]))
                                            .or_throw(&ctx)?;

                                        let result =
                                            serde_json::to_value(response)
                                                .or_throw(&ctx)?
                                                .into();

                                        Ok::<JsValue, rquickjs::Error>(
                                            JsValue(result),
                                        )
                                    }
                                },
                            )),
                        )?;
                    },
                    _ => {
                        //4个以上参数 的参数必须以数组的形式传入
                        default.set(
                            &function_key,
                            Func::from(Async(
                                move |ctx: Ctx<'js>, args: Vec<JsValue>| {
                                    let function_definition =
                                        function_definition.clone();
                                    async move {
                                        let args_vec = args
                                            .iter()
                                            .map(|arg| arg.0.clone())
                                            .collect::<Vec<_>>();
                                        let response = function_definition
                                            .call(Arguments(&args_vec))
                                            .or_throw(&ctx)?;

                                        let result =
                                            serde_json::to_value(response)
                                                .or_throw(&ctx)?
                                                .into();

                                        Ok::<JsValue, rquickjs::Error>(
                                            JsValue(result),
                                        )
                                    }
                                },
                            )),
                        )?;
                    },
                }
            }
        }

        Ok(())
    })
}

}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants