Source code for kdtree

# -*- coding: utf-8 -*-
# SPDX-License-Identifier: CECILL-2.1
"""
Grid importance sampling using a k-d tree.

:copyright:
    2022-2024 Claudio Satriano <satriano@ipgp.fr>
:license:
    CeCILL Free Software License Agreement v2.1
    (http://www.cecill.info/licences.en.html)
"""
import itertools
import numpy as np
from scipy.interpolate import griddata


[docs] class KDTCell(): """A cell of a k-d tree.""" def __init__(self, extent, calc_pdf, min_cell_prob=0, ndiv=None, maxdiv=None): self.extent = extent self.coords = np.mean(extent, axis=1) self.delta = np.diff(extent) self.calc_pdf = calc_pdf self.pdf = calc_pdf(self.coords) self.volume = np.prod(self.delta) self.prob = self.pdf * self.volume self.min_cell_prob = min_cell_prob if self.prob <= min_cell_prob: self.is_divisible = False self.prob_divisible = 0. else: self.is_divisible = True self.prob_divisible = self.prob self.ndiv = ndiv if self.ndiv is None: self.ndiv = np.zeros(self.extent.shape[0]) self.maxdiv = maxdiv
[docs] def divide(self, parts): """Divide the cell in `parts` parts along every dimension.""" if not self.is_divisible: return [self, ] # dim is the dimension of the parameter space dim = self.extent.shape[0] # define a set of increments for all the spatial directions # the number of increments is "parts" # The code "np.mgrid...[1]" produces a 2D array of the type: # [[0 1 ... parts-1] # ... (dim) ... # [0 1 ... parts-1]] # increments = np.mgrid[0:dim, 0:parts][1] * self.delta/(parts-1) increments = np.mgrid[0:dim, 0:parts][1] if self.maxdiv is not None: # dimensions that will not be divided increments[self.ndiv >= self.maxdiv] *= 0 # dimensions that will be divided self.ndiv[self.ndiv < self.maxdiv] += parts increments = increments * self.delta / (parts - 1) # take the minimum coordinate and transform it to a column vector mincoord = self.extent[:, 0].reshape(-1, 1) # add minimum coordinates to the increments coords = mincoord + increments cells = [] # loop over all the possible n-uplets of coordinates # we use set() to avoid repetitions for dimensions that # will not be dvided for c in set(itertools.product(*coords)): # c is a coordinate n-uplet. Let's transform it to column vector c = np.array(c).reshape(-1, 1) delta = self.delta / parts # FIXME: I forgot my own code! # should this `2` be `parts`? or is it related to delta/2? inc = np.mgrid[0:dim, 0:2][1] * delta - delta / 2 extent = c + inc cells.append( KDTCell( extent, self.calc_pdf, self.min_cell_prob, self.ndiv, self.maxdiv)) return cells
[docs] class KDTree(): """A k-d tree.""" def __init__(self, extent, init_parts, calc_pdf, min_cell_prob=0., maxdiv=None): # extent defines the size of search hypervolume # reshape extent to (dim, 2), where dim is the # arbitrary dimension of the parameter space self.extent = np.array(extent).reshape(-1, 2) # create the first cell, with the same size # of the search hypervolume cell0 = KDTCell(self.extent, calc_pdf, maxdiv=maxdiv) self.init_prob = cell0.prob cell0.min_cell_prob = cell0.prob * min_cell_prob self.cells = cell0.divide(init_parts) self.ncells = len(self.cells)
[docs] def divide(self): """ Find the cell with highest probability and divide it in 2 parts along every dimension """ # self.cells.sort(key=lambda c: c.prob) self.cells.sort(key=lambda c: c.prob_divisible) cell0 = self.cells.pop() # print(self.init_prob, cell0.prob, cell0.prob/self.init_prob) self.cells += cell0.divide(2) self.ncells = len(self.cells)
[docs] def get_pdf(self, deltas): """ Calculate the probability density function (PDF) on a grid defined by deltas. """ deltas = np.array(deltas).reshape(-1, 1) extent = self.extent.copy() ranges = [] extent_new = [] for v in np.hstack((extent, deltas)): start, stop, step = v # add a small number to make sure end value is included stop += stop / 1e5 rng = np.arange(start, stop, step) ranges.append(rng) extent_new += [rng[0], rng[-1]] xi = np.meshgrid(*ranges, indexing='ij') coords = np.array([cell.coords for cell in self.cells]) pdf_values = np.array([cell.pdf for cell in self.cells]) pdf = griddata(coords, pdf_values, tuple(xi), method='linear') return pdf, extent_new