import * as utils from "./utils.js";
import { RunPcaResults } from "./runPca.js";
import * as wasm from "./wasm.js";
/**
* Perform mutual nearest neighbor (MNN) correction on a low-dimensional representation to remo
* This is used to remove batch effects prior to downstream analyses like clustering,
* check out the [**mnncorrect**](https://github.com/libscran/mnncorrect) for details.
*
* @param {(RunPcaResults|TypedArray|Array|Float64WasmArray)} x - A matrix of low-dimensional results where rows are dimensions and columns are cells.
* If this is a {@linkplain RunPcaResults} object, the PCs are automatically extracted.
* Otherwise, the matrix should be provided as an array in column-major form, with specification of `numberOfDims` and `numberOfCells`.
* @param {(Int32WasmArray|Array|TypedArray)} block - Array containing the block assignment for each cell.
* This should have length equal to the number of cells and contain all values from 0 to `n - 1` at least once, where `n` is the number of blocks.
* This is used to segregate cells in order to perform normalization within each block.
* @param {object} [options={}] - Further optional parameters.
* @param {boolean} [options.asTypedArray=true] - Whether to return a Float64Array.
* If `false`, a Float64WasmArray is returned instead.
* @param {?Float64WasmArray} [options.buffer=null] - Buffer of length equal to the product of the number of cells and dimensions,
* to be used to store the corrected coordinates for each cell.
* If `null`, this is allocated and returned by the function.
* @param {?number} [options.numberOfDims=null] - Number of dimensions in `x`.
* This should be specified if an array-like object is provided, otherwise it is ignored.
* @param {?number} [options.numberOfCells=null] - Number of cells in `x`.
* This should be specified if an array-like object is provided, otherwise it is ignored.
* @param {number} [options.k=15] - Number of neighbors to use in the MNN search.
* @param {number} [options.steps=1] - Number of steps to take in the nearest neighbor graph when computing the center of mass for each cell in an MNN pair.
* @param {string} [options.mergePolicy="rss"] - What policy to use for ordering the batches to be merged.
* Options are to use the size (`"size"`), the variance (`"variance"`), the residual sum of squares (`"rss"`) or the input order (`"input"`).
* @param {boolean} [options.approximate=true] - Whether to perform an approximate nearest neighbor search.
* @param {?number} [options.numberOfThreads=null] - Number of threads to use.
* If `null`, defaults to {@linkcode maximumThreads}.
*
* @return {Float64Array|Float64WasmArray} Array of length equal to `x`, containing the batch-corrected low-dimensional coordinates for all cells.
* Corrected values are organized using the column-major layout, where rows are dimensions and columns are cells.
* If `buffer` is supplied, the function returns `buffer` if `asTypedArray = false`, or a view on `buffer` if `asTypedArray = true`.
*/
export function mnnCorrect(x, block, options = {}) {
let {
asTypedArray = true,
buffer = null,
numberOfDims = null,
numberOfCells = null,
k = 15,
steps = 1,
numberOfMADs = 3, // back-compatibility
robustIterations = null, // back-compatibility
robustTrim = null, // back-compatibility
referencePolicy = null, // back-compatibility
mergePolicy = "rss",
approximate = true,
numberOfThreads = null,
...others
} = options;
utils.checkOtherOptions(others);
let local_buffer = null;
let x_data;
let block_data;
let nthreads = utils.chooseNumberOfThreads(numberOfThreads);
if (referencePolicy !== null) {
console.warning("'referencePolicy=' is deprecated, use 'mergePolicy' instead");
mergePolicy = referencePolicy.replace(/^max-/, "");
}
try {
if (x instanceof RunPcaResults) {
numberOfDims = x.numberOfPCs();
numberOfCells = x.numberOfCells();
x = x.principalComponents({ copy: "view" });
} else {
if (numberOfDims === null || numberOfCells === null || numberOfDims * numberOfCells !== x.length) {
throw new Error("length of 'x' must be equal to the product of 'numberOfDims' and 'numberOfCells'");
}
x_data = utils.wasmifyArray(x, "Float64WasmArray");
x = x_data;
}
if (buffer == null) {
local_buffer = utils.createFloat64WasmArray(numberOfCells * numberOfDims);
buffer = local_buffer;
} else if (buffer.length !== x.length) {
throw new Error("length of 'buffer' must be equal to the product of the number of dimensions and cells");
}
block_data = utils.wasmifyArray(block, "Int32WasmArray");
if (block_data.length != numberOfCells) {
throw new Error("'block' must be of length equal to the number of cells in 'x'");
}
wasm.call(module => module.mnn_correct(
numberOfDims,
numberOfCells,
x.offset,
block_data.offset,
buffer.offset,
k,
steps,
mergePolicy,
approximate,
nthreads
));
} catch (e) {
utils.free(local_buffer);
throw e;
} finally {
utils.free(x_data);
}
return utils.toTypedArray(buffer, local_buffer == null, asTypedArray);
}