Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions .github/workflows/test_mypy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Mypy type hint checks

on:
pull_request:
push:
branches:
- main

jobs:
run:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Setup Micromamba env
uses: mamba-org/setup-micromamba@v1
with:
environment-name: TEST
create-args: >-
python=3
--file requirements.txt
--file requirements-dev.txt
- name: Install branca from source
shell: bash -l {0}
run: |
python -m pip install -e . --no-deps --force-reinstall
- name: Mypy test
shell: bash -l {0}
run: |
mypy branca
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha! You are always one (many actually) step ahead of me.

176 changes: 101 additions & 75 deletions branca/colormap.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,51 +9,67 @@
import json
import math
import os
from typing import Dict, List, Optional, Sequence, Tuple, Union

from jinja2 import Template

from branca.element import ENV, Figure, JavascriptLink, MacroElement
from branca.utilities import legend_scaler

rootpath = os.path.abspath(os.path.dirname(__file__))
rootpath: str = os.path.abspath(os.path.dirname(__file__))

with open(os.path.join(rootpath, "_cnames.json")) as f:
_cnames = json.loads(f.read())
_cnames: Dict[str, str] = json.loads(f.read())

with open(os.path.join(rootpath, "_schemes.json")) as f:
_schemes = json.loads(f.read())
_schemes: Dict[str, List[str]] = json.loads(f.read())


def _is_hex(x):
TypeRGBInts = Tuple[int, int, int]
TypeRGBFloats = Tuple[float, float, float]
TypeRGBAInts = Tuple[int, int, int, int]
TypeRGBAFloats = Tuple[float, float, float, float]
TypeAnyColorType = Union[TypeRGBInts, TypeRGBFloats, TypeRGBAInts, TypeRGBAFloats, str]


def _is_hex(x: str) -> bool:
return x.startswith("#") and len(x) == 7


def _parse_hex(color_code):
def _parse_hex(color_code: str) -> TypeRGBAFloats:
return (
int(color_code[1:3], 16),
int(color_code[3:5], 16),
int(color_code[5:7], 16),
_color_int_to_float(int(color_code[1:3], 16)),
_color_int_to_float(int(color_code[3:5], 16)),
_color_int_to_float(int(color_code[5:7], 16)),
1.0,
)


def _parse_color(x):
def _color_int_to_float(x: int) -> float:
"""Convert an integer between 0 and 255 to a float between 0. and 1.0"""
return x / 255.0


def _color_float_to_int(x: float) -> int:
"""Convert a float between 0. and 1.0 to an integer between 0 and 255"""
return int(x * 255.9999)


def _parse_color(x: Union[tuple, list, str]) -> TypeRGBAFloats:
if isinstance(x, (tuple, list)):
color_tuple = tuple(x)[:4]
elif isinstance(x, (str, bytes)) and _is_hex(x):
color_tuple = _parse_hex(x)
elif isinstance(x, (str, bytes)):
return tuple(tuple(x) + (1.0,))[:4] # type: ignore
elif isinstance(x, str) and _is_hex(x):
return _parse_hex(x)
elif isinstance(x, str):
cname = _cnames.get(x.lower(), None)
if cname is None:
raise ValueError(f"Unknown color {cname!r}.")
color_tuple = _parse_hex(cname)
return _parse_hex(cname)
else:
raise ValueError(f"Unrecognized color code {x!r}")
if max(color_tuple) > 1.0:
color_tuple = tuple(u / 255.0 for u in color_tuple)
return tuple(map(float, (color_tuple + (1.0,))[:4]))


