← Back to index
enums_expansion.py
True Positive
False Positive
False Negative
Optional (detected)
Warning or Info
TP: 0
FP: 1
FN: 1
Optional: 0 / 2
1
"""
2
Tests
that
the
type
checker
handles
literal
expansion
of
enum
classes
.
3
"""
4
5
# Specification: https://typing.readthedocs.io/en/latest/spec/enums.html#enum-literal-expansion
6
7
from
enum
import
Enum
,
Flag
8
from
typing
import
Literal
,
Never
,
assert_type
9
10
# > From the perspective of the type system, most enum classes are equivalent
11
# > to the union of the literal members within that enum. Type checkers may
12
# > therefore expand an enum type
13
14
15
class
Color
(
Enum
):
16
RED
=
1
17
GREEN
=
2
18
BLUE
=
3
19
20
21
def
print_color1
(
c
:
Color
):
22
if
c
is
Color
.
RED
or
c
is
Color
.
BLUE
:
23
print
(
"red or blue"
)
24
else
:
25
assert_type
(
c
,
Literal
[
Color
.
GREEN
])
# E?
26
27
28
def
print_color2
(
c
:
Color
):
29
match
c
:
30
case
Color
.
RED
|
Color
.
BLUE
:
31
print
(
"red or blue"
)
32
case
Color
.
GREEN
:
33
print
(
"green"
)
34
case
_
:
35
assert_type
(
c
,
Never
)
# E?
36
37
38
# > This rule does not apply to classes that derive from enum. Flag because
39
# > these enums allow flags to be combined in arbitrary ways.
40
41
42
class
CustomFlags
(
Flag
):
43
FLAG1
=
1
44
FLAG2
=
2
45
FLAG3
=
4
46
47
48
def
test1
(
f
:
CustomFlags
):
49
if
f
is
CustomFlags
.
FLAG1
or
f
is
CustomFlags
.
FLAG2
:
50
print
(
"flag1 and flag2"
)
51
else
:
52
assert_type
(
f
,
CustomFlags
)
Unexpected error [type-assertion-failure] Type `CustomFlags` does not match asserted type `Literal[CustomFlags.FLAG3]`
53
assert_type
(
f
,
Literal
[
CustomFlags
.
FLAG3
])
# E
Expected a ty diagnostic for this line
54
55
56
def
test2
(
f
:
CustomFlags
):
57
match
f
:
58
case
CustomFlags
.
FLAG1
|
CustomFlags
.
FLAG2
:
59
pass
60
case
CustomFlags
.
FLAG3
:
61
pass
62
case
_
:
63
assert_type
(
f
,
CustomFlags
)
64
65
66
# > A type checker should treat a complete union of all literal members as
67
# > compatible with the enum type.
68
69
70
class
Answer
(
Enum
):
71
Yes
=
1
72
No
=
2
73
74
75
def
test3
(
val
:
object
)
->
list
[
Answer
]:
76
assert
val
is
Answer
.
Yes
or
val
is
Answer
.
No
77
x
=
[
val
]
78
return
x