diff --git a/worker/BUILD b/worker/BUILD index a1da9fa..3399bb0 100644 --- a/worker/BUILD +++ b/worker/BUILD @@ -4,8 +4,9 @@ proto_library( name = "communication_proto", srcs = ["communication.proto"], deps = [ - "@com_google_protobuf//:empty_proto", + "@com_google_protobuf//:any_proto", "@com_google_protobuf//:struct_proto", + "@com_google_protobuf//:empty_proto", ], ) @@ -18,7 +19,7 @@ generate_cc( name = "communication_cc_grpc_gen", srcs = [":communication_proto"], plugin = "@grpc//src/compiler:grpc_cpp_plugin", - well_known_protos = True, + well_known_protos = False, generate_mocks = True, ) @@ -37,7 +38,15 @@ cc_library( hdrs = [ "include/worker.hpp", "include/metrics_collector.hpp", + "include/dpdk_filter/net_port.h", + "include/dpdk_filter/dns_cache.h", + "include/dpdk_filter/filtr_packets.h", + "include/dpdk_filter/pars_packets.h", + "include/dpdk_filter/proc_packets.h", + "include/dpdk_filter/types.h", + "include/dpdk_filter/constants.h", ], + includes = ["include", "include/dpdk_filter"], srcs = [], visibility = ["//visibility:public"], ) @@ -47,7 +56,12 @@ cc_binary( srcs = [ "src/main.cpp", "src/worker.cpp", + "src/dpdk_filter/dns_cache.c", "src/metrics_collector.cpp", + "src/dpdk_filter/net_port.c", + "src/dpdk_filter/filtr_packets.c", + "src/dpdk_filter/pars_packets.c", + "src/dpdk_filter/proc_packets.c", ], deps = [ ":worker_headers", @@ -58,14 +72,42 @@ cc_binary( "@curl//:curl", ], copts = [ + "-mssse3", + "-msse4.2", + "-mpclmul", + "-maes", "-I$(GENDIR)/..", + "-I/usr/include", + ], + cxxopts = [ + "-std=c++17", ], linkopts = [ "-L/usr/local/openssl/lib", "-lssl", "-lcrypto", + "-L/usr/local/lib", "-lprometheus-cpp-push", "-lprometheus-cpp-core", + + "-L/usr/lib", + "-lrte_eal", + "-lrte_ethdev", + "-lrte_mempool", + "-lrte_mbuf", + "-lrte_bus_vdev", + "-lrte_ring", + "-lrte_telemetry", + "-lrte_kvargs", + "-lrte_log", + "-lrte_net", + "-lrte_hash", + "-lrte_timer", + "-lsqlite3", + + "-lnuma", + "-ldl", + "-lpthread", ], ) diff --git a/worker/Dockerfile.cc_x86_to_x86 b/worker/Dockerfile.cc_x86_to_x86 index 5d6cf09..800c2d8 100644 --- a/worker/Dockerfile.cc_x86_to_x86 +++ b/worker/Dockerfile.cc_x86_to_x86 @@ -1,17 +1,50 @@ -FROM alpine:3.21.3 AS builder - -RUN apk update && apk add --no-cache g++ openssl-dev cmake make curl-dev protobuf-dev -RUN apk add bazel --repository=http://dl-cdn.alpinelinux.org/alpine/edge/testing/ +FROM ubuntu:22.04 AS builder + +RUN apt-get update && apt-get install -y \ + build-essential=12.9* \ + cmake=3.22* \ + curl=7.81* \ + git=1:2.34* \ + wget=1.21* \ + meson=0.61* \ + ninja-build=1.10* \ + libssl-dev=3.0* \ + protobuf-compiler=3.12* \ + libprotobuf-dev=3.12* \ + python3=3.10* \ + python3-pip=22.0* \ + libnuma-dev=2.0* \ + pkg-config=0.29* \ + libcurl4-openssl-dev=7.81* \ + libbpf-dev=1:0.5* \ + gcc=4:11* \ + g++=4:11* \ + m4=1.4* \ + libpcap-dev=1.10* \ + libsqlite3-dev=3.37* \ + && rm -rf /var/lib/apt/lists/* + +RUN pip3 install pyelftools + +RUN wget https://github.com/bazelbuild/bazel/releases/download/8.2.1/bazel-8.2.1-linux-x86_64 \ + && chmod +x bazel-8.2.1-linux-x86_64 \ + && mv bazel-8.2.1-linux-x86_64 /usr/local/bin/bazel + +RUN wget https://fast.dpdk.org/rel/dpdk-23.11.tar.xz && \ + tar -xf dpdk-23.11.tar.xz && \ + cd dpdk-23.11 && \ + meson setup build --libdir=lib && \ + ninja -C build && \ + ninja -C build install && \ + cd .. && \ + rm -rf dpdk-23.11 dpdk-23.11.tar.xz + +ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig WORKDIR /app COPY scripts/get_prometheus_cpp.sh scripts/ RUN sh scripts/get_prometheus_cpp.sh -RUN apk add --no-cache llvm18 clang18 -RUN ln -s /usr/lib/llvm18/bin/llvm-ar /bin/llvm-ar-18 -RUN ln -s /usr/bin/clang++-18 /usr/bin/clang++ -RUN ln -s /usr/bin/clang-18 /usr/bin/clang - COPY ./src/ ./src/ COPY ./include/ ./include/ @@ -21,17 +54,10 @@ COPY ./communication.proto ./ COPY ./toolchains ./toolchains COPY ./platforms ./platforms - -RUN bazel build //:worker --extra_toolchains=//toolchains/x86_64:cc_toolchain_for_linux_x86_64 --platforms=//platforms:x86_64_linux - -FROM alpine:3.21.3 - -RUN apk update && apk add --no-cache libstdc++ libgcc libssl3 libcurl protobuf-dev - -COPY --from=builder /app/bazel-bin/worker /usr/local/bin/worker -COPY --from=builder /app/prometheus-cpp-with-submodules/build/lib/ /usr/lib - +RUN bazel build //:worker WORKDIR /data -ENTRYPOINT ["/usr/local/bin/worker", "/data/test.txt", "sha256"] +RUN ldconfig + +ENTRYPOINT ["/app/bazel-bin/worker"] diff --git a/worker/Makefile.main_riscv b/worker/Makefile.main_riscv index 3a06a3a..2d9b7dc 100644 --- a/worker/Makefile.main_riscv +++ b/worker/Makefile.main_riscv @@ -1,6 +1,7 @@ CC = riscv64-linux-gnu-gcc DPDK_PREFIX = ./dpdk-riscv-install +SQLITE_PREFIX = ./sqlite3-riscv-install PKG_CONFIG = env PKG_CONFIG_LIBDIR=$(DPDK_PREFIX)/lib/pkgconfig pkg-config CFLAGS_BASE = -Iinclude -O2 $(shell $(PKG_CONFIG) --cflags libdpdk) @@ -13,10 +14,11 @@ LDFLAGS = -L$(DPDK_PREFIX)/lib \ -lrte_net \ -lrte_log -ldl \ -lrte_hash \ - -sqlite3 \ + -lrte_timer \ -Wl,--end-group \ - -latomic - + -latomic \ + -L$(SQLITE_PREFIX)/lib \ + -lsqlite3 SRCS = src/dpdk_filter/main.c src/dpdk_filter/net_port.c src/dpdk_filter/filtr_packets.c src/dpdk_filter/pars_packets.c src/dpdk_filter/proc_packets.c src/dpdk_filter/dns_cache.c diff --git a/worker/Makefile.main_x86 b/worker/Makefile.main_x86 index 71fd148..b4c30ba 100644 --- a/worker/Makefile.main_x86 +++ b/worker/Makefile.main_x86 @@ -1,6 +1,6 @@ CC = gcc CFLAGS_BASE = -Iinclude -O2 -msse4.2 -mpclmul -maes -LDFLAGS = -lrte_eal -lrte_ethdev -lrte_mempool -lrte_mbuf -lrte_bus_vdev -lpthread -lnuma -ldl -lrte_net -lrte_hash -lsqlite3 +LDFLAGS = -lrte_eal -lrte_ethdev -lrte_mempool -lrte_mbuf -lrte_bus_vdev -lpthread -lnuma -ldl -lrte_net -lrte_hash -lsqlite3 -lrte_timer SRCS = src/dpdk_filter/main.c src/dpdk_filter/net_port.c src/dpdk_filter/filtr_packets.c src/dpdk_filter/pars_packets.c src/dpdk_filter/proc_packets.c src/dpdk_filter/dns_cache.c @@ -20,4 +20,4 @@ $(TARGET_VIRT): $(SRCS) clean: rm -f $(TARGET_REAL) $(TARGET_VIRT) -.PHONY: all clean virt \ No newline at end of file +.PHONY: all clean virt diff --git a/worker/README(DPDK FILTRING).md b/worker/README(DPDK FILTRING).md new file mode 100644 index 0000000..cba6538 --- /dev/null +++ b/worker/README(DPDK FILTRING).md @@ -0,0 +1,80 @@ +# Драйвера dpdk +DPDK должен быть собран с драйверами net/af_xdp net/tap + + +# Кросс-компиляция + +## Окружение +Скрипт `scripts/setup-riscv-env.sh` автоматически скачивает (при необходимости) и собирает DPDK 23.11 для архитектуры RISC-V. + +```bash +./scripts/setup-riscv-env.sh +``` + +## SQLite +Если целевая архитектура — RISC-V, SQLite необходимо собрать кросс-компилятором. + +```bash +wget https://www.sqlite.org/2024/sqlite-autoconf-3460100.tar.gz +tar -xzf sqlite-autoconf-3460100.tar.gz +cd sqlite-autoconf-3460100 + +./configure --host=riscv64-linux-gnu --prefix=/path/to/sqlite3-riscv-install +make -j$(nproc) +make install +``` + +После установки в указанном prefix появятся подкаталоги include/ и lib/ с необходимыми файлами. + + + +# Создание пары veth и TAP-устройства + +```bash +sudo ./scripts/set_virt_dev_for_test_xdp.sh +``` +Скрипт создаёт пару veth0 - veth1 + + +```bash +sudo ./scripts/set_tap_dev.sh +``` +Скрипт создаёт TAP-устройство tap0 + + + +# Сборка проекта +Для реальных портов (eth0/eth1): +```bash +make -f Makefile.main_riscv all +``` + +Для виртуальных портов (veth0/veth1 + tap0): +```bash +make -f Makefile.main_riscv virt +``` +Определение макроса -DVIRT_PORTS переключает программу на использование виртуальных интерфейсов. + + +Перед запуском рекомендуется выполнить скрипт настройки виртуальных устройств: +```bash +sudo ./scripts/set_virt_dev_for_test_xdp.sh +``` + + +# Очистка +```bash +make -f Makefile.main_riscv clean +``` + +# Запуск +Программа требует прав суперпользователя (для работы с DPDK и XDP): +```bash +sudo ./main-riscv-virt +``` + + +# Примечания +Кэш DNS автоматически сохраняется в cache.db (SQLite) и восстанавливается при перезапуске. + +Периодическое сохранение кэша происходит каждый час с помощью таймеров DPDK. diff --git a/worker/helper for association with Worker.md b/worker/helper for association with Worker.md index de73307..ca4428e 100644 --- a/worker/helper for association with Worker.md +++ b/worker/helper for association with Worker.md @@ -1,18 +1,24 @@ -REQUESTED_CLASSIFICATION структура для передачи от контроллера к воркеру: +REQUESTED_CLASSIFICATION - структура для передачи от контроллера к воркеру: + +```code struct requested_classification { - char get_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN] - политика - int get_trust_level - уровень доверия к сайту + char get_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN] + int get_trust_level } +``` +Структура для хранения категории с минимальным уровнем доверия для этой категории: -Структура для хранения категории с минимальным уровнем доверия для этой категории +```code struct trust_categories_with_lvl { char locked_by_trust_category[CATEGORY_MAX_LEN]; int trust_lvl; } +``` +у нас есть переменные, которые получаем при инициализации воркера и заносим в структуру (периодически обновляем): -у нас есть переменные, которые получаем при инициализации воркера и заносим в структуру (периодически обновляем) +```code struct BASE_POLICY { char locked_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]; struct trust_categories_with_lvl categories_with_lvl[MAX_CATEGORIES_BY_TRUST_LVL]; @@ -20,7 +26,8 @@ struct BASE_POLICY { char allow_domains[MAX_DOMAINS][MAX_LEN_DOMEIN]; int min_trust_level; } +``` +Добавлен tap порт, по которому проходят пакеты исключений в ядро, обрабатываются и ответ отсылается на входящий порт (port_in) -Добавлен tap порт, по которому проходят пакеты исключений в ядро, обрабатываются и ответ отсылается на входящий порт (port_in) \ No newline at end of file diff --git a/worker/include/dpdk_filter/constants.h b/worker/include/dpdk_filter/constants.h index 6ae6d1d..f046259 100644 --- a/worker/include/dpdk_filter/constants.h +++ b/worker/include/dpdk_filter/constants.h @@ -1,6 +1,7 @@ #ifndef CONSTANTS_H #define CONSTANTS_H +#include #define MAX_CATEGORIES_BY_TRUST_LVL 64 #define MAX_DOMAINS 64 @@ -10,6 +11,6 @@ #define CATEGORY_MAX_LEN 64 #define DNS_CACHE_DEFAULT_TTL (7 * 24 * 60 * 60) #define LEN_LIST_EXCEPTION_PORTS 1 -extern const uint16_t LIST_EXCEPTION_PORTS[LEN_LIST_EXCEPTION_PORTS]; +extern const uint16_t LIST_EXCEPTION_PORTS[LEN_LIST_EXCEPTION_PORTS]; #endif \ No newline at end of file diff --git a/worker/include/dpdk_filter/dns_cache.h b/worker/include/dpdk_filter/dns_cache.h index 7f953f9..8463124 100644 --- a/worker/include/dpdk_filter/dns_cache.h +++ b/worker/include/dpdk_filter/dns_cache.h @@ -6,15 +6,12 @@ #include #include #include +#include #include #include -#include - -#include "../../include/dpdk_filter/constants.h" -#include "../../include/dpdk_filter/types.h" - - +#include "constants.h" +#include "types.h" void init_dns_cache(void); int lookup_dns_cache(const char *domain, struct node_cache **return_node); diff --git a/worker/include/dpdk_filter/filtr_packets.h b/worker/include/dpdk_filter/filtr_packets.h index dda5332..188830b 100644 --- a/worker/include/dpdk_filter/filtr_packets.h +++ b/worker/include/dpdk_filter/filtr_packets.h @@ -1,22 +1,29 @@ #ifndef FILTR_PAK_H #define FILTR_PAK_H +#include "constants.h" #include "pars_packets.h" +#include "types.h" #include #include -#include "../../include/dpdk_filter/constants.h" -#include "../../include/dpdk_filter/types.h" -bool check_is_block(char domain[DOMAIN_MAX_LEN], char block_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]); +bool check_is_block(char domain[DOMAIN_MAX_LEN], + char block_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]); -bool check_is_allow(char domain[DOMAIN_MAX_LEN], char allow_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]); +bool check_is_allow(char domain[DOMAIN_MAX_LEN], + char allow_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]); bool check_trust_level(int get_trust_level, int min_trust_level); -bool check_categories(char get_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN], char locked_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]); +bool check_categories(char get_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN], + char locked_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]); -bool check_categories_with_lvl(struct requested_classification* req_clas, struct trust_categories_with_lvl categories_with_lvl[MAX_CATEGORIES_BY_TRUST_LVL]); +bool check_categories_with_lvl( + struct requested_classification *req_clas, + struct trust_categories_with_lvl + categories_with_lvl[MAX_CATEGORIES_BY_TRUST_LVL]); -bool main_filtring(struct requested_classification* req_clas, struct BASE_POLICY* policy, char domain[DOMAIN_MAX_LEN]); +bool main_filtring(struct requested_classification *req_clas, + struct BASE_POLICY *policy, char domain[DOMAIN_MAX_LEN]); #endif \ No newline at end of file diff --git a/worker/include/dpdk_filter/net_port.h b/worker/include/dpdk_filter/net_port.h index e34f8b1..ebe9224 100644 --- a/worker/include/dpdk_filter/net_port.h +++ b/worker/include/dpdk_filter/net_port.h @@ -1,18 +1,15 @@ #ifndef AF_XDP_PORT_H #define AF_XDP_PORT_H +#include "types.h" #include #include -#include "../../include/dpdk_filter/types.h" - - struct net_port *init_struct_tap_port(const char *tap_iface_name, - struct rte_mempool *mbuf_pool); - + struct rte_mempool *mbuf_pool); struct net_port *init_struct_af_xdp_port(const char *iface_name, - struct rte_mempool *mbuf_pool); + struct rte_mempool *mbuf_pool); int net_port_init(struct net_port *port); diff --git a/worker/include/dpdk_filter/pars_packets.h b/worker/include/dpdk_filter/pars_packets.h index d319225..7d72726 100644 --- a/worker/include/dpdk_filter/pars_packets.h +++ b/worker/include/dpdk_filter/pars_packets.h @@ -1,12 +1,10 @@ #ifndef PARS_PAK_H #define PARS_PAK_H +#include "constants.h" +#include "types.h" #include #include -#include "../../include/dpdk_filter/constants.h" -#include "../../include/dpdk_filter/types.h" - - void parsing_pakage(struct rte_mbuf *paket, struct info_of_pakage *info_pac); diff --git a/worker/include/dpdk_filter/proc_packets.h b/worker/include/dpdk_filter/proc_packets.h index c7f7737..aafc1de 100644 --- a/worker/include/dpdk_filter/proc_packets.h +++ b/worker/include/dpdk_filter/proc_packets.h @@ -1,24 +1,20 @@ #ifndef PROC_PAK_H #define PROC_PAK_H -#include "../../include/dpdk_filter/net_port.h" -#include "../../include/dpdk_filter/filtr_packets.h" -#include "../../include/dpdk_filter/pars_packets.h" -#include "../../include/dpdk_filter/constants.h" -#include "../../include/dpdk_filter/types.h" +#include "constants.h" +#include "filtr_packets.h" +#include "net_port.h" +#include "pars_packets.h" +#include "types.h" #include #include #include #include #include - - - - - -void pakage_processing(struct net_port *port_in, - struct net_port *port_out, struct net_port *port_exception, uint16_t queue_number, - uint16_t nb_pkts, struct rte_mbuf **pkts, struct BASE_POLICY* policy); +void pakage_processing(struct net_port *port_in, struct net_port *port_out, + struct net_port *port_exception, uint16_t queue_number, + uint16_t nb_pkts, struct rte_mbuf **pkts, + struct BASE_POLICY *policy); #endif \ No newline at end of file diff --git a/worker/include/dpdk_filter/types.h b/worker/include/dpdk_filter/types.h index d056e50..d93881b 100644 --- a/worker/include/dpdk_filter/types.h +++ b/worker/include/dpdk_filter/types.h @@ -2,8 +2,8 @@ #define TYPES_H #include "constants.h" -#include #include +#include struct net_port { uint16_t port_id; @@ -20,24 +20,24 @@ struct info_of_pakage { }; struct trust_categories_with_lvl { - char locked_by_trust_category[CATEGORY_MAX_LEN]; - int trust_lvl; + char locked_by_trust_category[CATEGORY_MAX_LEN]; + int trust_lvl; }; struct BASE_POLICY { char locked_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]; - struct trust_categories_with_lvl categories_with_lvl[MAX_CATEGORIES_BY_TRUST_LVL]; + struct trust_categories_with_lvl + categories_with_lvl[MAX_CATEGORIES_BY_TRUST_LVL]; char block_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]; char allow_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]; int min_trust_level; }; struct requested_classification { - char get_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]; - int get_trust_level; + char get_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]; + int get_trust_level; }; - struct node_cache { char categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]; bool solution_is_send; @@ -47,5 +47,4 @@ struct node_cache { char *key_domain; }; - #endif diff --git a/worker/include/worker.hpp b/worker/include/worker.hpp index 58c437a..d1859a2 100644 --- a/worker/include/worker.hpp +++ b/worker/include/worker.hpp @@ -3,9 +3,19 @@ #include "communication.grpc.pb.h" #include "communication.pb.h" +extern "C" { +#include "dpdk_filter/filtr_packets.h" +#include "dpdk_filter/net_port.h" +#include "dpdk_filter/proc_packets.h" +#include "dpdk_filter/types.h" +} #include #include #include +#include +#include +#include +#include #define EXPECTED_POLICY_TIME 60 #define MIN_POLICY_TIME 30 @@ -29,6 +39,14 @@ class Worker { int64_t policy_interval = MIN_POLICY_TIME; int64_t stats_interval = MIN_STATS_TIME; + struct net_port *port_in = nullptr; + struct net_port *port_out = nullptr; + struct net_port *port_exception = nullptr; + struct rte_mempool *mbuf_pool = nullptr; + std::mutex policy_mutex; + struct BASE_POLICY current_policy; + uint16_t queue_number = 0; + std::unique_ptr stub_; WorkerState state; @@ -39,9 +57,13 @@ class Worker { Worker(uint64_t id); ~Worker(); + void initDPDK(int argc, char **argv); inline uint64_t GetID() const { return worker_id; } void requestPolicyFromController(); - void classifyDomain(const std::string &domain); + bool classifyDomain(const std::string &domain, + struct requested_classification *out_req); + void forward_to_out(struct net_port *incoming_port, + struct net_port *outgoing_port, uint16_t queue_number); void statsReport(); WorkerState GetState() const { return state; } void MainLoop(); diff --git a/worker/scripts/set_tap_dev.sh b/worker/scripts/set_tap_dev.sh new file mode 100755 index 0000000..cb0cb03 --- /dev/null +++ b/worker/scripts/set_tap_dev.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +TAP="tap0" + +sudo ip tuntap add $TAP mode tap +sudo ip link set $TAP up +sudo ip addr add 10.0.3.1/24 dev $TAP \ No newline at end of file diff --git a/worker/src/dpdk_filter/af_xdp_port.c b/worker/src/dpdk_filter/af_xdp_port.c deleted file mode 100644 index 25ecb7d..0000000 --- a/worker/src/dpdk_filter/af_xdp_port.c +++ /dev/null @@ -1,176 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../../include/dpdk_filter/af_xdp_port.h" - -#define RX_RING_SIZE 1024 -#define TX_RING_SIZE 1024 - -int find_port_by_dev_name(const char *dev_name, uint16_t *port_id_dev) { - uint16_t count_ports = rte_eth_dev_count_avail(); - struct rte_eth_dev_info dev_info; - char name[64]; - - for (uint16_t port_id = 0; port_id < count_ports; port_id++) { - int ret = rte_eth_dev_info_get(port_id, &dev_info); - - if (ret) { - printf("[ERROR] Failed to retrieve the contextual information of an " - "Ethernet device: %s\n", - strerror(-ret)); - return ret; - } - - if (rte_eth_dev_get_name_by_port(port_id, name) == 0 && - strcmp(name, dev_name) == 0) { - *port_id_dev = port_id; - return 0; - } - } - return -1; -} - -struct af_xdp_port *init_struct_af_xdp_port(const char *iface_name, - struct rte_mempool *mbuf_pool) { - struct af_xdp_port *port = calloc(1, sizeof(struct af_xdp_port)); - if (!port) { - printf("[ERROR] Failed to allocate memory for struct af_xdp_port\n"); - return NULL; - } - - snprintf(port->dev_args, sizeof(port->dev_args), - "iface=%s,start_queue=0,queue_count=1", iface_name); - snprintf(port->dev_name, sizeof(port->dev_name), "net_af_xdp_%s", iface_name); - strncpy(port->iface_name, iface_name, sizeof(port->iface_name) - 1); - port->iface_name[sizeof(port->iface_name) - 1] = '\0'; - port->mbuf_pool = mbuf_pool; - port->port_id = -1; - - return port; -} - -int af_xdp_port_init(struct af_xdp_port *port) { - int ret; - struct rte_eth_conf port_conf = {0}; - const char *dev_name = port->dev_name; - uint16_t port_id; - - ret = rte_vdev_init(dev_name, port->dev_args); - - if (ret < 0) { - printf("[ERROR] Failed to create vdev: %s\n", strerror(-ret)); - return ret; - } - - ret = find_port_by_dev_name(port->dev_name, &port_id); - if (ret) { - printf("no port was found that has the same vdev name. vdev = %s", - port->dev_name); - rte_vdev_uninit(dev_name); - return -1; - } - - port->port_id = port_id; - - if (!rte_eth_dev_is_valid_port(port_id)) { - printf("[ERROR] Port %u is not valid\n", port_id); - rte_vdev_uninit(dev_name); - return -EINVAL; - } - - ret = rte_eth_dev_configure(port_id, 1, 1, &port_conf); - if (ret < 0) { - printf("[ERROR] Failed to configure port: %s\n", strerror(-ret)); - rte_vdev_uninit(dev_name); - return ret; - } - - ret = rte_eth_rx_queue_setup(port_id, 0, RX_RING_SIZE, - rte_eth_dev_socket_id(port_id), NULL, - port->mbuf_pool); - if (ret < 0) { - printf("[ERROR] Failed to setup RX queue: %s\n", strerror(-ret)); - rte_vdev_uninit(dev_name); - return ret; - } - - ret = rte_eth_tx_queue_setup(port_id, 0, TX_RING_SIZE, - rte_eth_dev_socket_id(port_id), NULL); - - if (ret < 0) { - printf("[ERROR] Failed to setup TX queue: %s\n", strerror(-ret)); - rte_vdev_uninit(dev_name); - return ret; - } - - printf("Port %u initialized\n", port_id); - return 0; -} - -int af_xdp_port_start(uint16_t port_id) { - int ret; - - ret = rte_eth_dev_start(port_id); - if (ret < 0) { - printf("[ERROR] Failed to start: %s\n", strerror(-ret)); - return ret; - } - - ret = rte_eth_promiscuous_enable(port_id); - if (ret) { - printf("[ERROR] Failed to enable receipt in promiscuous mode for an " - "Ethernet device: %s\n", - strerror(-ret)); - return ret; - } - - printf("Port %u started\n", port_id); - return 0; -} - -void af_xdp_port_destroy(struct af_xdp_port *port) { - if (!port) - return; - free(port); -} - -void af_xdp_port_close(struct af_xdp_port *port) { - - if (!port) - return; - - int ret; - uint16_t port_id = port->port_id; - - ret = rte_eth_dev_stop(port_id); - if (ret) { - printf("[ERROR] Failed to stop an Ethernet device: %s\n", strerror(-ret)); - return; - } - - ret = rte_eth_dev_close(port_id); - if (ret) { - printf("[ERROR] Failed to close a stopped Ethernet device: %s\n", - strerror(-ret)); - return; - } - - ret = rte_vdev_uninit(port->dev_name); - if (ret) { - printf("[ERROR] Failed to uninitialize a driver: %s\n", strerror(-ret)); - return; - } - - port->port_id = -1; - printf("Port %u closed\n", port_id); -} \ No newline at end of file diff --git a/worker/src/dpdk_filter/dns_cache.c b/worker/src/dpdk_filter/dns_cache.c index 4427e1f..7be7373 100644 --- a/worker/src/dpdk_filter/dns_cache.c +++ b/worker/src/dpdk_filter/dns_cache.c @@ -1,4 +1,4 @@ -#include "../../include/dpdk_filter/dns_cache.h" +#include "dns_cache.h" static struct rte_hash *dns_hash; static struct rte_hash_parameters hash_params = { diff --git a/worker/src/dpdk_filter/filtr_packets.c b/worker/src/dpdk_filter/filtr_packets.c index a6aeadd..530204d 100644 --- a/worker/src/dpdk_filter/filtr_packets.c +++ b/worker/src/dpdk_filter/filtr_packets.c @@ -1,7 +1,8 @@ -#include "../../include/dpdk_filter/filtr_packets.h" -#include "../../include/dpdk_filter/pars_packets.h" +#include "filtr_packets.h" +#include "pars_packets.h" -bool check_is_block(char domain[DOMAIN_MAX_LEN], char block_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]) { +bool check_is_block(char domain[DOMAIN_MAX_LEN], + char block_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]) { for (int i = 0; i < MAX_DOMAINS; i++) { if (strcmp(block_domains[i], domain) == 0) { @@ -12,7 +13,8 @@ bool check_is_block(char domain[DOMAIN_MAX_LEN], char block_domains[MAX_DOMAINS] return false; } -bool check_is_allow(char domain[DOMAIN_MAX_LEN], char allow_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]) { +bool check_is_allow(char domain[DOMAIN_MAX_LEN], + char allow_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]) { for (int i = 0; i < MAX_DOMAINS; i++) { if (strcmp(allow_domains[i], domain) == 0) { @@ -32,34 +34,41 @@ bool check_trust_level(int get_trust_level, int min_trust_level) { return true; } -bool check_categories(char get_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN], char locked_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]) { - +bool check_categories( + char get_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN], + char locked_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]) { + for (int i = 0; i < MAX_CATEGORIES; i++) { for (int j = 0; j < MAX_CATEGORIES; j++) { if (strcmp(get_categories[i], locked_categories[j]) == 0) { return false; } } - } + } return true; } - -bool check_categories_with_lvl(struct requested_classification* req_clas, struct trust_categories_with_lvl categories_with_lvl[MAX_CATEGORIES_BY_TRUST_LVL]) { +bool check_categories_with_lvl( + struct requested_classification *req_clas, + struct trust_categories_with_lvl + categories_with_lvl[MAX_CATEGORIES_BY_TRUST_LVL]) { for (int i = 0; i < MAX_CATEGORIES; i++) { for (int j = 0; j < MAX_CATEGORIES; j++) { - if (strcmp(req_clas->get_categories[j], categories_with_lvl[i].locked_by_trust_category) == 0 && req_clas->get_trust_level < categories_with_lvl[i].trust_lvl) { - return false; + if (strcmp(req_clas->get_categories[j], + categories_with_lvl[i].locked_by_trust_category) == 0 && + req_clas->get_trust_level < categories_with_lvl[i].trust_lvl) { + return false; } - } - } + } + } return true; } -bool main_filtring(struct requested_classification* req_clas, struct BASE_POLICY* policy, char domain[DOMAIN_MAX_LEN]) { +bool main_filtring(struct requested_classification *req_clas, + struct BASE_POLICY *policy, char domain[DOMAIN_MAX_LEN]) { if (check_is_block(domain, policy->block_domains) == true) { printf("This domain is blocked"); @@ -71,18 +80,22 @@ bool main_filtring(struct requested_classification* req_clas, struct BASE_POLICY return true; } - if (check_categories(req_clas->get_categories, policy->locked_categories) == false) { + if (check_categories(req_clas->get_categories, policy->locked_categories) == + false) { printf("This site has a locked category"); return false; } - if (check_trust_level(req_clas->get_trust_level, policy->min_trust_level) == false) { + if (check_trust_level(req_clas->get_trust_level, policy->min_trust_level) == + false) { printf("This site has a too small trust level"); return false; } - if (check_categories_with_lvl(req_clas, policy->categories_with_lvl) == false) { - printf("This site blocked in accordance with 'trust categories with level'"); + if (check_categories_with_lvl(req_clas, policy->categories_with_lvl) == + false) { + printf( + "This site blocked in accordance with 'trust categories with level'"); return false; } diff --git a/worker/src/dpdk_filter/main.c b/worker/src/dpdk_filter/main.c index 6da81a8..36e59b4 100644 --- a/worker/src/dpdk_filter/main.c +++ b/worker/src/dpdk_filter/main.c @@ -1,6 +1,6 @@ -#include "../../include/dpdk_filter/net_port.h" -#include "../../include/dpdk_filter/dns_cache.h" -#include "../../include/dpdk_filter/proc_packets.h" +#include "dns_cache.h" +#include "net_port.h" +#include "proc_packets.h" #include #include #include @@ -18,21 +18,23 @@ static void signal_handler(int signum) { } } -void forward_tap_to_out(struct net_port *port_exception, struct net_port *port_in, uint16_t queue_number) { - struct rte_mbuf *tap_pkts[32]; - uint16_t nb_tap = rte_eth_rx_burst(port_exception->port_id, queue_number, tap_pkts, 32); - for (int i = 0; i < nb_tap; i++) { - int ret = rte_eth_tx_burst(port_in->port_id, queue_number, &tap_pkts[i], 1); - if (ret < 1) { - printf("[ERROR] Failed to send packet\n"); - // PLUG (to be added later) - need to add processing for this case - rte_pktmbuf_free(tap_pkts[i]); - } +void forward_tap_to_out(struct net_port *port_exception, + struct net_port *port_in, uint16_t queue_number) { + struct rte_mbuf *tap_pkts[32]; + uint16_t nb_tap = + rte_eth_rx_burst(port_exception->port_id, queue_number, tap_pkts, 32); + for (int i = 0; i < nb_tap; i++) { + int ret = rte_eth_tx_burst(port_in->port_id, queue_number, &tap_pkts[i], 1); + if (ret < 1) { + printf("[ERROR] Failed to send packet\n"); + // PLUG (to be added later) - need to add processing for this case + rte_pktmbuf_free(tap_pkts[i]); } + } } int main(int argc, char **argv) { - //since BASE_POLICY is filled when initializing worker, let’s initialize here + // since BASE_POLICY is filled when initializing worker, let’s initialize here struct BASE_POLICY policy; if (signal(SIGINT, signal_handler) == SIG_ERR) { printf("[ERROR] Failed to set SIGINT handler\n"); @@ -43,8 +45,6 @@ int main(int argc, char **argv) { return 1; } - - struct net_port *port_in = NULL; struct net_port *port_out = NULL; struct net_port *port_exception = NULL; @@ -83,24 +83,23 @@ int main(int argc, char **argv) { port_exception = init_struct_tap_port("tap0", mbuf_pool); - if (!port_in || !port_out || !port_exception) { return 1; } - if (net_port_init(port_in) || net_port_init(port_out) || net_port_init(port_exception)) { + if (net_port_init(port_in) || net_port_init(port_out) || + net_port_init(port_exception)) { return 1; } - if (net_port_start(port_in->port_id) || - net_port_start(port_out->port_id) || + if (net_port_start(port_in->port_id) || net_port_start(port_out->port_id) || net_port_start(port_exception->port_id)) { return 1; } ret = system("sudo ip link set tap0 up && " - "sudo ip addr add 10.0.3.1/24 dev tap0"); - if(ret) { + "sudo ip addr add 10.0.3.1/24 dev tap0"); + if (ret) { printf("[ERROR] Failed to set tap0 up\n"); } @@ -110,7 +109,8 @@ int main(int argc, char **argv) { while (running) { forward_tap_to_out(port_exception, port_in, queue_number); - pakage_processing(port_in, port_out, port_exception, queue_number, nb_pkts, pkts, &policy); + pakage_processing(port_in, port_out, port_exception, queue_number, nb_pkts, + pkts, &policy); } // function for save cache info if need diff --git a/worker/src/dpdk_filter/net_port.c b/worker/src/dpdk_filter/net_port.c index 92eeca4..761a1cb 100644 --- a/worker/src/dpdk_filter/net_port.c +++ b/worker/src/dpdk_filter/net_port.c @@ -11,7 +11,7 @@ #include #include -#include "../../include/dpdk_filter/net_port.h" +#include "net_port.h" #define RX_RING_SIZE 1024 #define TX_RING_SIZE 1024 @@ -41,16 +41,18 @@ int find_port_by_dev_name(const char *dev_name, uint16_t *port_id_dev) { } struct net_port *init_struct_tap_port(const char *tap_iface_name, - struct rte_mempool *mbuf_pool) { + struct rte_mempool *mbuf_pool) { struct net_port *port = calloc(1, sizeof(struct net_port)); - if (!port) { + if (!port) { printf("[ERROR] Failed to allocate memory for struct net_port\n"); return NULL; } - snprintf(port->dev_args, sizeof(port->dev_args),"iface=%s", tap_iface_name); - snprintf(port->dev_name, sizeof(port->dev_name), "net_tap_%s", tap_iface_name); + snprintf(port->dev_args, sizeof(port->dev_args), "iface=%s, remote=%s", + tap_iface_name, tap_iface_name); + snprintf(port->dev_name, sizeof(port->dev_name), "net_tap_%s", + tap_iface_name); strncpy(port->iface_name, tap_iface_name, sizeof(port->iface_name) - 1); port->iface_name[sizeof(port->iface_name) - 1] = '\0'; port->mbuf_pool = mbuf_pool; @@ -59,9 +61,8 @@ struct net_port *init_struct_tap_port(const char *tap_iface_name, return port; } - struct net_port *init_struct_af_xdp_port(const char *iface_name, - struct rte_mempool *mbuf_pool) { + struct rte_mempool *mbuf_pool) { struct net_port *port = calloc(1, sizeof(struct net_port)); if (!port) { printf("[ERROR] Failed to allocate memory for struct net_port\n"); diff --git a/worker/src/dpdk_filter/pars_packets.c b/worker/src/dpdk_filter/pars_packets.c index ea59e2d..4092b4b 100644 --- a/worker/src/dpdk_filter/pars_packets.c +++ b/worker/src/dpdk_filter/pars_packets.c @@ -1,4 +1,4 @@ -#include "../../include/dpdk_filter/pars_packets.h" +#include "pars_packets.h" #include #include #include diff --git a/worker/src/dpdk_filter/proc_packets.c b/worker/src/dpdk_filter/proc_packets.c index 397f4f6..6eebebb 100644 --- a/worker/src/dpdk_filter/proc_packets.c +++ b/worker/src/dpdk_filter/proc_packets.c @@ -1,5 +1,8 @@ -#include "../../include/dpdk_filter/proc_packets.h" -#include "../../include/dpdk_filter/dns_cache.h" +#include "proc_packets.h" +#include "dns_cache.h" + +extern bool worker_classify_domain(const char *domain, + struct requested_classification *out_req); const uint16_t LIST_EXCEPTION_PORTS[LEN_LIST_EXCEPTION_PORTS] = {22}; @@ -20,7 +23,6 @@ void package_sending_decision(bool solution_is_send, struct rte_mbuf *pkt, rte_pktmbuf_free(pkt); } - bool check_is_exception(uint16_t number_port) { for (int i = 0; i < LEN_LIST_EXCEPTION_PORTS; i++) { if (number_port == LIST_EXCEPTION_PORTS[i]) { @@ -30,10 +32,10 @@ bool check_is_exception(uint16_t number_port) { return false; } - -void pakage_processing(struct net_port *port_in, - struct net_port *port_out, struct net_port *port_exception, uint16_t queue_number, - uint16_t nb_pkts, struct rte_mbuf **pkts, struct BASE_POLICY* policy) { +void pakage_processing(struct net_port *port_in, struct net_port *port_out, + struct net_port *port_exception, uint16_t queue_number, + uint16_t nb_pkts, struct rte_mbuf **pkts, + struct BASE_POLICY *policy) { uint16_t nb_rx = rte_eth_rx_burst(port_in->port_id, queue_number, pkts, nb_pkts); @@ -44,14 +46,15 @@ void pakage_processing(struct net_port *port_in, memset(&info_pac, 0, sizeof(info_pac)); parsing_pakage(pkts[i], &info_pac); - printf("[PKT] port = %hu; domain = %s\n", ntohs(info_pac.number_port), info_pac.domain); + printf("[PKT] port = %hu; domain = %s\n", ntohs(info_pac.number_port), + info_pac.domain); if (info_pac.domain[0] == '\0') { printf("[INFO] Packet without dns request\n"); package_sending_decision(true, pkts[i], port_out, queue_number); continue; } - if(check_is_exception(info_pac.number_port) == true) { + if (check_is_exception(info_pac.number_port) == true) { package_sending_decision(true, pkts[i], port_exception, queue_number); continue; } @@ -67,7 +70,15 @@ void pakage_processing(struct net_port *port_in, struct requested_classification req_clas; - bool solution_is_send = main_filtring(&req_clas, policy, info_pac.domain); + bool solution_is_send; + bool classification_success = + worker_classify_domain(info_pac.domain, &req_clas); + if (classification_success) { + solution_is_send = main_filtring(&req_clas, policy, info_pac.domain); + } else { + solution_is_send = true; + printf("[WARN] Classification failed for %s\n", info_pac.domain); + } package_sending_decision(solution_is_send, pkts[i], port_out, queue_number); diff --git a/worker/src/main.cpp b/worker/src/main.cpp index c75d4a2..91a4c95 100644 --- a/worker/src/main.cpp +++ b/worker/src/main.cpp @@ -1,5 +1,5 @@ -#include "../include/metrics_collector.hpp" -#include "../include/worker.hpp" +#include "metrics_collector.hpp" +#include "worker.hpp" #include @@ -20,7 +20,7 @@ class FiltrWorker : public Worker { ("worker-" + std::to_string(id)).c_str()) {} }; -int main() { +int main(int argc, char **argv) { const char *worker_id_str = getenv("WORKER_ID"); if (worker_id_str == nullptr) { spdlog::error("WORKER_ID environment variable not set"); @@ -42,7 +42,7 @@ int main() { try { Worker worker(worker_id); - + worker.initDPDK(argc, argv); bool test_mode = false; if (getenv("TEST_REQUEST_POLICY") != nullptr) { test_mode = true; @@ -62,7 +62,15 @@ int main() { test_mode = true; spdlog::info("Test mode: classifying domain '{}'", domain); std::this_thread::sleep_for(std::chrono::seconds(1)); - worker.classifyDomain(domain); + struct requested_classification req_clas; + memset(&req_clas, 0, sizeof(req_clas)); + bool success = worker.classifyDomain(domain, &req_clas); + if (success) { + spdlog::info("Classification successful: trust_level={}", + req_clas.get_trust_level); + } else { + spdlog::error("Classification failed"); + } } if (test_mode) { diff --git a/worker/src/worker.cpp b/worker/src/worker.cpp index b63c850..0e66607 100644 --- a/worker/src/worker.cpp +++ b/worker/src/worker.cpp @@ -1,12 +1,35 @@ -#include "../include/worker.hpp" - +#include "worker.hpp" #include "communication.grpc.pb.h" +#include "proc_packets.h" #include +#include #include #include +#include #include #include +Worker *g_worker = nullptr; + +extern "C" bool +worker_classify_domain(const char *domain, + struct requested_classification *out_req) { + if (!g_worker) { + fprintf(stderr, "worker_classify_domain: g_worker is null\n"); + return false; + } + return g_worker->classifyDomain(std::string(domain), out_req); +} + +static volatile bool stop_flag = false; + +static void signal_handler(int signum) { + if (signum == SIGINT || signum == SIGTERM) { + spdlog::info("Signal {} received, shutting down.", signum); + stop_flag = true; + } +} + void Worker::LogStateChange(WorkerState new_state) { const char *state_names[] = {"BOOTING", "FREE", "BUSY", "SHUTTING_DOWN", "ERROR"}; @@ -22,6 +45,64 @@ void Worker::SetState(WorkerState new_state) { } } +void Worker::initDPDK(int argc, char **argv) { + unsigned mbuf_quantity_in_pool = 8192; + unsigned cache_size_per_kernel = 250; + uint16_t priv_size = 0; + + int ret = rte_eal_init(argc, argv); + if (ret < 0) { + throw std::runtime_error("EAL init failed"); + } + + mbuf_pool = rte_pktmbuf_pool_create( + "POOL", mbuf_quantity_in_pool, cache_size_per_kernel, priv_size, + RTE_MBUF_DEFAULT_BUF_SIZE, rte_socket_id()); + if (!mbuf_pool) { + throw std::runtime_error("Failed to create mbuf pool"); + } + const char *iface_in = getenv("DPDK_PORT_IN"); + const char *iface_out = getenv("DPDK_PORT_OUT"); + + if (!iface_in || !iface_out) { + throw std::runtime_error("DPDK_PORT_IN and DPDK_PORT_OUT must be set"); + } + + port_in = init_struct_af_xdp_port(iface_in, mbuf_pool); + port_out = init_struct_af_xdp_port(iface_out, mbuf_pool); + port_exception = init_struct_tap_port("tap0", mbuf_pool); + + if (net_port_init(port_in) || net_port_init(port_out) || + net_port_init(port_exception)) { + throw std::runtime_error("Init ports"); + } + + if (net_port_start(port_in->port_id) || net_port_start(port_out->port_id) || + net_port_start(port_exception->port_id)) { + throw std::runtime_error("Start ports"); + } + + spdlog::info("DPDK initialized: in_port={}, out_port={}", port_in->port_id, + port_out->port_id); +} + +void Worker::forward_to_out(struct net_port *incoming_port, + struct net_port *outgoing_port, + uint16_t queue_number) { + struct rte_mbuf *tap_pkts[32]; + uint16_t nb_tap = + rte_eth_rx_burst(incoming_port->port_id, queue_number, tap_pkts, 32); + for (int i = 0; i < nb_tap; i++) { + int ret = + rte_eth_tx_burst(outgoing_port->port_id, queue_number, &tap_pkts[i], 1); + if (ret < 1) { + spdlog::warn("Failed to send packet"); + // PLUG (to be added later) - need to add processing for this case + rte_pktmbuf_free(tap_pkts[i]); + } + } +} + void Worker::requestPolicyFromController() { try { spdlog::info("Worker {} requests policy", worker_id); @@ -40,23 +121,70 @@ void Worker::requestPolicyFromController() { } switch (resp.result()) { - case GetPolicyResponse::POLICY_PROVIDED: + case GetPolicyResponse::POLICY_PROVIDED: { spdlog::info("Policy received"); - current_config_version = resp.policy().config_version(); + const auto &pol = resp.policy(); + std::lock_guard lock(policy_mutex); + memset(¤t_policy, 0, sizeof(current_policy)); + + int block_cat_count = pol.block_categories_size(); + for (int i = 0; i < block_cat_count; ++i) { + strncpy(current_policy.locked_categories[i], + pol.block_categories(i).c_str(), CATEGORY_MAX_LEN - 1); + current_policy.locked_categories[i][CATEGORY_MAX_LEN - 1] = '\0'; + } + + int idx = 0; + for (const auto &[category, min_trust] : pol.block_by_trust()) { + if (idx >= MAX_CATEGORIES_BY_TRUST_LVL) + break; + + strncpy( + current_policy.categories_with_lvl[idx].locked_by_trust_category, + category.c_str(), CATEGORY_MAX_LEN - 1); + current_policy.categories_with_lvl[idx] + .locked_by_trust_category[CATEGORY_MAX_LEN - 1] = '\0'; + + current_policy.categories_with_lvl[idx].trust_lvl = min_trust; + + idx++; + } + + int block_dom_count = pol.block_domains_size(); + for (int i = 0; i < block_dom_count; ++i) { + strncpy(current_policy.block_domains[i], pol.block_domains(i).c_str(), + DOMAIN_MAX_LEN - 1); + current_policy.block_domains[i][DOMAIN_MAX_LEN - 1] = '\0'; + } + + int allow_dom_count = pol.allow_domains_size(); + for (int i = 0; i < allow_dom_count; ++i) { + strncpy(current_policy.allow_domains[i], pol.allow_domains(i).c_str(), + DOMAIN_MAX_LEN - 1); + current_policy.allow_domains[i][DOMAIN_MAX_LEN - 1] = '\0'; + } + + current_policy.min_trust_level = pol.min_trust_level(); + + current_config_version = pol.config_version(); break; - case GetPolicyResponse::POLICY_UNCHANGED: + } + case GetPolicyResponse::POLICY_UNCHANGED: { spdlog::info("Policy unchanged"); break; - default: + } + default: { spdlog::error("Unknown response result"); } + } } catch (const std::exception &e) { spdlog::error("requestPolicyFromController exception: {}", e.what()); } } -void Worker::classifyDomain(const std::string &domain) { +bool Worker::classifyDomain(const std::string &domain, + struct requested_classification *out_req) { try { spdlog::info("Worker {} classifying domain '{}'", worker_id, domain); @@ -70,7 +198,7 @@ void Worker::classifyDomain(const std::string &domain) { auto status = stub_->Classify(&context, req, &resp); if (!status.ok()) { spdlog::error("Classify failed: " + status.error_message()); - return; + return false; } std::string cat = @@ -78,8 +206,16 @@ void Worker::classifyDomain(const std::string &domain) { spdlog::info("Domain '{}' classified as category '{}' with trust level {}", domain, cat, resp.trust_level()); + out_req->get_trust_level = resp.trust_level(); + int cat_count = std::min(resp.categories_size(), MAX_CATEGORIES); + for (int i = 0; i < cat_count; ++i) { + strncpy(out_req->get_categories[i], resp.categories(i).c_str(), + CATEGORY_MAX_LEN - 1); + } + return true; } catch (const std::exception &e) { spdlog::error(std::string("classifyDomain: ") + e.what()); + return false; } } @@ -108,7 +244,7 @@ void Worker::statsReport() { } Worker::Worker(uint64_t id) : worker_id(id), state(WorkerState::FREE) { - + g_worker = this; std::string controller_addr = "localhost:50051"; if (const char *env_addr = getenv("CONTROLLER_GRPC_ADDR")) { controller_addr = env_addr; @@ -117,6 +253,9 @@ Worker::Worker(uint64_t id) : worker_id(id), state(WorkerState::FREE) { grpc::CreateChannel(controller_addr, grpc::InsecureChannelCredentials()); stub_ = DataService::NewStub(channel); spdlog::info("gRPC channel created to {}", controller_addr); + signal(SIGINT, signal_handler); + signal(SIGTERM, signal_handler); + spdlog::info("Signal handlers registered"); srand(time(nullptr)); SetState(WorkerState::FREE); @@ -124,17 +263,40 @@ Worker::Worker(uint64_t id) : worker_id(id), state(WorkerState::FREE) { } Worker::~Worker() { - SetState(WorkerState::SHUTTING_DOWN); spdlog::info("Worker {} shutting down", worker_id); + + if (port_in && port_out) { + net_port_close(port_in); + net_port_close(port_out); + net_port_close(port_exception); + + net_port_destroy(port_in); + net_port_destroy(port_out); + net_port_destroy(port_exception); + spdlog::info("DPDK ports closed"); + } } void Worker::MainLoop() { + struct BASE_POLICY local_policy; using namespace std::chrono; last_policy_time = steady_clock::now(); last_stats_time = steady_clock::now(); - while (GetState() != WorkerState::SHUTTING_DOWN) { + struct rte_mbuf *pkts[32]; + uint16_t nb_pkts = 32; + uint16_t queue_number = 0; + while (!stop_flag && GetState() != WorkerState::SHUTTING_DOWN) { + { + std::lock_guard lock(policy_mutex); + local_policy = current_policy; + } + forward_to_out(port_exception, port_in, queue_number); + pakage_processing(port_in, port_out, port_exception, queue_number, nb_pkts, + pkts, &local_policy); + forward_to_out(port_out, port_in, queue_number); + auto now = steady_clock::now(); int64_t seconds_since_stats = (now - last_stats_time) / 1s; @@ -152,7 +314,9 @@ void Worker::MainLoop() { policy_interval = MIN_POLICY_TIME + (rand() % (MAX_POLICY_TIME - MIN_POLICY_TIME + 1)); } + } - std::this_thread::sleep_for(milliseconds(100)); + if (stop_flag) { + SetState(WorkerState::SHUTTING_DOWN); } }