def _base(x):
def _base(x: float) -> float:
if x > 0:
base = pow(10, math.floor(math.log10(x)))
return round(x / base) * base
Expand All @@ -78,15 +94,15 @@ class ColorMap(MacroElement):
Maximum number of legend tick labels
"""

_template = ENV.get_template("color_scale.js")
_template: Template = ENV.get_template("color_scale.js")

def __init__(
self,
vmin=0.0,
vmax=1.0,
caption="",
text_color="black",
max_labels=10,
vmin: float = 0.0,
vmax: float = 1.0,
caption: str = "",
text_color: str = "black",
max_labels: int = 10,
):
super().__init__()
self._name = "ColorMap"
Expand All @@ -95,9 +111,9 @@ def __init__(
self.vmax = vmax
self.caption = caption
self.text_color = text_color
self.index = [vmin, vmax]
self.index: List[float] = [vmin, vmax]
self.max_labels = max_labels
self.tick_labels = None
self.tick_labels: Optional[Sequence[Union[float, str]]] = None

self.width = 450
self.height = 40
Expand Down Expand Up @@ -127,7 +143,7 @@ def render(self, **kwargs):
name="d3",
) # noqa

def rgba_floats_tuple(self, x):
def rgba_floats_tuple(self, x: float) -> TypeRGBAFloats:
"""
This class has to be implemented for each class inheriting from
Colormap. This has to be a function of the form float ->
Expand All @@ -137,37 +153,37 @@ def rgba_floats_tuple(self, x):
"""
raise NotImplementedError

def rgba_bytes_tuple(self, x):
def rgba_bytes_tuple(self, x: float) -> TypeRGBAInts:
"""Provides the color corresponding to value `x` in the
form of a tuple (R,G,B,A) with int values between 0 and 255.
"""
return tuple(int(u * 255.9999) for u in self.rgba_floats_tuple(x))
return tuple(_color_float_to_int(u) for u in self.rgba_floats_tuple(x)) # type: ignore

def rgb_bytes_tuple(self, x):
def rgb_bytes_tuple(self, x: float) -> TypeRGBInts:
"""Provides the color corresponding to value `x` in the
form of a tuple (R,G,B) with int values between 0 and 255.
"""
return self.rgba_bytes_tuple(x)[:3]

def rgb_hex_str(self, x):
def rgb_hex_str(self, x: float) -> str:
"""Provides the color corresponding to value `x` in the
form of a string of hexadecimal values "#RRGGBB".
"""
return "#%02x%02x%02x" % self.rgb_bytes_tuple(x)

def rgba_hex_str(self, x):
def rgba_hex_str(self, x: float) -> str:
"""Provides the color corresponding to value `x` in the
form of a string of hexadecimal values "#RRGGBBAA".
"""
return "#%02x%02x%02x%02x" % self.rgba_bytes_tuple(x)

def __call__(self, x):
def __call__(self, x: float) -> str:
"""Provides the color corresponding to value `x` in the
form of a string of hexadecimal values "#RRGGBBAA".
"""
return self.rgba_hex_str(x)

def _repr_html_(self):
def _repr_html_(self) -> str:
"""Display the colormap in a Jupyter Notebook.

Does not support all the class arguments.
Expand Down Expand Up @@ -264,14 +280,14 @@ class LinearColormap(ColorMap):

