This commit is contained in:
Daria Sushnikova 2023-01-12 15:53:10 +03:00
parent 2a1b1fef27
commit c41e4890e7
5 changed files with 2715 additions and 0 deletions

1935
code/FMM_LU.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,99 @@
import sys
sys.path.insert(0,'..')
import numpy as np
from numpy.random import default_rng
import math
from time import time
from functions import test_funcs
from FMM_LU import FMM_LU as fmm_lu
from problem_tools import problem
import fmm3dbie as h3
import fmm3dpy as fmm3d
verbose = 1
np.random.seed(0)
bs = 64
func = test_funcs.double_layer
close_r = 1.1
num_child_tree = 'hyper'
point_based_tree = 0
eps = 0.51e-4
zk = 1.0 + 1j*0
alpha = 0.0
beta = 1.0
proxy_p = (2, 200)
proxy_r = 1.
csc_fun = 0
symmetric_fun = 0
half_sym = 1
t0 = time()
order = 3
nu = 20
q_fun = 1
pr = fmm_lu.build_problem(geom_type='wtorus', block_size=bs, func=func,
point_based_tree=point_based_tree, close_r=close_r,
num_child_tree=num_child_tree, eps=eps,
zk=zk, alpha=alpha, beta=beta, wtd_T=1, half_sym=half_sym,
csc_fun=csc_fun, q_fun=q_fun, nu=nu, order=order)
pr.coef = coef = -1
print(f'n = {pr.shape[0]}\n')
print(f'problem-build time: {time() - t0}\n')
# FMM-LU solver
xyz_out = np.array([[31.17,-0.03,3.15],[6.13,-4.1,22.2]]).transpose()
xyz_in = np.array([[0.11,-2.13,0.05],[0.13,2.1,-0.01]]).transpose()
#coef = -1
# comparizon to FMM
c = np.array([1 + 1j*0,1+1.1j])
out = fmm3d.h3ddir(zk=zk, sources=xyz_out, targets=pr.srcvals[0:3,:], charges=c, pgt=1)
rhs = out.pottarg
sigma, factor, tf, ts = fmm_lu.fmm_lu_solve(pr, eps, rhs, proxy_p, proxy_r, verbose=verbose)
ntarg = np.shape(xyz_in)[1]
ipatch_id = coef*np.ones(2)
uvs_targ = np.zeros((2,ntarg))
norders = pr.order*np.ones(pr.npatches)
iptype = np.ones(pr.npatches)
pot_comp = h3.lpcomp_helm_comb_dir(norders, pr.ixyzs, iptype, pr.srccoefs, pr.srcvals,
xyz_in, ipatch_id, uvs_targ, eps, pr.zpars, sigma)
out = fmm3d.h3ddir(zk=zk, sources=xyz_out, targets=xyz_in, charges=c,pgt=1)
pot_ex = out.pottarg
erra = np.linalg.norm(pot_ex-pot_comp)
print(f'Fact time: {tf}\n')
print(f'Sol time: {ts}\n')
print(f"error in solution = {erra}\n")
if verbose:
tree = pr.row_tree
level_count = len(tree.level) - 2
print(f'Compression on levels. \nlevel_count: {level_count-2}')
for l in range(level_count-1, factor.tail_lvl-1, -1):
job = [j for j in
range(tree.level[l], tree.level[l+1])]
print(f'Level: {l}')
proc = 0
mean_b = 0
mean_ind = 0
nindl = 0
mean_other_lvl_close = 0
for i in job:
mean_other_lvl_close +=len(pr.other_lvl_close[i])
if factor.index_lvl[i].shape[0] != 0:
proc += factor.basis[i].shape[0]/factor.index_lvl[i].shape[0]*100
nindl += 1
mean_b += factor.basis[i].shape[0]
mean_ind += factor.index_lvl[i].shape[0]
print(f' Mean other lvl close: {mean_other_lvl_close/nindl}')
print(f' Mean compression: {proc/nindl:.2f}%, mean basis: {mean_b/nindl:.2f}, mean index: {mean_ind/nindl:.2f}')

240
functions/test_funcs.py Normal file
View File

