Commit c0fc42a5 authored by Sayah-Sel's avatar Sayah-Sel
Browse files

several epsilon using jax

parent ecfeaf43
No preview for this file type
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap,jacfwd
from jax import random
import scipy.spatial as sp
from scipy.linalg import cholesky, cho_solve
from numpy import fft as f
from tqdm import tqdm
import multiprocessing as mp
from tqdm.contrib.concurrent import process_map
from jax.ops import index, index_add, index_update
from jax.config import config
import time
import pandas as pd
config.update('jax_platform_name', 'cpu')
config.update("jax_enable_x64", True)
d=2
L=1
N=64
dx=L/N
mu=0
corr_L=["norm1","norm2"]
corr="norm2"
t1=time.time()
l_cV=0.04
l_cA=0.04
e_L=[0.0,0.01,0.1,0.2,0.3,0.5,0.7,0.8,0.9,0.99]
#e_L=[0.01,0.1,0.2,0.3,0.5,0.7,0.8,0.9,0.99]
XI=[]
for k in range(1,d):
XI.append(f.fftfreq(2*N)*2*N*2*np.pi/(2*L))
XI.append(f.rfftfreq(2*N)*(2*N)*2*np.pi/(2*L))
xi=np.array(np.meshgrid(*XI,indexing='ij'))
X= np.mgrid[tuple(slice(- L, L, dx) for _ in range(d))].T
Y=X.reshape((-1,d))
I=np.ones_like(Y.sum(axis=-1))
if corr=="norm2":
p=2
elif corr=="norm1":
p=1
AX=[-k for k in range(1,d+1)]
dist=sp.distance_matrix(Y,Y,p=p)
samples=6 #several disorder per value of epsilon to reduce the error
th=24
eps=np.identity(len(Y))*10**-12
def gamma_V_func(r,corr=corr):
if corr=="norm2":
return np.exp(-1/2*(r/l_cV)**2)
elif corr=="norm1":
return np.exp(-r/l_cV)
def gamma_A_func(r,corr=corr):
if corr=="norm2":
return np.exp(-1/2*(r/l_cA)**2)
elif corr=="norm1":
return np.exp(-r/l_cA)
def Gamma_VX(x,Y,corr=corr):
if corr=="norm2":
r=jnp.sqrt(((Y-x)**2).sum(axis=-1))
return jnp.exp(-1/2*(r/l_cV)**2)
elif corr=="norm1":
r=jnp.abs(Y-x).sum(axis=-1)
return jnp.exp(-r/l_cV)
def Gamma_AX(x,Y):
if corr=="norm2":
r=jnp.sqrt(((Y-x)**2).sum(axis=-1))
return jnp.exp(-1/2*(r/l_cA)**2)
elif corr=="norm1":
r=jnp.abs(Y-x).sum(axis=-1)
return jnp.exp(-r/l_cA)
def make_V(size=1,d=d,l_c=l_cV,xi=xi,N=N,corr=corr,eps=eps,nb=len(e_L)):
L=cholesky(gamma_V_func(dist)+eps,lower=True)
if corr=="norm2":
k2=(xi**2).sum(axis=0)
covVH=(2*N)**d*(2*np.pi*l_c**2)**(d/2)*np.exp(-l_c**2*k2/4)
elif corr=="norm1":
truc=2*l_c/(1+(l_c*xi)**2)
covVH=(2*N)**d*2*truc.prod(axis=0)
if size>1:
list_V=[]
list_KV=[]
for k in range(size):
Vh=covVH*np.random.standard_normal(covVH.shape)
V=1/2*f.irfftn(Vh)
V2=V.reshape(Y[:,0].shape)
KV=cho_solve((L,True),V2)
list_V.append(V)
list_KV.append(KV)
list_V=np.array(list_V)
list_KV=np.array(list_KV)
return list_V,list_KV
else:
Vh=covVH*np.random.standard_normal(covVH.shape)
V=1/2*f.irfftn(Vh)
V2=V.reshape(Y[:,0].shape)
KV=cho_solve((L,True),V2)
return V,KV
def make_A(size=1,d=d,l_c=l_cA,xi=xi,N=N,eps=eps,corr=corr,nb=len(e_L)):
L=cholesky(gamma_A_func(dist)+eps,lower=True)
A=np.zeros((d,d)+(2*N,)*d)
KA=np.zeros((d,d,(2*N)**d))
if corr=="norm2":
k2=(xi**2).sum(axis=0)
covAH=(2*N)**d*(2*np.pi*l_c**2)**(d/2)*np.exp(-l_c**2*k2/4)
elif corr=="norm1":
truc=2*l_c/(1+(l_c*xi)**2)
covAH=(2*N)**d*2*truc.prod(axis=0)
if size>1:
list_A=[]
list_KA=[]
for k in range(size):
for i in range(1,d):
for j in range(i):
AH=covAH*np.random.standard_normal(covAH.shape)
A[i,j,:]=1/2*f.irfftn(AH)
A[j,i,:]=-A[i,j,:]
A2=A[i,j,:].reshape(Y[:,0].shape)
KA[i,j,:]=cho_solve((L,True),A2)
KA[j,i,:]=-KA[i,j,:]
list_A.append(A)
list_KA.append(KA)
listA=np.array(list_A)
listKA=np.array(list_KA)
return listA,listKA
else:
for i in range(1,d):
for j in range(i):
AH=covAH*np.random.standard_normal(covAH.shape)
A[i,j,:]=1/2*f.irfftn(AH)
A[j,i,:]=-A[i,j,:]
A2=A[i,j,:].reshape(Y[:,0].shape)
KA[i,j,:]=cho_solve((L,True),A2)
KA[j,i,:]=-KA[i,j,:]
return A,KA
Vl,KVl=make_V(samples*len(e_L))
Al,KAl=make_A(samples*len(e_L))
del dist #the distance matrix can be quite heavy, so better delete it once it is not useful anymore
eL=np.array([np.ones(samples)*e for e in e_L]).flatten()
def mean_on_disorder(V,KV,A,KA,e):
v=1-e
a=e
np.random.seed()
#key=random.PRNGKey(0)
def V_func(x,Y=Y,v=v):
return v*jnp.dot(Gamma_VX(x,Y),KV)
nablaV=jit(vmap(grad(V_func,argnums=0)))
def A_func(x,Y=Y,a=a):
return a*(Gamma_AX(x,Y)*KA).sum(axis=-1)
jacobA=jacfwd(A_func)
cste=1/np.sqrt(d)
def rotA(x,d=d,cste=cste):
jacA=jacobA(x)*cste
return jnp.einsum('ij...j -> ...i',jacA)
rotA=jit(vmap(rotA))
Time=2*10**3
dt=2*10**-3
T=0.0
threshold=dt*(T+0.1)
#@jit(nopython=True)
def run(Time=Time,d=d,dt=dt,T=T,mu=mu,threshold=threshold):
x=np.random.uniform(-1,1,size=(1,d))
points=np.empty((Time,d))
fixe=0
for k in range(Time):
points[k]=x
u=-nablaV(x)+np.random.standard_normal(size=(1,d))*T+rotA(x)#-mu*x
x_new=x+dt*u
x=(1+x_new)%2-1
if k>50 and np.std(points[k-20:k,:],axis=0).sum()<threshold: #stop condition: stop trajectories which have converged to something
fixe=1 #if the trajectory found a fixed point
return x,fixe,k
# TODO: add another stop condition to tackle the case where the trajectory is obviously not converging to anything
return x,fixe,k
Nt=600
def dist(Nt,d=d,Time=Time,dt=dt,T=T,mu=mu):
points=np.empty((Nt,d))
fixed=np.empty((Nt,))
nb_it=np.empty((Nt,))
for k in range(Nt):
points[k],fixed[k],nb_it[k]=run()
return points, fixed, nb_it
points,fixed,K=dist(Nt)
points_eq=points[fixed==1] #the points at an equilibrium are only those where fixe=1...
if len(points_eq)==0:
return [0,e,Time]
hi=np.histogramdd(points_eq,bins=20) #the unique are count with an appropriate binning using numpy histograms, not absolutely necessary at T=0, but it is for non-vanishing temperature
nb=hi[0]
nb_eq=len(nb[nb>0])
return [nb_eq,e,K[fixed==1].mean()]
nb_eq_moy=[]
nb_eq_std=[]
nb_eq_time=[]
t1=time.time()
if __name__=='__main__':
P=mp.Pool(th)
argg=[(Vl[k],KVl[k],Al[k],KAl[k],eL[k]) for k in range(samples*len(e_L))]
out=P.starmap_async(mean_on_disorder,argg) #the computations are parallelized on disorders: each process/thread take one disorder at a time
P.close()
P.join()
nb_eq,e_out,K_max=np.array(out.get()).T
t2=time.time()
print(t2-t1)
for e in e_L:
nb_eq_moy.append(nb_eq[e_out==e].mean())
nb_eq_std.append(nb_eq[e_out==e].std())
nb_eq_time.append(K_max[(e_out==e) & (K_max!=None)].mean())
result=pd.DataFrame(data=[nb_eq_moy,nb_eq_std,nb_eq_time],index=e_L,columns=["mean eq","std eq", "mean time"])
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment