Note
Go to the end to download the full example code.
Cell sorting in 3D
This script simulates 3D cell aggregates with the following weighted \(L^2\) cost and force for the cell \(i\)
where \(\gamma_{ij}\) and \(\eta_{ij}\) are surface tension parameters, \(R_i\) is the radius of cell \(i\), \(\Gamma_{ij}\) is the interface between cells \(i\) and \(j\), \(\kappa\) is the local mean curvature and \(\vec{n}\) is the inward normal of cell \(\mathscr{L}_i\).
In this script, there are two cell types (indexed by \(b\) and \(o\)) and the surface tension parameters only depend on the type. Varying them lead to various cell sorting phenomena which can be classified according to the following ratios
Representative situations are obtained as follow:
Separation : \(\overline{\eta} = 3, \quad \overline{\gamma} = 2,\quad \overline{k} = 1\)
Checkerboard : \(\overline{\eta} = 0.3, \quad \overline{\gamma} = 2,\quad \overline{k} = 1\)
Internalization : \(\overline{\eta} = 3, \quad \overline{\gamma} = 2,\quad \overline{k}\overline{\gamma}\overline{\eta} = 1\)
Engulfment with initial segregation : :math:` overline{eta} = 3, quad overline{gamma} = 2,quad overline{k}overline{gamma}overline{eta} = 1`
Note: this script only save the mesh data, that can then be loaded in VTK, PyVista, Paraview etc.
# sphinx_gallery_thumbnail_path = '_static/3Dengulfment.png'
import os
import sys
sys.path.append("..")
import time
import pickle
import math
import torch
import numpy as np
from matplotlib import colors
from matplotlib.colors import ListedColormap
from iceshot import cells
from iceshot import costs
from iceshot import OT
from iceshot.OT import OT_solver
from iceshot import plot_cells
from iceshot import sample
from iceshot import utils
import tifffile as tif
import pyvista as pv
import vtk as vtk
from pyvista.core import _vtk_core as _vtk
from pyvista.core.filters import _get_output, _update_alg
from typing import Literal, Optional, cast
from pyvista.core.utilities.arrays import FieldAssociation, set_default_active_scalars
pv.start_xvfb()
pv.set_jupyter_backend('static')
use_cuda = torch.cuda.is_available()
if use_cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
device = "cuda"
def run_simu(params,title=None,init="uniform"):
ot_algo = OT.LBFGSB
N = 128
N1 = 64
# N = 96
# N1 = 24
M = 256
dim = 3
R0 = 1.0
R00 = 0.055
RN = torch.ones(N)*R00
vol_x = 4./3.*math.pi*RN**3
if init == "uniform":
seeds = 0.5 + 0.14*2.*(torch.rand((N,dim))-0.5)
elif init == "split":
seeds1 = 0.4 + 0.1*2.*(torch.rand((N1,dim))-0.5)
seeds2 = 0.6 + 0.1*2.*(torch.rand((N-N1,dim))-0.5)
seeds = torch.cat((seeds1,seeds2),dim=0)
#================ SURFACE TENSION PARAMETERS ==================#
gb = params["gb"]
g12 = params["g12"]
g11 = params["g11"]
g22 = params["g22"]
g10 = params["g10"]
g20 = params["g20"]
b12 = params["b12"]
b11 = params["b11"]
b22 = params["b22"]
print("===============================================================")
print("Surface Tension Parameters",flush=True)
print(f"g12={g12}",flush=True)
print(f"g11={g11}",flush=True)
print(f"g22={g22}",flush=True)
print(f"g10={g10}",flush=True)
print(f"g20={g20}",flush=True)
print(f"b12={b12}",flush=True)
print(f"b11={b11}",flush=True)
print(f"b22={b22}",flush=True)
print("Compaction Tension Parameters",flush=True)
print(f"k10={0.5*b11/g10}",flush=True)
print(f"k20={0.5*b22/g20}",flush=True)
print(f"k12={0.5*b11/g12}",flush=True)
r_b1 = params["r_b1"]
r_g1 = params["r_g1"]
r_k1 = params["r_k1"]
print("Ratios",flush=True)
print(f"r_b1 = {r_b1}")
print(f"r_g1 = {r_g1}")
print(f"r_k1 = {r_k1}")
if title is None:
simu_name = f"simu_3Dsorting_b_{r_b1}_g_{r_g1}_k_{r_k1}"
else:
simu_name = f"simu_3Dsorting_" + title
os.mkdir(simu_name)
os.mkdir(simu_name+"/frames")
os.mkdir(simu_name+"/data")
print("===============================================================")
#===============================================#
tau = 0.0
sc = R0/RN
sc[N1:] *= g20/g10
source = sample.sample_grid(M,dim=dim)
simu = cells.Cells(
seeds=seeds,source=source,
vol_x=vol_x,extra_space="void",jct_method='Kmin'
)
cost_params = {
"p" : 2,
"scaling" : "constant",
"C" : 1.0
}
solver = OT_solver(
n_sinkhorn=800,n_sinkhorn_last=2000,n_lloyds=10,s0=2.0,
cost_function=costs.power_cost,cost_params=cost_params
)
T = 4.0
dt = 0.0003 # This is too small!
plot_every = 20
save_every = 1
t = 0.0
t_iter = 0
t_plot = 0
#===========================================================#
def radius(simu):
return torch.sqrt(simu.volumes[:-1]/math.pi) if simu.d==2 else (simu.volumes[:-1]/(4./3.*math.pi)) ** (1./3.)
def compute_mesh(img):
img = np.pad(img,1,mode='constant',constant_values=-2.0)
vol = pv.wrap(img)
alg = vtk.vtkSurfaceNets3D()
set_default_active_scalars(vol) # type: ignore
field, scalars = vol.active_scalars_info # type: ignore
# args: (idx, port, connection, field, name)
alg.SetInputArrayToProcess(0, 0, 0, field.value, scalars)
alg.SetInputData(vol)
alg.GenerateValues(simu.N_cells, 0, simu.N_cells-1)
# Suppress improperly used INFO for debugging messages in vtkSurfaceNets3D
verbosity = _vtk.vtkLogger.GetCurrentVerbosityCutoff()
_vtk.vtkLogger.SetStderrVerbosity(_vtk.vtkLogger.VERBOSITY_OFF)
_update_alg(alg, False, 'Performing Labeled Surface Extraction')
# Restore the original vtkLogger verbosity level
_vtk.vtkLogger.SetStderrVerbosity(verbosity)
surfaces = cast(pv.PolyData, pv.wrap(alg.GetOutput()))
surfaces = surfaces.smooth_taubin(n_iter=100, pass_band=0.05, normalize_coordinates=True)
surfaces = surfaces.compute_normals(consistent_normals=True,
auto_orient_normals=True,
flip_normals=True,
non_manifold_traversal=False)
surfaces = surfaces.compute_cell_sizes()
surfaces["Curvature"] = surfaces.curvature()
surfaces["Particle"] = surfaces["BoundaryLabels"].min(axis=1)*(surfaces["BoundaryLabels"].min(axis=1)>=0) + surfaces["BoundaryLabels"].max(axis=1)*(surfaces["BoundaryLabels"].min(axis=1)<0)
surfaces.set_active_scalars("Particle")
surfaces = surfaces.point_data_to_cell_data()
return surfaces
def extract_stuff(surfaces,M=M,simu=simu,eps=None):
normals = torch.tensor(surfaces["Normals"])
normals /= torch.norm(normals,dim=1)[:,None]
lab = torch.tensor(surfaces["BoundaryLabels"])
area = torch.tensor(surfaces["Area"])/((M+2)**2)
curv = torch.tensor(surfaces["Curvature"])*(M+2)
centers = torch.tensor((surfaces.cell_centers().points - 1.0/(M+2))/M)
return good_stuff(simu,(normals, lab, area, curv, centers),eps=eps)
def good_stuff(simu,stuff,eps=None):
return reorient_normals(simu,stuff,eps=eps)
def belongs_to(simu,x):
M = round(simu.M_grid ** (1/simu.d))
ijk = torch.floor(x*M).type(torch.long)
ijk = torch.clamp(ijk,0,M-1)
lab = ijk[:,0]*M**2 + ijk[:,1]*M + ijk[:,2]
labels = simu.labels[lab]
labels[labels > simu.N_cells-1] = -1.0
return labels
def reorient_normals(simu,stuff,eps=None):
# normals should go from lab[:,0] to lab[:,1]
normals, lab, area, curv, centers = stuff
if eps is None:
eps = 3.0/((simu.M_grid)**(1./simu.d))
test_m = centers - eps*normals
test_p = centers + eps*normals
lab_test_m = belongs_to(simu,test_m)
lab_test_p = belongs_to(simu,test_p)
tm_fst = lab_test_m == lab[:,0]
tm_scd = lab_test_m == lab[:,1]
tp_fst = lab_test_p == lab[:,0]
tp_scd = lab_test_p == lab[:,1]
tm_out = ((test_m.max(dim=1).values>1) | (test_m.min(dim=1).values<0))
tp_out = ((test_p.max(dim=1).values>1) | (test_p.min(dim=1).values<0))
out = (tm_out & tp_scd) | (tm_out & tp_fst) | (tm_fst & tp_out) | (tm_scd & tp_out)
good = (((tm_fst) & (tp_scd)) | ((tm_scd) & (tp_fst)) | out)
to_reorient = ((tm_scd) & (tp_fst)) | (tp_out)
normals[to_reorient,:] *= -1
curv[to_reorient] *= -1
return normals[good], lab[good], area[good], curv[good], centers[good]
def compute_forces(simu,normals,lab,area,curv,centers):
N = len(simu.x)
F = torch.zeros_like(simu.x)
g_ij = torch.zeros(len(lab))
g_ij[(lab[:,0]<N1) & (lab[:,1]>=N1)] = g12
g_ij[(lab[:,1]<N1) & (lab[:,0]>=N1)] = g12
g_ij[(lab[:,0]<N1) & (lab[:,1]<N1)] = g11
g_ij[(lab[:,0]>=N1) & (lab[:,1]>=N1)] = g22
g_ij[(lab[:,0]==-1) & (lab[:,1]<N1)] = g10
g_ij[(lab[:,0]==-1) & (lab[:,1]>=N1)] = g20
g_ij[(lab[:,1]==-1) & (lab[:,0]<N1)] = g10
g_ij[(lab[:,1]==-1) & (lab[:,0]>=N1)] = g20
g_ij[(lab[:,0]==-2) | (lab[:,1]==-2)] = gb
b_ij = torch.zeros(len(lab))
b_ij[(lab[:,0]<N1) & (lab[:,1]>=N1)] = b12
b_ij[(lab[:,1]<N1) & (lab[:,0]>=N1)] = b12
b_ij[(lab[:,0]<N1) & (lab[:,1]<N1)] = b11
b_ij[(lab[:,0]>=N1) & (lab[:,1]>=N1)] = b22
b_ij[(lab[:,0]==-1) & (lab[:,1]<N1)] = 0
b_ij[(lab[:,0]==-1) & (lab[:,1]>=N1)] = 0
b_ij[(lab[:,1]==-1) & (lab[:,0]<N1)] = 0
b_ij[(lab[:,1]==-1) & (lab[:,0]>=N1)] = 0
b_ij[(lab[:,0]==-2) | (lab[:,1]==-2)] = gb
for i in range(N):
fst = lab[:,0] == i
scd = lab[:,1] == i
# Curvature force
F_crv_fst = (-normals[fst,:]*curv[fst,None].abs()*area[fst,None]*g_ij[fst,None]).sum(0)
F_crv_scd = (normals[scd,:]*curv[scd,None].abs()*area[scd,None]*g_ij[scd,None]).sum(0)
F_crv = F_crv_fst + F_crv_scd
# Boundary force
fst_bnd = fst & (lab[:,1]==-2)
scd_bnd = scd & (lab[:,0]==-2)
F_bnd_fst = (-normals[fst_bnd,:]*area[fst_bnd,None]*g_ij[fst_bnd,None]).sum(0)
F_bnd_scd = (normals[scd_bnd,:]*area[scd_bnd,None]*g_ij[scd_bnd,None]).sum(0)
F_bnd = F_bnd_fst + F_bnd_scd
# Positional force
fst_ij = fst & (lab[:,1]>=0)
scd_ji = scd & (lab[:,0]>=0)
d_ij = torch.maximum(torch.norm(simu.x[i,:] - simu.x[lab[fst_ij,1].int(),:],dim=1),torch.tensor(0.01))
F_pos_ij = (-normals[fst_ij,:]*area[fst_ij,None]*1.0/d_ij[:,None]*b_ij[fst_ij,None]).sum(0)
d_ji = torch.maximum(torch.norm(simu.x[i,:] - simu.x[lab[scd_ji,0].int(),:],dim=1),torch.tensor(0.01))
F_pos_ji = (normals[scd_ji,:]*area[scd_ji,None]*1.0/d_ji[:,None]*b_ij[scd_ji,None]).sum(0)
F_pos = F_pos_ij + F_pos_ji
F[i,:] = F_pos + F_crv + F_bnd
return F
def neigh_list_to_cc(neigh_list):
cc = []
cc_index = torch.zeros(N,dtype=neigh_list.dtype)
seen = torch.zeros(N,dtype=bool)
def add(b,index,cc,cc_index,seen):
cc[index].append(b)
seen[b] = True
cc_index[b] = index
def merge(i,j,cc,cc_index):
cc[i] = cc[i] + cc[j]
cc.pop(j)
for index,component in enumerate(cc):
cc_index[component] = index
for edge in neigh_list:
a = edge[0].item()
b = edge[1].item()
if seen[a] & (~seen[b]):
add(b,cc_index[a],cc,cc_index,seen)
elif seen[b] & (~seen[a]):
add(a,cc_index[b],cc,cc_index,seen)
elif (~seen[a]) & (~seen[b]):
cc.append([a,b])
cc_index[a] = len(cc) - 1
cc_index[b] = len(cc) - 1
seen[a] = True
seen[b] = True
else:
if (cc_index[a] != cc_index[b]):
merge(cc_index[a],cc_index[b],cc,cc_index)
return cc, cc_index
def correction_force(F,lab):
F_correction = torch.zeros_like(F)
only_particles = (lab[:,0]>=0) & (lab[:,1]>=0)
lab_particles = lab[only_particles,:]
neigh_list = torch.unique(lab_particles,dim=0).to(device=lab.device,dtype=torch.long)
cc, _ = neigh_list_to_cc(neigh_list)
for component in cc:
F_correction[component,:] = -F[component,:].mean(dim=0)[None,:]
return F_correction
#======================= INITIALISE ========================#
solver.solve(simu,
sinkhorn_algo=ot_algo,
tau=0.4,
to_bary=True,
show_progress=False,
default_init=False,
weight=1.0,
bsr=True)
img = simu.labels.reshape(M,M,M).cpu().numpy()
img[img==img.max()] = -1.0
surfaces = compute_mesh(img)
surfaces.save(simu_name + "/data/"+f"t_{int(t_plot/save_every)}.vtk")
img = np.pad(img,1,mode='constant',constant_values=-2.0)
tif.imwrite(simu_name + "/data/"+f"t_{int(t_plot/save_every)}.tif", img, bigtiff=True)
solver.cost_params["C"] = sc
t += dt
t_iter += 1
t_plot += 1
#=========================== RUN ===========================#
while t<T:
print("--------------------------",flush=True)
print(f"t={t}",flush=True)
print("--------------------------",flush=True)
plotting_time = t_iter%plot_every==0
if plotting_time:
print("I plot.",flush=True)
solver.n_sinkhorn_last = 2000
solver.n_sinkhorn = 2000
solver.s0 = 2.0
else:
print("I do not plot.",flush=True)
solver.n_sinkhorn_last = 250
solver.n_sinkhorn = 250
solver.s0 = 2*simu.R_mean
F_inc = solver.lloyd_step(simu,
sinkhorn_algo=OT.LBFGSB,
tau=tau/(radius(simu)**(simu.d - 1)),
to_bary=False,
show_progress=False,
default_init=False,bsr=True)
img = simu.labels.reshape(M,M,M).cpu().numpy()
img[img==img.max()] = -1.0
surfaces = compute_mesh(img)
stime = time.time()
stuff = extract_stuff(surfaces)
print(f"Mesh extraction time: {time.time()-stime}",flush=True)
stime = time.time()
F_att = compute_forces(simu,*stuff)
print(f"Force computation time: {time.time()-stime}",flush=True)
stime = time.time()
F_correct = correction_force(F_att,stuff[1])
# F_correct = torch.tensor([[0.0,0.0,0.0]])
print(f"Correction force computation time: {time.time()-stime}",flush=True)
simu.x += F_att*dt + F_inc*dt + F_correct*dt
print(f"Maximal incompressibility force: {torch.max(torch.norm(F_inc,dim=1))}",flush=True)
print(f"Maximal attraction force: {torch.max(torch.norm(F_att,dim=1))}",flush=True)
print(f"Mean attraction force: {torch.mean(torch.norm(F_att,dim=1))}",flush=True)
print(f"Maximal correction force: {torch.max(torch.norm(F_correct,dim=1))}",flush=True)
print(f"Mean correction force: {torch.mean(torch.norm(F_correct,dim=1))}",flush=True)
print(f"Maximal force: {torch.max(torch.norm(F_att + F_inc + F_correct,dim=1))}",flush=True)
print(f"Mean force: {torch.mean(torch.norm(F_att + F_inc + F_correct,dim=1))}",flush=True)
if plotting_time:
if t_plot%save_every==0:
surfaces.save(simu_name + "/data/"+f"t_{int(t_plot/save_every)}.vtk")
img = np.pad(img,1,mode='constant',constant_values=-2.0)
# tif.imwrite(simu_name + "/data/"+f"t_{int(t_plot/save_every)}.tif", img, bigtiff=True)
t_plot += 1
t += dt
t_iter += 1
t_plot +=1
surfaces.save(simu_name + "/data/"+f"t_{int(t_plot/save_every)}.vtk")
img = np.pad(img,1,mode='constant',constant_values=-2.0)
tif.imwrite(simu_name + "/data/"+f"t_{int(t_plot/save_every)}.tif", img, bigtiff=True)
with open(simu_name + "/simu_final.pkl",'wb') as file:
pickle.dump(simu,file)
def ratio_to_stparams(r_b1,r_g1,r_k1,g20=10.0):
params = {
"r_b1" : r_b1,
"r_g1" : r_g1,
"r_k1" : r_k1,
"g20" : g20
}
k20 = 0.4
k12 = k20
params["gb"] = g20
params["g11"] = 0.0
params["g22"] = 0.0
k10 = r_k1 * k20
params["g10"] = r_g1 * g20
params["b11"] = 2 * k10 * r_g1 * g20
params["b22"] = 2 * k20 * g20
params["b12"] = 2 * k10 * r_g1 * r_b1 * g20
params["g12"] = k10/k12 * r_g1 * g20
return params
r_b1 = 3.0
r_g1 = 1.0
r_k1 = 1.0
run_simu(ratio_to_stparams(r_b1,r_g1,r_k1,g20=5.0),title="separation")
r_b1 = 0.3
r_g1 = 1.0
r_k1 = 1.0
run_simu(ratio_to_stparams(r_b1,r_g1,r_k1,g20=5.0),title="checkerboard")
r_b1 = 3.0
r_g1 = 2.0
r_k1 = 0.8/(r_b1*r_g1)
run_simu(ratio_to_stparams(r_b1,r_g1,r_k1,g20=5.0),title="engulfment_bgeq1")
r_b1 = 0.3
r_g1 = 2.0
r_k1 = 0.8/r_g1
run_simu(ratio_to_stparams(r_b1,r_g1,r_k1,g20=5.0),title="engulfment_bleq1")
r_b1 = 3.0
r_g1 = 2.0
r_k1 = 0.8/(r_b1*r_g1)
run_simu(ratio_to_stparams(r_b1,r_g1,r_k1,g20=5.0),title="engulfment_bgeq1_split",init="split")
r_b1 = 0.3
r_g1 = 2.0
r_k1 = 0.8/r_g1
run_simu(ratio_to_stparams(r_b1,r_g1,r_k1,g20=5.0),title="engulfment_bleq1_split",init="split")