def __init__(
self,
colors,
index=None,
vmin=0.0,
vmax=1.0,
caption="",
text_color="black",
max_labels=10,
tick_labels=None,
colors: Sequence[TypeAnyColorType],
index: Optional[Sequence[float]] = None,
vmin: float = 0.0,
vmax: float = 1.0,
caption: str = "",
text_color: str = "black",
max_labels: int = 10,
tick_labels: Optional[Sequence[float]] = None,
):
super().__init__(
vmin=vmin,
Expand All @@ -280,7 +296,7 @@ def __init__(
text_color=text_color,
max_labels=max_labels,
)
self.tick_labels = tick_labels
self.tick_labels: Optional[Sequence[float]] = tick_labels

n = len(colors)
if n < 2:
Expand All @@ -289,9 +305,9 @@ def __init__(
self.index = [vmin + (vmax - vmin) * i * 1.0 / (n - 1) for i in range(n)]
else:
self.index = list(index)
self.colors = [_parse_color(x) for x in colors]
self.colors: List[TypeRGBAFloats] = [_parse_color(x) for x in colors]

def rgba_floats_tuple(self, x):
def rgba_floats_tuple(self, x: float) -> TypeRGBAFloats:
"""Provides the color corresponding to value `x` in the
form of a tuple (R,G,B,A) with float values between 0. and 1.
"""
Expand All @@ -308,20 +324,20 @@ def rgba_floats_tuple(self, x):
else:
raise ValueError("Thresholds are not sorted.")

return tuple(
return tuple( # type: ignore
(1.0 - p) * self.colors[i - 1][j] + p * self.colors[i][j] for j in range(4)
)

def to_step(
self,
n=None,
index=None,
data=None,
method=None,
quantiles=None,
round_method=None,
max_labels=10,
):
n: Optional[int] = None,
index: Optional[Sequence[float]] = None,
data: Optional[Sequence[float]] = None,
method: str = "linear",
quantiles: Optional[Sequence[float]] = None,
round_method: Optional[str] = None,
max_labels: int = 10,
) -> "StepColormap":
"""Splits the LinearColormap into a StepColormap.

Parameters
Expand Down Expand Up @@ -382,11 +398,7 @@ def to_step(
max_ = max(data)
min_ = min(data)
scaled_cm = self.scale(vmin=min_, vmax=max_)
method = (
"quantiles"
if quantiles is not None
else method if method is not None else "linear"
)
method = "quantiles" if quantiles is not None else method
if method.lower().startswith("lin"):
if n is None:
raise ValueError(msg)
Expand Down Expand Up @@ -454,7 +466,12 @@ def to_step(
tick_labels=self.tick_labels,
)

def scale(self, vmin=0.0, vmax=1.0, max_labels=10):
def scale(
self,
vmin: float = 0.0,
vmax: float = 1.0,
max_labels: int = 10,
) -> "LinearColormap":
"""Transforms the colorscale so that the minimal and maximal values
fit the given parameters.
"""
Expand Down Expand Up @@ -510,14 +527,14 @@ class StepColormap(ColorMap):

def __init__(
self,
colors,
index=None,
vmin=0.0,
vmax=1.0,
caption="",
text_color="black",
max_labels=10,
tick_labels=None,
colors: Sequence[TypeAnyColorType],
index: Optional[Sequence[float]] = None,
vmin: float = 0.0,
vmax: float = 1.0,
caption: str = "",
text_color: str = "black",
max_labels: int = 10,
tick_labels: Optional[Sequence[float]] = None,
):
super().__init__(
vmin=vmin,
Expand All @@ -535,9 +552,9 @@ def __init__(
self.index = [vmin + (vmax - vmin) * i * 1.0 / n for i in range(n + 1)]
else:
self.index = list(index)
self.colors = [_parse_color(x) for x in colors]
self.colors: List[TypeRGBAFloats] = [_parse_color(x) for x in colors]

def rgba_floats_tuple(self, x):
def rgba_floats_tuple(self, x: float) -> TypeRGBAFloats:
"""
Provides the color corresponding to value `x` in the
form of a tuple (R,G,B,A) with float values between 0. and 1.
Expand All @@ -549,9 +566,13 @@ def rgba_floats_tuple(self, x):
return self.colors[-1]

i = len([u for u in self.index if u <= x]) # 0 < i < n.
return tuple(self.colors[i - 1])
return self.colors[i - 1]

def to_linear(self, index=None, max_labels=10):
def to_linear(
self,
index: Optional[Sequence[float]] = None,
max_labels: int = 10,
) -> LinearColormap:
"""
Transforms the StepColormap into a LinearColormap.

Expand Down Expand Up @@ -584,7 +605,12 @@ def to_linear(self, index=None, max_labels=10):
max_labels=max_labels,
)

def scale(self, vmin=0.0, vmax=1.0, max_labels=10):
def scale(
self,
vmin: float = 0.0,
vmax: float = 1.0,
max_labels: int = 10,
) -> "StepColormap":
"""Transforms the colorscale so that the minimal and maximal values
fit the given parameters.
"""
Expand All @@ -611,7 +637,7 @@ def __init__(self):
for key, val in _schemes.items():
setattr(self, key, LinearColormap(val))

def _repr_html_(self):
def _repr_html_(self) -> str:
return Template(
"""
<table>
Expand All @@ -634,7 +660,7 @@ def __init__(self):
for key, val in _schemes.items():
setattr(self, key, StepColormap(val))

def _repr_html_(self):
def _repr_html_(self) -> str:
return Template(
"""
<table>
Expand Down
Loading