Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ lint.select = [
"ICN", # Follow import conventions
"ISC", # Implicit string concatenation
"N", # Naming conventions
"PD", # Pandas
"PERF", # Performance
"PIE", # Syntax simplifications
"PL", # Pylint
Expand Down
60 changes: 28 additions & 32 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,11 @@ def compute_association_matrix_of_groups(
if "?" in pred_group:
pred_group = str(ipred_group) # noqa: PLW2901
# starting from numpy version 1.13, subtractions of boolean arrays are deprecated
mask_pred = adata.obs[prediction].values == pred_group
mask_pred = adata.obs[prediction].to_numpy() == pred_group
mask_pred_int = mask_pred.astype(np.int8)
asso_matrix += [[]]
for ref_group in adata.obs[reference].cat.categories:
mask_ref = (adata.obs[reference].values == ref_group).astype(np.int8)
mask_ref = (adata.obs[reference].to_numpy() == ref_group).astype(np.int8)
mask_ref_or_pred = mask_ref.copy()
mask_ref_or_pred[mask_pred] = 1
# e.g. if the pred group is contained in mask_ref, mask_ref and
Expand Down Expand Up @@ -796,42 +796,38 @@ def select_groups(
groups_masks_obs = adata.uns[f"{key}_masks"]
else:
groups_masks_obs = np.zeros(
(len(adata.obs[key].cat.categories), adata.obs[key].values.size), dtype=bool
(len(adata.obs[key].cat.categories), adata.obs[key].size), dtype=bool
)
for iname, name in enumerate(adata.obs[key].cat.categories):
# if the name is not found, fallback to index retrieval
if name in adata.obs[key].values:
mask_obs = name == adata.obs[key].values
if name in adata.obs[key].array:
mask_obs = name == adata.obs[key].to_numpy()
else:
mask_obs = str(iname) == adata.obs[key].values
mask_obs = str(iname) == adata.obs[key].to_numpy()
groups_masks_obs[iname] = mask_obs
groups_ids = list(range(len(groups_order)))
if groups_order_subset != "all":
groups_ids = []
for name in groups_order_subset:
groups_ids.append(
np.where(adata.obs[key].cat.categories.values == name)[0][0]
)
if len(groups_ids) == 0:
# fallback to index retrieval
groups_ids = np.where(
np.isin(
np.arange(len(adata.obs[key].cat.categories)).astype(str),
np.array(groups_order_subset),
)
)[0]
if len(groups_ids) == 0:
logg.debug(
f"{np.array(groups_order_subset)} invalid! specify valid "
f"groups_order (or indices) from {adata.obs[key].cat.categories}",
)
from sys import exit
if groups_order_subset == "all":
return groups_order.to_numpy(), groups_masks_obs

exit(0)
groups_masks_obs = groups_masks_obs[groups_ids]
groups_order_subset = adata.obs[key].cat.categories[groups_ids].values
else:
groups_order_subset = groups_order.values
groups_ids = [
np.flatnonzero(adata.obs[key].cat.categories.array == name)[0]
for name in groups_order_subset
]
if len(groups_ids) == 0:
# fallback to index retrieval
groups_ids = np.flatnonzero(
np.isin(
np.arange(len(adata.obs[key].cat.categories)).astype(str),
np.array(groups_order_subset),
)
)
if len(groups_ids) == 0:
msg = (
f"{np.array(groups_order_subset)} invalid! specify valid "
f"groups_order (or indices) from {adata.obs[key].cat.categories}",
)
raise RuntimeError(msg)
groups_masks_obs = groups_masks_obs[groups_ids]
groups_order_subset = adata.obs[key].cat.categories[groups_ids].to_numpy()
return groups_order_subset, groups_masks_obs


Expand Down
25 changes: 11 additions & 14 deletions src/scanpy/experimental/pp/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _highly_variable_pearson_residuals( # noqa: PLR0912, PLR0915
if batch_key is None:
batch_info = np.zeros(adata.shape[0], dtype=int)
else:
batch_info = adata.obs[batch_key].values
batch_info = adata.obs[batch_key].to_numpy()
n_batches = len(np.unique(batch_info))

# Get pearson residuals for each batch separately
Expand Down Expand Up @@ -239,11 +239,10 @@ def _highly_variable_pearson_residuals( # noqa: PLR0912, PLR0915

# Sort genes by how often they selected as hvg within each batch and
# break ties with median rank of residual variance across batches
df.sort_values(
df = df.sort_values(
["highly_variable_nbatches", "highly_variable_rank"],
ascending=[False, True],
na_position="last",
inplace=True,
)

high_var = np.zeros(df.shape[0], dtype=bool)
Expand All @@ -263,29 +262,27 @@ def _highly_variable_pearson_residuals( # noqa: PLR0912, PLR0915
" 'variances', float vector (adata.var)\n"
" 'residual_variances', float vector (adata.var)"
)
adata.var["means"] = df["means"].values
adata.var["variances"] = df["variances"].values
adata.var["residual_variances"] = df["residual_variances"]
adata.var["highly_variable_rank"] = df["highly_variable_rank"].values
adata.var["means"] = df["means"].array
adata.var["variances"] = df["variances"].array
adata.var["residual_variances"] = df["residual_variances"].array
adata.var["highly_variable_rank"] = df["highly_variable_rank"].array
if batch_key is not None:
adata.var["highly_variable_nbatches"] = df[
"highly_variable_nbatches"
].values
adata.var["highly_variable_nbatches"] = df["highly_variable_nbatches"].array
adata.var["highly_variable_intersection"] = df[
"highly_variable_intersection"
].values
adata.var["highly_variable"] = df["highly_variable"].values
].array
adata.var["highly_variable"] = df["highly_variable"].array

if subset:
adata._inplace_subset_var(df["highly_variable"].values)
adata._inplace_subset_var(df["highly_variable"].to_numpy())

else:
if batch_key is None:
df = df.drop(
["highly_variable_nbatches", "highly_variable_intersection"], axis=1
)
if subset:
df = df.iloc[df.highly_variable.values, :]
df = df.iloc[df["highly_variable"].to_numpy(), :]

return df

Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/external/exporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def _export_paga_to_spring(adata, paga_coords, outpath) -> None:
coords = [list(xy) for xy in paga_coords]

sizes = list(adata.uns[f"{group_key}_sizes"])
clus_labels = adata.obs[group_key].cat.codes.values
clus_labels = adata.obs[group_key].cat.codes.to_numpy()
cell_groups = [
[int(j) for j in np.nonzero(clus_labels == i)[0]] for i in range(len(names))
]
Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/external/pp/_hashsolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def hashsolo(
"Please cite HashSolo paper:\nhttps://www.cell.com/cell-systems/fulltext/S2405-4712(20)30195-2"
)
adata = adata.copy() if not inplace else adata
data = adata.obs[cell_hashing_columns].values
data = adata.obs[cell_hashing_columns].to_numpy()
if not check_nonnegative_integers(data):
msg = "Cell hashing counts must be non-negative"
raise ValueError(msg)
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/get/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def rank_genes_groups_df(

d = [pd.DataFrame(adata.uns[key][c])[group] for c in colnames]
d = pd.concat(d, axis=1, names=[None, "group"], keys=colnames)
d = d.stack(level=1, future_stack=True).reset_index()
d = d.stack(level=1, future_stack=True).reset_index() # noqa: PD013
d["group"] = pd.Categorical(d["group"], categories=group)
d = d.sort_values(["group", "level_0"]).drop(columns="level_0")

Expand All @@ -106,7 +106,7 @@ def rank_genes_groups_df(

# remove group column for backward compat if len(group) == 1
if len(group) == 1:
d.drop(columns="group", inplace=True)
del d["group"]

return d.reset_index(drop=True)

Expand Down
6 changes: 2 additions & 4 deletions src/scanpy/plotting/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,7 @@ def add_centroid(centroids, name, xy, mask) -> None:
)
raise ValueError(msg)
else:
iname = np.flatnonzero(
adata.obs[key].cat.categories.values == name
)[0]
iname = np.flatnonzero(adata.obs[key].cat.categories == name)[0]
mask = scatter_group(
axs[ikey],
key,
Expand Down Expand Up @@ -1992,7 +1990,7 @@ def _prepare_dataframe( # noqa: PLR0912

if groupby_index is not None:
# reset index to treat all columns the same way.
obs_tidy.reset_index(inplace=True)
obs_tidy = obs_tidy.reset_index()
groupby.append(groupby_index)

if groupby is None:
Expand Down
8 changes: 4 additions & 4 deletions src/scanpy/plotting/_dotplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,8 @@ def _plot_stacked_colorbars(self, fig, colorbar_area_spec, normalize):
)

# Create a dedicated normalizer for the legend
vmin = self.dot_color_df.values.min()
vmax = self.dot_color_df.values.max()
vmin = self.dot_color_df.to_numpy().min()
vmax = self.dot_color_df.to_numpy().max()
legend_norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

for i, group_name in enumerate(groups_to_plot):
Expand Down Expand Up @@ -799,8 +799,8 @@ def _dotplot( # noqa: PLR0912, PLR0913, PLR0915
y, x = np.indices(dot_color.shape)
y = y.flatten() + 0.5
x = x.flatten() + 0.5
frac = dot_size.values.flatten()
mean_flat = dot_color.values.flatten()
frac = dot_size.to_numpy().flatten()
mean_flat = dot_color.to_numpy().flatten()
cmap = colormaps.get_cmap(cmap)
if dot_max is None:
dot_max = np.ceil(max(frac) * 10) / 10
Expand Down
21 changes: 6 additions & 15 deletions src/scanpy/plotting/_stacked_violin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
from matplotlib import colormaps
from matplotlib.colors import is_color_like

Expand All @@ -25,6 +24,7 @@
from collections.abc import Mapping, Sequence
from typing import Literal, Self

import pandas as pd
from anndata import AnnData
from matplotlib.axes import Axes
from matplotlib.colors import Colormap, Normalize
Expand Down Expand Up @@ -440,7 +440,7 @@ def _mainplot(self, ax: Axes):
def _make_rows_of_violinplots(
self,
ax,
_matrix,
_matrix: pd.DataFrame,
colormap_array,
_color_df,
x_spacer_size: float,
Expand All @@ -466,18 +466,9 @@ def _make_rows_of_violinplots(
# the expression value
# This format is convenient to aggregate per gene or per category
# while making the violin plots.
df = (
pd
.DataFrame(_matrix.stack(future_stack=True))
.reset_index()
.rename(
columns={
"level_1": "genes",
_matrix.index.name: "categories",
0: "values",
}
)
)
df = _matrix.melt(
var_name="genes", value_name="values", ignore_index=False
).reset_index(names="categories")
df["genes"] = (
df["genes"].astype("category").cat.reorder_categories(_matrix.columns)
)
Expand Down Expand Up @@ -514,7 +505,7 @@ def _make_rows_of_violinplots(

if not self.are_axes_swapped:
x = "genes"
_df = df[df.categories == row_label]
_df = df[df["categories"] == row_label]
else:
x = "categories"
# because of the renamed matrix columns here
Expand Down
8 changes: 4 additions & 4 deletions src/scanpy/plotting/_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,15 @@ def dpt_timeseries(
if as_heatmap:
# plot time series as heatmap, as in Haghverdi et al. (2016), Fig. 1d
timeseries_as_heatmap(
adata.X[adata.obs["dpt_order_indices"].values],
adata.X[adata.obs["dpt_order_indices"].to_numpy()],
var_names=adata.var_names,
highlights_x=adata.uns["dpt_changepoints"],
color_map=color_map,
)
else:
# plot time series as gene expression vs time
timeseries(
adata.X[adata.obs["dpt_order_indices"].values],
adata.X[adata.obs["dpt_order_indices"].to_numpy()],
var_names=adata.var_names,
highlights_x=adata.uns["dpt_changepoints"],
xlim=[0, 1.3 * adata.X.shape[0]],
Expand Down Expand Up @@ -587,7 +587,7 @@ def _rank_genes_groups_plot( # noqa: PLR0912, PLR0913, PLR0915
if gene_symbols is not None:
df["names"] = df[gene_symbols]

genes_list = df.names[df.names.notnull()].tolist()
genes_list = df.names[df.names.notna()].tolist()

if len(genes_list) == 0:
logg.warning(f"No genes found for group {group}")
Expand Down Expand Up @@ -1740,7 +1740,7 @@ def _get_values_to_plot(
column = values_to_plot.replace("log10_", "")
else:
column = values_to_plot
values_df = pd.pivot(
values_df = pd.pivot_table(
values_df, index="names", columns="group", values=column
).fillna(1)

Expand Down
25 changes: 10 additions & 15 deletions src/scanpy/plotting/_tools/paga.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ def _paga_graph( # noqa: PLR0912, PLR0913, PLR0915
from io import StringIO

df = pd.read_csv(StringIO(s), header=-1)
pos_array = df[[4, 5]].values
pos_array = df[[4, 5]].to_numpy()

# convert to dictionary
pos = {n: [p[0], p[1]] for n, p in enumerate(pos_array)}
Expand All @@ -809,7 +809,7 @@ def _paga_graph( # noqa: PLR0912, PLR0913, PLR0915
x_color = []
cats = adata.obs[groups_key].cat.categories
for cat in cats:
subset = (cat == adata.obs[groups_key]).values
subset = (cat == adata.obs[groups_key]).to_numpy()
if adata.raw is not None and use_raw:
adata_gene = adata.raw[:, colors]
else:
Expand All @@ -826,7 +826,7 @@ def _paga_graph( # noqa: PLR0912, PLR0913, PLR0915
x_color = []
cats = adata.obs[groups_key].cat.categories
for cat in cats:
subset = (cat == adata.obs[groups_key]).values
subset = (cat == adata.obs[groups_key]).to_numpy()
x_color.append(adata.obs.loc[subset, colors].mean())
colors = x_color

Expand Down Expand Up @@ -1199,9 +1199,8 @@ def moving_average(a):
for ikey, key in enumerate(keys):
x = []
for igroup, group in enumerate(nodes_ints):
idcs = np.arange(adata.n_obs)[
adata.obs[groups_key].values == nodes_strs[igroup]
]
mask = (adata.obs[groups_key] == nodes_strs[igroup]).to_numpy()
idcs = np.flatnonzero(mask)
if len(idcs) == 0:
msg = (
"Did not find data points that match "
Expand All @@ -1210,15 +1209,11 @@ def moving_average(a):
"actually contains what you expect."
)
raise ValueError(msg)
idcs_group = np.argsort(
adata.obs["dpt_pseudotime"].values[
adata.obs[groups_key].values == nodes_strs[igroup]
]
)
idcs_group = np.argsort(adata.obs["dpt_pseudotime"].iloc[mask].to_numpy())
idcs = idcs[idcs_group]
values = (adata.obs[key].values if key in adata.obs else adata_x[:, key].X)[
idcs
]
values = (
adata.obs[key].to_numpy() if key in adata.obs else adata_x[:, key].X
)[idcs]
x += (values.toarray() if isinstance(values, CSBase) else values).tolist()
if ikey == 0:
groups += [group] * len(idcs)
Expand All @@ -1227,7 +1222,7 @@ def moving_average(a):
series = adata.obs[anno]
if isinstance(series.dtype, CategoricalDtype):
series = series.cat.codes
anno_dict[anno] += list(series.values[idcs])
anno_dict[anno] += series.iloc[idcs].to_list()
if n_avg > 1:
x = moving_average(x)
if ikey == 0:
Expand Down
Loading
Loading