Skip to content

The simulation always collapsed after step=150 or more, can anyone help? #153

@fred2019-dev

Description

@fred2019-dev

Hi, I tried many of your example codes and they all work very well. And i am trying to simulate the wind flow of a mountainous area, having asked and answered a lot with google ai, I somehow managed to work out some code as follows. But the simulation always collapsed after 150 steps or more, how can I make my simulation run better?

And for the initial steps, i observed that the wind speed seemed to start at the two corners, google ai told me this is abnormal, how should i get rid of this?
this image shows the 40th step results, here you could see the two corner situations:
Image

the data_dem.txt used is a follows:

data_dem.txt

the following is the code:

[###########################################################]

import os
print(os.environ.get("XLA_PYTHON_CLIENT_PREALLOCATE"))


f_resize_ratio = 0.8

import xlb
from xlb.compute_backend import ComputeBackend
from xlb.precision_policy import PrecisionPolicy
from xlb.grid import grid_factory
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.operator.boundary_condition import (
    FullwayBounceBackBC,
    HalfwayBounceBackBC,
    RegularizedBC,
    ExtrapolationOutflowBC,

)
from xlb.operator.macroscopic import Macroscopic
from xlb.utils import save_image
import warp as wp
import numpy as np
import jax.numpy as jnp
import jax
import time
from xlb.distribute import distribute
#############################################################################
#############################################################################
# from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
# from jax.experimental import mesh_utils
# import jax

# 1. auto detect all gpus
devices = jax.devices()
num_devices = len(devices)
print(f"THERE ARE {num_devices} GPUS")
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".50" 
n_divide = num_devices

###############################################################################
###############################################################################

from PIL import Image
import cv2
from haomai_utils import *



################### read the dem data
str_path_data_dem = 'simu_turbines/data_dem.txt'
arr_data_dem = np.loadtxt(str_path_data_dem)[:,:3]

arr_data_dem0 = np.min(arr_data_dem[:,0])
arr_data_dem1 = np.min(arr_data_dem[:,1])
arr_data_dem2 = np.min(arr_data_dem[:,2])
arr_data_dem_zero = arr_data_dem - (arr_data_dem0,arr_data_dem1,arr_data_dem2)

range_x = int(np.max(arr_data_dem_zero[:,0]) - np.min(arr_data_dem_zero[:,0])) + 1
range_y = int(np.max(arr_data_dem_zero[:,1]) - np.min(arr_data_dem_zero[:,1])) + 1
range_z = int(np.max(arr_data_dem_zero[:,2]) - np.min(arr_data_dem_zero[:,2])) + 1

def get_terrain_wind_profile(u_ref, 
                             z_ref = 100.0, # ref height
                             z0 = 0.03,     # coarse ratio
                            ):
    inv_log_z_ref = 1.0 / jnp.log(z_ref / z0)
    def profile(*args):
        # 1. for the case of empty args
        if not args:
            return jnp.zeros(3)
        
        # 2. get height namely z axis value
        z = args[-1]
        
        # 3. log of wind speed
        z_safe = jnp.where(z > z0 + 0.1, z, z0 + 0.1)
        u_mag = u_ref * (jnp.log(z_safe / z0) * inv_log_z_ref)
        
        # 4. m/s to lattice unit
        u_lbm = u_mag * (dt / dx)
        
        # 5. for return
        zeros = jnp.zeros_like(u_lbm)
        
        # return
        return jnp.stack([u_lbm, zeros, zeros], axis=-1)
    return profile

def get_slip_top_profile(u_ref, 
                        ):

    def profile(*args):
        u_lbm = u_ref * (dt / dx)
        zeros = jnp.zeros_like(u_lbm)
        
        if not args:
            return jnp.zeros(3)
        return jnp.stack([u_lbm, zeros, zeros], axis=-1)
    return profile
    
h_grid = arr_data_dem_zero[:,2].reshape(range_y, range_x)
x_grid = arr_data_dem_zero[:,0].reshape(range_y, range_x)

############# get missing_mask
import warp as wp
range_z_use = ((range_z * 3) // n_divide + 1) * n_divide

shape = (range_x, range_y, range_z_use)
h_grid_np = h_grid.astype(np.int32)
h_grid_np = np.ascontiguousarray(h_grid_np)
print(range_x, range_y, range_z_use)

device = 'cuda'
h_grid_wp = wp.from_numpy(
    h_grid_np, 
    dtype=wp.int32, 
    device=device
)

missing_mask_wp = wp.zeros(
    shape=shape, 
    dtype=wp.int32, 
    device=device
)

@wp.kernel
def structural_dem_to_mask(h_grid: wp.array2d(dtype=int), 
                           mask: wp.array3d(dtype=int)):
    x,y,z = wp.tid()
   
    if z <= h_grid[y, x]:
        mask[x,y,z] = 1  
    else:
        mask[x,y,z] = 0  
wp.launch(kernel=structural_dem_to_mask, dim=(range_x, range_y, range_z_use), inputs=[h_grid_wp, missing_mask_wp])

############# get bc_boundary
bc_mask = wp.zeros(
    shape=shape, 
    dtype=wp.int32, 
    device=device
)

@wp.kernel
def structural_dem_to_mask(h_grid: wp.array2d(dtype=int), 
                           mask: wp.array3d(dtype=int)):
    x,y,z = wp.tid()
    
    if z <= h_grid[y, x]:
        mask[x,y,z] = 1  
    else:
        mask[x,y,z] = 0  

# dim as the same shape with grid
wp.launch(kernel=structural_dem_to_mask, dim=(range_x, range_y, range_z_use), inputs=[h_grid_wp, bc_mask])
arr_bc_mask = bc_mask.numpy()

arr_bc_mask[0, :, :] = 2  # 2 Velocity Boundary (Inlet)

# 3. outlet
arr_bc_mask[-1, :, :] = 3 # 3 Pressure Boundary (Outlet)

# 4. sky
arr_bc_mask[:, :, -1] = 4 # 4 Slip Wall



dx = 30.0
u_lbm = 0.06
target_tau = 0.6
u_physical_max = 14.0
dt = u_lbm * dx / u_physical_max 
nu_phys = 1.5e-5
nu_lbm = (target_tau - 0.5) / 3.0

omega = 1.7
grid_shape = (range_x, range_y, range_z_use)
compute_backend = ComputeBackend.JAX
precision_policy = PrecisionPolicy.FP32FP32
velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=compute_backend)
u_max = 0.04
num_steps = 10000
post_process_interval = 10



# Initialize XLB
xlb.init(
    velocity_set=velocity_set,
    default_backend=compute_backend,
    default_precision_policy=precision_policy,
)

# Create Grid
grid = grid_factory(grid_shape, 
                    compute_backend=compute_backend,

                   )

################## initialize the terrain bc
indices_raw = np.where(arr_bc_mask == 1)
indices_terrain = [axis.astype(np.int32).tolist() for axis in indices_raw]#jnp.stack([jnp.array(i) for i in indices_raw])
bc_terrain = HalfwayBounceBackBC(indices=indices_terrain)

################# initialize the inlet BC
indices_raw = np.where(arr_bc_mask == 2)
indices_inlet = [axis.astype(np.int32).tolist() for axis in indices_raw]#jnp.stack([jnp.array(i) for i in indices_raw])

cur_speed = 1.0 

num_nodes_expected = len(indices_inlet[0])
print(num_nodes_expected)
bc_left = RegularizedBC('velocity', profile = get_terrain_wind_profile(cur_speed),  indices=indices_inlet)

################## initialize the outlet bc 
indices_raw = np.where(arr_bc_mask == 3)
indices_outlet = [axis.astype(np.int32).tolist() for axis in indices_raw]#jnp.stack([jnp.array(i) for i in indices_raw])
bc_outlet = ExtrapolationOutflowBC(indices=indices_outlet)


################## initialize the slipwall
indices_raw = np.where(arr_bc_mask == 4)
indices_slipwall = [axis.astype(np.int32).tolist() for axis in indices_raw]#jnp.stack([jnp.array(i) for i in indices_raw])
num_nodes_expected = len(indices_slipwall[0])
print(num_nodes_expected)
bc_slipwall = RegularizedBC('velocity', profile = get_slip_top_profile(cur_speed),  indices=indices_slipwall)#ExtrapolationOutflowBC(indices=indices_slipwall)###HalfwayBounceBackBC(indices=indices_terrain)##
#RegularizedBC('velocity', profile = get_profile(u_ref = cur_speed),  indices=indices_slipwall)

import gc
boundary_conditions = [bc_terrain, bc_left, bc_outlet, bc_slipwall]


# Setup Stepper
stepper = IncompressibleNavierStokesStepper(
    grid=grid,
    boundary_conditions=boundary_conditions,
    collision_type='BGK'#'SmagorinskyLESBGK',#"BGK",#


)
# u_lbm_init = cur_speed * (dt / dx)
# u_init = jnp.zeros((*grid_shape, 3)).at[:, :, :, 0].set(u_lbm_init)
# stepper.initial_state(velocity=u_init, density=1.0)

stepper = distribute(
            stepper,
            grid,
            velocity_set,
        )

# release the 
jax.clear_caches() 
gc.collect()

f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields()
arr_bc_mask = np.array(bc_mask).squeeze(0)
import numpy as np
import matplotlib.pyplot as plt

plt.imsave('top_view_out.png', np.sum(arr_bc_mask, axis=2)) 
plt.imsave('side_view_out.png', arr_bc_mask[range_x // 2, :, :], origin='lower')
np.save('arr_bc_mask.npy', arr_bc_mask)


# Define Macroscopic Calculation
macro = Macroscopic(
    compute_backend=ComputeBackend.JAX,
    precision_policy=precision_policy,
    velocity_set=xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=ComputeBackend.JAX),
)

# Post-Processing Function
def post_process(step, f_current, str_folder_out = 'data/20260109_simu'):
    # Convert to JAX array if necessary
    if not isinstance(f_current, jnp.ndarray):
        f_current = wp.to_jax(f_current)

    rho, u = macro(f_current)

    # Remove boundary cells
    u = u[:, 1:-1, 1:-1, 1:-1]
    rho = rho[:, 1:-1, 1:-1, 1:-1]
    u_magnitude = jnp.sqrt(u[0] ** 2 + u[1] ** 2 + u[2] ** 2)

    fields = {
        "u_magnitude": u_magnitude,
        "u_x": u[0],
        "u_y": u[1],
        "u_z": u[2],
        "rho": rho[0],
    }

    os.makedirs(str_folder_out, exist_ok = True)

    # Save the u_magnitude slice at the mid y-plane + '/'
    save_image(fields["u_magnitude"][:, :, grid_shape[2] // 2], timestep=step, prefix = str_folder_out + '/')
    print(f"Post-processed step {step}: Saved u_magnitude slice at z={grid_shape[2] // 2}")

start_time = time.time()
for step in range(num_steps):
    f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, step)
    f_0, f_1 = f_1, f_0  # Swap the buffers
    print(step)

    if step % post_process_interval == 0 or step == num_steps - 1:
        if compute_backend == ComputeBackend.WARP:
            wp.synchronize()
        post_process(step, f_0)
        end_time = time.time()
        elapsed = end_time - start_time
        print(f"Completed step {step}. Time elapsed for {post_process_interval} steps: {elapsed:.6f} seconds.")
        start_time = time.time()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions