Note
Go to the end to download the full example code.
Growth of a 2D cell aggregate
We consider a 2D cell aggregate growing according to a basic somatic cell cycle. Starting from one cell, each cell grows at a linear speed until a target volume is reached, then it divides after a random exponential time producing two daughter cells with identical half volumes.
# sphinx_gallery_thumbnail_path = '_static/TissueGrowth_t442.png'
import os
import sys
sys.path.append("..")
import pickle
import math
import torch
import numpy as np
from matplotlib import colors
from matplotlib.colors import ListedColormap
from iceshot import cells
from iceshot import costs
from iceshot import OT
from iceshot.OT import OT_solver
from iceshot import plot_cells
from iceshot import sample
from iceshot import utils
use_cuda = torch.cuda.is_available()
if use_cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
device = "cuda"
p = 2
# ot_algo = OT.sinkhorn_zerolast
ot_algo = OT.LBFGSB
simu_name = "simu_TissueGrowth"
os.mkdir(simu_name)
os.mkdir(simu_name+"/frames")
os.mkdir(simu_name+"/data")
cmap = utils.cmap_from_list(1000,color_names=["tab:blue"])
N = 1
M = 800
Nmax = 400
vol0 = 0.5*0.75/Nmax
vol1 = 0.75/Nmax
R1 = math.sqrt(vol1/math.pi)
seeds = torch.tensor([[0.5,0.5]])
source = sample.sample_grid(M)
vol_x = torch.tensor([vol1])
simu = cells.Cells(
seeds=seeds,source=source,
vol_x=vol_x,extra_space="void",
bc=None
)
cost_params = {
"p" : p,
"scaling" : "volume",
"R" : simu.R_mean,
"C" : 0.1
}
solver = OT_solver(
n_sinkhorn=300,n_sinkhorn_last=1000,n_lloyds=4,s0=2.0,
cost_function=costs.l2_cost,cost_params=cost_params
)
T = 30.0
# T = 5.0
dt = 0.01
plot_every = 1
t = 0.0
t_iter = 0
t_plot = 0
growth_rate = (vol1-vol0)/0.5
growth_rate_factor = 0.5 + 1.5*torch.rand(simu.N_cells)
div_rate = 5.0
cap = None
def insert(x,ind,elem1,elem2):
sh = list(x.shape)
sh[0] += 1
new_x = torch.zeros(sh)
new_x[:ind] = x[:ind]
new_x[(ind+2):] = x[(ind+1):]
new_x[ind] = elem1
new_x[ind+1] = elem2
return new_x
def sample_unit(N,d):
x = torch.randn((N,d))
x /= torch.norm(x,dim=1).reshape((N,1))
return x
def divide(simu,ind,R1):
simu.x = insert(simu.x,ind,simu.x[ind]-0.5*R1*simu.axis[ind],simu.x[ind]+0.5*R1*simu.axis[ind])
simu.axis = insert(simu.axis,ind,sample_unit(1,simu.d),sample_unit(1,simu.d))
simu.ar = insert(simu.ar,ind,1.0,1.0)
simu.orientation = simu.orientation_from_axis()
simu.N_cells += 1
simu.volumes = insert(simu.volumes,ind,0.5*simu.volumes[ind],0.5*simu.volumes[ind])
simu.f_x = insert(simu.f_x,ind,simu.f_x[ind],simu.f_x[ind])
def kill(simu,who,solver=solver,cost_matrix=None):
who_p = torch.cat((who,torch.zeros(1,dtype=bool,device=who.device)))
simu.x = simu.x[~who]
simu.f_x = simu.f_x[~who_p]
simu.volumes[-1] += simu.volumes[who_p].sum()
simu.volumes = simu.volumes[~who_p]
simu.axis = simu.axis[~who]
simu.ar = simu.ar[~who]
simu.orientation = simu.orientation[~who]
simu.N_cells -= int(who.sum().item())
simu.labels[torch.isin(simu.labels,torch.where(who)[0])] = simu.x.shape[0] + 42
exit = torch.tensor([[0.5,0.5]])
#======================= INITIALISE ========================#
solver.solve(simu,
sinkhorn_algo=ot_algo,cap=cap,
tau=0.0,
to_bary=True,
show_progress=False)
simu_plot = plot_cells.CellPlot(simu,figsize=8,cmap=cmap,
plot_pixels=True,plot_scat=True,plot_quiv=False,plot_boundary=True,
scat_size=5,scat_color='k',
r=None,K=5,boundary_color='k',
plot_type="imshow",void_color='w')
simu_plot.fig.savefig(simu_name + "/frames/" + f"t_{t_plot}.png")
with open(simu_name + "/data/" + f"data_{t_plot}.pkl",'wb') as file:
pickle.dump(simu,file)
t += dt
t_iter += 1
t_plot += 1
solver.n_lloyds = 1
solver.cost_params["p"] = p
with open(simu_name + f"/params.pkl",'wb') as file:
pickle.dump(solver,file)
#=========================== RUN ===========================#
while t<T:
print("--------------------------",flush=True)
print(f"t={t}",flush=True)
print("--------------------------",flush=True)
plotting_time = t_iter%plot_every==0
if plotting_time:
print("I plot.",flush=True)
solver.n_sinkhorn_last = 2000
solver.n_sinkhorn = 2000
solver.s0 = 2.0
else:
print("I do not plot.",flush=True)
solver.n_sinkhorn_last = 250
solver.n_sinkhorn = 250
solver.s0 = 2*simu.R_mean
simu.volumes[:-1] += growth_rate_factor * growth_rate*dt
simu.volumes[:-1] = torch.minimum(simu.volumes[:-1],torch.tensor([vol1]))
simu.volumes[-1] = 1.0 - simu.volumes[:-1].sum()
who_divide = (simu.volumes[:-1] > 0.8*vol1) & (torch.rand(simu.N_cells) > math.exp(-dt*div_rate))
for ind,who in enumerate(who_divide):
if who:
if simu.N_cells<=Nmax:
divide(simu,ind,R1)
growth_rate_factor = insert(growth_rate_factor,ind,growth_rate_factor[ind],0.5+1.5*torch.rand(1))
F_inc = solver.lloyd_step(simu,
sinkhorn_algo=ot_algo,cap=cap,
tau=1.0/torch.sqrt(simu.volumes[:-1]/math.pi),
to_bary=False,
show_progress=False,
default_init=False)
F_evacuation = (exit - simu.x)/(torch.norm(exit - simu.x,dim=1).reshape((simu.N_cells,1)) + 1e-6)
simu.x += F_inc*dt + 0.2*F_evacuation*dt
try:
cov = simu.covariance_matrix()
cov /= torch.sqrt(torch.det(cov).reshape((simu.N_cells,1,1)))
L,Q = torch.linalg.eigh(cov)
axis = Q[:,:,-1]
axis = (axis * simu.axis).sum(1).sign().reshape((simu.N_cells,1)) * axis
simu.axis = axis
simu.orientation = simu.orientation_from_axis()
except:
pass
print(f"Maximal incompressibility force: {torch.max(torch.norm(F_inc,dim=1))}")
if plotting_time:
simu_plot.update_plot(simu)
simu_plot.fig.savefig(simu_name + "/frames/" + f"t_{t_plot}.png")
with open(simu_name + "/data/" + f"data_{t_plot}.pkl",'wb') as file:
pickle.dump(simu,file)
t_plot += 1
t += dt
t_iter += 1
utils.make_video(simu_name=simu_name,video_name=simu_name)