-
Notifications
You must be signed in to change notification settings - Fork 68
Closed
Description
Hi, I tried many of your example codes and they all work very well. And i am trying to simulate the wind flow of a mountainous area, having asked and answered a lot with google ai, I somehow managed to work out some code as follows. But the simulation always collapsed after 150 steps or more, how can I make my simulation run better?
And for the initial steps, i observed that the wind speed seemed to start at the two corners, google ai told me this is abnormal, how should i get rid of this?
this image shows the 40th step results, here you could see the two corner situations:

the data_dem.txt used is a follows:
the following is the code:
[###########################################################]
import os
print(os.environ.get("XLA_PYTHON_CLIENT_PREALLOCATE"))
f_resize_ratio = 0.8
import xlb
from xlb.compute_backend import ComputeBackend
from xlb.precision_policy import PrecisionPolicy
from xlb.grid import grid_factory
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.operator.boundary_condition import (
FullwayBounceBackBC,
HalfwayBounceBackBC,
RegularizedBC,
ExtrapolationOutflowBC,
)
from xlb.operator.macroscopic import Macroscopic
from xlb.utils import save_image
import warp as wp
import numpy as np
import jax.numpy as jnp
import jax
import time
from xlb.distribute import distribute
#############################################################################
#############################################################################
# from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
# from jax.experimental import mesh_utils
# import jax
# 1. auto detect all gpus
devices = jax.devices()
num_devices = len(devices)
print(f"THERE ARE {num_devices} GPUS")
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".50"
n_divide = num_devices
###############################################################################
###############################################################################
from PIL import Image
import cv2
from haomai_utils import *
################### read the dem data
str_path_data_dem = 'simu_turbines/data_dem.txt'
arr_data_dem = np.loadtxt(str_path_data_dem)[:,:3]
arr_data_dem0 = np.min(arr_data_dem[:,0])
arr_data_dem1 = np.min(arr_data_dem[:,1])
arr_data_dem2 = np.min(arr_data_dem[:,2])
arr_data_dem_zero = arr_data_dem - (arr_data_dem0,arr_data_dem1,arr_data_dem2)
range_x = int(np.max(arr_data_dem_zero[:,0]) - np.min(arr_data_dem_zero[:,0])) + 1
range_y = int(np.max(arr_data_dem_zero[:,1]) - np.min(arr_data_dem_zero[:,1])) + 1
range_z = int(np.max(arr_data_dem_zero[:,2]) - np.min(arr_data_dem_zero[:,2])) + 1
def get_terrain_wind_profile(u_ref,
z_ref = 100.0, # ref height
z0 = 0.03, # coarse ratio
):
inv_log_z_ref = 1.0 / jnp.log(z_ref / z0)
def profile(*args):
# 1. for the case of empty args
if not args:
return jnp.zeros(3)
# 2. get height namely z axis value
z = args[-1]
# 3. log of wind speed
z_safe = jnp.where(z > z0 + 0.1, z, z0 + 0.1)
u_mag = u_ref * (jnp.log(z_safe / z0) * inv_log_z_ref)
# 4. m/s to lattice unit
u_lbm = u_mag * (dt / dx)
# 5. for return
zeros = jnp.zeros_like(u_lbm)
# return
return jnp.stack([u_lbm, zeros, zeros], axis=-1)
return profile
def get_slip_top_profile(u_ref,
):
def profile(*args):
u_lbm = u_ref * (dt / dx)
zeros = jnp.zeros_like(u_lbm)
if not args:
return jnp.zeros(3)
return jnp.stack([u_lbm, zeros, zeros], axis=-1)
return profile
h_grid = arr_data_dem_zero[:,2].reshape(range_y, range_x)
x_grid = arr_data_dem_zero[:,0].reshape(range_y, range_x)
############# get missing_mask
import warp as wp
range_z_use = ((range_z * 3) // n_divide + 1) * n_divide
shape = (range_x, range_y, range_z_use)
h_grid_np = h_grid.astype(np.int32)
h_grid_np = np.ascontiguousarray(h_grid_np)
print(range_x, range_y, range_z_use)
device = 'cuda'
h_grid_wp = wp.from_numpy(
h_grid_np,
dtype=wp.int32,
device=device
)
missing_mask_wp = wp.zeros(
shape=shape,
dtype=wp.int32,
device=device
)
@wp.kernel
def structural_dem_to_mask(h_grid: wp.array2d(dtype=int),
mask: wp.array3d(dtype=int)):
x,y,z = wp.tid()
if z <= h_grid[y, x]:
mask[x,y,z] = 1
else:
mask[x,y,z] = 0
wp.launch(kernel=structural_dem_to_mask, dim=(range_x, range_y, range_z_use), inputs=[h_grid_wp, missing_mask_wp])
############# get bc_boundary
bc_mask = wp.zeros(
shape=shape,
dtype=wp.int32,
device=device
)
@wp.kernel
def structural_dem_to_mask(h_grid: wp.array2d(dtype=int),
mask: wp.array3d(dtype=int)):
x,y,z = wp.tid()
if z <= h_grid[y, x]:
mask[x,y,z] = 1
else:
mask[x,y,z] = 0
# dim as the same shape with grid
wp.launch(kernel=structural_dem_to_mask, dim=(range_x, range_y, range_z_use), inputs=[h_grid_wp, bc_mask])
arr_bc_mask = bc_mask.numpy()
arr_bc_mask[0, :, :] = 2 # 2 Velocity Boundary (Inlet)
# 3. outlet
arr_bc_mask[-1, :, :] = 3 # 3 Pressure Boundary (Outlet)
# 4. sky
arr_bc_mask[:, :, -1] = 4 # 4 Slip Wall
dx = 30.0
u_lbm = 0.06
target_tau = 0.6
u_physical_max = 14.0
dt = u_lbm * dx / u_physical_max
nu_phys = 1.5e-5
nu_lbm = (target_tau - 0.5) / 3.0
omega = 1.7
grid_shape = (range_x, range_y, range_z_use)
compute_backend = ComputeBackend.JAX
precision_policy = PrecisionPolicy.FP32FP32
velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=compute_backend)
u_max = 0.04
num_steps = 10000
post_process_interval = 10
# Initialize XLB
xlb.init(
velocity_set=velocity_set,
default_backend=compute_backend,
default_precision_policy=precision_policy,
)
# Create Grid
grid = grid_factory(grid_shape,
compute_backend=compute_backend,
)
################## initialize the terrain bc
indices_raw = np.where(arr_bc_mask == 1)
indices_terrain = [axis.astype(np.int32).tolist() for axis in indices_raw]#jnp.stack([jnp.array(i) for i in indices_raw])
bc_terrain = HalfwayBounceBackBC(indices=indices_terrain)
################# initialize the inlet BC
indices_raw = np.where(arr_bc_mask == 2)
indices_inlet = [axis.astype(np.int32).tolist() for axis in indices_raw]#jnp.stack([jnp.array(i) for i in indices_raw])
cur_speed = 1.0
num_nodes_expected = len(indices_inlet[0])
print(num_nodes_expected)
bc_left = RegularizedBC('velocity', profile = get_terrain_wind_profile(cur_speed), indices=indices_inlet)
################## initialize the outlet bc
indices_raw = np.where(arr_bc_mask == 3)
indices_outlet = [axis.astype(np.int32).tolist() for axis in indices_raw]#jnp.stack([jnp.array(i) for i in indices_raw])
bc_outlet = ExtrapolationOutflowBC(indices=indices_outlet)
################## initialize the slipwall
indices_raw = np.where(arr_bc_mask == 4)
indices_slipwall = [axis.astype(np.int32).tolist() for axis in indices_raw]#jnp.stack([jnp.array(i) for i in indices_raw])
num_nodes_expected = len(indices_slipwall[0])
print(num_nodes_expected)
bc_slipwall = RegularizedBC('velocity', profile = get_slip_top_profile(cur_speed), indices=indices_slipwall)#ExtrapolationOutflowBC(indices=indices_slipwall)###HalfwayBounceBackBC(indices=indices_terrain)##
#RegularizedBC('velocity', profile = get_profile(u_ref = cur_speed), indices=indices_slipwall)
import gc
boundary_conditions = [bc_terrain, bc_left, bc_outlet, bc_slipwall]
# Setup Stepper
stepper = IncompressibleNavierStokesStepper(
grid=grid,
boundary_conditions=boundary_conditions,
collision_type='BGK'#'SmagorinskyLESBGK',#"BGK",#
)
# u_lbm_init = cur_speed * (dt / dx)
# u_init = jnp.zeros((*grid_shape, 3)).at[:, :, :, 0].set(u_lbm_init)
# stepper.initial_state(velocity=u_init, density=1.0)
stepper = distribute(
stepper,
grid,
velocity_set,
)
# release the
jax.clear_caches()
gc.collect()
f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields()
arr_bc_mask = np.array(bc_mask).squeeze(0)
import numpy as np
import matplotlib.pyplot as plt
plt.imsave('top_view_out.png', np.sum(arr_bc_mask, axis=2))
plt.imsave('side_view_out.png', arr_bc_mask[range_x // 2, :, :], origin='lower')
np.save('arr_bc_mask.npy', arr_bc_mask)
# Define Macroscopic Calculation
macro = Macroscopic(
compute_backend=ComputeBackend.JAX,
precision_policy=precision_policy,
velocity_set=xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=ComputeBackend.JAX),
)
# Post-Processing Function
def post_process(step, f_current, str_folder_out = 'data/20260109_simu'):
# Convert to JAX array if necessary
if not isinstance(f_current, jnp.ndarray):
f_current = wp.to_jax(f_current)
rho, u = macro(f_current)
# Remove boundary cells
u = u[:, 1:-1, 1:-1, 1:-1]
rho = rho[:, 1:-1, 1:-1, 1:-1]
u_magnitude = jnp.sqrt(u[0] ** 2 + u[1] ** 2 + u[2] ** 2)
fields = {
"u_magnitude": u_magnitude,
"u_x": u[0],
"u_y": u[1],
"u_z": u[2],
"rho": rho[0],
}
os.makedirs(str_folder_out, exist_ok = True)
# Save the u_magnitude slice at the mid y-plane + '/'
save_image(fields["u_magnitude"][:, :, grid_shape[2] // 2], timestep=step, prefix = str_folder_out + '/')
print(f"Post-processed step {step}: Saved u_magnitude slice at z={grid_shape[2] // 2}")
start_time = time.time()
for step in range(num_steps):
f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, step)
f_0, f_1 = f_1, f_0 # Swap the buffers
print(step)
if step % post_process_interval == 0 or step == num_steps - 1:
if compute_backend == ComputeBackend.WARP:
wp.synchronize()
post_process(step, f_0)
end_time = time.time()
elapsed = end_time - start_time
print(f"Completed step {step}. Time elapsed for {post_process_interval} steps: {elapsed:.6f} seconds.")
start_time = time.time()
Metadata
Metadata
Assignees
Labels
No labels