@ -0,0 +1,240 @@
from __future__ import print_function, absolute_import, division
__all__ = ['Particles', 'inv_distance', 'log_distance']
import numpy as np
from time import time
from numba import jit
import math
import cmath
from scipy import integrate as intg
def log_dist_int(x,y):
return -1 / (2 * np.pi) * np.log(np.sqrt(x ** 2 + y ** 2))
def log_dist_2d(xd,yd):
return -1 / (2 * np.pi) * np.log(np.sqrt((xd[0] - yd[0]) ** 2 + (xd[1] - yd[1]) ** 2))
def log_distance(data1, list1, data2, list2):
ans = np.ndarray((list1.size, list2.size), dtype=np.float64)
vertex1 = data1.vertex
vertex2 = data2.vertex
n = list1.size
m = list2.size
N = data1.vertex.shape[1]
for i in range(n):
for j in range(m):
if (vertex1[:,list1[i]] == vertex2[:,list2[j]]).all():
ans[i, j] = intg.dblquad(log_dist_int,0,1/(2*np.sqrt(N)),lambda x: 0, lambda x: 1/(2*np.sqrt(N)))[0]*4
else:
ans[i, j] = log_dist_2d(vertex1[:,list1[i]],vertex2[:,list2[j]])/N
return ans
###############################################################################
### interactions for Particles ###
###############################################################################
def inv_distance(data1, list1, data2, list2):
"""
Returns 1/r for each pair of particles from two sets.
Function 1/r is used as interaction between two particles.
Parameters
----------
data1 : Python object
Destination of interactions
list1 : array
Indices of particles from `data1` to compute interactions
data2 : Python object
Source of interactions
list2 : array
Indices of particles from `data1` to compute interactions
Returns
-------
numpy.ndarray(ndim=2)
Array of interactions of corresponding particles.
"""
ans = np.ndarray((list1.size, list2.size), dtype=np.float64)
return inv_distance_numba(data1.ndim, data1.vertex, list1, data2.vertex,
list2, ans)
@jit(nopython=True, parallel=True)
def inv_distance_numba(ndim, vertex1, list1, vertex2, list2, ans):
n = list1.size
m = list2.size
for i in range(n):
for j in range(m):
tmp_l = 0.0
for k in range(ndim):
tmp_v = vertex1[k, list1[i]]-vertex2[k, list2[j]]
tmp_l += tmp_v*tmp_v
if tmp_l <= 0:
ans[i, j] = 0
else:
ans[i, j] = 1./math.sqrt(tmp_l)
return ans
def log_distance_h2t(data1, list1, data2, list2):
"""
Returns -log(r) for each pair of particles from two sets.
Function -log(r) is used as interaction between two particles.
Parameters
----------
data1 : Python object
Destination of interactions
list1 : array
Indices of particles from `data1` to compute interactions
data2 : Python object
Source of interactions
list2 : array
Indices of particles from `data1` to compute interactions
Returns
-------
numpy.ndarray(ndim=2)
Array of interactions of corresponding particles.
"""
ans = np.ndarray((list1.size, list2.size), dtype=np.float64)
return log_distance_numba(data1.ndim, data1.vertex, list1, data2.vertex,
list2, ans)
@jit(nopython=True)
def log_distance_numba(ndim, vertex1, list1, vertex2, list2, ans):
n = list1.size
m = list2.size
for i in range(n):
for j in range(m):
tmp_l = 0.0
for k in range(ndim):
tmp_v = vertex1[k, list1[i]]-vertex2[k, list2[j]]
tmp_l += tmp_v*tmp_v
if tmp_l <= 0:
ans[i, j] = 0
else:
ans[i, j] = -0.5*math.log(tmp_l)
if list1[i] == list2[j]:
ans[i, j] = 15
return ans
def exp_distance_h2t(data1, list1, data2, list2):
ans = np.ndarray((list1.size, list2.size), dtype=np.cdouble)
return exp_distance_numba(data1.ndim, data1.vertex, list1, data2.vertex,
list2, data1.k, ans)
@jit(nopython=True)
def exp_distance_numba(ndim, vertex1, list1, vertex2, list2, kz, ans):
n = list1.size
m = list2.size
for i in range(n):
for j in range(m):
tmp_l = 0.0
for k in range(ndim):
tmp_v = vertex1[k, list1[i]] - vertex2[k, list2[j]]
tmp_l += tmp_v*tmp_v
if tmp_l <= 0:
ans[i, j] = 0
else:
r = math.sqrt(tmp_l)
ans[i, j] = cmath.exp(1j * kz * r)/ r
if list1[i] == list2[j]:
ans[i, j] = 6 + 1j*0
return ans
# def test_fun(data1, list1, data2, list2):
# ans = np.ndarray((list1.size, list2.size), dtype=np.float64)
# # ans = np.ndarray((list1.size, list2.size), dtype=np.float64)
# return test_fun_numba(data1.ndim, data1.vertex, list1, data2.vertex,
# list2, ans)
# @jit(nopython=True)
# def test_fun_numba(ndim, vertex1, list1, vertex2, list2, ans):
# n = list1.size
# m = list2.size
# for i in range(n):
# for j in range(m):
# tmp_l = 0.0
# for k in range(ndim):
# tmp_v = vertex1[k, list1[i]]-vertex2[k, list2[j]]
# tmp_l += tmp_v*tmp_v
# if tmp_l <= 0:
# ans[i, j] = 0
# else:
# # r = math.sqrt(tmp_l)
# ans[i, j] = 1./ (tmp_l)
# if list1[i] == list2[j]:
# ans[i, j] = 1000
# return ans
def double_layer(data1, list1, data2, list2):
ans = np.ndarray((list1.size, list2.size), dtype=np.cdouble)
return double_layer_numba(data1.ndim, data1.vertex, list1, data2.vertex,
list2, data1.k, data1.norms, ans)
@jit(nopython=True)
def double_layer_numba(ndim, vertex1, list1, vertex2, list2, kz, norms, ans):
n = list1.size
m = list2.size
for i in range(n):
for j in range(m):
tmp_l = 0.0
tetha = 0.0
len_norm = 0.0
for k in range(ndim):
tmp_v = vertex1[k, list1[i]] - vertex2[k, list2[j]]
tmp_l += tmp_v * tmp_v
tetha += tmp_v * norms[k,list1[i]]
len_norm += norms[k,list1[i]] * norms[k,list1[i]]
# print (tetha)
if tmp_l <= 0:
ans[i, j] = 0
else:
r = math.sqrt(tmp_l)
len_norm = math.sqrt(len_norm)
tetha = tetha / (r * len_norm)
ans[i, j] = (cmath.exp(1j * kz * r)/ r) * (1j * kz - 1/r)* math.cos(tetha)
if list1[i] == list2[j]:
ans[i, j] = 6 + 1j*0
return ans
@jit(nopython=True)
def comp_sph_numba(ndim, vertex1, list1, vertex2, list2, ans):
n = list1.size
m = list2.size
for i in range(n):
for j in range(m):
tmp_l = 0.0
for k in range(ndim):
tmp_v = vertex1[k, list1[i]]-vertex2[k, list2[j]]
tmp_l += tmp_v*tmp_v
if tmp_l <= 0:
ans[i, j] = 0
else:
ans[i, j] = 1/(4 * np.pi * math.sqrt(tmp_l))
if list1[i] == list2[j]:
ans[i, j] = 150
return ans
def comp_sph(data1, list1, data2, list2):
ans = np.ndarray((list1.size, list2.size))
return comp_sph_numba(data1.ndim, data1.vertex, list1, data2.vertex,
list2, ans)

