Source code for shift_nth_row_n_steps.cli

import warnings

import numpy as np
import seaborn as sns
import typer
from aquarel import load_theme
from array_api_compat import numpy, torch
from cm_time import timer
from pandas import DataFrame
from rich import print

from ._main import (
    shift_nth_row_n_steps,
    shift_nth_row_n_steps_advanced_indexing,
)

app = typer.Typer()


[docs] @app.command() def benchmark( dtype: str = typer.Option("float32"), n_end: int = typer.Option(10), n_iter: int = typer.Option(10), ) -> None: """ Benchmark the two implementations of the function. Parameters ---------- backend : str, optional The backend to use, by default typer.Option("numpy") device : str, optional The device to use, by default typer.Option("cpu") dtype : str, optional The dtype to use, by default typer.Option("float32") n_end : int, optional The maximum power of 2 to use, by default typer.Option(10) n_iter : int, optional The number of iterations to run, by default typer.Option(10) """ res = [] for name, xp in [ ("numpy", numpy), ("torch", torch), # ("jax", jnp), ]: for device in ["cpu", "cuda"]: # if xp == jnp: # device = jax.devices(device)[0] for i in range(n_end): try: n = 2**i rng = np.random.default_rng() input = rng.uniform(size=(n, n)) input = xp.asarray(input, dtype=getattr(xp, dtype), device=device) with timer() as t1: for _ in range(n_iter): input = shift_nth_row_n_steps(input, cut_padding=True) with timer() as t2: for _ in range(n_iter): input = shift_nth_row_n_steps_advanced_indexing(input) print( f"{n}: propsed: {t1.elapsed:g}, " f"advanced indexing: {t2.elapsed:g}" ) res.append( { "n": n, "type": "Proposed Method", "time": t1.elapsed / n_iter, "backend": name, "device": device, } ) res.append( { "n": n, "type": "Advanced Indexing", "time": t2.elapsed / n_iter, "backend": name, "device": device, } ) except Exception as e: warnings.warn( f"Failed to run on {name} with {device} for n={n}: {e}", stacklevel=2, ) df = DataFrame(res) # df = pd.pivot_table( # df, # index="n", # columns=["backend", "device"], # values="time", # aggfunc="mean", # ) # df.plot(ax=ax, logx=True, logy=True) theme = load_theme("boxy_dark") theme.apply() g = sns.FacetGrid(df, row="backend", col="device", margin_titles=True) g.map_dataframe( sns.lineplot, x="n", y="time", hue="type", style="type", markers=["o", "X"] ).set(xscale="log", yscale="log", xlabel="Matrix size", ylabel="Time (s)") g.add_legend() g.savefig("benchmark.webp", dpi=300)