import * as scran from "scran.js";
import * as utils from "./utils/general.js";
import * as rna_pca_module from "./rna_pca.js";
import * as adt_pca_module from "./adt_pca.js";
import * as crispr_pca_module from "./crispr_pca.js";
export const step_name = "combine_embeddings";
function find_nonzero_upstream_states(pca_states, weights) {
let tmp = utils.findValidUpstreamStates(pca_states);
let to_use = [];
for (const k of tmp) {
if (weights[k] > 0) {
to_use.push(k);
}
}
return to_use;
}
/**
* This step combines multiple embeddings from different modalities into a single matrix for downstream analysis.
* It wraps the [`scaleByNeighbors`](https://kanaverse.github.io/scran.js/global.html#scaleByNeighbors) function
* from [**scran.js**](https://kanaverse.github.io/scran.js).
*
* Methods not documented here are not part of the stable API and should not be used by applications.
* @hideconstructor
*/
export class CombineEmbeddingsState {
#pca_states;
#parameters;
#cache;
constructor(pca_states, parameters = null, cache = null) {
if (!(pca_states.RNA instanceof rna_pca_module.RnaPcaState)) {
throw new Error("'pca_states.RNA' should be an RnaPcaState object");
}
if (!(pca_states.ADT instanceof adt_pca_module.AdtPcaState)) {
throw new Error("'pca_states.ADT' should be an AdtPcaState object");
}
if (!(pca_states.CRISPR instanceof crispr_pca_module.CrisprPcaState)) {
throw new Error("'pca_states.CRISPR' should be an CrisprPcaState object");
}
this.#pca_states = pca_states;
this.#parameters = (parameters === null ? {} : parameters);
this.#cache = (cache === null ? {} : cache);
this.changed = false;
}
free() {
utils.freeCache(this.#cache.combined_buffer);
}
/***************************
******** Getters **********
***************************/
/**
* @return {Float64WasmArray} Buffer containing the combined embeddings as a column-major dense matrix,
* where the rows are the dimensions and the columns are the cells.
* This is available after running {@linkcode CombineEmbeddingsState#compute compute}.
*/
fetchCombined() {
return this.#cache.combined_buffer;
}
/**
* @return {number} Number of cells in {@linkcode CombineEmbeddingsState#fetchCombined fetchCombined},
* available after running {@linkcode CombineEmbeddingsState#compute compute}.
*/
fetchNumberOfCells() {
return this.#cache.num_cells;
}
/**
* @return {number} Number of dimensions in {@linkcode CombineEmbeddingsState#fetchCombined fetchCombined},
* available after running {@linkcode CombineEmbeddingsState#compute compute}.
*/
fetchNumberOfDimensions() {
return this.#cache.total_dims;
}
/**
* @return {object} Object containing the parameters.
*/
fetchParameters() {
// Avoid any pass-by-reference activity.
return { ...this.#parameters };
}
/***************************
******** Compute **********
***************************/
/**
* @return {object} Object containing default parameters,
* see the `parameters` argument in {@linkcode CombineEmbeddingsState#compute compute} for details.
*/
static defaults() {
return {
rna_weight: 1,
adt_weight: 1,
crispr_weight: 0,
approximate: true
};
}
static #createPcsView(cache, upstream) {
}
/**
* This method should not be called directly by users, but is instead invoked by {@linkcode runAnalysis}.
*
* @param {object} parameters - Parameter object, equivalent to the `adt_normalization` property of the `parameters` of {@linkcode runAnalysis}.
* @param {number} [parameters.rna_weight] - Relative weight of the RNA embeddings.
* @param {number} [parameters.adt_weight] - Relative weight of the ADT embeddings.
* @param {number} [parameters.crispr_weight] - Relative weight of the CRISPR embeddings.
* @param {boolean} [parameters.approximate] - Whether an approximate nearest neighbor search should be used by `scaleByNeighbors`.
*
* @return The object is updated with new results.
*/
compute(parameters) {
parameters = utils.defaultizeParameters(parameters, CombineEmbeddingsState.defaults());
this.changed = false;
for (const v of Object.values(this.#pca_states)) {
if (v.changed) {
this.changed = true;
break;
}
}
if (parameters.approximate !== this.#parameters.approximate) {
this.#parameters.approximate = parameters.approximate;
this.changed = true;
}
if (parameters.rna_weight !== this.#parameters.rna_weight ||
parameters.adt_weight !== this.#parameters.adt_weight ||
parameters.crispr_weight !== this.#parameters.crispr_weight)
{
this.#parameters.rna_weight = parameters.rna_weight;
this.#parameters.adt_weight = parameters.adt_weight;
this.#parameters.crispr_weight = parameters.crispr_weight;
this.changed = true;
}
if (this.changed) {
const weights = { RNA: parameters.rna_weight, ADT: parameters.adt_weight, CRISPR: parameters.crispr_weight };
let to_use = find_nonzero_upstream_states(this.#pca_states, weights);
if (to_use.length > 1) {
let weight_arr = to_use.map(x => weights[x]);
let collected = [];
let total = 0;
let ncells = null;
for (const k of to_use) {
let curpcs = this.#pca_states[k].fetchPCs();
collected.push(curpcs.principalComponents({ copy: "view" }));
if (ncells == null) {
ncells = curpcs.numberOfCells();
} else if (ncells !== curpcs.numberOfCells()) {
throw new Error("number of cells should be consistent across all embeddings");
}
total += curpcs.numberOfPCs();
}
let buffer = utils.allocateCachedArray(ncells * total, "Float64Array", this.#cache, "combined_buffer");
scran.scaleByNeighbors(collected, ncells, { buffer: buffer, weights: weight_arr, approximate: parameters.approximate });
this.#cache.num_cells = ncells;
this.#cache.total_dims = total;
} else {
// If there's only one embedding, we shouldn't respond to changes
// in parameters, because they won't have any effect.
let upstream = this.#pca_states[to_use[0]].fetchPCs();
utils.freeCache(this.#cache.combined_buffer);
this.#cache.combined_buffer = upstream.principalComponents({ copy: "view" }).view();
this.#cache.num_cells = upstream.numberOfCells();
this.#cache.total_dims = upstream.numberOfPCs();
}
}
return;
}
}