View File

@ -0,0 +1,297 @@
import numpy as np
from numba import jit
# from h2tools.cluster_tree import SmartIndex
from copy import deepcopy as dc
class SmartIndex(object):
"""
Stores only view to index and information about each node.
It is only used in `ClusterTree` class for convenient work with
indexes. Main reason this is implemented separately from
`ClusterTree` is easily readable syntax: `index[key]` returns view
to subarray of array `index`, corresponding to indices of node
`key`.
Parameters
----------
size : integer
Number of objects in cluster
Attributes
----------
index: 1-dimensional array
Permutation array such, that indexes of objects, corresponding
to the same subcluster, are located one after each other.
node: list of tuples
Indexes of `i`-th node of cluster tree are
`index[node[i][0]:node[i][1]]`.
"""
def __init__(self, size):
self.index = np.arange(size, dtype=np.uint64)
self.node = [(0, size)]
def __getitem__(self, key):
"""Get indices for cluster `key`."""
return self.index[slice(*self.node[key])]
def __setitem__(self, key, value):
"""
Set indices for cluster `key`.
Changes only main index array.
"""
self.index[slice(*self.node[key])] = value
def add_node(self, parent, node):
"""Add node, that corresponds to `index[node[0]:node[1]]`."""
start = self.node[parent][0]+node[0]
stop = self.node[parent][0]+node[1]
self.node.append((start, stop))
def __len__(self):
return len(self.node)
class Data(object):
def __init__(self, ndim, count, vertex, close_r='1box'):
self.ndim = ndim
self.count = count
self.vertex = vertex
self.close_r = close_r
def check_far(self, self_aux, other_aux):
return Data.fast_check_far_ndim(self_aux, other_aux, self.ndim, self.close_r)
def fast_check_far_ndim(self_aux, other_aux, ndim, close_r):
if close_r == '1box':
if ndim == 2:
corners_self = [np.array([self_aux[0,0], self_aux[0,1]]),
np.array([self_aux[1,0], self_aux[0,1]]),
np.array([self_aux[0,0], self_aux[1,1]]),
np.array([self_aux[1,0], self_aux[1,1]])]
corners_other = [np.array([other_aux[0,0], other_aux[0,1]]),
np.array([other_aux[1,0], other_aux[0,1]]),
np.array([other_aux[0,0], other_aux[1,1]]),
np.array([other_aux[1,0], other_aux[1,1]])]
for i in corners_self:
for j in corners_other:
if np.array_equal(i,j):
return False
return True
elif ndim == 3:
corners_self = [np.array([self_aux[0,0], self_aux[0,1], self_aux[0,2]]),
np.array([self_aux[0,0], self_aux[0,1], self_aux[1,2]]),
np.array([self_aux[0,0], self_aux[1,1], self_aux[0,2]]),
np.array([self_aux[0,0], self_aux[1,1], self_aux[1,2]]),
np.array([self_aux[1,0], self_aux[0,1], self_aux[0,2]]),
np.array([self_aux[1,0], self_aux[0,1], self_aux[1,2]]),
np.array([self_aux[1,0], self_aux[1,1], self_aux[0,2]]),
np.array([self_aux[1,0], self_aux[1,1], self_aux[1,2]])]
corners_other = [np.array([other_aux[0,0], other_aux[0,1], other_aux[0,2]]),
np.array([other_aux[0,0], other_aux[0,1], other_aux[1,2]]),
np.array([other_aux[0,0], other_aux[1,1], other_aux[0,2]]),
np.array([other_aux[0,0], other_aux[1,1], other_aux[1,2]]),
np.array([other_aux[1,0], other_aux[0,1], other_aux[0,2]]),
np.array([other_aux[1,0], other_aux[0,1], other_aux[1,2]]),
np.array([other_aux[1,0], other_aux[1,1], other_aux[0,2]]),
np.array([other_aux[1,0], other_aux[1,1], other_aux[1,2]])]
for i in corners_self:
for j in corners_other:
if np.allclose(i,j):
return False
return True
elif type(close_r) == float:
diam0 = 0.
diam1 = 0.
dist = 0.
for i in range(ndim):
tmp = self_aux[0, i]-self_aux[1, i]
diam0 += tmp*tmp
tmp = other_aux[0, i]-other_aux[1, i]
diam1 += tmp*tmp
tmp = self_aux[0, i]+self_aux[1, i]-other_aux[0, i]-other_aux[1, i]
dist += tmp*tmp
dist *= 0.25
return dist > diam0 * close_r and dist > diam1*close_r
else:
raise NameError('Wrong close_r')
def compute_aux(self, index):
tmp_particles = self.vertex[:,index]
xyzmin = np.min(tmp_particles,axis=1)
xyzmax = np.max(tmp_particles,axis=1)
bs = xyzmax-xyzmin
bs = np.max(bs)*1.01 # Fudge factor to account for some rounding issue
xyzmaxuse = xyzmin + bs
return np.array([xyzmin,xyzmaxuse])
def divide(self, index):
vertex = self.vertex[:, index]
center = vertex.mean(axis=1)
vertex -= center.reshape(-1, 1)
normal = np.linalg.svd(vertex, full_matrices=0)[0][:,0]
scal_dot = normal.dot(vertex)
scal_sorted = scal_dot.argsort()
scal_dot = scal_dot[scal_sorted]
k = scal_dot.searchsorted(0)
return scal_sorted, [0, k, scal_sorted.size]
def half_box(self, index, ax, mid_point):
ndim = self.ndim
vertex = self.vertex[:, index]
center = mid_point#vertex.mean(axis=1)
vertex -= center.reshape(-1, 1)
normal = np.zeros(ndim)
normal[ax] = 1.
scal_dot = normal.dot(vertex)
scal_sorted = scal_dot.argsort()
scal_dot = scal_dot[scal_sorted]
k = scal_dot.searchsorted(0)
return scal_sorted, [0, k, scal_sorted.size]
def __len__(self):
return self.count
class Tree(object):
def __init__(self, data, block_size, point_based_tree = True, num_child_tree = 'hyper'):
self.block_size = block_size
self.data = data
self.index = SmartIndex(len(data))
self.parent = [-1]
self.child = [[]]
self.leaf = [0]
self.level = [0, 1]
self.num_levels = 0
self.num_leaves = 1
self.num_nodes = 1
self.point_based_tree = point_based_tree
self.num_child_tree = num_child_tree
if num_child_tree == 'hyper':
self.nchild = 2 ** data.ndim
elif num_child_tree == 2:
self.nchild = num_child_tree
else:
print(f'Number of children = {num_child_tree} is not suported, # children changed to 2')
self.nchild = 2
def divide_space(self, key):
ndim = self.data.ndim
index = self.index[key]
box_list = []
for i in range(self.nchild):
box_list.append(dc(self.aux[key]))
if self.num_child_tree == 'hyper':
for i in range(ndim):
mid_point = (self.aux[key][0, i] + self.aux[key][1, i]) / 2
for ii in range(len(box_list)):
if self.check(ii, i, ndim):
box_list[ii][1,i] = mid_point
else:
box_list[ii][0,i] = mid_point
index_list_old = []
for i in range(2**ndim):
self.aux.append(box_list[i])
index_list_old.append([])
vertex = self.data.vertex[:, index]
for i_v in range(vertex.shape[1]):
v_in_b = 0
for i_aux in range(2**ndim):
vertex_in_box = 1
for nd in range(ndim):
vertex_in_box = vertex_in_box and (vertex[nd,i_v] >= box_list[i_aux][0,nd]) and (vertex[nd,i_v] <= box_list[i_aux][1,nd])
if vertex_in_box and v_in_b == 0:
index_list_old[i_aux].append(index[i_v])
v_in_b += 1
if v_in_b != 1:
for i_aux in range(2**ndim):
print(f'n box: {i_aux}, box: {box_list[i_aux]} \nlast ind: {index_list_old[i_aux]}')
raise NameError(f'{i_v, v_in_b}, v:{vertex[:,i_v]}')
index_res = np.array([])
list_k = [0]
for local_index in index_list_old:
local_index = np.array(local_index)
list_k.append(list_k[-1]+local_index.shape[0])
index_res = np.hstack((index_res,local_index))
return index_res.astype(int), list_k
else:
ax = len(self.level)%2
mid_point = (self.aux[key][0, ax] + self.aux[key][1, ax]) / 2
box_list[0][1,ax] = mid_point
box_list[1][0,ax] = mid_point
for i in range(2):
self.aux.append(box_list[i])
l = len(self.level)
new_index, subclusters = self.data.half_box(index, l%2, mid_point)
new_index = index[new_index]
return new_index, subclusters
def divide_point(self, key):
ndim = self.data.ndim
index = self.index[key]
if self.num_child_tree == 'hyper':
index_list_old = [self.index[key]]
list_k = [0]
ndim = self.data.ndim
for i in range(ndim):
index_list_new = []
for index in index_list_old:
new_index, subclusters = self.data.divide(index)
new_index = index[new_index]
index_list_new.append(new_index[:subclusters[1]])
index_list_new.append(new_index[subclusters[1]:])
index_list_old = dc(index_list_new)
index_res = np.array([])
for local_index in index_list_old:
list_k.append(list_k[-1]+np.array(local_index).shape[0])
index_res = np.hstack((index_res,local_index))
new_index = dc(index_res.astype(int))
subclusters = dc(list_k)
else:
index = self.index[key]
new_index, subclusters = self.data.divide(index)
new_index = index[new_index]
last_ind = subclusters[0]
for i in range(len(subclusters)-1):
next_ind = subclusters[i+1]
self.aux.append(self.data.compute_aux(new_index[last_ind:next_ind]))
last_ind = next_ind
return new_index, subclusters
def check(self, n, dim, ndim):
for _ in range(ndim - dim):
res = n % 2
n = n // 2
return res==0
def divide(self, key):
ndim = self.data.ndim
index = self.index[key]
# d = 1/0
if self.point_based_tree:
new_index, subclusters = self.divide_point(key)
else:
new_index, subclusters = self.divide_space(key)
test_index = new_index.copy()
test_index.sort()
last_ind = subclusters[0]
for i in range(len(subclusters)-1):
next_ind = subclusters[i+1]
if next_ind < last_ind:
raise NameError("children indices must be one after other")
self.index.add_node(key, (last_ind, next_ind))
last = len(self.parent)
self.parent.append(key)
if self.child[key]:
self.num_leaves += 1
self.num_nodes += 1
self.child[key].append(last)
self.child.append([])
last_ind = next_ind
if next_ind != test_index.size:
raise Error("Sum of sizes of children must be the same as"
" size of the parent")
self.index[key] = new_index
def is_far(self, i, other_tree, j):
if i <= j:
result = self.data.check_far(self.aux[i], other_tree.aux[j])
else:
result = other_tree.data.check_far(other_tree.aux[j], self.aux[i])
return result
def __len__(self):
return len(self.parent)

