Repository Analysis

jax-ml/jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

19.2 Moderate AI signal View on GitHub
19.2
Adjusted Score
19.2
Raw Score
100%
Time Factor
2026-05-30
Last Push
35,719
Stars
Python
Language
627,419
Lines of Code
1348
Files
11935
Pattern Hits
2026-05-31
Scan Date

Score History

Severity Breakdown

CRITICAL 3HIGH 620MEDIUM 299LOW 11013

Pattern Findings

11935 matches across 19 categories. Click a row to expand file-level details.

Hyper-Verbose Identifiers5566 hits · 4365 pts
SeverityFileLineSnippet
LOWci/parse_wheel_metadata.py34def _extract_wheel_package_name(wheel_name: str) -> str | None:
LOWci/parse_wheel_metadata.py52def parse_expected_wheel_versions(wheel_dir: Path) -> dict[str, str]:
LOWtests/random_test.py569 def test_key_construction_with_dtype(self, dtype_spec):
LOWtests/random_test.py573 def test_key_construction_with_both_impl_and_dtype(self):
LOWtests/random_test.py583 def test_wrap_key_data_with_dtype(self, dtype_spec):
LOWtests/random_test.py591 def test_wrap_key_data_with_both_impl_and_dtype(self):
LOWtests/random_test.py1215 def test_key_make_like_other_key_via_dtype(self, prng_name):
LOWtests/random_test.py1223 def test_key_wrap_like_other_key_via_dtype(self, prng_name):
LOWtests/random_test.py1231 def test_key_impl_from_string_error(self):
LOWtests/random_test.py1235 def test_key_impl_from_object_error(self):
LOWtests/random_test.py1242 def test_key_impl_builtin_is_string_name(self, name):
LOWtests/random_test.py199 def test_config_prngs_registered(self):
LOWtests/random_test.py412 def test_threefry_gpu_kernel_lowering(self):
LOWtests/random_test.py445 def test_threefry_split_fold_in_symmetry(self, make_key):
LOWtests/random_test.py458 def test_threefry_split_vmapped_fold_in_symmetry(self, make_key):
LOWtests/random_test.py472 def test_loggamma_nan_corner_case(self):
LOWtests/random_test.py545 def test_default_prng_selection(self, make_key, name, impl):
LOWtests/random_test.py557 def test_key_construction_with_explicit_impl_name(self, make_key, name, impl):
LOWtests/random_test.py609 def test_legacy_prng_key_flag(self):
LOWtests/random_test.py637 def test_seed_no_implicit_transfers(self, make_key):
LOWtests/random_test.py713 def test_construction_upgrade_flag(self):
LOWtests/random_test.py759 def test_key_dtype_attributes(self):
LOWtests/random_test.py814 def test_cpp_dispatch_aot_normal(self):
LOWtests/random_test.py827 def test_cpp_dispatch_aot_split(self):
LOWtests/random_test.py905 def test_eval_shape_keys_in_out(self):
LOWtests/random_test.py963 def test_dynamic_update_slice(self):
LOWtests/random_test.py1053 def test_device_put_replicated(self):
LOWtests/random_test.py1059 def test_make_array_from_callback(self):
LOWtests/random_test.py1071 def test_make_array_from_single_device_arrays(self):
LOWtests/random_test.py1081 def test_key_array_custom_jvp(self):
LOWtests/random_test.py1111 def test_key_array_indexing_nd(self):
LOWtests/random_test.py1136 def test_array_impl_attributes(self):
LOWtests/random_test.py1194 def test_key_make_like_other_key(self, prng_name):
LOWtests/random_test.py1204 def test_key_wrap_like_other_key(self, prng_name):
LOWtests/random_test.py1262 def test_keyarray_custom_vjp_symbolic_zeros(self):
LOWtests/random_test.py1277 def test_keyarray_array_conversion_fails(self):
LOWtests/random_test.py1301def _double_threefry_random_bits(key, bit_width, shape):
LOWtests/random_test.py1565 def test_full_like_with_key_fillvalue(self):
LOWtests/random_test.py1586 def test_full_with_key_fillvalue(self):
LOWtests/extend_test.py84 def test_key_make_with_custom_impl(self):
LOWtests/extend_test.py92 def test_key_wrap_with_custom_impl(self):
LOWtests/extend_test.py102 def test_key_make_with_custom_impl_via_dtype(self):
LOWtests/extend_test.py109 def test_key_wrap_with_custom_impl_via_dtype(self):
LOWtests/extend_test.py118 def test_key_dtype_and_spec_with_custom_impl(self):
LOWtests/extend_test.py130 def test_unknown_platform_error(self):
LOWtests/extend_test.py139 def test_hlo_sharding_roundtrip(self):
LOWtests/scheduling_groups_test.py67 def test_xla_metadata_call_inlineable(self):
LOWtests/scheduling_groups_test.py88 def test_xla_metadata_call_inlineable_remat_in_scan(self):
LOWtests/scheduling_groups_test.py105 def test_xla_metadata_call_deduplication(self):
LOWtests/scheduling_groups_test.py133 def test_xla_metadata_call_deduplication_remat(self):
LOWtests/scheduling_groups_test.py162 def test_xla_metadata_call_deduplication_kwargs(self):
LOWtests/jax_to_ir_test.py65 def test_jax_to_hlo_with_constants(self):
LOWtests/jax_to_ir_test.py85 def test_parse_shape_str_invalid(self):
LOWtests/ragged_collective_test.py133 def test_ragged_all_to_all_grad(self, axis_name, mesh_axes):
LOWtests/ragged_collective_test.py207 def test_ragged_all_to_all_axis_index_groups(self, axis_name, mesh_axes):
LOWtests/ragged_collective_test.py301 def test_ragged_all_to_all_degenerate_groups(self, axis_name, mesh_axes):
LOWtests/ragged_collective_test.py384 def test_ragged_all_to_all_vmap_multi_dim_operand(self):
LOWtests/ragged_collective_test.py480 def test_ragged_all_to_all_vmap(
LOWtests/ragged_collective_test.py581 def test_ragged_all_to_all_vmap_unsupported_axis_index_groups(self):
LOWtests/ragged_collective_test.py639 def test_ragged_all_to_all_errors(self):
5506 more matches not shown…
Unused Imports3476 hits · 2452 pts
SeverityFileLineSnippet
LOWtests/random_test.py15
LOWtests/batching_test.py15
LOWtests/debug_info_test.py15
LOWtests/mutable_array_test.py15
LOWtests/lax_numpy_test.py15
LOWtests/shard_map_test.py15
LOWtests/export_harnesses_multi_platform_test.py22
LOWtests/shape_poly_test.py15
LOWtests/lax_vmap_test.py15
LOWtests/lax_numpy_indexing_test.py15
LOWtests/typing_test.py21
LOWtests/array_api_test.py20
LOWtests/state_test.py15
LOWtests/export_test.py14
LOWtests/export_test.py46
LOWtests/export_test.py53
LOWtests/pmap_test.py15
LOWtests/colocated_python_test.py36
LOWtests/lru_cache_test.py15
LOWtests/compilation_cache_test.py15
LOWtests/hijax_test.py15
LOWtests/export_serialization_back_compat_test.py53
LOWtests/api_test.py15
LOWtests/roofline_test.py14
LOWtests/lax_test.py14
LOWtests/mosaic/gpu_torch_test.py29
LOWtests/mosaic/flash_attention_test.py24
LOWtests/mosaic/gpu_test.py55
LOWtests/multiprocess/colocated_python_test.py24
LOWtests/pallas/tpu_all_gather_test.py16
LOWtests/pallas/tpu_splash_attention_mask_test.py15
LOWtests/pallas/tpu_pallas_pipeline_test.py16
LOWtests/pallas/pallas_shape_poly_test.py31
LOWtests/pallas/tpu_pallas_random_test.py26
LOWtests/pallas/indexing_test.py15
LOWtests/pallas/tpu_splash_attention_kernel_test.py15
LOWtests/pallas/pallas_test.py14
LOWdocs/parallel.py61
LOWdocs/autodidax.py2147
LOWjaxlib/xla_client.py17
LOWjaxlib/init.py15
LOWjaxlib/mosaic/gpu/wheel/__init__.py17
LOWjaxlib/mosaic/python/tpu.py20
LOWjaxlib/mosaic/python/tpu.py22
LOWjaxlib/mosaic/python/tpu.py24
LOWjaxlib/mosaic/python/mosaic_gpu.py23
LOWjaxlib/mosaic/python/mosaic_gpu.py25
LOWjaxlib/mosaic/python/mosaic_gpu.py26
LOWjaxlib/tools/build_utils.py17
LOWjaxlib/triton/dialect.py19
LOWjaxlib/triton/dialect.py23
LOWjaxlib/triton/dialect.py23
LOWjaxlib/triton/dialect.py30
LOWjaxlib/triton/dialect.py31
LOWjax/sharding.py18
LOWjax/sharding.py19
LOWjax/sharding.py19
LOWjax/sharding.py19
LOWjax/sharding.py19
LOWjax/sharding.py19
3416 more matches not shown…
Docstring Block Structure566 hits · 2302 pts
SeverityFileLineSnippet
HIGHjax/experimental/multihost_utils.py549A context manager for atomically running code on the set of live devices. THIS API IS UNDER ACTIVE DEVELOPMENT AND IS
HIGHjax/experimental/mosaic/gpu/profiler.py142Measures the GPU runtime of a function using CUPTI. ``measure`` is a higher-order function that wraps a function ``f`
HIGH…ops/tpu/splash_attention/splash_attention_mask_info.py84Downcast numpy array. If possible, downcast the data-type of the input array to the smallest numpy type (among np.i
HIGH…ops/tpu/splash_attention/splash_attention_mask_info.py326Similar to `_process_mask` but the mask must be a dynamic array. Since the mask is dynamic, we can't know the exact n
HIGH…ops/tpu/splash_attention/splash_attention_mask_info.py528Transform a dense mask into a sparse representation. The number of head and Q sequence shards are needed to create a
HIGH…llas/ops/tpu/splash_attention/splash_attention_mask.py97Makes a chunked causal attention mask. Args: shape: The desired shape of the mask (q_seq_len, kv_seq_len). ch
HIGHjax/experimental/sparse/linalg.py43Compute the top-k standard eigenvalues using the LOBPCG routine. LOBPCG [1] stands for Locally Optimal Block Precondi
HIGHjax/experimental/sparse/bcoo.py2437Experimental batched COO matrix implemented in JAX Args: (data, indices) : data and indices in batched COO format
HIGHjax/_src/callback.py269Calls a pure Python callback. Works under :func:`jit`/:func:`~vmap`/etc. For more explanation, see `External Callback
HIGHjax/_src/tree.py26Call all() over the leaves of a tree. Args: tree: the pytree to evaluate is_leaf : an optionally specified fu
HIGHjax/_src/tree.py55Flattens a pytree. The flattening order (i.e. the order of elements in the output list) is deterministic, correspon
HIGHjax/_src/tree.py91Gets the leaves of a pytree. Args: tree: the pytree for which to get the leaves is_leaf : an optionally speci
HIGHjax/_src/tree.py120Maps a multi-input function over pytree args to produce a new pytree. Args: f: function that takes ``1 + len(rest
HIGHjax/_src/tree.py163Call reduce() over the leaves of a tree. Args: function: the reduction function tree: the pytree to reduce ov
HIGHjax/_src/tree.py203Perform a reduction over a pytree with an associative binary operation. This function exploits the fact that the oper
HIGHjax/_src/tree.py245Gets the treedef for a pytree. Args: tree: the pytree for which to get the leaves is_leaf : an optionally spe
HIGHjax/_src/tree.py273Transform a tree having tree structure (outer, inner) into one having structure (inner, outer). Args: outer_treed
HIGHjax/_src/tree.py303Reconstructs a pytree from the treedef and the leaves. The inverse of :func:`tree_flatten`. Args: treedef: the
HIGHjax/_src/tree.py335Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path. Args: tree: a pytree to flatten.
HIGHjax/_src/tree.py366Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path. Args: tree: a pytree. If it co
HIGHjax/_src/tree.py395Maps a multi-input function over pytree key path and args to produce a new pytree. This is a more powerful alternativ
HIGHjax/_src/tree.py433Broadcasts a tree prefix into the full structure of a given tree. Args: prefix_tree: a pytree that is a tree
HIGHjax/_src/xla_bridge.py317Creates a CPU client with the requested collectives implementation. The implementation of CPU collectives used by the
HIGHjax/_src/config.py368Set up thread-local state and return a contextmanager for managing it. This function is a convenience wrapper. It def
HIGHjax/_src/core.py2891Invalidate a given reference and return its final value. For more information about mutable array references, refer t
HIGHjax/_src/dtypes.py887Returns the type to which a binary operation should cast its arguments. JAX implementation of :func:`numpy.promote_ty
HIGHjax/_src/dtypes.py1135Check if a dtype/value is safe to cast to another dtype/value Args: input_dtype_or_value: a dtype or value (to be
HIGHjax/_src/api.py208Sets up ``fun`` for just-in-time compilation with XLA. Args: fun: Function to be jitted. ``fun`` should be a pure
HIGHjax/_src/api.py2203Transfer array shards to specified devices and form Array(s). Args: shards: A sequence of arrays, scalars, or (ne
HIGHjax/_src/api.py2283Transfer array(s) to each specified device and form Array(s). Args: x: an array, scalar, or (nested) standard Pyt
HIGHjax/_src/api.py2352Transfer ``x`` to host. If ``x`` is a pytree, then the individual buffers are copied in parallel. Args: x: An
HIGHjax/_src/api.py2502A context manager that adds a user specified name to the JAX name stack. When staging out computations for just-in-ti
HIGHjax/_src/distributed.py233Initializes the JAX distributed system. Calling :func:`~jax.distributed.initialize` prepares JAX for execution on m
HIGHjax/_src/sharding_impls.py650Computes the global shape given the per process if possible. The returned shape will have the size of the global tens
HIGHjax/_src/tree_util.py125Makes a tuple treedef from an iterable of child treedefs. Args: treedefs: iterable of PyTree structures Return
HIGHjax/_src/tree_util.py151Return a list of treedefs for immediate children Args: treedef: a single PyTreeDef Returns: a list of PyTr
HIGHjax/_src/tree_util.py176Return True if the treedef represents a leaf. Args: treedef: tree to check Returns: True if treedef is a l
HIGHjax/_src/tree_util.py204Tests whether all elements in the given iterable are all leaves. This function is useful in advanced cases, for examp
HIGHjax/_src/tree_util.py336Extends the set of types that are considered internal nodes in pytrees. This function is a thin wrapper around ``regi
HIGHjax/_src/tree_util.py688Flatten the given pytree node by one level. Args: tree: A valid pytree node, either built-in or registered via
HIGHjax/_src/tree_util.py834Helper to pretty-print a tuple of keys. Args: keys: A tuple of ``KeyEntry`` or any class that can be converted to
HIGHjax/_src/tree_util.py951Extends the set of types that are considered internal nodes in pytrees. This function is similar to ``register_pytree
HIGHjax/_src/tree_util.py1005Extends the set of types that are considered internal nodes in pytrees. This differs from ``register_pytree_with_keys
HIGHjax/_src/tree_util.py1191Registers `cls` as a pytree with no leaves. Instances are treated as static by :func:`jax.jit`, :func:`jax.pmap`, etc
HIGHjax/_src/custom_partitioning_sharding_rule.py215Parses the LHS or RHS of an Einsum notation like string. Converts each operand or result in the Einsum notation like
HIGHjax/_src/array.py691Returns a ``jax.Array`` via data fetched from ``data_callback``. ``data_callback`` is used to fetch the data for each
HIGHjax/_src/array.py826Creates distributed tensor using the data available in process. This function is a common special case of `make_array
HIGHjax/_src/array.py1026Returns a ``jax.Array`` from a sequence of ``jax.Array``\s each on a single device. Every device in input ``shardi
HIGHjax/_src/profiler.py94Registers a subprocess's profiler server to be profiled alongside the current process. When the current process colle
HIGHjax/_src/mesh_utils.py789Creates a performant device mesh for jax.sharding.Mesh. Args: mesh_shape: shape of logical mesh, ordered by incre
HIGHjax/_src/mesh_utils.py870Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism. Args: mesh_shape: shape of the logical mesh for
HIGHjax/_src/pmap.py477Infer axis size from the first mapped argument. shard_map already does a check on all arguments, so just look at firs
HIGHjax/_src/pmap.py528Extract dynamic args and argnums after handling static args. Args: wrapped_f: The wrapped function. static_br
HIGHjax/_src/pmap.py566Compute flat in_axes tuple from in_axes prefix and args structure. Args: in_axes: The original in_axes specificat
HIGHjax/_src/pmap.py673Compute effective mesh devices based on context. Args: devices: The mesh devices tuple. backend: The backend
HIGHjax/_src/custom_derivatives.py190Define a custom JVP rule for the function represented by this instance. Args: jvp: a Python callable represen
HIGHjax/_src/custom_derivatives.py239Convenience wrapper for defining JVPs for each argument separately. This convenience wrapper cannot be used togethe
HIGHjax/_src/custom_derivatives.py624Define a custom VJP rule for the function represented by this instance. Args: fwd: a Python callable represen
HIGHjax/_src/interpreters/mlir.py259Translate a Python ``val`` to an IR constant. See https://docs.jax.dev/en/latest/internals/constants.html. Args:
HIGHjax/_src/interpreters/mlir.py291Translate a Python ``val`` to a sequence of IR constants. See https://docs.jax.dev/en/latest/internals/constants.html
506 more matches not shown…
Over-Commented Block1399 hits · 1334 pts
SeverityFileLineSnippet
LOWconftest.py1# Copyright 2021 The JAX Authors.
LOWconftest.py41# The pytest_collection hook can be used to overwrite the collection logic, but
LOWsetup.py1# Copyright 2018 The JAX Authors.
LOWbuild_wheel.py1# Copyright 2025 The JAX Authors.
LOWci/run_pytest_tpu.sh1#!/bin/bash
LOWci/run_pytest_rocm.sh1#!/bin/bash
LOWci/run_bazel_test_cuda_non_rbe.sh1#!/bin/bash
LOWci/run_bazel_test_cuda_non_rbe.sh21# -e: abort script if one command fails
LOWci/run_bazel_cuda_targeted_tests.sh1#!/bin/bash
LOWci/run_bazel_cuda_targeted_tests.sh21# JAXCI_BAZEL_TARGETS: newline-separated Bazel targets.
LOWci/run_bazel_test_rocm_rbe.sh1#!/bin/bash
LOWci/build_artifacts.sh1#!/bin/bash
LOWci/build_rocm_artifacts.sh1#!/bin/bash
LOWci/run_bazel_test_cpu_rbe.sh1#!/bin/bash
LOWci/upload_rocm_logs.sh1#!/bin/bash
LOWci/run_bazel_test_cuda_rbe.sh1#!/bin/bash
LOWci/run_pytest_cuda.sh1#!/bin/bash
LOWci/parse_wheel_metadata.py1# Copyright 2026 The JAX Authors.
LOWci/run_bazel_test_tpu.sh1#!/bin/bash
LOWci/run_bazel_test_tpu.sh181 --test_env=JAX_PLATFORMS=tpu,cpu \
LOWci/run_pytest_cpu.sh1#!/bin/bash
LOWci/utilities/install_wheels_locally.sh1#!/bin/bash
LOWci/utilities/run_docker_container.sh1#!/bin/bash
LOWci/utilities/run_docker_container.sh21# Note: While GitHub action workflows use the same Docker images, they do not
LOWci/utilities/convert_msys_paths_to_win_paths.py1# Copyright 2024 The JAX Authors.
LOWci/utilities/collect_bazel_test_xmls.sh1#!/bin/bash
LOWci/utilities/collect_bazel_test_xmls.sh21# bazel-testlogs/tests/cpu_tests/test.xml). Pytest, by contrast, writes a
LOWci/utilities/setup_build_environment.sh1#!/bin/bash
LOWci/utilities/run_auditwheel.sh1#!/bin/bash
LOWci/postprocess/process_test_results.sh1#!/bin/bash
LOWci/postprocess/xml2json.py1# Copyright 2026 The JAX Authors.
LOWtests/random_test.py1# Copyright 2018 The JAX Authors.
LOWtests/random_test.py641
LOWtests/extend_test.py1# Copyright 2023 The JAX Authors.
LOWtests/svd_test.py1# Copyright 2022 The JAX Authors.
LOWtests/scipy_spatial_test.py1# Copyright 2023 The JAX Authors.
LOWtests/scheduling_groups_test.py1# Copyright 2025 The JAX Authors.
LOWtests/jax_to_ir_test.py1# Copyright 2019 The JAX Authors.
LOWtests/ragged_collective_test.py1# Copyright 2025 The JAX Authors.
LOWtests/pytorch_interoperability_test.py1# Copyright 2020 The JAX Authors.
LOWtests/fft_test.py1# Copyright 2019 The JAX Authors.
LOWtests/export_back_compat_test.py1# Copyright 2023 The JAX Authors.
LOWtests/export_back_compat_test.py1041 # stablehlo.dynamic_rbg_bit_generator is used temporarily for a
LOWtests/mesh_utils_test.py1# Copyright 2021 The JAX Authors. All Rights Reserved.
LOWtests/lax_vmap_op_test.py1# Copyright 2020 The JAX Authors.
LOWtests/debug_nans_test.py1# Copyright 2019 The JAX Authors.
LOWtests/lobpcg_test.py1# Copyright 2022 The JAX Authors.
LOWtests/heap_profiler_test.py1# Copyright 2021 The JAX Authors.
LOWtests/custom_root_test.py1# Copyright 2018 The JAX Authors.
LOWtests/mosaic_test.py1# Copyright 2023 The JAX Authors.
LOWtests/lax_control_flow_test.py1# Copyright 2018 The JAX Authors.
LOWtests/lax_control_flow_test.py2221
LOWtests/array_interoperability_test.py1# Copyright 2020 The JAX Authors.
LOWtests/lax_scipy_test.py1# Copyright 2018 The JAX Authors.
LOWtests/errors_test.py1# Copyright 2020 The JAX Authors.
LOWtests/distributed_initialize_test.py1# Copyright 2025 The JAX Authors.
LOWtests/debugger_test.py1# Copyright 2022 The JAX Authors.
LOWtests/multi_device_test.py1# Copyright 2019 The JAX Authors.
LOWtests/magma_linalg_test.py1# Copyright 2024 The JAX Authors.
LOWtests/lax_numpy_setops_test.py1# Copyright 2025 The JAX Authors.
1339 more matches not shown…
Decorative Section Separators125 hits · 375 pts
SeverityFileLineSnippet
MEDIUMci/run_pytest_tpu.sh15# ==============================================================================
MEDIUMci/run_pytest_rocm.sh15# ==============================================================================
MEDIUMci/run_pytest_rocm.sh41# ==============================================================================
MEDIUMci/run_pytest_rocm.sh43# ==============================================================================
MEDIUMci/run_pytest_rocm.sh50# ==============================================================================
MEDIUMci/run_pytest_rocm.sh53# ==============================================================================
MEDIUMci/run_pytest_rocm.sh106# ==============================================================================
MEDIUMci/run_pytest_rocm.sh108# ==============================================================================
MEDIUMci/run_bazel_test_cuda_non_rbe.sh15# ==============================================================================
MEDIUMci/run_bazel_cuda_targeted_tests.sh15# ==============================================================================
MEDIUMci/run_bazel_test_rocm_rbe.sh15# ==============================================================================
MEDIUMci/build_artifacts.sh15# ==============================================================================
MEDIUMci/build_rocm_artifacts.sh15# ==============================================================================
MEDIUMci/run_bazel_test_cpu_rbe.sh15# ==============================================================================
MEDIUMci/upload_rocm_logs.sh15# ==============================================================================
MEDIUMci/run_bazel_test_cuda_rbe.sh15# ==============================================================================
MEDIUMci/run_pytest_cuda.sh15# ==============================================================================
MEDIUMci/run_pytest_cuda.sh45# ==============================================================================
MEDIUMci/run_pytest_cuda.sh47# ==============================================================================
MEDIUMci/run_pytest_cuda.sh54# ==============================================================================
MEDIUMci/run_pytest_cuda.sh57# ==============================================================================
MEDIUMci/run_pytest_cuda.sh107# ==============================================================================
MEDIUMci/run_pytest_cuda.sh109# ==============================================================================
MEDIUMci/run_bazel_test_tpu.sh15# ==============================================================================
MEDIUMci/run_pytest_cpu.sh15# ==============================================================================
MEDIUMci/utilities/install_wheels_locally.sh15# ==============================================================================
MEDIUMci/utilities/run_docker_container.sh15# ==============================================================================
MEDIUMci/utilities/convert_msys_paths_to_win_paths.py14# ==============================================================================
MEDIUMci/utilities/collect_bazel_test_xmls.sh15# ==============================================================================
MEDIUMci/utilities/setup_build_environment.sh15# ==============================================================================
MEDIUMci/utilities/run_auditwheel.sh15# ==============================================================================
MEDIUMci/postprocess/process_test_results.sh15# ==============================================================================
MEDIUMtests/mesh_utils_test.py14# ==============================================================================
MEDIUMtests/hijax_test.py2041 # ------------
MEDIUMtests/hijax_test.py2043 # ------------
MEDIUMtests/hijax_test.py2114 #------------
MEDIUMtests/hijax_test.py2116 #------------
MEDIUMtests/mosaic/gpu_constraints_test.py14# ==============================================================================
MEDIUMtests/mosaic/gpu_torch_distributed_test.py14# ==============================================================================
MEDIUMtests/mosaic/gpu_dialect_test.py14# ==============================================================================
MEDIUMtests/mosaic/gpu_torch_test.py14# ==============================================================================
MEDIUMtests/mosaic/profiler_cupti_test.py14# ==============================================================================
MEDIUMtests/mosaic/flash_attention_test.py14# ==============================================================================
MEDIUMtests/mosaic/gpu_multidevice_test.py14# ==============================================================================
MEDIUMtests/mosaic/gpu_layout_inference_test.py14# ==============================================================================
MEDIUMtests/mosaic/matmul_test.py14# ==============================================================================
MEDIUMtests/mosaic/gpu_test.py14# ==============================================================================
MEDIUMtests/mosaic/gpu_distributed_test.py14# ==============================================================================
MEDIUMtests/pallas/mgpu_ragged_dot_test.py14# ==============================================================================
MEDIUMtests/pallas/mgpu_torch_test.py14# ==============================================================================
MEDIUMtests/pallas/mgpu_examples_test.py14# ==============================================================================
MEDIUMtests/pallas/mgpu_attention_test.py14# ==============================================================================
MEDIUMtests/pallas/mgpu_collective_matmul_test.py14# ==============================================================================
MEDIUMtests/pallas/mgpu_matmul_test.py14# ==============================================================================
MEDIUMdocs/notebooks/cute_dsl_jax/cute_dsl_jax_kernels.py284 # ── Vector Add ────────────────────────────────────────────────────
MEDIUMdocs/notebooks/cute_dsl_jax/cute_dsl_jax_kernels.py297 # ── SAXPY ─────────────────────────────────────────────────────────
MEDIUMdocs/notebooks/cute_dsl_jax/cute_dsl_jax_kernels.py311 # ── ReLU ──────────────────────────────────────────────────────────
MEDIUMdocs/notebooks/cute_dsl_jax/cute_dsl_jax_kernels.py323 # ── Fused Bias + ReLU ─────────────────────────────────────────────
MEDIUMdocs/notebooks/cute_dsl_jax/cute_dsl_jax_kernels.py337 # ── GEMM ──────────────────────────────────────────────────────────
MEDIUMdocs/notebooks/cute_dsl_jax/cute_dsl_jax_kernels.py351 # ── Elementwise Add (2-D) ─────────────────────────────────────────
65 more matches not shown…
Deep Nesting437 hits · 346 pts
SeverityFileLineSnippet
LOWconftest.py55
LOWci/postprocess/xml2json.py70
LOWtests/random_test.py690
LOWtests/export_back_compat_test.py557
LOWtests/lax_control_flow_test.py2376
LOWtests/errors_test.py44
LOWtests/ffi_test.py45
LOWtests/lax_numpy_einsum_test.py295
LOWtests/debug_info_test.py119
LOWtests/api_util_test.py28
LOWtests/dtypes_test.py489
LOWtests/fused_attention_stablehlo_test.py123
LOWtests/lax_numpy_test.py2551
LOWtests/lax_numpy_test.py6579
LOWtests/lax_numpy_reducers_test.py164
LOWtests/linalg_test.py765
LOWtests/linalg_test.py1843
LOWtests/linalg_test.py1945
LOWtests/linalg_test.py1846
LOWtests/linalg_test.py1953
LOWtests/profiler_test.py54
LOWtests/profiler_test.py133
LOWtests/lax_numpy_indexing_test.py1491
LOWtests/export_test.py2391
LOWtests/control_deps_test.py53
LOWtests/pgle_test.py135
LOWtests/pgle_test.py221
LOWtests/pgle_test.py367
LOWtests/sparse_bcoo_bcsr_test.py57
LOWtests/sparse_bcoo_bcsr_test.py101
LOWtests/device_test.py24
LOWtests/device_test.py46
LOWtests/compilation_cache_test.py530
LOWtests/transfer_guard_test.py190
LOWtests/transfer_guard_test.py201
LOWtests/api_test.py4955
LOWtests/api_test.py5278
LOWtests/api_test.py5286
LOWtests/documentation_coverage_test.py140
LOWtests/sparse_test.py1021
LOWtests/lax_test.py117
LOWtests/lax_test.py1612
LOWtests/lax_test.py4565
LOWtests/mosaic/gpu_test.py662
LOWtests/mosaic/gpu_test.py1112
LOWtests/mosaic/gpu_test.py4993
LOWtests/mosaic/gpu_test.py5678
LOWtests/mosaic/gpu_test.py6862
LOWtests/multiprocess/thread_guard_test.py42
LOWtests/multiprocess/pgle_test.py44
LOWtests/pallas/tpu_pallas_interpret_test.py313
LOWtests/pallas/tpu_pallas_interpret_test.py409
LOWtests/pallas/triton_pallas_test.py142
LOW…allas/tpu_fusible_matmul_with_stateful_fusions_test.py84
LOWtests/pallas/einshape_test.py60
LOWtests/pallas/ops_test.py80
LOWtests/pallas/ops_test.py724
LOWtests/pallas/ops_test.py2126
LOWtests/pallas/mosaic_gpu_test.py3352
LOWtests/pallas/tpu_fusible_matmul_test.py84
377 more matches not shown…
Self-Referential Comments73 hits · 202 pts
SeverityFileLineSnippet
MEDIUMci/utilities/run_docker_container.sh55 # Create a temporary file to pass any user defined JAXCI_ / JAX_ / JAXLIB_
MEDIUMci/utilities/setup_build_environment.sh105# Create the output directory if it doesn't exist.
MEDIUMtests/ragged_collective_test.py769 # Define a mesh with PP + EP
MEDIUMtests/lax_scipy_test.py548 # This function is not defined for negative values, this makes sure they are nan
MEDIUMtests/custom_api_test.py2081 # Create the custom function
MEDIUMtests/garbage_collection_guard_test.py54 # Create a reference cycle of two jax.Arrays.
MEDIUMtests/garbage_collection_guard_test.py67 # Create a reference cycle of two jax.Arrays.
MEDIUMtests/layout_test.py361 # Create a custom layout instead of using `arr.layout` to test the API.
MEDIUMtests/layout_test.py384 # Create a custom layout instead of using `arr.layout` to test the API.
MEDIUMtests/layout_test.py665 # Create a custom layout instead of using `arr.layout` to test the API.
MEDIUMtests/sparse_bcoo_bcsr_test.py1153 # Create a matrix with duplicate indices
MEDIUMtests/aot_test.py337 # Create a compile-only topology, but DON'T switch to CPU so the real
MEDIUMtests/array_test.py666 # Create a few arrays
MEDIUMtests/hijax_test.py62# Define a type
MEDIUMtests/api_test.py2053 # Creating an array from a numpy array with a fully-replicated sharding
MEDIUMtests/api_test.py2070 # Creating an array from a numpy array with a non-fully-replicated sharding
MEDIUMtests/api_test.py2086 # Creating an array from per-device JAX arrays calls internal
MEDIUMtests/mosaic/gpu_layout_inference_test.py1466 # Create a var to use in the constraint system.
MEDIUMtests/multiprocess/array_test.py613 # Create an array that is non-addressable in processes besides `pid`.
MEDIUMtests/multiprocess/array_test.py635 # Create a PRNG key array that is non-addressable in processes besides
MEDIUMtests/multiprocess/array_test.py669 # Create a sharding that is non-addressable in processes besides `pid`.
MEDIUMtests/multiprocess/array_test.py696 # Create a sharding that is non-addressable in processes besides `pid`.
MEDIUMtests/multiprocess/array_test.py723 # Create a single device sharding for a device local to process `pid`.
MEDIUMtests/pallas/tpu_pallas_interpret_distributed_test.py63 # Create an input array that shards the last dimension across
MEDIUMtests/pallas/tpu_pallas_interpret_distributed_test.py149 # Create an input array that shards the first dimension across
MEDIUMtests/pallas/mosaic_gpu_test.py6859 # Create an index-invariant output.
MEDIUMtests/pallas/pallas_test.py2509 # Create a validity mask for OOB values.
MEDIUMdocs/sphinxext/jax_list_config_options.py38 # Create a field list item
MEDIUMdocs/sphinxext/jax_list_config_options.py41 # Create the field name (label)
MEDIUMdocs/sphinxext/jax_list_config_options.py46 # Create the field body (content)
MEDIUMdocs/sphinxext/jax_list_config_options.py81 # Create a section for this option
MEDIUMdocs/sphinxext/jax_list_config_options.py86 # Create a title with the option name (important for TOC)
MEDIUMdocs/sphinxext/jax_list_config_options.py92 # Create a field list for side-by-side display
MEDIUMjaxlib/tools/build_wheel.py225 # This file is required by PEP-561. It marks jaxlib as package containing
MEDIUM.github/workflows/bazel_cuda_b200_mosaic_presubmit.yml87 # Create an array with the file patterns and
MEDIUMjax/version.py15# This file is included as part of both jax and jaxlib. It is also
MEDIUM…ops/tpu/splash_attention/splash_attention_mask_info.py587 # Create a collection of the unique head masks in the input multi-head mask.
MEDIUM…llas/ops/tpu/splash_attention/splash_attention_mask.py374 # Define the mask function for chunk attention
MEDIUMjax/experimental/pallas/ops/tpu/megablox/gmm.py180 # Create the group ids for each grid index based on the tile counts for each
MEDIUMjax/experimental/pallas/ops/tpu/megablox/gmm.py230 # Create the m-dimension tile ids for each grid index based on the visit
MEDIUMjax/experimental/pallas/ops/tpu/megablox/gmm.py383 # Create the metadata we need for computation.
MEDIUMjax/experimental/pallas/ops/tpu/megablox/gmm.py632 # Create the metadata we need for computation.
MEDIUMjax/experimental/jax2tf/call_tf.py219 # Define the fwd and bwd custom_vjp functions
MEDIUMjax/experimental/jax2tf/tests/jax2tf_test.py349 x = tf.Variable(4.0, dtype=tf.float32) # Create a Tensorflow variable initialized to 4.0
MEDIUMjax/experimental/jax2tf/examples/mnist_lib.py225 # Create the model and save it
MEDIUMjax/_src/test_util.py2040 # The following function methods operate on mpmath number instances.
MEDIUMjax/_src/custom_batching.py178# Define a class, instead of making a function closing over `rule`, so
MEDIUMjax/_src/ad_checkpoint.py400# This function is similar to api_util.argnums_partial, except the error
MEDIUMjax/_src/ad_checkpoint.py926 >>> # Define a function where we explicitly name an intermediate value
MEDIUMjax/_src/ffi.py710 # This method is kept to support the behavior that was previously exposed
MEDIUMjax/_src/interpreters/partial_eval.py255 # Create the input tracers for the staged-out (unknown-value) call.
MEDIUMjax/_src/interpreters/partial_eval.py2449 # This function is conceptually the same thing as just calling eval_jaxpr,
MEDIUMjax/_src/interpreters/mlir.py1350 # Create a keepalives list that will be mutated during the lowering.
MEDIUMjax/_src/lax/utils.py15# This module contains utility functions split out of jax._src.lax.lax to
MEDIUMjax/_src/lax/control_flow/loops.py1333 # Create the staged eqn.
MEDIUMjax/_src/pallas/mosaic/lowering.py3849 # Create a scalar constant.
MEDIUMjax/_src/pallas/mosaic/lowering.py3852 # Create a vector constant.
MEDIUMjax/_src/pallas/mosaic_gpu/lowering.py590 # The below code emission relies on the assumption that the first scratch
MEDIUMjax/_src/lib/__init__.py15# This module is largely a wrapper around `jaxlib` that performs version
MEDIUMjax/_src/export/serialization_generated.py109 """This method is deprecated. Please switch to GetRootAs."""
13 more matches not shown…
Cross-File Repetition27 hits · 135 pts
SeverityFileLineSnippet
HIGHjaxlib/setup.py0this class makes 'bdist_wheel' include an abi tag on the wheel.
HIGHjax_plugins/cuda/plugin_setup.py0this class makes 'bdist_wheel' include an abi tag on the wheel.
HIGHjax_plugins/rocm/plugin_setup.py0this class makes 'bdist_wheel' include an abi tag on the wheel.
HIGHjax_plugins/oneapi/plugin_setup.py0this class makes 'bdist_wheel' include an abi tag on the wheel.
HIGHjax/experimental/mosaic/gpu/fragmented_array.py0returns the shape of the register array needed to represent an array of the given logical shape.
HIGHjax/experimental/mosaic/gpu/fragmented_array.py0returns the shape of the register array needed to represent an array of the given logical shape.
HIGHjax/experimental/mosaic/gpu/fragmented_array.py0returns the shape of the register array needed to represent an array of the given logical shape.
HIGH…mental/jax2tf/tests/flax_models/transformer_nlp_seq.py0global hyperparameters used to minimize obnoxious kwarg plumbing.
HIGH…erimental/jax2tf/tests/flax_models/transformer_lm1b.py0global hyperparameters used to minimize obnoxious kwarg plumbing.
HIGH…perimental/jax2tf/tests/flax_models/transformer_wmt.py0global hyperparameters used to minimize obnoxious kwarg plumbing.
HIGHjax/_src/export/shape_poly.py0returns -1 if self < other, 0 if self == other, 1 if self > other. the comparison is done lexicographically (syntactic),
HIGHjax/_src/export/shape_poly.py0returns -1 if self < other, 0 if self == other, 1 if self > other. the comparison is done lexicographically (syntactic),
HIGHjax/_src/export/shape_poly.py0returns -1 if self < other, 0 if self == other, 1 if self > other. the comparison is done lexicographically (syntactic),
HIGHjax/_src/export/serialization_generated.py0this method is deprecated. please switch to getrootas.
HIGHjax/_src/export/serialization_generated.py0this method is deprecated. please switch to getrootas.
HIGHjax/_src/export/serialization_generated.py0this method is deprecated. please switch to getrootas.
HIGHjax/_src/export/serialization_generated.py0this method is deprecated. please switch to getrootas.
HIGHjax/_src/export/serialization_generated.py0this method is deprecated. please switch to getrootas.
HIGHjax/_src/export/serialization_generated.py0this method is deprecated. please switch to getrootas.
HIGHjax/_src/export/serialization_generated.py0this method is deprecated. please switch to getrootas.
HIGHjax/_src/export/serialization_generated.py0this method is deprecated. please switch to getrootas.
HIGHjax/_src/export/serialization_generated.py0this method is deprecated. please switch to getrootas.
HIGHjax/_src/export/serialization_generated.py0this method is deprecated. please switch to getrootas.
HIGHjax/_src/export/serialization_generated.py0this method is deprecated. please switch to getrootas.
HIGH…xport_back_compat_test_data/annotate_data_placement.py0#loc1 = loc("x") #loc2 = loc("y") module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions
HIGH…xport_back_compat_test_data/annotate_data_placement.py0#loc1 = loc("x") #loc2 = loc("y") module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions
HIGH…xport_back_compat_test_data/annotate_data_placement.py0#loc1 = loc("x") #loc2 = loc("y") module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions
Cross-Language Confusion24 hits · 130 pts
SeverityFileLineSnippet
HIGHtests/stack_test.py34 stack = stack.push(jnp.int32(7))
HIGHtests/stack_test.py36 stack = stack.push(jnp.int32(8))
HIGHtests/stack_test.py41 stack = stack.push(jnp.int32(9))
HIGHtests/export_test.py1117 stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error)))
HIGHtests/export_test.py934 stack.push(self.assertRaisesRegex(Exception, expect_error))
HIGHtests/export_test.py1040 stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp))
HIGHtests/export_test.py1054 stack.push(self.assertRaisesRegex(Exception, expect_error_run))
HIGHjax/experimental/jax2tf/tests/shape_poly_test.py440 stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error)))
HIGHjax/experimental/colocated_python/func.py363 func_backend.SINGLETON_RESULT_STORE.push(uid, result_leaves)
HIGHjax/_src/jaxpr_util.py627 while ((match = anchorRegex.exec(line)) !== null) {{
HIGHjax/_src/jaxpr_util.py672 const isCurrent = currentMatchIdx !== -1 && lineAbsoluteIdx === matchingLines[currentMatchIdx];
HIGHjax/_src/jaxpr_util.py705 let selectedElement = null;
HIGHjax/_src/jaxpr_util.py713 if (lineIdx !== undefined) {{
HIGHjax/_src/jaxpr_util.py738 while (currentIdx !== null && currentIdx !== undefined) {{
HIGHjax/_src/jaxpr_util.py738 while (currentIdx !== null && currentIdx !== undefined) {{
HIGHjax/_src/jaxpr_util.py743 renderedFrames.push({{file: file, func: func, line: frame.line, col: frame.col}});
HIGHjax/_src/jaxpr_util.py804 matchingLines.push(idx);
HIGHjax/_src/jaxpr_util.py808 matchingLines.push(idx);
HIGHjax/_src/jaxpr_util.py937 if (offsetRight > 100 && offsetRight < document.body.clientWidth - 100) {{
HIGHjax/_src/lax/lax.py9018 b <= a, then the result is undefined, and different implementations may
HIGHjax/_src/numpy/linalg.py529 discontinuous at det=0 so standard autodiff is undefined; we use sign_dot=0
HIGHjax/_src/tpu/linalg/eigh.py326 agenda = agenda.push(_Subproblem(offset=jnp.array(0, np.int32), size=n))
HIGHjax/_src/tpu/linalg/eigh.py417 agenda.push(_Subproblem(offset, rank)),
HIGHjax/_src/tpu/linalg/eigh.py436 agenda.push(_Subproblem(offset + rank, (b - rank))),
AI Slop Vocabulary47 hits · 114 pts
SeverityFileLineSnippet
MEDIUMci/utilities/run_docker_container.sh22# run this script as they leverage built-in containerization features to run
MEDIUMtests/export_back_compat_test.py264 # Compute the inputs to simplify the harness
MEDIUMtests/lax_numpy_operators_test.py294 # probably add a custom test harness for unwrap that tests the period
LOWtests/custom_api_test.py366 # useful either: instead of using nondiff_argnums here, a user can just pass
MEDIUMtests/lax_numpy_test.py3350 # permissive canonicalization logic in the test harness.
MEDIUMtests/export_harnesses_multi_platform_test.py60 # If you want to run this test for only one harness, add parameter
MEDIUMtests/shape_poly_test.py3671 # Pick the dtype with the most harnesses in this group. Some harness
MEDIUMtests/shape_poly_test.py3743 # If you want to run this test for only one harness that includes "foo"
MEDIUMtests/shape_poly_test.py3800 # Update this here rather than in harness object because vmap_random_gamma is derived
LOWtests/shape_poly_test.py1862 # It is not sufficient to just use the shape of an input; it is still unused
MEDIUMtests/linalg_test.py984 # This expresses identity function, which makes us robust to, e.g., the
MEDIUMtests/lax_autodiff_test.py159 # TODO(mattjj): make some-equal checks more robust, enable second-order
MEDIUMtests/sparsify_test.py341 # Note: more comprehensive tests in sparse_test.py:test_bcoo_squeeze
MEDIUMtests/sparsify_test.py354 # Note: more comprehensive tests in sparse_test.py:test_bcoo_rev
MEDIUMtests/sparsify_test.py419 # Note: more comprehensive tests in sparse_test.py:test_bcoo_conv_general_dilated
LOWtests/api_test.py454 # We can't just use `lambda x: x` because JAX simplifies this away to an
LOWtests/mosaic/gpu_layout_inference_test.py90 # TMEM reference, so we can just return a trivial mapping.
LOWdocs/array_refs.py115# `jax.ref.swap`, but usually you'd just use NumPy-style array indexing syntax:
LOWdocs/autodidax.py1475# call `bind`, in the primitive wrapper we can just use `make_jaxpr` to form
MEDIUM…ops/tpu/splash_attention/splash_attention_mask_info.py892 # maintain the SPMD paradigm.
MEDIUM…erimental/jax2tf/tests/jax_primitives_coverage_test.py50 # If you want to run this test for only one harness, add parameter
MEDIUM…erimental/jax2tf/tests/jax_primitives_coverage_test.py90 # f"{[u.description for u in jax_unimpl]} in harness: {harness.fullname}"))
MEDIUMjax/experimental/jax2tf/tests/shape_poly_test.py223 # Makes and tests a harness. See PolyHarness documentation.
LOWjax/experimental/jax2tf/tests/shape_poly_test.py587 # It is not sufficient to just use the shape of an input; it is still unused
MEDIUMjax/experimental/jax2tf/tests/model_harness.py77 """Partially apply harness in order to create variables lazily.
MEDIUMjax/experimental/jax2tf/tests/tf_test_util.py221 # Run JAX. Should not fail, we assume that the harness has been filtered
MEDIUMjax/experimental/jax2tf/tests/primitives_test.py91 # If you want to run this test for only one harness, add parameter
LOWjax/_src/compiler.py494# simply return the non-PGLE profiled module from the persistent cache if it
LOWjax/_src/pjit.py1019 # checks and just return the pjit_in_shardings directly. `shard_args` will
LOWjax/_src/checkify.py348 # Default non-HOP case: just call primitive and don't update error.
LOWjax/_src/interpreters/batching.py632 # if there's only agreeing batch dims and scalars, just call the primitive
LOWjax/_src/lax/ann.py367# 1. ApproxTopK is internally a variadic reduce, so we can simply call
MEDIUMjax/_src/lax/control_flow/loops.py2023# lowering. Fundamentally, we'd like to rewrite a while loop that looks like
MEDIUMjax/_src/numpy/linalg.py1445 # TODO: add custom jvp rule for more robust lstsq differentiation
LOWjax/_src/pallas/pallas_call.py689 # dimensions. For now, we just use 0.
LOWjax/_src/pallas/mosaic/lowering.py1682 # `jax.empty_ref`), but lowering expects them to exist---so we just return
MEDIUM…/_src/pallas/mosaic/interpret/interpret_pallas_call.py1592 # TODO(nrink): It would be more robust if the buffer id, i.e. `src`,
LOW…/_src/pallas/mosaic/interpret/interpret_pallas_call.py636 # callback to `get`. Should we just pass the shape to `get`?
LOWjax/_src/pallas/mosaic_gpu/interpret/gpu_callbacks.py662 # callback to `get`. Should we just pass the shape to `get`?
MEDIUMjax/_src/internal_test_util/test_harnesses.py148 # Descriptive name of the harness, used as a testcase_name. Unique in a group.
MEDIUMjax/_src/internal_test_util/test_harnesses.py157 # partially implemented in JAX for this harness.
MEDIUMjax/_src/internal_test_util/test_harnesses.py386 # Change the testcase name to include the harness name.
MEDIUMjax/_src/internal_test_util/test_harnesses.py2003# This first harness runs the tests for all dtypes using default values for
MEDIUMjax/_src/internal_test_util/test_harnesses.py2589# This first harness runs the tests for all dtypes using default values for
MEDIUMjax/_src/internal_test_util/test_harnesses.py2894# This first harness runs the tests for all dtypes and precisions using
MEDIUMjax/_src/internal_test_util/test_harnesses.py3095 # This first harness runs the tests for all dtypes and precisions using
MEDIUMjax/_src/internal_test_util/test_harnesses.py351 self.enabled = enabled # Does it apply to the current harness?
Dead Code55 hits · 104 pts
SeverityFileLineSnippet
MEDIUMtests/custom_api_test.py374
MEDIUMtests/custom_api_test.py376
MEDIUMtests/custom_api_test.py379
MEDIUMtests/custom_api_test.py382
MEDIUMtests/custom_api_test.py385
MEDIUMtests/custom_api_test.py386
MEDIUMtests/custom_api_test.py387
MEDIUMtests/custom_api_test.py1816
MEDIUMtests/custom_api_test.py1828
MEDIUMtests/custom_api_test.py1831
MEDIUMtests/custom_api_test.py1832
MEDIUMtests/custom_api_test.py1833
MEDIUMtests/custom_api_test.py1835
MEDIUMtests/custom_api_test.py1836
MEDIUMtests/custom_api_test.py1837
MEDIUMtests/lax_numpy_reducers_test.py907
MEDIUMtests/lax_numpy_reducers_test.py908
MEDIUMtests/lax_numpy_reducers_test.py909
MEDIUMtests/lax_numpy_reducers_test.py919
MEDIUMtests/lax_numpy_reducers_test.py920
MEDIUMtests/lax_numpy_reducers_test.py921
MEDIUMtests/pmap_test.py2067
MEDIUMtests/pmap_test.py2069
MEDIUMtests/pmap_test.py2075
MEDIUMtests/pmap_test.py2085
MEDIUMtests/pmap_test.py2089
MEDIUMtests/pmap_test.py2104
MEDIUMtests/pmap_test.py2110
MEDIUMtests/pmap_test.py2111
MEDIUMtests/pmap_test.py2112
MEDIUMtests/pmap_test.py2113
MEDIUMtests/pmap_test.py2124
MEDIUMtests/pmap_test.py2130
MEDIUMtests/pmap_test.py2131
MEDIUMtests/pmap_test.py2132
MEDIUMtests/pmap_test.py2133
MEDIUMtests/pmap_test.py2139
MEDIUMtests/pmap_test.py2145
MEDIUMtests/pmap_test.py2295
MEDIUMtests/pmap_test.py2297
MEDIUMtests/pmap_test.py2300
MEDIUMtests/pmap_test.py2301
MEDIUMtests/pmap_test.py2303
MEDIUMtests/pmap_test.py2304
MEDIUMtests/pmap_test.py2305
MEDIUMtests/pmap_test.py2306
MEDIUMtests/pmap_test.py2318
MEDIUMtests/lax_test.py1774
MEDIUMtests/lax_test.py1806
MEDIUMtests/mosaic/gpu_test.py949
MEDIUMtests/mosaic/gpu_test.py951
MEDIUMtests/mosaic/gpu_test.py952
MEDIUMtests/mosaic/gpu_test.py953
MEDIUMtests/mosaic/gpu_test.py960
MEDIUMjax/_src/pallas/core.py1675
Excessive Try-Catch Wrapping86 hits · 95 pts
SeverityFileLineSnippet
MEDIUMci/utilities/convert_msys_paths_to_win_paths.py40 print("Error: cygpath not found. Make sure it's in your PATH.")
MEDIUMci/utilities/convert_msys_paths_to_win_paths.py43 print(f"Error converting path: {e}")
LOWci/postprocess/xml2json.py87 except Exception as e:
LOWtests/array_interoperability_test.py155 except Exception as e:
LOWtests/debug_info_test.py111 except Exception as e:
LOWtests/shape_poly_test.py2041 except Exception:
LOWtests/profiler_test.py63 except Exception as e:
LOWtests/transfer_guard_test.py117 except Exception as e:
LOWtests/transfer_guard_test.py136 except Exception as e:
LOWtests/api_test.py5039 except Exception as e:
LOWtests/pallas/gpu_pallas_distributed_test.py600 except Exception:
LOWtests/pallas/gpu_pallas_distributed_test.py651 except Exception:
LOWtests/pallas/gpu_pallas_distributed_test.py1306 except Exception:
LOWtests/pallas/ops_test.py709 except Exception as e:
LOWtests/pallas/ops_test.py837 except Exception as e:
LOWtests/pallas/ops_test.py2688 except Exception as e:
LOWdocs/array_refs.py82except Exception as e:
LOWdocs/array_refs.py222except Exception as e:
LOWdocs/array_refs.py231except Exception as e:
LOWdocs/array_refs.py240except Exception as e:
LOWdocs/array_refs.py249except Exception as e:
LOWdocs/array_refs.py289except Exception as e:
LOWdocs/parallel.md425except Exception as e:
LOWdocs/parallel.md464except Exception as e:
LOWdocs/array_refs.md87except Exception as e:
LOWdocs/array_refs.md222except Exception as e:
LOWdocs/array_refs.md231except Exception as e:
LOWdocs/array_refs.md240except Exception as e:
LOWdocs/array_refs.md249except Exception as e:
LOWdocs/array_refs.md288except Exception as e:
LOWdocs/parallel.py379except Exception as e:
LOWdocs/parallel.py417except Exception as e:
LOWdocs/sphinxext/source_include.py72 except Exception as e:
MEDIUMdocs/sphinxext/source_include.py40def get_tagged_block(filepath, tag, lines_spec=None):
LOW…atic/fault_tolerance/data_parallelism_with_recovery.py171 except Exception as e:
LOWdocs/_static/fault_tolerance/live_devices.py56 except Exception as e:
LOWdocs/_static/fault_tolerance/data_parallelism.py118 except Exception as e:
MEDIUMdocs/pallas/tpu/sparse.md148print("Error |result - lax.dynamic_slice| =", diff)
LOWjax/__init__.py28except Exception as exc:
LOWjax/experimental/multihost_utils.py679 except Exception as e:
LOWjax/experimental/array_serialization/serialization.py238 except Exception as e:
MEDIUMjax/experimental/array_serialization/serialization.py197def _thread_func(self):
LOWjax/experimental/jax2tf/call_tf.py294 except Exception as e:
LOWjax/experimental/jax2tf/call_tf.py591 except Exception as e:
LOW…erimental/jax2tf/tests/jax_primitives_coverage_test.py72 except Exception as e:
LOWjax/experimental/jax2tf/tests/models_test_main.py206 except Exception as e:
LOWjax/experimental/jax2tf/tests/tf_test_util.py246 except Exception as e:
LOWjax/experimental/jax2tf/tests/primitives_test.py135 except Exception as e:
LOWjax/experimental/jax2tf/examples/mnist_lib.py76 except Exception as e:
LOWjax/experimental/colocated_python/obj_backend.py107 except Exception as exc:
LOWjax/experimental/sparse/bcsr.py892 except Exception:
LOWjax/_src/environment_info.py30 except Exception:
MEDIUMjax/_src/environment_info.py27def try_nvidia_smi() -> str | None:
LOWjax/_src/xla_bridge.py832 except Exception as err:
LOWjax/_src/compiler.py83 except Exception:
LOWjax/_src/compiler.py760 except Exception as ex:
LOWjax/_src/compiler.py779 except Exception as ex:
LOWjax/_src/compiler.py830 except Exception as ex:
LOWjax/_src/util.py540 except Exception:
LOWjax/_src/util.py635 except Exception as e:
26 more matches not shown…
Redundant / Tautological Comments26 hits · 38 pts
SeverityFileLineSnippet
LOWtests/array_interoperability_test.py183 # Check if the source device is preserved
LOWtests/nn_test.py119 # Check if float8_e8m0fnu is available
LOWtests/nn_test.py327 # Check if cudnn backend is called (only on CUDA).
LOWtests/pgle_test.py277 # Check if FDO profile file of the biggest module is not empty
LOWtests/pgle_test.py285 # Check if FDO profile file in dump directory is not empty
LOWtests/multiprocess/array_test.py297 # Check if we can specify that local input actually contains full-span
LOW…allas/tpu_fusible_matmul_with_stateful_fusions_test.py145 # Check if either input fusion reads from a Ref that is also written by the
LOWtests/pallas/tpu_pallas_call_print_test.py149 # Check if the numbers in the output match the values generated by `arange`.
LOW.github/workflows/upload_metadata.yml66 # Check if zip file is empty before unzipping
LOW.github/workflows/verify-squash.yml31 # Check if the skip label is present
LOW.github/workflows/verify-squash.yml54 # Check if squashed
LOW.github/workflows/build_artifacts.yml171 # Set shell to cmd to avoid path errors when using gcloud commands on Windows
LOWjax/experimental/mosaic/gpu/fragmented_array.py4843 # Check if input maps exactly to the end (prevents trailing dims).
LOWjax/_src/test_util.py1276 # Check if strict hermeticity is already satisfied
LOWjax/_src/checkify.py865 # Check if the first cond application will error.
LOWjax/_src/pretty_printer.py48 # Check if we're in IPython or Colab
LOWjax/_src/lax/control_flow/loops.py2200 # Check if the same Ref is written to in both cond and body.
LOWjax/_src/state/primitives.py763 # Check if start is static (which it can be)
LOWjax/_src/state/primitives.py776 # Check if we are indexing with a scalar or not. If we are indexing
LOWjax/_src/numpy/setops.py177 # Set mask to zero at locations corresponding to unique() padding.
LOWjax/_src/numpy/array_constructors.py251 # Check if object supports any of the data exchange protocols
LOWjax/_src/pallas/core.py1231 elif grid and isinstance(grid[0], tuple): # Check if we have a named grid
LOWjax/_src/pallas/einshape.py230 # Check if all RHS dims are known
LOWjax/_src/scipy/signal.py681 # Check if we can broadcast the outer axes together
LOWjax/_src/scipy/signal.py730 # Check if x and y are the same length, zero-pad if necessary
LOWjax/_src/cudnn/fused_attention_stablehlo.py1260 # Check if all required keys are present
Hallucination Indicators3 hits · 30 pts
SeverityFileLineSnippet
CRITICALtests/mosaic/profiler_cupti_test.py96 jax._src.lib.mosaic_gpu._mosaic_gpu_ext._cupti_init()
CRITICALtests/filecheck/jax_mlir_ext.filecheck.py114 return jax._src.lib._jax.Traceback.get_traceback()
CRITICALjax/_src/pallas/mosaic_gpu/lowering.py3910 for k, v in ctx.launch_ctx.module.operation.attributes.items():
Synthetic Comment Markers3 hits · 20 pts
SeverityFileLineSnippet
HIGHdocs/contributing.md26## Can I contribute AI generated code?
HIGHdocs/contributing.md35You are responsible for any code you contribute to JAX, regardless of whether it was written manually or generated by AI
HIGHjaxlib/py_executable.h287 // Python objects to keep alive as requested by user.
Slop Phrases9 hits · 14 pts
SeverityFileLineSnippet
MEDIUMdocs/parallel.py396# want the compiler to choose them automatically, you can use the `@auto_axes`
MEDIUMjax/experimental/jax2tf/examples/saved_model_lib.py102 # names, you can use `tree.map_structure_with_path` from the `dm-tree` package
LOWjax/_src/core.py2821 # but we make sure to reset it to Device because the Ref owns the memory space
MEDIUMjax/_src/ad_checkpoint.py272 Here is a simple example:
MEDIUMjax/_src/errors.py201 indices, such as with :code:`.at[...].set(...)`. Here is a simple example::
MEDIUMjax/_src/errors.py666 Here is a simple example of code that would lead to such an error::
MEDIUMjax/_src/lax/windowed_reductions.py154 Here is a simple example of a windowed product over pairs in a 1-dimensional array:
MEDIUMjax/_src/lax/slicing.py373 For example, here is how you can extract values at particular indices using
LOWjax/_src/pallas/core.py782 # If you hit this, make sure you take transforms into account and use either
Overly Generic Function Names9 hits · 10 pts
SeverityFileLineSnippet
LOWtests/core_test.py584 def my_function():
LOWtests/util_test.py80 def my_function():
LOWtests/aot_test.py122 def my_function(x):
LOWtests/aot_test.py143 def my_function(x):
LOWtests/api_test.py93 def my_function():
LOWtests/api_test.py107 def my_function():
LOWtests/api_test.py1632 def my_function(x, flag):
LOWjax/typing.py47 def my_function(x: ArrayLike) -> Array:
LOWjax/experimental/jax2tf/tests/jax2tf_test.py306 def test_function(self):
Example Usage Blocks3 hits · 5 pts
SeverityFileLineSnippet
LOWci/utilities/run_docker_container.sh24# Usage:
LOWjaxlib/mosaic/dialect/tpu/util.h110// Example usage:
LOWjax/_src/deprecations.py22# Example usage:
Verbosity Indicators1 hit · 2 pts
SeverityFileLineSnippet
LOWjaxlib/python_ref_manager.h85 // The purpose of this function is to amortize lock acquisition costs over