diff --git a/controller/service/internal/models/provider.go b/controller/service/internal/models/provider.go index dd3b6e9..e082a29 100644 --- a/controller/service/internal/models/provider.go +++ b/controller/service/internal/models/provider.go @@ -46,7 +46,7 @@ func (pl *ProviderList) LoadFromFile(filename string) error { return fmt.Errorf("reading providers: %w", err) } - interpolatedData := interpolateEnvVars(data) + interpolatedData := interpolateEnvVars(data) return json.Unmarshal(data, pl) } diff --git a/controller/service/internal/service/service.go b/controller/service/internal/service/service.go index d01fb1f..f89ece2 100644 --- a/controller/service/internal/service/service.go +++ b/controller/service/internal/service/service.go @@ -20,7 +20,7 @@ type Service struct { func NewService(categoryFile, providerFile string) (*Service, error) { - if err := godotenv.Load(); err != nil { + if err := godotenv.Load(); err != nil { log.Println("Warning: .env file not found, using system environment variables") } @@ -75,7 +75,7 @@ func (s *Service) Check(checkValue string, endpointName string) ([]int, error) { continue } - resp.Body.Close() + resp.Body.Close() if resp.StatusCode != http.StatusOK { log.Printf("Provider %s returned %s", providerName, resp.Status) diff --git a/controller/service/tests/http_client_test.go b/controller/service/tests/http_client_test.go index 581d2d7..d60842e 100644 --- a/controller/service/tests/http_client_test.go +++ b/controller/service/tests/http_client_test.go @@ -113,4 +113,4 @@ func TestHttpClientRequest(t *testing.T) { } }) } -} \ No newline at end of file +} diff --git a/controller/test/BUILD b/controller/test/BUILD index bcffe54..4fd8b29 100644 --- a/controller/test/BUILD +++ b/controller/test/BUILD @@ -1,11 +1,36 @@ load("@rules_go//go:def.bzl", "go_test") +load("@rules_proto//proto:defs.bzl", "proto_library") +load("@rules_go//proto:def.bzl", "go_proto_library") + +proto_library( + name = "test_communication_proto", + srcs = ["communication.proto"], + visibility = ["//visibility:private"], + deps = [ + "@protobuf//:struct_proto", + "@protobuf//:empty_proto", + ], +) + +go_proto_library( + name = "test_communication_go_proto", + compilers = ["@rules_go//proto:go_grpc"], + importpath = "github.com/moevm/grpc_server/controller/test", + proto = ":test_communication_proto", + visibility = ["//visibility:private"], +) go_test( name = "integration_test", srcs = ["worker_test.go"], deps = [ + "//pkg/proto/communication:communication_go_proto", + "@org_golang_google_grpc//:grpc", + "@org_golang_google_grpc//test/bufconn", + "@org_golang_google_grpc//credentials/insecure", + "@org_golang_google_protobuf//types/known/emptypb:go_default_library", "@com_github_stretchr_testify//assert", ], - args = ["--test.v"], + args = ["-test.v"], tags = ["exclusive"], -) \ No newline at end of file +) diff --git a/controller/test/README.md b/controller/test/README.md index 159ba36..f080fe2 100644 --- a/controller/test/README.md +++ b/controller/test/README.md @@ -3,9 +3,6 @@ ### 1. Сборка компонентов ```bash -cd controller -bazel build //cmd/grpc_server:grpc_server - cd ../worker bazel build //:worker ``` @@ -13,6 +10,7 @@ bazel build //:worker ### 2. Запуск интеграционного теста ```bash +export TEST_CONTROLLER_ADDR="" # например localhost:0 cd ../controller ./test/run.sh ``` \ No newline at end of file diff --git a/controller/test/worker_test.go b/controller/test/worker_test.go index d7f998b..9a14a1a 100644 --- a/controller/test/worker_test.go +++ b/controller/test/worker_test.go @@ -2,152 +2,207 @@ package test import ( "context" + "net" "os" "os/exec" "path/filepath" "testing" "time" + pb "github.com/moevm/grpc_server/pkg/proto/communication" "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/emptypb" ) +type MockController struct { + pb.UnimplementedDataServiceServer + Policy *pb.WorkerPolicy + t *testing.T +} + +func (m *MockController) GetPolicy(ctx context.Context, req *pb.GetPolicyRequest) (*pb.GetPolicyResponse, error) { + m.t.Logf("GetPolicy called: worker_id=%d, version=%d", req.WorkerId, req.ConfigVersion) + return &pb.GetPolicyResponse{ + Result: pb.GetPolicyResponse_POLICY_PROVIDED, + Policy: m.Policy, + }, nil +} + +func (m *MockController) Classify(ctx context.Context, req *pb.ClassifyRequest) (*pb.ClassifyResponse, error) { + m.t.Logf("Classify called: worker_id=%d, domain=%s", req.WorkerId, req.Domain) + return &pb.ClassifyResponse{ + Categories: []string{"news", "technology"}, + TrustLevel: 3, + }, nil +} + +func (m *MockController) SendStats(ctx context.Context, req *pb.StatsReport) (*emptypb.Empty, error) { + m.t.Logf("SendStats called: worker_id=%d", req.WorkerId) + return &emptypb.Empty{}, nil +} + +func StartMockController(t *testing.T, policy *pb.WorkerPolicy) (string, func()) { + listenAddr := os.Getenv("TEST_CONTROLLER_ADDR") + if listenAddr == "" { + listenAddr = "localhost:0" + } + + lis, err := net.Listen("tcp", listenAddr) + if err != nil { + t.Fatalf("Failed to listen on %s: %v", listenAddr, err) + } + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + + s := grpc.NewServer() + mock := &MockController{ + Policy: policy, + t: t, + } + pb.RegisterDataServiceServer(s, mock) + + go func() { + if err := s.Serve(lis); err != nil { + t.Logf("Server error: %v", err) + } + }() + + addr := lis.Addr().String() + cleanup := func() { + s.Stop() + lis.Close() + } + return addr, cleanup +} + func findProjectRoot() string { dir, _ := os.Getwd() for { - if _, err := os.Stat(filepath.Join(dir, "controller")); err != nil { - if _, err := os.Stat(filepath.Join(dir, "worker")); err != nil { - dir = filepath.Dir(dir) - continue - } + if _, err := os.Stat(filepath.Join(dir, "controller")); err == nil { + return dir } - return dir + if _, err := os.Stat(filepath.Join(dir, "worker")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + return dir + } + dir = parent } } -func TestWorkerPolicyRequest(t *testing.T) { +func TestWorkerPolicyContent(t *testing.T) { root := findProjectRoot() - - ctrlBin := filepath.Join(root, "controller", "bazel-bin", "cmd", "grpc_server", "grpc_server_", "grpc_server") workerBin := filepath.Join(root, "worker", "bazel-bin", "worker") - if _, err := os.Stat(ctrlBin); err != nil { - t.Skipf("Controller binary not found: %v", err) - } if _, err := os.Stat(workerBin); err != nil { t.Skipf("Worker binary not found: %v", err) } - ctrl := exec.Command(ctrlBin) - - if err := ctrl.Start(); err != nil { - t.Fatalf("Failed to start controller: %v", err) + testPolicy := &pb.WorkerPolicy{ + ConfigVersion: 2, + MinTrustLevel: 2, + BlockCategories: []string{"CATEGORY_ONLINE_SHOPS", "CATEGORY_ANONYMIZERS", "CATEGORY_ALCOHOL"}, + BlockDomains: []string{"1xbet.com"}, + AllowDomains: []string{"github.com", "vk.com"}, + BlockByTrust: map[string]int32{ + "CATEGORY_MALWARE": 3, + "CATEGORY_BETTING": 4, + }, } - defer func() { - if err := ctrl.Process.Kill(); err != nil { - t.Logf("Warning: failed to kill controller: %v", err) - } - }() - - time.Sleep(1 * time.Second) + addr, cleanup := StartMockController(t, testPolicy) + defer cleanup() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() worker := exec.CommandContext(ctx, workerBin) worker.Env = []string{ + "WORKER_ID=1", + "CONTROLLER_GRPC_ADDR=" + addr, "METRICS_GATEWAY_ADDRESS=localhost", "METRICS_GATEWAY_PORT=9091", "TEST_REQUEST_POLICY=true", } output, err := worker.CombinedOutput() - assert.NoError(t, err, "Worker failed") - assert.Contains(t, string(output), "Worker 1 requests policy") - assert.Contains(t, string(output), "Policy received") + assert.NoError(t, err, "Worker failed: %s", string(output)) + + outputStr := string(output) + + assert.Contains(t, outputStr, "Policy received") + assert.Contains(t, outputStr, "Min trust level: 2") + assert.Contains(t, outputStr, "Config version: 2") + + assert.Contains(t, outputStr, "blocked_categories: CATEGORY_ONLINE_SHOPS") + assert.Contains(t, outputStr, "blocked_categories: CATEGORY_ANONYMIZERS") + assert.Contains(t, outputStr, "blocked_categories: CATEGORY_ALCOHOL") + assert.Contains(t, outputStr, "block_domains: 1xbet.com") + assert.Contains(t, outputStr, "allow_domains: github.com") + assert.Contains(t, outputStr, "allow_domains: vk.com") } -func TestWorkerStatsReport(t *testing.T) { +func TestWorkerClassify(t *testing.T) { root := findProjectRoot() - - ctrlBin := filepath.Join(root, "controller", "bazel-bin", "cmd", "grpc_server", "grpc_server_", "grpc_server") workerBin := filepath.Join(root, "worker", "bazel-bin", "worker") - if _, err := os.Stat(ctrlBin); err != nil { - t.Skipf("Controller binary not found: %v", err) - } if _, err := os.Stat(workerBin); err != nil { t.Skipf("Worker binary not found: %v", err) } - ctrl := exec.Command(ctrlBin) - - if err := ctrl.Start(); err != nil { - t.Fatalf("Failed to start controller: %v", err) - } - - defer func() { - if err := ctrl.Process.Kill(); err != nil { - t.Logf("Warning: failed to kill controller: %v", err) - } - }() - - time.Sleep(1 * time.Second) + addr, cleanup := StartMockController(t, nil) + defer cleanup() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() worker := exec.CommandContext(ctx, workerBin) worker.Env = []string{ + "WORKER_ID=1", + "CONTROLLER_GRPC_ADDR=" + addr, "METRICS_GATEWAY_ADDRESS=localhost", "METRICS_GATEWAY_PORT=9091", - "TEST_STATS=true", + "TEST_CLASSIFY_DOMAIN=example.com", } output, err := worker.CombinedOutput() - assert.NoError(t, err, "Worker failed") - assert.Contains(t, string(output), "Worker 1 send stats") - assert.Contains(t, string(output), "Policy received") + outputStr := string(output) + assert.NoError(t, err, "Worker failed: %s", string(output)) + assert.Contains(t, outputStr, "Domain 'example.com' classified as categories [news, technology] with trust level 3") } -func TestWorkerClassifyRequest(t *testing.T) { +func TestWorkerSendStats(t *testing.T) { root := findProjectRoot() - - ctrlBin := filepath.Join(root, "controller", "bazel-bin", "cmd", "grpc_server", "grpc_server_", "grpc_server") workerBin := filepath.Join(root, "worker", "bazel-bin", "worker") - if _, err := os.Stat(ctrlBin); err != nil { - t.Skipf("Controller binary not found: %v", err) - } if _, err := os.Stat(workerBin); err != nil { t.Skipf("Worker binary not found: %v", err) } - ctrl := exec.Command(ctrlBin) - - if err := ctrl.Start(); err != nil { - t.Fatalf("Failed to start controller: %v", err) - } - - defer func() { - if err := ctrl.Process.Kill(); err != nil { - t.Logf("Warning: failed to kill controller: %v", err) - } - }() - - time.Sleep(1 * time.Second) + addr, cleanup := StartMockController(t, nil) + defer cleanup() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() worker := exec.CommandContext(ctx, workerBin) worker.Env = []string{ + "WORKER_ID=1", + "CONTROLLER_GRPC_ADDR=" + addr, "METRICS_GATEWAY_ADDRESS=localhost", "METRICS_GATEWAY_PORT=9091", - "TEST_CLASSIFY_DOMAIN=facebook.com", + "TEST_STATS=true", } output, err := worker.CombinedOutput() - assert.NoError(t, err, "Worker failed") - assert.Contains(t, string(output), "Domain 'facebook.com' classified as category") + assert.NoError(t, err, "Worker failed: %s", string(output)) + + outputStr := string(output) + + assert.Contains(t, outputStr, "Stats sent successfully") + } diff --git a/scripts/manager_integration_test.sh b/scripts/manager_integration_test.sh index e891d96..7644c91 100755 --- a/scripts/manager_integration_test.sh +++ b/scripts/manager_integration_test.sh @@ -6,128 +6,47 @@ COL_GREEN="\e[32;1m" COL_LBLUE="\e[96;1m" COL_RESET="\e[0m" -CONTROLLER_LOGFILE="controller.log" -WORKER_NAMES=("worker1" "worker2") -TIMEOUT_DURATION=20 -INIT_DURATION=2 - -always_log="$1" -user=$(id -un) -group=$(id -gn) -manager_pid= -fail= +fail=0 main() { - setup - build - run_manager - run_workers - wait_for_manager - check_results -} - -setup() { - echo -e "${COL_LBLUE}Setting up socket directory...${COL_RESET}" - sudo mkdir -p /run/controller/ - sudo chown ${user}:${group} /run/controller/ + build_worker + run_integration_tests + show_results } -build() { +build_worker() { echo -e "${COL_LBLUE}Building worker...${COL_RESET}" - docker compose --file ./worker/docker-compose.yml build - echo -e "${COL_LBLUE}Building manager...${COL_RESET}" - cd controller && bazel build //cmd/manager:manager && cd .. + cd worker && bazel build //:worker && cd .. } -run_manager() { - echo -e "${COL_LBLUE}Running manager...${COL_RESET}" +run_integration_tests() { + echo -e "${COL_LBLUE}Running integration tests...${COL_RESET}" cd controller - timeout ${TIMEOUT_DURATION} bazel run //cmd/manager:manager &> ${CONTROLLER_LOGFILE} & - manager_pid=$! - - echo -e "${COL_LBLUE}Waiting ${INIT_DURATION} seconds for manager to initialize...${COL_RESET}" - sleep ${INIT_DURATION} -} - -run_worker() { - local worker=$1 - echo -e "${COL_LBLUE}Starting ${worker}...${COL_RESET}" - docker run \ - --user "$(id -u):$(id -g)" \ - -d \ - -v "/run/controller:/run/controller" \ - -e "METRICS_GATEWAY_ADDRESS=metrics" \ - -e "METRICS_GATEWAY_PORT=9091" \ - -e "METRICS_WORKER_NAME=${worker}" \ - --name "${worker}" \ - "worker-worker" -} - -run_workers() { - for worker in "${WORKER_NAMES[@]}"; do - run_worker "${worker}" - done -} - -wait_for_manager() { - echo -e "${COL_LBLUE}Waiting for manager to finish...${COL_RESET}" - set +e - wait $manager_pid - manager_exit_code=$? - set -e -} - -check_results() { - for worker in "${WORKER_NAMES[@]}"; do - check_worker "${worker}" - done - check_manager_exit_code - show_failure_details -} - -check_worker() { - local worker=$1 - if docker logs "${worker}" | grep -qE '\[error\]'; then - echo -e "${COL_RED}* ${worker}: Error logs detected${COL_RESET}" - fail=1 + if ./test/run.sh; then + echo -e "${COL_GREEN}Integration tests passed${COL_RESET}" else - echo -e "${COL_GREEN}* ${worker}: OK${COL_RESET}" + echo -e "${COL_RED}Integration tests failed${COL_RESET}" + fail=1 fi + + cd .. } -check_manager_exit_code() { - if [[ ${manager_exit_code} -ne 0 ]]; then - echo -e "${COL_RED}* Manager: Bad exit code (${manager_exit_code})${COL_RESET}" - fail=1 +show_results() { + if [[ $fail -eq 0 ]]; then + echo -e "\n${COL_GREEN}ALL TESTS PASSED${COL_RESET}" else - echo -e "${COL_GREEN}* Manager: OK${COL_RESET}" + echo -e "\n${COL_RED}SOME TESTS FAILED${COL_RESET}" fi } -show_failure_details() { - [[ -z ${always_log} ]] && [[ -z ${fail} ]] && return - - echo -e "\n${COL_RED}=== TEST FAILED ===${COL_RESET}" - docker ps -a - for worker in "${WORKER_NAMES[@]}"; do - echo -e "\nLogs for ${worker}:" - docker logs "${worker}" - done - - echo -e "\nLogs for manager:" - cat ${CONTROLLER_LOGFILE} - - echo -e "=======================" -} - cleanup() { echo -e "${COL_LBLUE}Cleaning up...${COL_RESET}" - sudo rm -rf /run/controller - docker container rm -f worker1 worker2 > /dev/null - rm -f ${CONTROLLER_LOGFILE} bazel-* MODULE.bazel.lock + cd controller && bazel clean 2>/dev/null || true + cd ../worker && bazel clean 2>/dev/null || true } trap cleanup EXIT main -exit "${fail:-0}" +exit $fail diff --git a/worker/BUILD b/worker/BUILD index a1da9fa..3399bb0 100644 --- a/worker/BUILD +++ b/worker/BUILD @@ -4,8 +4,9 @@ proto_library( name = "communication_proto", srcs = ["communication.proto"], deps = [ - "@com_google_protobuf//:empty_proto", + "@com_google_protobuf//:any_proto", "@com_google_protobuf//:struct_proto", + "@com_google_protobuf//:empty_proto", ], ) @@ -18,7 +19,7 @@ generate_cc( name = "communication_cc_grpc_gen", srcs = [":communication_proto"], plugin = "@grpc//src/compiler:grpc_cpp_plugin", - well_known_protos = True, + well_known_protos = False, generate_mocks = True, ) @@ -37,7 +38,15 @@ cc_library( hdrs = [ "include/worker.hpp", "include/metrics_collector.hpp", + "include/dpdk_filter/net_port.h", + "include/dpdk_filter/dns_cache.h", + "include/dpdk_filter/filtr_packets.h", + "include/dpdk_filter/pars_packets.h", + "include/dpdk_filter/proc_packets.h", + "include/dpdk_filter/types.h", + "include/dpdk_filter/constants.h", ], + includes = ["include", "include/dpdk_filter"], srcs = [], visibility = ["//visibility:public"], ) @@ -47,7 +56,12 @@ cc_binary( srcs = [ "src/main.cpp", "src/worker.cpp", + "src/dpdk_filter/dns_cache.c", "src/metrics_collector.cpp", + "src/dpdk_filter/net_port.c", + "src/dpdk_filter/filtr_packets.c", + "src/dpdk_filter/pars_packets.c", + "src/dpdk_filter/proc_packets.c", ], deps = [ ":worker_headers", @@ -58,14 +72,42 @@ cc_binary( "@curl//:curl", ], copts = [ + "-mssse3", + "-msse4.2", + "-mpclmul", + "-maes", "-I$(GENDIR)/..", + "-I/usr/include", + ], + cxxopts = [ + "-std=c++17", ], linkopts = [ "-L/usr/local/openssl/lib", "-lssl", "-lcrypto", + "-L/usr/local/lib", "-lprometheus-cpp-push", "-lprometheus-cpp-core", + + "-L/usr/lib", + "-lrte_eal", + "-lrte_ethdev", + "-lrte_mempool", + "-lrte_mbuf", + "-lrte_bus_vdev", + "-lrte_ring", + "-lrte_telemetry", + "-lrte_kvargs", + "-lrte_log", + "-lrte_net", + "-lrte_hash", + "-lrte_timer", + "-lsqlite3", + + "-lnuma", + "-ldl", + "-lpthread", ], ) diff --git a/worker/Dockerfile.cc_x86_to_x86 b/worker/Dockerfile.cc_x86_to_x86 index 5d6cf09..800c2d8 100644 --- a/worker/Dockerfile.cc_x86_to_x86 +++ b/worker/Dockerfile.cc_x86_to_x86 @@ -1,17 +1,50 @@ -FROM alpine:3.21.3 AS builder - -RUN apk update && apk add --no-cache g++ openssl-dev cmake make curl-dev protobuf-dev -RUN apk add bazel --repository=http://dl-cdn.alpinelinux.org/alpine/edge/testing/ +FROM ubuntu:22.04 AS builder + +RUN apt-get update && apt-get install -y \ + build-essential=12.9* \ + cmake=3.22* \ + curl=7.81* \ + git=1:2.34* \ + wget=1.21* \ + meson=0.61* \ + ninja-build=1.10* \ + libssl-dev=3.0* \ + protobuf-compiler=3.12* \ + libprotobuf-dev=3.12* \ + python3=3.10* \ + python3-pip=22.0* \ + libnuma-dev=2.0* \ + pkg-config=0.29* \ + libcurl4-openssl-dev=7.81* \ + libbpf-dev=1:0.5* \ + gcc=4:11* \ + g++=4:11* \ + m4=1.4* \ + libpcap-dev=1.10* \ + libsqlite3-dev=3.37* \ + && rm -rf /var/lib/apt/lists/* + +RUN pip3 install pyelftools + +RUN wget https://github.com/bazelbuild/bazel/releases/download/8.2.1/bazel-8.2.1-linux-x86_64 \ + && chmod +x bazel-8.2.1-linux-x86_64 \ + && mv bazel-8.2.1-linux-x86_64 /usr/local/bin/bazel + +RUN wget https://fast.dpdk.org/rel/dpdk-23.11.tar.xz && \ + tar -xf dpdk-23.11.tar.xz && \ + cd dpdk-23.11 && \ + meson setup build --libdir=lib && \ + ninja -C build && \ + ninja -C build install && \ + cd .. && \ + rm -rf dpdk-23.11 dpdk-23.11.tar.xz + +ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig WORKDIR /app COPY scripts/get_prometheus_cpp.sh scripts/ RUN sh scripts/get_prometheus_cpp.sh -RUN apk add --no-cache llvm18 clang18 -RUN ln -s /usr/lib/llvm18/bin/llvm-ar /bin/llvm-ar-18 -RUN ln -s /usr/bin/clang++-18 /usr/bin/clang++ -RUN ln -s /usr/bin/clang-18 /usr/bin/clang - COPY ./src/ ./src/ COPY ./include/ ./include/ @@ -21,17 +54,10 @@ COPY ./communication.proto ./ COPY ./toolchains ./toolchains COPY ./platforms ./platforms - -RUN bazel build //:worker --extra_toolchains=//toolchains/x86_64:cc_toolchain_for_linux_x86_64 --platforms=//platforms:x86_64_linux - -FROM alpine:3.21.3 - -RUN apk update && apk add --no-cache libstdc++ libgcc libssl3 libcurl protobuf-dev - -COPY --from=builder /app/bazel-bin/worker /usr/local/bin/worker -COPY --from=builder /app/prometheus-cpp-with-submodules/build/lib/ /usr/lib - +RUN bazel build //:worker WORKDIR /data -ENTRYPOINT ["/usr/local/bin/worker", "/data/test.txt", "sha256"] +RUN ldconfig + +ENTRYPOINT ["/app/bazel-bin/worker"] diff --git a/worker/Makefile.main_riscv b/worker/Makefile.main_riscv index 3a06a3a..2d9b7dc 100644 --- a/worker/Makefile.main_riscv +++ b/worker/Makefile.main_riscv @@ -1,6 +1,7 @@ CC = riscv64-linux-gnu-gcc DPDK_PREFIX = ./dpdk-riscv-install +SQLITE_PREFIX = ./sqlite3-riscv-install PKG_CONFIG = env PKG_CONFIG_LIBDIR=$(DPDK_PREFIX)/lib/pkgconfig pkg-config CFLAGS_BASE = -Iinclude -O2 $(shell $(PKG_CONFIG) --cflags libdpdk) @@ -13,10 +14,11 @@ LDFLAGS = -L$(DPDK_PREFIX)/lib \ -lrte_net \ -lrte_log -ldl \ -lrte_hash \ - -sqlite3 \ + -lrte_timer \ -Wl,--end-group \ - -latomic - + -latomic \ + -L$(SQLITE_PREFIX)/lib \ + -lsqlite3 SRCS = src/dpdk_filter/main.c src/dpdk_filter/net_port.c src/dpdk_filter/filtr_packets.c src/dpdk_filter/pars_packets.c src/dpdk_filter/proc_packets.c src/dpdk_filter/dns_cache.c diff --git a/worker/Makefile.main_x86 b/worker/Makefile.main_x86 index 71fd148..b4c30ba 100644 --- a/worker/Makefile.main_x86 +++ b/worker/Makefile.main_x86 @@ -1,6 +1,6 @@ CC = gcc CFLAGS_BASE = -Iinclude -O2 -msse4.2 -mpclmul -maes -LDFLAGS = -lrte_eal -lrte_ethdev -lrte_mempool -lrte_mbuf -lrte_bus_vdev -lpthread -lnuma -ldl -lrte_net -lrte_hash -lsqlite3 +LDFLAGS = -lrte_eal -lrte_ethdev -lrte_mempool -lrte_mbuf -lrte_bus_vdev -lpthread -lnuma -ldl -lrte_net -lrte_hash -lsqlite3 -lrte_timer SRCS = src/dpdk_filter/main.c src/dpdk_filter/net_port.c src/dpdk_filter/filtr_packets.c src/dpdk_filter/pars_packets.c src/dpdk_filter/proc_packets.c src/dpdk_filter/dns_cache.c @@ -20,4 +20,4 @@ $(TARGET_VIRT): $(SRCS) clean: rm -f $(TARGET_REAL) $(TARGET_VIRT) -.PHONY: all clean virt \ No newline at end of file +.PHONY: all clean virt diff --git a/worker/README(DPDK FILTRING).md b/worker/README(DPDK FILTRING).md new file mode 100644 index 0000000..cba6538 --- /dev/null +++ b/worker/README(DPDK FILTRING).md @@ -0,0 +1,80 @@ +# Драйвера dpdk +DPDK должен быть собран с драйверами net/af_xdp net/tap + + +# Кросс-компиляция + +## Окружение +Скрипт `scripts/setup-riscv-env.sh` автоматически скачивает (при необходимости) и собирает DPDK 23.11 для архитектуры RISC-V. + +```bash +./scripts/setup-riscv-env.sh +``` + +## SQLite +Если целевая архитектура — RISC-V, SQLite необходимо собрать кросс-компилятором. + +```bash +wget https://www.sqlite.org/2024/sqlite-autoconf-3460100.tar.gz +tar -xzf sqlite-autoconf-3460100.tar.gz +cd sqlite-autoconf-3460100 + +./configure --host=riscv64-linux-gnu --prefix=/path/to/sqlite3-riscv-install +make -j$(nproc) +make install +``` + +После установки в указанном prefix появятся подкаталоги include/ и lib/ с необходимыми файлами. + + + +# Создание пары veth и TAP-устройства + +```bash +sudo ./scripts/set_virt_dev_for_test_xdp.sh +``` +Скрипт создаёт пару veth0 - veth1 + + +```bash +sudo ./scripts/set_tap_dev.sh +``` +Скрипт создаёт TAP-устройство tap0 + + + +# Сборка проекта +Для реальных портов (eth0/eth1): +```bash +make -f Makefile.main_riscv all +``` + +Для виртуальных портов (veth0/veth1 + tap0): +```bash +make -f Makefile.main_riscv virt +``` +Определение макроса -DVIRT_PORTS переключает программу на использование виртуальных интерфейсов. + + +Перед запуском рекомендуется выполнить скрипт настройки виртуальных устройств: +```bash +sudo ./scripts/set_virt_dev_for_test_xdp.sh +``` + + +# Очистка +```bash +make -f Makefile.main_riscv clean +``` + +# Запуск +Программа требует прав суперпользователя (для работы с DPDK и XDP): +```bash +sudo ./main-riscv-virt +``` + + +# Примечания +Кэш DNS автоматически сохраняется в cache.db (SQLite) и восстанавливается при перезапуске. + +Периодическое сохранение кэша происходит каждый час с помощью таймеров DPDK. diff --git a/worker/helper for association with Worker.md b/worker/helper for association with Worker.md index de73307..ca4428e 100644 --- a/worker/helper for association with Worker.md +++ b/worker/helper for association with Worker.md @@ -1,18 +1,24 @@ -REQUESTED_CLASSIFICATION структура для передачи от контроллера к воркеру: +REQUESTED_CLASSIFICATION - структура для передачи от контроллера к воркеру: + +```code struct requested_classification { - char get_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN] - политика - int get_trust_level - уровень доверия к сайту + char get_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN] + int get_trust_level } +``` +Структура для хранения категории с минимальным уровнем доверия для этой категории: -Структура для хранения категории с минимальным уровнем доверия для этой категории +```code struct trust_categories_with_lvl { char locked_by_trust_category[CATEGORY_MAX_LEN]; int trust_lvl; } +``` +у нас есть переменные, которые получаем при инициализации воркера и заносим в структуру (периодически обновляем): -у нас есть переменные, которые получаем при инициализации воркера и заносим в структуру (периодически обновляем) +```code struct BASE_POLICY { char locked_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]; struct trust_categories_with_lvl categories_with_lvl[MAX_CATEGORIES_BY_TRUST_LVL]; @@ -20,7 +26,8 @@ struct BASE_POLICY { char allow_domains[MAX_DOMAINS][MAX_LEN_DOMEIN]; int min_trust_level; } +``` +Добавлен tap порт, по которому проходят пакеты исключений в ядро, обрабатываются и ответ отсылается на входящий порт (port_in) -Добавлен tap порт, по которому проходят пакеты исключений в ядро, обрабатываются и ответ отсылается на входящий порт (port_in) \ No newline at end of file diff --git a/worker/include/dpdk_filter/constants.h b/worker/include/dpdk_filter/constants.h index 6ae6d1d..f046259 100644 --- a/worker/include/dpdk_filter/constants.h +++ b/worker/include/dpdk_filter/constants.h @@ -1,6 +1,7 @@ #ifndef CONSTANTS_H #define CONSTANTS_H +#include #define MAX_CATEGORIES_BY_TRUST_LVL 64 #define MAX_DOMAINS 64 @@ -10,6 +11,6 @@ #define CATEGORY_MAX_LEN 64 #define DNS_CACHE_DEFAULT_TTL (7 * 24 * 60 * 60) #define LEN_LIST_EXCEPTION_PORTS 1 -extern const uint16_t LIST_EXCEPTION_PORTS[LEN_LIST_EXCEPTION_PORTS]; +extern const uint16_t LIST_EXCEPTION_PORTS[LEN_LIST_EXCEPTION_PORTS]; #endif \ No newline at end of file diff --git a/worker/include/dpdk_filter/dns_cache.h b/worker/include/dpdk_filter/dns_cache.h index 7f953f9..8463124 100644 --- a/worker/include/dpdk_filter/dns_cache.h +++ b/worker/include/dpdk_filter/dns_cache.h @@ -6,15 +6,12 @@ #include #include #include +#include #include #include -#include - -#include "../../include/dpdk_filter/constants.h" -#include "../../include/dpdk_filter/types.h" - - +#include "constants.h" +#include "types.h" void init_dns_cache(void); int lookup_dns_cache(const char *domain, struct node_cache **return_node); diff --git a/worker/include/dpdk_filter/filtr_packets.h b/worker/include/dpdk_filter/filtr_packets.h index dda5332..188830b 100644 --- a/worker/include/dpdk_filter/filtr_packets.h +++ b/worker/include/dpdk_filter/filtr_packets.h @@ -1,22 +1,29 @@ #ifndef FILTR_PAK_H #define FILTR_PAK_H +#include "constants.h" #include "pars_packets.h" +#include "types.h" #include #include -#include "../../include/dpdk_filter/constants.h" -#include "../../include/dpdk_filter/types.h" -bool check_is_block(char domain[DOMAIN_MAX_LEN], char block_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]); +bool check_is_block(char domain[DOMAIN_MAX_LEN], + char block_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]); -bool check_is_allow(char domain[DOMAIN_MAX_LEN], char allow_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]); +bool check_is_allow(char domain[DOMAIN_MAX_LEN], + char allow_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]); bool check_trust_level(int get_trust_level, int min_trust_level); -bool check_categories(char get_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN], char locked_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]); +bool check_categories(char get_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN], + char locked_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]); -bool check_categories_with_lvl(struct requested_classification* req_clas, struct trust_categories_with_lvl categories_with_lvl[MAX_CATEGORIES_BY_TRUST_LVL]); +bool check_categories_with_lvl( + struct requested_classification *req_clas, + struct trust_categories_with_lvl + categories_with_lvl[MAX_CATEGORIES_BY_TRUST_LVL]); -bool main_filtring(struct requested_classification* req_clas, struct BASE_POLICY* policy, char domain[DOMAIN_MAX_LEN]); +bool main_filtring(struct requested_classification *req_clas, + struct BASE_POLICY *policy, char domain[DOMAIN_MAX_LEN]); #endif \ No newline at end of file diff --git a/worker/include/dpdk_filter/net_port.h b/worker/include/dpdk_filter/net_port.h index e34f8b1..ebe9224 100644 --- a/worker/include/dpdk_filter/net_port.h +++ b/worker/include/dpdk_filter/net_port.h @@ -1,18 +1,15 @@ #ifndef AF_XDP_PORT_H #define AF_XDP_PORT_H +#include "types.h" #include #include -#include "../../include/dpdk_filter/types.h" - - struct net_port *init_struct_tap_port(const char *tap_iface_name, - struct rte_mempool *mbuf_pool); - + struct rte_mempool *mbuf_pool); struct net_port *init_struct_af_xdp_port(const char *iface_name, - struct rte_mempool *mbuf_pool); + struct rte_mempool *mbuf_pool); int net_port_init(struct net_port *port); diff --git a/worker/include/dpdk_filter/pars_packets.h b/worker/include/dpdk_filter/pars_packets.h index d319225..7d72726 100644 --- a/worker/include/dpdk_filter/pars_packets.h +++ b/worker/include/dpdk_filter/pars_packets.h @@ -1,12 +1,10 @@ #ifndef PARS_PAK_H #define PARS_PAK_H +#include "constants.h" +#include "types.h" #include #include -#include "../../include/dpdk_filter/constants.h" -#include "../../include/dpdk_filter/types.h" - - void parsing_pakage(struct rte_mbuf *paket, struct info_of_pakage *info_pac); diff --git a/worker/include/dpdk_filter/proc_packets.h b/worker/include/dpdk_filter/proc_packets.h index c7f7737..aafc1de 100644 --- a/worker/include/dpdk_filter/proc_packets.h +++ b/worker/include/dpdk_filter/proc_packets.h @@ -1,24 +1,20 @@ #ifndef PROC_PAK_H #define PROC_PAK_H -#include "../../include/dpdk_filter/net_port.h" -#include "../../include/dpdk_filter/filtr_packets.h" -#include "../../include/dpdk_filter/pars_packets.h" -#include "../../include/dpdk_filter/constants.h" -#include "../../include/dpdk_filter/types.h" +#include "constants.h" +#include "filtr_packets.h" +#include "net_port.h" +#include "pars_packets.h" +#include "types.h" #include #include #include #include #include - - - - - -void pakage_processing(struct net_port *port_in, - struct net_port *port_out, struct net_port *port_exception, uint16_t queue_number, - uint16_t nb_pkts, struct rte_mbuf **pkts, struct BASE_POLICY* policy); +void pakage_processing(struct net_port *port_in, struct net_port *port_out, + struct net_port *port_exception, uint16_t queue_number, + uint16_t nb_pkts, struct rte_mbuf **pkts, + struct BASE_POLICY *policy); #endif \ No newline at end of file diff --git a/worker/include/dpdk_filter/types.h b/worker/include/dpdk_filter/types.h index d056e50..d93881b 100644 --- a/worker/include/dpdk_filter/types.h +++ b/worker/include/dpdk_filter/types.h @@ -2,8 +2,8 @@ #define TYPES_H #include "constants.h" -#include #include +#include struct net_port { uint16_t port_id; @@ -20,24 +20,24 @@ struct info_of_pakage { }; struct trust_categories_with_lvl { - char locked_by_trust_category[CATEGORY_MAX_LEN]; - int trust_lvl; + char locked_by_trust_category[CATEGORY_MAX_LEN]; + int trust_lvl; }; struct BASE_POLICY { char locked_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]; - struct trust_categories_with_lvl categories_with_lvl[MAX_CATEGORIES_BY_TRUST_LVL]; + struct trust_categories_with_lvl + categories_with_lvl[MAX_CATEGORIES_BY_TRUST_LVL]; char block_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]; char allow_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]; int min_trust_level; }; struct requested_classification { - char get_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]; - int get_trust_level; + char get_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]; + int get_trust_level; }; - struct node_cache { char categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]; bool solution_is_send; @@ -47,5 +47,4 @@ struct node_cache { char *key_domain; }; - #endif diff --git a/worker/include/worker.hpp b/worker/include/worker.hpp index 58c437a..d1859a2 100644 --- a/worker/include/worker.hpp +++ b/worker/include/worker.hpp @@ -3,9 +3,19 @@ #include "communication.grpc.pb.h" #include "communication.pb.h" +extern "C" { +#include "dpdk_filter/filtr_packets.h" +#include "dpdk_filter/net_port.h" +#include "dpdk_filter/proc_packets.h" +#include "dpdk_filter/types.h" +} #include #include #include +#include +#include +#include +#include #define EXPECTED_POLICY_TIME 60 #define MIN_POLICY_TIME 30 @@ -29,6 +39,14 @@ class Worker { int64_t policy_interval = MIN_POLICY_TIME; int64_t stats_interval = MIN_STATS_TIME; + struct net_port *port_in = nullptr; + struct net_port *port_out = nullptr; + struct net_port *port_exception = nullptr; + struct rte_mempool *mbuf_pool = nullptr; + std::mutex policy_mutex; + struct BASE_POLICY current_policy; + uint16_t queue_number = 0; + std::unique_ptr stub_; WorkerState state; @@ -39,9 +57,13 @@ class Worker { Worker(uint64_t id); ~Worker(); + void initDPDK(int argc, char **argv); inline uint64_t GetID() const { return worker_id; } void requestPolicyFromController(); - void classifyDomain(const std::string &domain); + bool classifyDomain(const std::string &domain, + struct requested_classification *out_req); + void forward_to_out(struct net_port *incoming_port, + struct net_port *outgoing_port, uint16_t queue_number); void statsReport(); WorkerState GetState() const { return state; } void MainLoop(); diff --git a/worker/scripts/set_tap_dev.sh b/worker/scripts/set_tap_dev.sh new file mode 100755 index 0000000..cb0cb03 --- /dev/null +++ b/worker/scripts/set_tap_dev.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +TAP="tap0" + +sudo ip tuntap add $TAP mode tap +sudo ip link set $TAP up +sudo ip addr add 10.0.3.1/24 dev $TAP \ No newline at end of file diff --git a/worker/src/dpdk_filter/af_xdp_port.c b/worker/src/dpdk_filter/af_xdp_port.c deleted file mode 100644 index 25ecb7d..0000000 --- a/worker/src/dpdk_filter/af_xdp_port.c +++ /dev/null @@ -1,176 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../../include/dpdk_filter/af_xdp_port.h" - -#define RX_RING_SIZE 1024 -#define TX_RING_SIZE 1024 - -int find_port_by_dev_name(const char *dev_name, uint16_t *port_id_dev) { - uint16_t count_ports = rte_eth_dev_count_avail(); - struct rte_eth_dev_info dev_info; - char name[64]; - - for (uint16_t port_id = 0; port_id < count_ports; port_id++) { - int ret = rte_eth_dev_info_get(port_id, &dev_info); - - if (ret) { - printf("[ERROR] Failed to retrieve the contextual information of an " - "Ethernet device: %s\n", - strerror(-ret)); - return ret; - } - - if (rte_eth_dev_get_name_by_port(port_id, name) == 0 && - strcmp(name, dev_name) == 0) { - *port_id_dev = port_id; - return 0; - } - } - return -1; -} - -struct af_xdp_port *init_struct_af_xdp_port(const char *iface_name, - struct rte_mempool *mbuf_pool) { - struct af_xdp_port *port = calloc(1, sizeof(struct af_xdp_port)); - if (!port) { - printf("[ERROR] Failed to allocate memory for struct af_xdp_port\n"); - return NULL; - } - - snprintf(port->dev_args, sizeof(port->dev_args), - "iface=%s,start_queue=0,queue_count=1", iface_name); - snprintf(port->dev_name, sizeof(port->dev_name), "net_af_xdp_%s", iface_name); - strncpy(port->iface_name, iface_name, sizeof(port->iface_name) - 1); - port->iface_name[sizeof(port->iface_name) - 1] = '\0'; - port->mbuf_pool = mbuf_pool; - port->port_id = -1; - - return port; -} - -int af_xdp_port_init(struct af_xdp_port *port) { - int ret; - struct rte_eth_conf port_conf = {0}; - const char *dev_name = port->dev_name; - uint16_t port_id; - - ret = rte_vdev_init(dev_name, port->dev_args); - - if (ret < 0) { - printf("[ERROR] Failed to create vdev: %s\n", strerror(-ret)); - return ret; - } - - ret = find_port_by_dev_name(port->dev_name, &port_id); - if (ret) { - printf("no port was found that has the same vdev name. vdev = %s", - port->dev_name); - rte_vdev_uninit(dev_name); - return -1; - } - - port->port_id = port_id; - - if (!rte_eth_dev_is_valid_port(port_id)) { - printf("[ERROR] Port %u is not valid\n", port_id); - rte_vdev_uninit(dev_name); - return -EINVAL; - } - - ret = rte_eth_dev_configure(port_id, 1, 1, &port_conf); - if (ret < 0) { - printf("[ERROR] Failed to configure port: %s\n", strerror(-ret)); - rte_vdev_uninit(dev_name); - return ret; - } - - ret = rte_eth_rx_queue_setup(port_id, 0, RX_RING_SIZE, - rte_eth_dev_socket_id(port_id), NULL, - port->mbuf_pool); - if (ret < 0) { - printf("[ERROR] Failed to setup RX queue: %s\n", strerror(-ret)); - rte_vdev_uninit(dev_name); - return ret; - } - - ret = rte_eth_tx_queue_setup(port_id, 0, TX_RING_SIZE, - rte_eth_dev_socket_id(port_id), NULL); - - if (ret < 0) { - printf("[ERROR] Failed to setup TX queue: %s\n", strerror(-ret)); - rte_vdev_uninit(dev_name); - return ret; - } - - printf("Port %u initialized\n", port_id); - return 0; -} - -int af_xdp_port_start(uint16_t port_id) { - int ret; - - ret = rte_eth_dev_start(port_id); - if (ret < 0) { - printf("[ERROR] Failed to start: %s\n", strerror(-ret)); - return ret; - } - - ret = rte_eth_promiscuous_enable(port_id); - if (ret) { - printf("[ERROR] Failed to enable receipt in promiscuous mode for an " - "Ethernet device: %s\n", - strerror(-ret)); - return ret; - } - - printf("Port %u started\n", port_id); - return 0; -} - -void af_xdp_port_destroy(struct af_xdp_port *port) { - if (!port) - return; - free(port); -} - -void af_xdp_port_close(struct af_xdp_port *port) { - - if (!port) - return; - - int ret; - uint16_t port_id = port->port_id; - - ret = rte_eth_dev_stop(port_id); - if (ret) { - printf("[ERROR] Failed to stop an Ethernet device: %s\n", strerror(-ret)); - return; - } - - ret = rte_eth_dev_close(port_id); - if (ret) { - printf("[ERROR] Failed to close a stopped Ethernet device: %s\n", - strerror(-ret)); - return; - } - - ret = rte_vdev_uninit(port->dev_name); - if (ret) { - printf("[ERROR] Failed to uninitialize a driver: %s\n", strerror(-ret)); - return; - } - - port->port_id = -1; - printf("Port %u closed\n", port_id); -} \ No newline at end of file diff --git a/worker/src/dpdk_filter/dns_cache.c b/worker/src/dpdk_filter/dns_cache.c index 4427e1f..7be7373 100644 --- a/worker/src/dpdk_filter/dns_cache.c +++ b/worker/src/dpdk_filter/dns_cache.c @@ -1,4 +1,4 @@ -#include "../../include/dpdk_filter/dns_cache.h" +#include "dns_cache.h" static struct rte_hash *dns_hash; static struct rte_hash_parameters hash_params = { diff --git a/worker/src/dpdk_filter/filtr_packets.c b/worker/src/dpdk_filter/filtr_packets.c index a6aeadd..530204d 100644 --- a/worker/src/dpdk_filter/filtr_packets.c +++ b/worker/src/dpdk_filter/filtr_packets.c @@ -1,7 +1,8 @@ -#include "../../include/dpdk_filter/filtr_packets.h" -#include "../../include/dpdk_filter/pars_packets.h" +#include "filtr_packets.h" +#include "pars_packets.h" -bool check_is_block(char domain[DOMAIN_MAX_LEN], char block_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]) { +bool check_is_block(char domain[DOMAIN_MAX_LEN], + char block_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]) { for (int i = 0; i < MAX_DOMAINS; i++) { if (strcmp(block_domains[i], domain) == 0) { @@ -12,7 +13,8 @@ bool check_is_block(char domain[DOMAIN_MAX_LEN], char block_domains[MAX_DOMAINS] return false; } -bool check_is_allow(char domain[DOMAIN_MAX_LEN], char allow_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]) { +bool check_is_allow(char domain[DOMAIN_MAX_LEN], + char allow_domains[MAX_DOMAINS][DOMAIN_MAX_LEN]) { for (int i = 0; i < MAX_DOMAINS; i++) { if (strcmp(allow_domains[i], domain) == 0) { @@ -32,34 +34,41 @@ bool check_trust_level(int get_trust_level, int min_trust_level) { return true; } -bool check_categories(char get_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN], char locked_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]) { - +bool check_categories( + char get_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN], + char locked_categories[MAX_CATEGORIES][CATEGORY_MAX_LEN]) { + for (int i = 0; i < MAX_CATEGORIES; i++) { for (int j = 0; j < MAX_CATEGORIES; j++) { if (strcmp(get_categories[i], locked_categories[j]) == 0) { return false; } } - } + } return true; } - -bool check_categories_with_lvl(struct requested_classification* req_clas, struct trust_categories_with_lvl categories_with_lvl[MAX_CATEGORIES_BY_TRUST_LVL]) { +bool check_categories_with_lvl( + struct requested_classification *req_clas, + struct trust_categories_with_lvl + categories_with_lvl[MAX_CATEGORIES_BY_TRUST_LVL]) { for (int i = 0; i < MAX_CATEGORIES; i++) { for (int j = 0; j < MAX_CATEGORIES; j++) { - if (strcmp(req_clas->get_categories[j], categories_with_lvl[i].locked_by_trust_category) == 0 && req_clas->get_trust_level < categories_with_lvl[i].trust_lvl) { - return false; + if (strcmp(req_clas->get_categories[j], + categories_with_lvl[i].locked_by_trust_category) == 0 && + req_clas->get_trust_level < categories_with_lvl[i].trust_lvl) { + return false; } - } - } + } + } return true; } -bool main_filtring(struct requested_classification* req_clas, struct BASE_POLICY* policy, char domain[DOMAIN_MAX_LEN]) { +bool main_filtring(struct requested_classification *req_clas, + struct BASE_POLICY *policy, char domain[DOMAIN_MAX_LEN]) { if (check_is_block(domain, policy->block_domains) == true) { printf("This domain is blocked"); @@ -71,18 +80,22 @@ bool main_filtring(struct requested_classification* req_clas, struct BASE_POLICY return true; } - if (check_categories(req_clas->get_categories, policy->locked_categories) == false) { + if (check_categories(req_clas->get_categories, policy->locked_categories) == + false) { printf("This site has a locked category"); return false; } - if (check_trust_level(req_clas->get_trust_level, policy->min_trust_level) == false) { + if (check_trust_level(req_clas->get_trust_level, policy->min_trust_level) == + false) { printf("This site has a too small trust level"); return false; } - if (check_categories_with_lvl(req_clas, policy->categories_with_lvl) == false) { - printf("This site blocked in accordance with 'trust categories with level'"); + if (check_categories_with_lvl(req_clas, policy->categories_with_lvl) == + false) { + printf( + "This site blocked in accordance with 'trust categories with level'"); return false; } diff --git a/worker/src/dpdk_filter/main.c b/worker/src/dpdk_filter/main.c index 6da81a8..36e59b4 100644 --- a/worker/src/dpdk_filter/main.c +++ b/worker/src/dpdk_filter/main.c @@ -1,6 +1,6 @@ -#include "../../include/dpdk_filter/net_port.h" -#include "../../include/dpdk_filter/dns_cache.h" -#include "../../include/dpdk_filter/proc_packets.h" +#include "dns_cache.h" +#include "net_port.h" +#include "proc_packets.h" #include #include #include @@ -18,21 +18,23 @@ static void signal_handler(int signum) { } } -void forward_tap_to_out(struct net_port *port_exception, struct net_port *port_in, uint16_t queue_number) { - struct rte_mbuf *tap_pkts[32]; - uint16_t nb_tap = rte_eth_rx_burst(port_exception->port_id, queue_number, tap_pkts, 32); - for (int i = 0; i < nb_tap; i++) { - int ret = rte_eth_tx_burst(port_in->port_id, queue_number, &tap_pkts[i], 1); - if (ret < 1) { - printf("[ERROR] Failed to send packet\n"); - // PLUG (to be added later) - need to add processing for this case - rte_pktmbuf_free(tap_pkts[i]); - } +void forward_tap_to_out(struct net_port *port_exception, + struct net_port *port_in, uint16_t queue_number) { + struct rte_mbuf *tap_pkts[32]; + uint16_t nb_tap = + rte_eth_rx_burst(port_exception->port_id, queue_number, tap_pkts, 32); + for (int i = 0; i < nb_tap; i++) { + int ret = rte_eth_tx_burst(port_in->port_id, queue_number, &tap_pkts[i], 1); + if (ret < 1) { + printf("[ERROR] Failed to send packet\n"); + // PLUG (to be added later) - need to add processing for this case + rte_pktmbuf_free(tap_pkts[i]); } + } } int main(int argc, char **argv) { - //since BASE_POLICY is filled when initializing worker, let’s initialize here + // since BASE_POLICY is filled when initializing worker, let’s initialize here struct BASE_POLICY policy; if (signal(SIGINT, signal_handler) == SIG_ERR) { printf("[ERROR] Failed to set SIGINT handler\n"); @@ -43,8 +45,6 @@ int main(int argc, char **argv) { return 1; } - - struct net_port *port_in = NULL; struct net_port *port_out = NULL; struct net_port *port_exception = NULL; @@ -83,24 +83,23 @@ int main(int argc, char **argv) { port_exception = init_struct_tap_port("tap0", mbuf_pool); - if (!port_in || !port_out || !port_exception) { return 1; } - if (net_port_init(port_in) || net_port_init(port_out) || net_port_init(port_exception)) { + if (net_port_init(port_in) || net_port_init(port_out) || + net_port_init(port_exception)) { return 1; } - if (net_port_start(port_in->port_id) || - net_port_start(port_out->port_id) || + if (net_port_start(port_in->port_id) || net_port_start(port_out->port_id) || net_port_start(port_exception->port_id)) { return 1; } ret = system("sudo ip link set tap0 up && " - "sudo ip addr add 10.0.3.1/24 dev tap0"); - if(ret) { + "sudo ip addr add 10.0.3.1/24 dev tap0"); + if (ret) { printf("[ERROR] Failed to set tap0 up\n"); } @@ -110,7 +109,8 @@ int main(int argc, char **argv) { while (running) { forward_tap_to_out(port_exception, port_in, queue_number); - pakage_processing(port_in, port_out, port_exception, queue_number, nb_pkts, pkts, &policy); + pakage_processing(port_in, port_out, port_exception, queue_number, nb_pkts, + pkts, &policy); } // function for save cache info if need diff --git a/worker/src/dpdk_filter/net_port.c b/worker/src/dpdk_filter/net_port.c index 92eeca4..761a1cb 100644 --- a/worker/src/dpdk_filter/net_port.c +++ b/worker/src/dpdk_filter/net_port.c @@ -11,7 +11,7 @@ #include #include -#include "../../include/dpdk_filter/net_port.h" +#include "net_port.h" #define RX_RING_SIZE 1024 #define TX_RING_SIZE 1024 @@ -41,16 +41,18 @@ int find_port_by_dev_name(const char *dev_name, uint16_t *port_id_dev) { } struct net_port *init_struct_tap_port(const char *tap_iface_name, - struct rte_mempool *mbuf_pool) { + struct rte_mempool *mbuf_pool) { struct net_port *port = calloc(1, sizeof(struct net_port)); - if (!port) { + if (!port) { printf("[ERROR] Failed to allocate memory for struct net_port\n"); return NULL; } - snprintf(port->dev_args, sizeof(port->dev_args),"iface=%s", tap_iface_name); - snprintf(port->dev_name, sizeof(port->dev_name), "net_tap_%s", tap_iface_name); + snprintf(port->dev_args, sizeof(port->dev_args), "iface=%s, remote=%s", + tap_iface_name, tap_iface_name); + snprintf(port->dev_name, sizeof(port->dev_name), "net_tap_%s", + tap_iface_name); strncpy(port->iface_name, tap_iface_name, sizeof(port->iface_name) - 1); port->iface_name[sizeof(port->iface_name) - 1] = '\0'; port->mbuf_pool = mbuf_pool; @@ -59,9 +61,8 @@ struct net_port *init_struct_tap_port(const char *tap_iface_name, return port; } - struct net_port *init_struct_af_xdp_port(const char *iface_name, - struct rte_mempool *mbuf_pool) { + struct rte_mempool *mbuf_pool) { struct net_port *port = calloc(1, sizeof(struct net_port)); if (!port) { printf("[ERROR] Failed to allocate memory for struct net_port\n"); diff --git a/worker/src/dpdk_filter/pars_packets.c b/worker/src/dpdk_filter/pars_packets.c index ea59e2d..4092b4b 100644 --- a/worker/src/dpdk_filter/pars_packets.c +++ b/worker/src/dpdk_filter/pars_packets.c @@ -1,4 +1,4 @@ -#include "../../include/dpdk_filter/pars_packets.h" +#include "pars_packets.h" #include #include #include diff --git a/worker/src/dpdk_filter/proc_packets.c b/worker/src/dpdk_filter/proc_packets.c index 397f4f6..6eebebb 100644 --- a/worker/src/dpdk_filter/proc_packets.c +++ b/worker/src/dpdk_filter/proc_packets.c @@ -1,5 +1,8 @@ -#include "../../include/dpdk_filter/proc_packets.h" -#include "../../include/dpdk_filter/dns_cache.h" +#include "proc_packets.h" +#include "dns_cache.h" + +extern bool worker_classify_domain(const char *domain, + struct requested_classification *out_req); const uint16_t LIST_EXCEPTION_PORTS[LEN_LIST_EXCEPTION_PORTS] = {22}; @@ -20,7 +23,6 @@ void package_sending_decision(bool solution_is_send, struct rte_mbuf *pkt, rte_pktmbuf_free(pkt); } - bool check_is_exception(uint16_t number_port) { for (int i = 0; i < LEN_LIST_EXCEPTION_PORTS; i++) { if (number_port == LIST_EXCEPTION_PORTS[i]) { @@ -30,10 +32,10 @@ bool check_is_exception(uint16_t number_port) { return false; } - -void pakage_processing(struct net_port *port_in, - struct net_port *port_out, struct net_port *port_exception, uint16_t queue_number, - uint16_t nb_pkts, struct rte_mbuf **pkts, struct BASE_POLICY* policy) { +void pakage_processing(struct net_port *port_in, struct net_port *port_out, + struct net_port *port_exception, uint16_t queue_number, + uint16_t nb_pkts, struct rte_mbuf **pkts, + struct BASE_POLICY *policy) { uint16_t nb_rx = rte_eth_rx_burst(port_in->port_id, queue_number, pkts, nb_pkts); @@ -44,14 +46,15 @@ void pakage_processing(struct net_port *port_in, memset(&info_pac, 0, sizeof(info_pac)); parsing_pakage(pkts[i], &info_pac); - printf("[PKT] port = %hu; domain = %s\n", ntohs(info_pac.number_port), info_pac.domain); + printf("[PKT] port = %hu; domain = %s\n", ntohs(info_pac.number_port), + info_pac.domain); if (info_pac.domain[0] == '\0') { printf("[INFO] Packet without dns request\n"); package_sending_decision(true, pkts[i], port_out, queue_number); continue; } - if(check_is_exception(info_pac.number_port) == true) { + if (check_is_exception(info_pac.number_port) == true) { package_sending_decision(true, pkts[i], port_exception, queue_number); continue; } @@ -67,7 +70,15 @@ void pakage_processing(struct net_port *port_in, struct requested_classification req_clas; - bool solution_is_send = main_filtring(&req_clas, policy, info_pac.domain); + bool solution_is_send; + bool classification_success = + worker_classify_domain(info_pac.domain, &req_clas); + if (classification_success) { + solution_is_send = main_filtring(&req_clas, policy, info_pac.domain); + } else { + solution_is_send = true; + printf("[WARN] Classification failed for %s\n", info_pac.domain); + } package_sending_decision(solution_is_send, pkts[i], port_out, queue_number); diff --git a/worker/src/main.cpp b/worker/src/main.cpp index c75d4a2..093a2c3 100644 --- a/worker/src/main.cpp +++ b/worker/src/main.cpp @@ -1,5 +1,5 @@ -#include "../include/metrics_collector.hpp" -#include "../include/worker.hpp" +#include "metrics_collector.hpp" +#include "worker.hpp" #include @@ -20,7 +20,7 @@ class FiltrWorker : public Worker { ("worker-" + std::to_string(id)).c_str()) {} }; -int main() { +int main(int argc, char **argv) { const char *worker_id_str = getenv("WORKER_ID"); if (worker_id_str == nullptr) { spdlog::error("WORKER_ID environment variable not set"); @@ -42,7 +42,6 @@ int main() { try { Worker worker(worker_id); - bool test_mode = false; if (getenv("TEST_REQUEST_POLICY") != nullptr) { test_mode = true; @@ -62,14 +61,22 @@ int main() { test_mode = true; spdlog::info("Test mode: classifying domain '{}'", domain); std::this_thread::sleep_for(std::chrono::seconds(1)); - worker.classifyDomain(domain); + struct requested_classification req_clas; + memset(&req_clas, 0, sizeof(req_clas)); + bool success = worker.classifyDomain(domain, &req_clas); + if (success) { + spdlog::info("Classification successful: trust_level={}", + req_clas.get_trust_level); + } else { + spdlog::error("Classification failed"); + } } if (test_mode) { spdlog::info("Test mode completed, exiting"); return 0; } - + worker.initDPDK(argc, argv); worker.MainLoop(); } catch (std::exception &e) { diff --git a/worker/src/worker.cpp b/worker/src/worker.cpp index b63c850..6ce8a58 100644 --- a/worker/src/worker.cpp +++ b/worker/src/worker.cpp @@ -1,12 +1,35 @@ -#include "../include/worker.hpp" - +#include "worker.hpp" #include "communication.grpc.pb.h" +#include "proc_packets.h" #include +#include #include #include +#include #include #include +Worker *g_worker = nullptr; + +extern "C" bool +worker_classify_domain(const char *domain, + struct requested_classification *out_req) { + if (!g_worker) { + fprintf(stderr, "worker_classify_domain: g_worker is null\n"); + return false; + } + return g_worker->classifyDomain(std::string(domain), out_req); +} + +static volatile bool stop_flag = false; + +static void signal_handler(int signum) { + if (signum == SIGINT || signum == SIGTERM) { + spdlog::info("Signal {} received, shutting down.", signum); + stop_flag = true; + } +} + void Worker::LogStateChange(WorkerState new_state) { const char *state_names[] = {"BOOTING", "FREE", "BUSY", "SHUTTING_DOWN", "ERROR"}; @@ -22,6 +45,64 @@ void Worker::SetState(WorkerState new_state) { } } +void Worker::initDPDK(int argc, char **argv) { + unsigned mbuf_quantity_in_pool = 8192; + unsigned cache_size_per_kernel = 250; + uint16_t priv_size = 0; + + int ret = rte_eal_init(argc, argv); + if (ret < 0) { + throw std::runtime_error("EAL init failed"); + } + + mbuf_pool = rte_pktmbuf_pool_create( + "POOL", mbuf_quantity_in_pool, cache_size_per_kernel, priv_size, + RTE_MBUF_DEFAULT_BUF_SIZE, rte_socket_id()); + if (!mbuf_pool) { + throw std::runtime_error("Failed to create mbuf pool"); + } + const char *iface_in = getenv("DPDK_PORT_IN"); + const char *iface_out = getenv("DPDK_PORT_OUT"); + + if (!iface_in || !iface_out) { + throw std::runtime_error("DPDK_PORT_IN and DPDK_PORT_OUT must be set"); + } + + port_in = init_struct_af_xdp_port(iface_in, mbuf_pool); + port_out = init_struct_af_xdp_port(iface_out, mbuf_pool); + port_exception = init_struct_tap_port("tap0", mbuf_pool); + + if (net_port_init(port_in) || net_port_init(port_out) || + net_port_init(port_exception)) { + throw std::runtime_error("Init ports"); + } + + if (net_port_start(port_in->port_id) || net_port_start(port_out->port_id) || + net_port_start(port_exception->port_id)) { + throw std::runtime_error("Start ports"); + } + + spdlog::info("DPDK initialized: in_port={}, out_port={}", port_in->port_id, + port_out->port_id); +} + +void Worker::forward_to_out(struct net_port *incoming_port, + struct net_port *outgoing_port, + uint16_t queue_number) { + struct rte_mbuf *tap_pkts[32]; + uint16_t nb_tap = + rte_eth_rx_burst(incoming_port->port_id, queue_number, tap_pkts, 32); + for (int i = 0; i < nb_tap; i++) { + int ret = + rte_eth_tx_burst(outgoing_port->port_id, queue_number, &tap_pkts[i], 1); + if (ret < 1) { + spdlog::warn("Failed to send packet"); + // PLUG (to be added later) - need to add processing for this case + rte_pktmbuf_free(tap_pkts[i]); + } + } +} + void Worker::requestPolicyFromController() { try { spdlog::info("Worker {} requests policy", worker_id); @@ -40,23 +121,97 @@ void Worker::requestPolicyFromController() { } switch (resp.result()) { - case GetPolicyResponse::POLICY_PROVIDED: + case GetPolicyResponse::POLICY_PROVIDED: { spdlog::info("Policy received"); + const auto &pol = resp.policy(); current_config_version = resp.policy().config_version(); + std::lock_guard lock(policy_mutex); + memset(¤t_policy, 0, sizeof(current_policy)); + + int block_cat_count = pol.block_categories_size(); + for (int i = 0; i < block_cat_count; ++i) { + strncpy(current_policy.locked_categories[i], + pol.block_categories(i).c_str(), CATEGORY_MAX_LEN - 1); + current_policy.locked_categories[i][CATEGORY_MAX_LEN - 1] = '\0'; + } + + int idx = 0; + for (const auto &[category, min_trust] : pol.block_by_trust()) { + if (idx >= MAX_CATEGORIES_BY_TRUST_LVL) + break; + + strncpy( + current_policy.categories_with_lvl[idx].locked_by_trust_category, + category.c_str(), CATEGORY_MAX_LEN - 1); + current_policy.categories_with_lvl[idx] + .locked_by_trust_category[CATEGORY_MAX_LEN - 1] = '\0'; + + current_policy.categories_with_lvl[idx].trust_lvl = min_trust; + + idx++; + } + + int block_dom_count = pol.block_domains_size(); + for (int i = 0; i < block_dom_count; ++i) { + strncpy(current_policy.block_domains[i], pol.block_domains(i).c_str(), + DOMAIN_MAX_LEN - 1); + current_policy.block_domains[i][DOMAIN_MAX_LEN - 1] = '\0'; + } + + int allow_dom_count = pol.allow_domains_size(); + for (int i = 0; i < allow_dom_count; ++i) { + strncpy(current_policy.allow_domains[i], pol.allow_domains(i).c_str(), + DOMAIN_MAX_LEN - 1); + current_policy.allow_domains[i][DOMAIN_MAX_LEN - 1] = '\0'; + } + + current_policy.min_trust_level = pol.min_trust_level(); + + current_config_version = pol.config_version(); + + spdlog::info("POLICY LOADED"); + spdlog::info("Config version: {}", current_config_version); + spdlog::info("Min trust level: {}", current_policy.min_trust_level); + + spdlog::info("Blocked categories ({} total)", block_cat_count); + for (int i = 0; i < block_cat_count && i < MAX_CATEGORIES; ++i) { + if (strlen(current_policy.locked_categories[i]) > 0) { + spdlog::info("blocked_categories: {}", + current_policy.locked_categories[i]); + } + } + + spdlog::info("Blocked domains ({} total)", block_dom_count); + for (int i = 0; i < block_dom_count && i < MAX_DOMAINS; ++i) { + if (strlen(current_policy.block_domains[i]) > 0) { + spdlog::info("block_domains: {}", current_policy.block_domains[i]); + } + } + + spdlog::info("Allowed domains ({} total)", allow_dom_count); + for (int i = 0; i < allow_dom_count && i < MAX_DOMAINS; ++i) { + if (strlen(current_policy.allow_domains[i]) > 0) { + spdlog::info("allow_domains: {}", current_policy.allow_domains[i]); + } + } break; - case GetPolicyResponse::POLICY_UNCHANGED: + } + case GetPolicyResponse::POLICY_UNCHANGED: { spdlog::info("Policy unchanged"); break; - default: + } + default: { spdlog::error("Unknown response result"); } + } } catch (const std::exception &e) { spdlog::error("requestPolicyFromController exception: {}", e.what()); } } -void Worker::classifyDomain(const std::string &domain) { +bool Worker::classifyDomain(const std::string &domain, + struct requested_classification *out_req) { try { spdlog::info("Worker {} classifying domain '{}'", worker_id, domain); @@ -70,16 +225,29 @@ void Worker::classifyDomain(const std::string &domain) { auto status = stub_->Classify(&context, req, &resp); if (!status.ok()) { spdlog::error("Classify failed: " + status.error_message()); - return; + return false; } - std::string cat = - resp.categories_size() > 0 ? resp.categories(0) : "unknown"; - spdlog::info("Domain '{}' classified as category '{}' with trust level {}", - domain, cat, resp.trust_level()); - + std::string categories_str; + for (int i = 0; i < resp.categories_size(); ++i) { + if (i > 0) + categories_str += ", "; + categories_str += resp.categories(i); + } + spdlog::info( + "Domain '{}' classified as categories [{}] with trust level {}", domain, + categories_str, resp.trust_level()); + + out_req->get_trust_level = resp.trust_level(); + int cat_count = std::min(resp.categories_size(), MAX_CATEGORIES); + for (int i = 0; i < cat_count; ++i) { + strncpy(out_req->get_categories[i], resp.categories(i).c_str(), + CATEGORY_MAX_LEN - 1); + } + return true; } catch (const std::exception &e) { spdlog::error(std::string("classifyDomain: ") + e.what()); + return false; } } @@ -108,7 +276,7 @@ void Worker::statsReport() { } Worker::Worker(uint64_t id) : worker_id(id), state(WorkerState::FREE) { - + g_worker = this; std::string controller_addr = "localhost:50051"; if (const char *env_addr = getenv("CONTROLLER_GRPC_ADDR")) { controller_addr = env_addr; @@ -117,6 +285,9 @@ Worker::Worker(uint64_t id) : worker_id(id), state(WorkerState::FREE) { grpc::CreateChannel(controller_addr, grpc::InsecureChannelCredentials()); stub_ = DataService::NewStub(channel); spdlog::info("gRPC channel created to {}", controller_addr); + signal(SIGINT, signal_handler); + signal(SIGTERM, signal_handler); + spdlog::info("Signal handlers registered"); srand(time(nullptr)); SetState(WorkerState::FREE); @@ -124,8 +295,18 @@ Worker::Worker(uint64_t id) : worker_id(id), state(WorkerState::FREE) { } Worker::~Worker() { - SetState(WorkerState::SHUTTING_DOWN); spdlog::info("Worker {} shutting down", worker_id); + + if (port_in && port_out) { + net_port_close(port_in); + net_port_close(port_out); + net_port_close(port_exception); + + net_port_destroy(port_in); + net_port_destroy(port_out); + net_port_destroy(port_exception); + spdlog::info("DPDK ports closed"); + } } void Worker::MainLoop() { @@ -134,7 +315,15 @@ void Worker::MainLoop() { last_policy_time = steady_clock::now(); last_stats_time = steady_clock::now(); - while (GetState() != WorkerState::SHUTTING_DOWN) { + struct rte_mbuf *pkts[32]; + uint16_t nb_pkts = 32; + uint16_t queue_number = 0; + while (!stop_flag && GetState() != WorkerState::SHUTTING_DOWN) { + forward_to_out(port_exception, port_in, queue_number); + pakage_processing(port_in, port_out, port_exception, queue_number, nb_pkts, + pkts, ¤t_policy); + forward_to_out(port_out, port_in, queue_number); + auto now = steady_clock::now(); int64_t seconds_since_stats = (now - last_stats_time) / 1s; @@ -155,4 +344,8 @@ void Worker::MainLoop() { std::this_thread::sleep_for(milliseconds(100)); } + + if (stop_flag) { + SetState(WorkerState::SHUTTING_DOWN); + } }