144
problem_tools/problem.py Normal file
View File

@ -0,0 +1,144 @@
import numpy as np
from copy import deepcopy as dc
from collections import defaultdict
from itertools import product
from numba import jit
class Problem(object):
def __init__(self, func, row_tree, col_tree, symmetric, verbose=False):
self._func = func
if symmetric and row_tree is not col_tree:
raise ValueError("row_tree and col_tree parameters must be the "
"same (as Python objects) if flag symmetric is `True`")
self.symmetric = symmetric
self.row_tree = row_tree
self.col_tree = col_tree
self.row_data = row_tree.data
self.col_data = col_tree.data
self.shape = (len(row_tree.data), len(col_tree.data))
l = np.arange(1, dtype=np.uint64)
tmp = self.func(l, l)
self.func_shape = tmp.shape[1:-1]
self.dtype = tmp.dtype
self._build(verbose)
def _build(self, verbose=False):
row_check = [[0]]
self.row_far = []
self.row_close = []
self.row_notransition = []
self.col_far = self.row_far
self.col_close = self.row_close
self.col_notransition = self.row_notransition
self.row_tree.aux =[self.row_data.compute_aux(self.row_tree.index[0])]
print(self.row_tree.aux)
cur_level = 0
while (self.row_tree.level[cur_level] < self.row_tree.level[cur_level+1]):
# print (f' level {cur_level}')
for i in range(self.row_tree.level[cur_level],self.row_tree.level[cur_level+1]):
self.row_far.append([])
self.row_close.append([])
for i in range(self.row_tree.level[cur_level],self.row_tree.level[cur_level+1]):
for j in row_check[i]:
if self.row_tree.is_far(i, self.col_tree, j):
self.row_far[i].append(j)
if self.row_tree is not self.col_tree:
self.col_far[j].append(i)
else:
self.row_close[i].append(j)
if self.row_tree is not self.col_tree:
self.col_close[j].append(i)
# print (i, self.row_close[i])
for i in range(self.row_tree.level[cur_level],self.row_tree.level[cur_level+1]):
if i == 0:
self.row_notransition.append(not self.row_far[i])
else:
self.row_notransition.append(not(self.row_far[i] or
not self.row_notransition[self.row_tree.parent[i]]))
for i in range(self.row_tree.level[cur_level],self.row_tree.level[cur_level+1]):
if(cur_level == 1):
self.row_tree.divide(i)
else:
if (self.row_close[i] and not self.row_tree.child[i] and
self.row_tree.index[i].size >
self.row_tree.block_size):
nonzero_close = False
for j in self.row_close[i]:
if (self.col_tree.index[j].size >
self.col_tree.block_size):
nonzero_close = True
break
if nonzero_close:
self.row_tree.divide(i)
for i in range(self.row_tree.level[cur_level],self.row_tree.level[cur_level+1]):
whom_to_check = []
for j in self.row_close[i]:
whom_to_check.extend(self.col_tree.child[j])
for j in self.row_tree.child[i]:
row_check.append(whom_to_check)
# print (f' 3:')
# for i in range(self.row_tree.level[cur_level],self.row_tree.level[cur_level+1]):
# print (i, self.row_close[i])
# for i in range(self.row_tree.level[cur_level],self.row_tree.level[cur_level+1]):
# tmp_close = []
# if self.row_tree.child[i]:
# for j in self.row_close[i]:
# if not self.col_tree.child[j]:
# tmp_close.append(j)
# self.row_close[i] = tmp_close
self.row_tree.level.append(len(self.row_tree))
# print (f' End:')
# for i in range(self.row_tree.level[cur_level],self.row_tree.level[cur_level+1]):
# print (i, self.row_close[i])
cur_level += 1
# update number of levels
self.num_levels = len(self.row_tree.level)-1
self.row_tree.num_levels = self.num_levels
self.col_tree.num_levels = self.num_levels
def func(self, row, col):
return self._func(self.row_data, row, self.col_data, col)
def multilevel_close(self):
self.lvl_close = dc(self.row_close)
tree = self.row_tree
close = self.row_close
level_count = len(tree.level)-2
row_size = tree.level[-1]
self.other_lvl_close = [[] for i in range(row_size)]
if self.add_up_level_close:
for i in range(level_count-1, 0, -1):
job = [j for j in range(tree.level[i], tree.level[i+1])]
for ind in job:
if tree.child[ind] == []:
for cl in close[ind]:
for ch_cl in tree.child[cl]:
if not tree.is_far(ch_cl, tree, ind):
self.add_child_to_close(ind, ch_cl)
def add_child_to_close(self, main_node, node):
close = self.row_close
tree = self.row_tree
self.other_lvl_close[node].append(main_node)
if not tree.child[node] == [] :
for ch_cl in tree.child[node]:
if not tree.is_far(ch_cl, tree, main_node):
self.add_child_to_close(main_node, ch_cl)
def schur_precompute(self):
row_tree = self.row_tree
N = row_tree.level[-1]
self.schur_list = [set() for i in range(N)] # list of sets
tmp_schur_dict = defaultdict(list) # like a dict, but if element is not exist, empty list are genereted
for ind in range(N):
close = self.lvl_close[ind]
other_close = self.other_lvl_close[ind]
# if self.symmetric:
# col_close = self.lvl_close[ind]
# else:
# col_close = self.col_lvl_close[ind]
for c1, c2 in product(close, close):
self.schur_list[c1].add(c2)
tmp_schur_dict[c1, c2].append(ind)
for c1,c2 in product(close, other_close):
self.schur_list[c1].add(c2)
tmp_schur_dict[c1, c2].append(ind)
self.schur_dict = dict(tmp_schur_dict)