Note
Go to the end to download the full example code.
Growth of a 3D cell aggregate
We consider a 3D cell aggregate growing according to a basic somatic cell cycle. Starting from one cell, each cell grows at a linear speed until a target volume is reached, then it divides after a random exponential time producing two daughter cells with identical half volumes.
We keep a constant resolution throughout the simulation by progressibely zooming out as the aggregate grows.
# sphinx_gallery_thumbnail_path = '_static/tissue_growth_3D.png'
import os
import sys
sys.path.append("..")
import time
import pickle
import math
import torch
import numpy as np
from matplotlib import colors
from matplotlib.colors import ListedColormap
from iceshot import cells
from iceshot import costs
from iceshot import OT
from iceshot.OT import OT_solver
from iceshot import plot_cells
from iceshot import sample
from iceshot import utils
import tifffile as tif
use_cuda = torch.cuda.is_available()
if use_cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
device = "cuda"
p = 2
ot_algo = OT.LBFGSB
simu_name = "simu_TissueGrowth_3D"
os.mkdir(simu_name)
os.mkdir(simu_name+"/frames")
os.mkdir(simu_name+"/data")
Nmax = 50000
N = 1
M = 400
R1 = 0.15
vol1 = 4./3. * math.pi * (R1**3)
vol0 = 0.5*vol1
cnt = torch.tensor([[0.5,0.5,0.5]])
vol1_end = 1.0/Nmax
R1_end = (vol1_end/(4./3.*math.pi)) ** (1./3.)
scale = R1_end/R1
seeds = torch.tensor([[0.5,0.5,0.5]])
source = sample.sample_grid(M,dim=3)
vol_x = vol1*torch.ones(N)
simu = cells.Cells(
seeds=seeds,source=source,
vol_x=vol_x,extra_space="void",
bc=None
)
cost_params = {
"p" : p,
"scaling" : "volume",
"R" : simu.R_mean,
"C" : 0.1
}
solver = OT_solver(
n_sinkhorn=300,n_sinkhorn_last=1000,n_lloyds=4,s0=2.0,
cost_function=costs.l2_cost,cost_params=cost_params
)
dt = 0.002
plot_every = 3
t = 0.0
t_iter = 0
t_plot = 0
growth_rate = 6.0*(vol1-vol0)
growth_rate_factor = 0.5 + 1.5*torch.rand(simu.N_cells)
div_rate = 3.0
cap = None
def insert(x,ind,elem1,elem2):
sh = list(x.shape)
sh[0] += 1
new_x = torch.zeros(sh)
new_x[:ind] = x[:ind]
new_x[(ind+2):] = x[(ind+1):]
new_x[ind] = elem1
new_x[ind+1] = elem2
return new_x
def sample_unit(N,d):
x = torch.randn((N,d))
x /= torch.norm(x,dim=1).reshape((N,1))
return x
def divide(simu,ind,R1):
simu.x = insert(simu.x,ind,simu.x[ind]-0.5*R1*simu.axis[ind],simu.x[ind]+0.5*R1*simu.axis[ind])
simu.axis = insert(simu.axis,ind,sample_unit(1,simu.d),sample_unit(1,simu.d))
simu.ar = insert(simu.ar,ind,1.0,1.0)
simu.orientation = simu.orientation_from_axis()
simu.N_cells += 1
simu.volumes = insert(simu.volumes,ind,0.5*simu.volumes[ind],0.5*simu.volumes[ind])
simu.f_x = insert(simu.f_x,ind,simu.f_x[ind],simu.f_x[ind])
total_vol = simu.volumes[:-1].sum().item()
data = {
"N" : [1],
"T" : [0.0],
"vol" : [total_vol],
"scale" : [scale]
}
#======================= INITIALISE ========================#
solver.solve(simu,
sinkhorn_algo=ot_algo,cap=cap,
tau=0.7,
to_bary=True,
show_progress=False,
default_init=False,
stopping_criterion="average",
tol=0.01)
t += dt
t_iter += 1
t_plot += 1
solver.n_lloyds = 1
solver.cost_params["p"] = p
#=========================== RUN ===========================#
stime = time.time()
while True:
print("--------------------------",flush=True)
print(f"t={t}",flush=True)
print(f"N={simu.N_cells}",flush=True)
print(f"V={total_vol}",flush=True)
print("--------------------------",flush=True)
plotting_time = t_iter%plot_every==0
if plotting_time:
print("I plot.",flush=True)
solver.n_sinkhorn_last = 200
solver.n_sinkhorn = 200
else:
print("I do not plot.",flush=True)
solver.n_sinkhorn_last = 200
solver.n_sinkhorn = 200
simu.volumes[:-1] += growth_rate_factor * growth_rate*dt
simu.volumes[:-1] = torch.minimum(simu.volumes[:-1],torch.tensor([vol1]))
simu.volumes[-1] = 1.0 - simu.volumes[:-1].sum()
who_divide = (simu.volumes[:-1] > 0.8*vol1) & (torch.rand(simu.N_cells) > math.exp(-dt*div_rate)) & (torch.max(torch.abs(simu.x - cnt),dim=1)[0] < 0.5 - R1)
for ind,who in enumerate(who_divide):
if who:
if simu.N_cells<=Nmax:
divide(simu,ind,R1)
growth_rate_factor = insert(growth_rate_factor,ind,growth_rate_factor[ind],0.5+1.5*torch.rand(1))
F_inc = solver.lloyd_step(simu,
sinkhorn_algo=ot_algo,cap=cap,
tau=10.0/(R1**2),
to_bary=False,
show_progress=False,
default_init=False,
stopping_criterion="average",
tol=0.01)
simu.x += F_inc*dt
print(f"Maximal incompressibility force: {torch.max(torch.norm(F_inc,dim=1))}")
simu.labels[simu.labels==torch.max(simu.labels)] = -100.0
total_vol = simu.volumes[:-1].sum().item()
R_m = 1.1 * (total_vol/(4./3.*math.pi)) ** (1./3.)
ratio = min(1.0,0.3/R_m)
print(f"RATIO={ratio}",flush=True)
new_scale = min(1.0,1/ratio * scale)
ratio = scale/new_scale
simu.x = cnt + ratio*(simu.x - cnt)
vol0 *= ratio**3
vol1 *= ratio**3
R1 *= ratio
simu.R_mean *= ratio
simu.volumes[:-1] *= ratio**3
simu.volumes[-1] = 1.0 - simu.volumes[:-1].sum()
scale = new_scale
print(f"SCALE={scale}",flush=True)
if plotting_time:
tif.imsave(simu_name + "/frames/"+f"t_{t_plot}.tif", simu.labels.reshape(M,M,M).cpu().numpy(), bigtiff=True)
t_plot += 1
data["N"].append(simu.N_cells)
data["T"].append(time.time() - stime)
data["vol"].append(total_vol)
data["scale"].append(scale)
pickle.dump(data,open(simu_name+"/data.p","wb"))
if total_vol>0.9999 and simu.N_cells>Nmax:
with open(simu_name + "/data/data_final.pkl",'wb') as file:
pickle.dump(simu,file)
break
t += dt
t_iter += 1