Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve pyright compatibility by using TypeVar's default argument #1246

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

stubs_deps = [
"mypy==1.11.2",
"typing-extensions",
"typing-extensions>=4.4",
]

def install_rustworkx(session):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
fixes:
- |
Enhanced the compatibility of the type annotations with `pyright` in strict
mode. See `issue 1242 <https://github.com/Qiskit/rustworkx/issues/1242>`__ for
more details.
12 changes: 9 additions & 3 deletions rustworkx/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@
# This file contains only type annotations for PyO3 functions and classes
# For implementation details, see __init__.py and src/lib.rs

import sys
import numpy as np

from typing import Generic, TypeVar, Any, Callable, overload
from typing import Generic, Any, Callable, overload
from collections.abc import Iterator, Sequence

if sys.version_info >= (3, 13):
from typing import TypeVar
else:
from typing_extensions import TypeVar

# Re-Exports of rust native functions in rustworkx.rustworkx
# To workaround limitations in mypy around re-exporting objects from the inner
# rustworkx module we need to explicitly re-export every inner function from
Expand Down Expand Up @@ -270,8 +276,8 @@ from .rustworkx import AllPairsMultiplePathMapping as AllPairsMultiplePathMappin
from .rustworkx import PyGraph as PyGraph
from .rustworkx import PyDiGraph as PyDiGraph

_S = TypeVar("_S")
_T = TypeVar("_T")
_S = TypeVar("_S", default=Any)
_T = TypeVar("_T", default=Any)
_BFSVisitor = TypeVar("_BFSVisitor", bound=visit.BFSVisitor)
_DFSVisitor = TypeVar("_DFSVisitor", bound=visit.DFSVisitor)
_DijkstraVisitor = TypeVar("_DijkstraVisitor", bound=visit.DijkstraVisitor)
Expand Down
13 changes: 9 additions & 4 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from .visit import BFSVisitor, DFSVisitor, DijkstraVisitor
from typing import (
TypeVar,
Callable,
final,
Any,
Expand All @@ -35,9 +34,15 @@ from rustworkx import generators # noqa
from typing_extensions import Self

import numpy as np
import sys

_S = TypeVar("_S")
_T = TypeVar("_T")
if sys.version_info >= (3, 13):
from typing import TypeVar
else:
from typing_extensions import TypeVar

_S = TypeVar("_S", default=Any)
_T = TypeVar("_T", default=Any)

class DAGHasCycle(Exception): ...
class DAGWouldCycle(Exception): ...
Expand Down Expand Up @@ -1059,7 +1064,7 @@ def dominance_frontiers(graph: PyDiGraph[_S, _T], start_node: int, /) -> dict[in

# Iterators

_T_co = TypeVar("_T_co", covariant=True)
_T_co = TypeVar("_T_co", covariant=True, default=Any)

class _RustworkxCustomVecIter(Generic[_T_co], Sequence[_T_co], ABC):
def __init__(self) -> None: ...
Expand Down
Loading