Source code for traja.plotting

import logging
from collections import OrderedDict
from datetime import timedelta
import os
from typing import Union, Optional, Tuple, List

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

from matplotlib import dates as md
from matplotlib.axes import Axes
from matplotlib.collections import PathCollection
from matplotlib.figure import Figure
from mpl_toolkits.mplot3d import Axes3D
from pandas.core.dtypes.common import (

import traja
from traja.frame import TrajaDataFrame
from traja.trajectory import coords_to_flow

__all__ = [

logger = logging.getLogger("traja")

def stylize_axes(ax):
    """Add top and right border to plot, set ticks."""

    ax.xaxis.set_tick_params(top="off", direction="out", width=1)
    ax.yaxis.set_tick_params(right="off", direction="out", width=1)

def sans_serif():
    """Convenience function for changing plot text to serif font."""
    plt.rc("font", family="serif")

def _rolling(df, window, step):
    count = 0
    df_length = len(df)
    while count < (df_length - window):
        yield count, df[count : window + count]
        count += step

def plot_prediction(model, dataloader, index, scaler=None):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    fig, ax = plt.subplots(2, 1, figsize=(10, 10))
    model =
    batch_size = model.batch_size
    num_past = model.num_past
    input_size = model.input_size

    data, target, category, parameters, classes = list(iter(dataloader))[index]
    data = data.float().to(device)
    prediction = model(data, latent=False)

    # Send tensors to CPU so numpy can work with them
    pred = prediction[batch_size - 1 : batch_size, :].cpu().squeeze().detach().numpy()
    target = target.clone().detach()[batch_size - 1 : batch_size, :].squeeze()
    real = target.cpu()

    data = data.cpu().reshape(batch_size * num_past, input_size).detach().numpy()

    if scaler:
        data = scaler.inverse_transform(data)
        real = scaler.inverse_transform(real)
        pred = scaler.inverse_transform(pred)

    ax[0].plot(data[:, 0], data[:, 1], label="History")
    ax[0].plot(real[:, 0], real[:, 1], label="Real")
    ax[0].plot(pred[:, 0], pred[:, 1], label="Pred")

    ax[1].scatter(real[:, 0], real[:, 1], label="Real")
    ax[1].scatter(pred[:, 0], pred[:, 1], label="Pred")

    for a in ax:

[docs]def bar_plot(trj: TrajaDataFrame, bins: Union[int, tuple] = None, **kwargs) -> Axes: """Plot trajectory for single animal over period. Args: trj (:class:`traja.TrajaDataFrame`): trajectory bins (int or tuple): number of bins for x and y **kwargs: additional keyword arguments to :meth:`mpl_toolkits.mplot3d.Axed3D.plot` Returns: ax (:class:`~matplotlib.collections.PathCollection`): Axes of plot """ # TODO: Add time component bins = traja.trajectory._bins_to_tuple(trj, bins) X, Y, U, V = coords_to_flow(trj, bins) hist, _ = trip_grid(trj, bins, hist_only=True) fig = plt.figure() ax = fig.add_subplot(111, projection="3d") ax.set_aspect("equal") X = X.flatten("F") Y = Y.flatten("F") ax.bar3d( X, Y, np.zeros_like(X), 1, 1, hist.flatten(), zsort="average", shade=True, **kwargs, ) ax.set(xlabel="x", ylabel="y", zlabel="Frames") return ax
def plot_rolling_hull(trj: TrajaDataFrame, window=100, step=20, areas=False, **kwargs): """Plot rolling convex hull of trajectory. If `areas` is True, only areas over time is plotted. """ hulls = [] for offset, window in _rolling(trj, window=window, step=step): if window.dropna().empty: continue shape = window.traja.to_shapely() hull = shape.convex_hull hulls.append(hull) if areas: hull_areas = [] for idx, hull in enumerate(hulls): hull_areas.append(hull.area) plt.plot(hull_areas, **kwargs) plt.title(f"Rolling Trajectory Convex Hull Area\nWindow={window},Step={step}") plt.ylabel(f"Area {trj.__dict__.get('spatial_units', 'm')}") plt.xlabel("Frame") else: xlim, ylim = traja.trajectory._get_xylim(trj) plt.xlim = xlim plt.ylim = ylim for idx, hull in enumerate(hulls): if hasattr( hull, "exterior" ): # Occassionally a Point object without it reaches plt.plot(*hull.exterior.xy, alpha=idx / len(hulls), c="k", **kwargs) ax = plt.gca() ax.set_aspect("equal") ax.set( xlabel=f"x ({trj.__dict__.get('spatial_units', 'm')})", ylabel=f"y ({trj.__dict__.get('spatial_units', 'm')})", title="Rolling Trajectory Convex Hull\nWindow={window},Step={step}", ) def plot_period(trj: TrajaDataFrame, col="x", dark=(7, 19), **kwargs): time_col = traja._get_time_col(trj) _trj = trj.set_index(time_col) if col not in _trj: raise ValueError(f"{col} not a column in dataframe") series = _trj[col] fig, ax = plt.subplots() series.plot(ax=ax) dates = np.unique( nights = [] nights.append([(date, date + timedelta(hours=dark[0])) for date in dates]) nights.append( [(date + timedelta(hours=dark[1]), date + timedelta(days=1)) for date in dates] ) for interval in nights: t0, t1 = interval ax.axvspan(t0, t1, color="gray", alpha=0.2) # Format date displayed on the x axis xfmt = md.DateFormatter("%H:%M\n%m-%d-%y") ax.xaxis.set_major_formatter(xfmt) if kwargs.get("interactive"): def plot_rolling_hull_3d(trj: TrajaDataFrame, window=100, step=20, **kwargs): hulls = [] fig = plt.figure() ax = fig.add_subplot(111, projection="3d") for offset, wind in _rolling(trj, window=window, step=step): if wind.dropna().empty: continue shape = wind.traja.to_shapely() hull = shape.convex_hull hulls.append(hull) xlim, ylim = traja.trajectory._get_xylim(trj) plt.xlim = xlim plt.ylim = ylim outlines = [] for idx, hull in enumerate(hulls): if hasattr(hull, "exterior"): # Occassionally a Point object without it reaches outlines.append(np.array(hull.exterior.xy)) # Add plots to axes NLINES = len(outlines) cm = plt.get_cmap(kwargs.get("cmap", "plasma")) ax.set_prop_cycle(color=[cm(1.0 * i / (NLINES)) for i in range(NLINES)]) for z, xy in enumerate(outlines): ax.plot(*xy, z) ax.set( xlabel=f"{trj.__dict__.get('spatial_units', 'm')}", ylabel=f"{trj.__dict__.get('spatial_units', 'm')}", title=f"Rolling Trajectory Convex Hull\nWindow={window},Step={step}", ) if kwargs.get("interactive"): def plot_3d(trj: TrajaDataFrame, **kwargs) -> matplotlib.collections.PathCollection: """Plot 3D trajectory for single identity over period. Args: trj (:class:`traja.TrajaDataFrame`): trajectory n_coords (int, optional): Number of coordinates to plot **kwargs: additional keyword arguments to :meth:`matplotlib.axes.Axes.scatter` Returns: ax (:class:`~matplotlib.collections.PathCollection`): Axes of plot .. note:: Takes a while to plot large trajectories. Consider using first:: rt = trj.traja.rediscretize(R=1.) # Replace R with appropriate step length rt.traja.plot_3d() """ fig = plt.figure() ax = fig.add_subplot(111, projection="3d") ax.set_xlabel("x", fontsize=15) ax.set_zlabel("time", fontsize=15) ax.set_ylabel("y", fontsize=15) title = kwargs.pop("title", "Trajectory") ax.set_title(f"{title}", fontsize=20) ax.plot(trj.x, trj.y, trj.index) cmap = kwargs.pop("cmap", "winter") cm = plt.get_cmap(cmap) NPOINTS = len(trj) ax.set_prop_cycle(color=[cm(1.0 * i / (NPOINTS - 1)) for i in range(NPOINTS - 1)]) for i in range(NPOINTS - 1): ax.plot(trj.x[i : i + 2], trj.y[i : i + 2], trj.index[i : i + 2]) dist = kwargs.pop("dist", None) if dist: ax.dist = dist labelpad = kwargs.pop("labelpad", None) if labelpad: from matplotlib import rcParams rcParams["axes.labelpad"] = labelpad return ax
[docs]def plot( trj: TrajaDataFrame, n_coords: Optional[int] = None, show_time: bool = False, accessor: Optional[traja.TrajaAccessor] = None, ax=None, **kwargs, ) -> matplotlib.collections.PathCollection: """Plot trajectory for single animal over period. Args: trj (:class:`traja.TrajaDataFrame`): trajectory n_coords (int, optional): Number of coordinates to plot show_time (bool): Show colormap as time accessor (:class:`~traja.accessor.TrajaAccessor`, optional): TrajaAccessor instance ax (:class:`~matplotlib.axes.Axes`): axes for plotting interactive (bool): show plot immediately **kwargs: additional keyword arguments to :meth:`matplotlib.axes.Axes.scatter` Returns: collection (:class:`~matplotlib.collections.PathCollection`): collection that was plotted """ import matplotlib.patches as patches from matplotlib.path import Path after_plot_args, kwargs = _get_after_plot_args(**kwargs) GRAY = "#999999" xlim = kwargs.pop("xlim", None) ylim = kwargs.pop("ylim", None) if not xlim or not ylim: xlim, ylim = traja.trajectory._get_xylim(trj) title = kwargs.pop("title", None) time_units = kwargs.pop("time_units", "s") fps = kwargs.pop("fps", None) figsize = kwargs.pop("figsize", None) coords = trj[["x", "y"]] time_col = traja.trajectory._get_time_col(trj) if time_col == "index": is_datetime = True else: is_datetime = is_datetime64_any_dtype(trj[time_col]) if time_col else False if n_coords is None: # Plot all coords start, end = 0, len(coords) verts = coords.iloc[start:end].values else: # Plot first `n_coords` verts = coords.iloc[:n_coords].values n_coords = len(verts) codes = [Path.MOVETO] + [Path.LINETO] * (len(verts) - 1) path = Path(verts, codes) if not ax: fig, ax = plt.subplots(figsize=figsize) fig.canvas.draw() patch = patches.PathPatch(path, edgecolor=GRAY, facecolor="none", lw=3, alpha=0.3) ax.add_patch(patch) xs, ys = zip(*verts) if time_col == "index": # DatetimeIndex determines color colors = [ind for ind, x in enumerate(trj.index[:n_coords])] elif time_col and time_col != "index": # `time_col` determines color colors = [ind for ind, x in enumerate(trj[time_col].iloc[:n_coords])] else: # Frame count determines color colors = trj.index[:n_coords] if time_col: # TODO: Calculate fps if not in datetime vmin = min(colors) vmax = max(colors) if is_datetime: # Show timestamps without units time_units = "" else: # Index/frame count is our only reference vmin = trj.index[0] vmax = trj.index[n_coords - 1] if not show_time: time_units = "" label = f"Time ({time_units})" if time_units else "" collection = ax.scatter( xs, ys, c=colors, s=kwargs.pop("s", 1),, alpha=0.7, vmin=vmin, vmax=vmax, **kwargs, ) ax.set_xlim(xlim) ax.set_ylim(ylim) if kwargs.pop("invert_yaxis", None): plt.gca().invert_yaxis() _label_axes(trj, ax) ax.set_title(title) ax.set_aspect("equal") # Number of color bar ticks CBAR_TICKS = 10 if n_coords > 20 else n_coords indices = np.linspace(0, n_coords - 1, CBAR_TICKS, endpoint=True, dtype=int) cbar = plt.colorbar( collection, fraction=0.046, pad=0.04, orientation="vertical", label=label ) # Get colorbar labels from time if time_col == "index": if is_datetime64_any_dtype(trj.index): cbar_labels = ( trj.index[indices].strftime("%Y-%m-%d %H:%M:%S").values.astype(str) ) elif is_timedelta64_dtype(trj.index): if time_units in ("s", "", None): cbar_labels = [round(x, 2) for x in trj.index[indices].total_seconds()] else: logger.error("Time unit {} not yet implemented".format(time_units)) else: raise NotImplementedError( "Indexing on {} is not yet implemented".format(type(trj.index)) ) elif time_col and is_timedelta64_dtype(trj[time_col]): cbar_labels = trj[time_col].iloc[indices].dt.total_seconds().values cbar_labels = ["%.2f" % number for number in cbar_labels] elif time_col and is_datetime: cbar_labels = ( trj[time_col] .iloc[indices] .dt.strftime("%Y-%m-%d %H:%M:%S") .values.astype(str) ) else: # Convert frames to time if time_col: cbar_labels = trj[time_col].iloc[indices].values else: cbar_labels = trj.index[indices].values cbar_labels = np.round(cbar_labels, 6) if fps is not None and fps > 0 and fps != 1 and show_time: cbar_labels = cbar_labels / fps cbar.set_ticks(indices) cbar.set_ticklabels(cbar_labels) plt.tight_layout() _process_after_plot_args(**after_plot_args) return collection
[docs]def plot_periodogram(trj, coord: str = "y", fs: int = 1, interactive: bool = True): """Plot power spectral density of ``coord`` timeseries using a periodogram. Args: trj - Trajectory coord - choice of 'x' or 'y' fs - Sampling frequency interactive - Plot immediately Returns: Figure .. plot:: import matplotlib.pyplot as plt trj = traja.generate() trj.traja.plot_periodogram() .. note:: Convenience wrapper for :meth:`scipy.signal.periodogram`. """ from scipy import signal vals = trj[coord].values f, Pxx = signal.periodogram(vals, fs=fs, window="hanning", scaling="spectrum") plt.title("Power Spectrum") plt.plot(f, Pxx) if interactive: return plt.gcf()
[docs]def plot_autocorrelation( trj: TrajaDataFrame, coord: str = "y", unit: str = "Days", xmax: int = 1000, interactive: bool = True, ): """Plot autocorrelation of given coordinate. Args: trj - Trajectory coord - 'x' or 'y' unit - string, eg, 'Days' xmax - max xaxis value interactive - Plot immediately Returns: Matplotlib Figure .. plot:: import traja df = traja.generate() df.traja.plot_autocorrelation() .. note:: Convenience wrapper for pandas :meth:`~pandas.plotting.autocorrelation_plot`. """ pd.plotting.autocorrelation_plot(trj[coord]) plt.xlim((0, xmax)) plt.xlabel(f"Lags ({unit})") plt.ylabel("Autocorrelation") if interactive: return plt.gcf()
[docs]def plot_pca(trj: TrajaDataFrame, id_col: str="id", bins: tuple = (8,8), three_dims: bool = False, ax = None): """Plot PCA comparing animals ids by trip grids. Args: trj - Trajectory id_col - column representing animal IDs bins - shape for binning trajectory into a trip grid three_dims - 3D plot. Default: False (2D plot) ax - Matplotlib axes (optional) Returns: fig - Figure .. plot:: # Load sample jaguar dataset with trajectories for 9 animals df = # Bin trajectory into a trip grid then perform PCA traja.plotting.plot_pca(df, id_col="ID", bins=(8,8)) """ from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler DIMS = 3 if three_dims else 2 # Bin trajectories to trip grids grids = [] ids = trj[id_col].unique() for id in ids: animal = trj[trj[id_col]==id].copy() animal.drop(columns=[id_col],inplace=True) grid = animal.traja.trip_grid(bins = bins, hist_only=True)[0] grids.append(grid.flatten()) # Standardize the data gridsarr = np.array(grids) X = StandardScaler().fit_transform(gridsarr) # PCA projection pca = PCA(n_components=DIMS) X_r = # Create plot axes if DIMS == 3: fig = plt.figure() ax = fig.add_subplot(111, projection='3d') if not ax: _, ax = plt.subplots() # Visualize 2D projection for idx, animal in enumerate(X_r): if DIMS == 2: ax.scatter(X_r[idx, 0], X_r[idx, 1], color=f'C{idx}', alpha=.8, lw=2, label=idx) elif DIMS == 3: ax.scatter(X_r[idx, 0], X_r[idx, 1], ax.scatter[idx,2], color=f'C{idx}', alpha=.8, lw=2, label=idx) plt.title("PCA") plt.legend(title=id_col, loc='best', shadow=False, scatterpoints=1) plt.xlabel("Principal Component 1") plt.ylabel("Principal Component 2") return plt.gcf()
def plot_collection( trjs: Union[pd.DataFrame, TrajaDataFrame], id_col: str = "id", colors: Optional[Union[dict, List[str]]] = None, **kwargs, ): """Plot trajectories of multiple subjects identified by `id`. Args: trjs: dataframe with multiple trajectories id_col: name of id_col, default is "id" colors (Optional): color lookup matching substrings to discreet colors. Possible values are, eg: - {"car0":"red","car1":"blue"} - {"car":"red","person":blue"} - ["car", "person"] kwargs: kwargs to :meth:`matplotlib.axes.Axes.plot` Returns: lines (list of `~matplotlib.lines.Line2D` objects): lines of plot """ ids = trjs[id_col].unique() # Get plot keyword args colormap = kwargs.pop("cmap", "hsv") alpha = kwargs.pop("alpha", 0.2) linestyle = kwargs.pop("linestyle", "-") marker = kwargs.pop("marker", "o") labels = [None] * len(ids) if not colors: cmap =, lut=len(ids) if len(ids) > 1 else None) colors = [cmap(idx) for idx in range(len(ids))] elif isinstance(colors, list): cmap =, len(colors)) color_lookup = [] for ind, id in enumerate(ids): for idx, substring in enumerate(colors): if substring in id: color_lookup.append(cmap(idx)) labels[ind] = substring break else: raise Exception(f"No substring matching {id} in {colors}.") colors = color_lookup elif isinstance(colors, dict): color_lookup = [colors.get(id) for id in ids] colors = color_lookup labels = ids _, ax = plt.subplots() lines = [] for idx, id in enumerate(ids): trj = trjs[trjs[id_col] == id] l = ax.plot( trj.x, trj.y, linestyle=linestyle, marker=marker, c=colors[idx], alpha=alpha, label=labels[idx], **kwargs, ) lines.extend(l) handles, labels = plt.gca().get_legend_handles_labels() by_label = OrderedDict(zip(labels, handles)) plt.legend( by_label.values(), by_label.keys(), bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0, ) plt.tight_layout() return lines def _label_axes(trj: TrajaDataFrame, ax) -> Axes: if "spatial_units" in trj.__dict__: ax.set_xlabel(trj.__dict__.get("spatial_units", "m")) ax.set_ylabel(trj.__dict__.get("spatial_units", "m")) return ax
[docs]def plot_quiver( trj: TrajaDataFrame, bins: Optional[Union[int, tuple]] = None, quiverplot_kws: dict = {}, **kwargs, ) -> Axes: """Plot average flow from each grid cell to neighbor. Args: bins (int or tuple): Tuple of x,y bin counts; if `bins` is int, bin count of x, with y inferred from aspect ratio quiverplot_kws: Additional keyword arguments for :meth:`~matplotlib.axes.Axes.quiver` Returns: ax (:class:`~matplotlib.axes.Axes`): Axes of quiver plot """ after_plot_args, _ = _get_after_plot_args(**kwargs) X, Y, U, V = coords_to_flow(trj, bins) Z = np.sqrt(U * U + V * V) fig, ax = plt.subplots() ax.quiver(X, Y, U, V, units="width", **quiverplot_kws) ax = _label_axes(trj, ax) ax.set_aspect("equal") _process_after_plot_args(**after_plot_args) return ax
[docs]def plot_contour( trj: TrajaDataFrame, bins: Optional[Union[int, tuple]] = None, filled: bool = True, quiver: bool = True, contourplot_kws: dict = {}, contourfplot_kws: dict = {}, quiverplot_kws: dict = {}, ax: Axes = None, **kwargs, ) -> Axes: """Plot average flow from each grid cell to neighbor. Args: trj: Traja DataFrame bins (int or tuple): Tuple of x,y bin counts; if `bins` is int, bin count of x, with y inferred from aspect ratio filled (bool): Contours filled quiver (bool): Quiver plot contourplot_kws: Additional keyword arguments for :meth:`~matplotlib.axes.Axes.contour` contourfplot_kws: Additional keyword arguments for :meth:`~matplotlib.axes.Axes.contourf` quiverplot_kws: Additional keyword arguments for :meth:`~matplotlib.axes.Axes.quiver` ax (optional): Matplotlib Axes Returns: ax (:class:`~matplotlib.axes.Axes`): Axes of quiver plot """ after_plot_args, _ = _get_after_plot_args(**kwargs) X, Y, U, V = coords_to_flow(trj, bins) Z = np.sqrt(U * U + V * V) if not ax: _, ax = plt.subplots() if filled: cfp = plt.contourf(X, Y, Z, **contourfplot_kws) plt.colorbar(cfp, ax=ax) plt.contour( X, Y, Z, colors="k", linewidths=1, linestyles="solid", **contourplot_kws ) if quiver: ax.quiver(X, Y, U, V, units="width", **quiverplot_kws) ax = _label_axes(trj, ax) ax.set_aspect("equal") _process_after_plot_args(**after_plot_args) return ax
[docs]def plot_surface( trj: TrajaDataFrame, bins: Optional[Union[int, tuple]] = None, cmap: str = "viridis", **surfaceplot_kws: dict, ) -> Figure: """Plot surface of flow from each grid cell to neighbor in 3D. Args: bins (int or tuple): Tuple of x,y bin counts; if `bins` is int, bin count of x, with y inferred from aspect ratio cmap (str): color map surfaceplot_kws: Additional keyword arguments for :meth:`~mpl_toolkits.mplot3D.Axes3D.plot_surface` Returns: ax (:class:`~matplotlib.axes.Axes`): Axes of quiver plot """ after_plot_args, surfaceplot_kws = _get_after_plot_args(**surfaceplot_kws) X, Y, U, V = coords_to_flow(trj, bins) Z = np.sqrt(U * U + V * V) fig = plt.figure() ax = fig.gca(projection="3d") ax.plot_surface( X, Y, Z, cmap= cmap, linewidth=0, **surfaceplot_kws ) ax = _label_axes(trj, ax) try: ax.set_aspect("equal") except NotImplementedError: # 3D pass _process_after_plot_args(**after_plot_args) return ax
[docs]def plot_stream( trj: TrajaDataFrame, bins: Optional[Union[int, tuple]] = None, cmap: str = "viridis", contourfplot_kws: dict = {}, contourplot_kws: dict = {}, streamplot_kws: dict = {}, **kwargs, ) -> Figure: """Plot average flow from each grid cell to neighbor. Args: bins (int or tuple): Tuple of x,y bin counts; if `bins` is int, bin count of x, with y inferred from aspect ratio contourplot_kws: Additional keyword arguments for :meth:`~matplotlib.axes.Axes.contour` contourfplot_kws: Additional keyword arguments for :meth:`~matplotlib.axes.Axes.contourf` streamplot_kws: Additional keyword arguments for :meth:`~matplotlib.axes.Axes.streamplot` Returns: ax (:class:`~matplotlib.axes.Axes`): Axes of stream plot """ after_plot_args, _ = _get_after_plot_args(**kwargs) X, Y, U, V = coords_to_flow(trj, bins) Z = np.sqrt(U * U + V * V) fig, ax = plt.subplots() plt.contourf(X, Y, Z, **contourfplot_kws) plt.contour( X, Y, Z, colors="k", linewidths=1, linestyles="solid", **contourplot_kws ) ax.streamplot(X, Y, U, V, color=Z, cmap=cmap, **streamplot_kws) ax = _label_axes(trj, ax) ax.set_aspect("equal") _process_after_plot_args(**after_plot_args) return ax
[docs]def plot_flow( trj: TrajaDataFrame, kind: str = "quiver", *args, contourplot_kws: dict = {}, contourfplot_kws: dict = {}, streamplot_kws: dict = {}, quiverplot_kws: dict = {}, surfaceplot_kws: dict = {}, **kwargs, ) -> Figure: """Plot average flow from each grid cell to neighbor. Args: bins (int or tuple): Tuple of x,y bin counts; if `bins` is int, bin count of x, with y inferred from aspect ratio kind (str): Choice of 'quiver','contourf','stream','surface'. Default is 'quiver'. contourplot_kws: Additional keyword arguments for :meth:`~matplotlib.axes.Axes.contour` contourfplot_kws: Additional keyword arguments for :meth:`~matplotlib.axes.Axes.contourf` streamplot_kws: Additional keyword arguments for :meth:`~matplotlib.axes.Axes.streamplot` quiverplot_kws: Additional keyword arguments for :meth:`~matplotlib.axes.Axes.quiver` surfaceplot_kws: Additional keyword arguments for :meth:`~matplotlib.axes.Axes.plot_surface` Returns: ax (:class:`~matplotlib.axes.Axes`): Axes of plot """ if kind == "quiver": return plot_quiver(trj, *args, **quiverplot_kws, **kwargs) elif kind == "contour": return plot_contour(trj, filled=False, *args, **quiverplot_kws, **kwargs) elif kind == "contourf": return plot_contour(trj, *args, **quiverplot_kws, **kwargs) elif kind == "stream": return plot_stream( trj, *args, contourplot_kws=contourplot_kws, contourfplot_kws=contourfplot_kws, streamplot_kws=streamplot_kws, **kwargs, ) elif kind == "surface": return plot_surface(trj, *args, **surfaceplot_kws, **kwargs) else: raise NotImplementedError(f"Kind {kind} is not implemented.")
def _get_after_plot_args(**kwargs: dict) -> (dict, dict): after_plot_args = dict( interactive=kwargs.pop("interactive", True), filepath=kwargs.pop("filepath", None), ) return after_plot_args, kwargs def trip_grid( trj: TrajaDataFrame, bins: Union[tuple, int] = 10, log: bool = False, spatial_units: str = None, normalize: bool = False, hist_only: bool = False, **kwargs, ) -> Tuple[np.ndarray, PathCollection]: """Generate a heatmap of time spent by point-to-cell gridding. Args: bins (int, optional): Number of bins (Default value = 10) log (bool): log scale histogram (Default value = False) spatial_units (str): units for plotting normalize (bool): normalize histogram into density plot hist_only (bool): return histogram without plotting Returns: hist (:class:`numpy.ndarray`): 2D histogram as array image (:class:`matplotlib.collections.PathCollection`: image of histogram """ after_plot_args, kwargs = _get_after_plot_args(**kwargs) bins = traja.trajectory._bins_to_tuple(trj, bins) # TODO: Add kde-based method for line-to-cell gridding df = trj[["x", "y"]].dropna() # Set aspect if `xlim` and `ylim` set. if "xlim" in kwargs and "ylim" in kwargs: xlim, ylim = kwargs.pop("xlim"), kwargs.pop("ylim") else: xlim, ylim = traja.trajectory._get_xylim(df) xmin, xmax = xlim ymin, ymax = ylim x, y = zip(*df.values) # FIXME: Remove redundant histogram calculation hist, x_edges, y_edges = np.histogram2d( x, y, bins, range=((xmin, xmax), (ymin, ymax)), normed=normalize ) if log: hist = np.log(hist + np.e) if hist_only: # TODO: Evaluate potential use cases or remove return (hist, None) fig, ax = plt.subplots() image = ax.imshow( hist, interpolation="bilinear", aspect="equal", extent=[xmin, xmax, ymin, ymax] ) # TODO: Adjust colorbar ytick_labels to correspond with time label = "Frames" if not log else "$ln(frames)$" plt.colorbar(image, ax=ax, label=label) _label_axes(trj, ax) plt.title("Time spent{}".format(" (Logarithmic)" if log else "")) _process_after_plot_args(**after_plot_args) # TODO: Add method for most common locations in grid # peak_index = unravel_index(hist.argmax(), hist.shape) return hist, image def _process_after_plot_args(**after_plot_args): filepath = after_plot_args.get("filepath") if filepath: plt.savefig(filepath) def color_dark( series: pd.Series, ax: matplotlib.axes.Axes = None, start: int = 19, end: int = 7 ): """Color dark phase in plot. Args: series (pd.Series) - Time-series variable ax (:class: `~matplotlib.axes.Axes`): axis to plot on (eg, `plt.gca()`) start (int): start of dark period/night end (hour): end of dark period/day Returns: ax (:class:`~matplotlib.axes._subplots.AxesSubplot`): Axes of plot """ assert is_datetime_or_timedelta_dtype( series.index ), f"Series must have datetime index but has {type(series.index)}" pd.plotting.register_matplotlib_converters() # prevents type error with axvspan if not ax: ax = plt.gca() # get boundaries for dark times dark_mask = (series.index.hour >= start) | (series.index.hour < end) run_values, run_starts, run_lengths = find_runs(dark_mask) for idx, is_dark in enumerate(run_values): if is_dark: start = run_starts[idx] end = run_starts[idx] + run_lengths[idx] - 1 ax.axvspan(series.index[start], series.index[end], alpha=0.5, color="gray") fig = plt.gcf() fig.autofmt_xdate() return ax def find_runs(x: pd.Series) -> (np.ndarray, np.ndarray, np.ndarray): """Find runs of consecutive items in an array. From""" # ensure array x = np.asanyarray(x) if x.ndim != 1: raise ValueError("only 1D array supported") n = x.shape[0] # handle empty array if n == 0: return np.array([]), np.array([]), np.array([]) else: # find run starts loc_run_start = np.empty(n, dtype=bool) loc_run_start[0] = True np.not_equal(x[:-1], x[1:], out=loc_run_start[1:]) run_starts = np.nonzero(loc_run_start)[0] # find run values run_values = x[loc_run_start] # find run lengths run_lengths = np.diff(np.append(run_starts, n)) return run_values, run_starts, run_lengths def fill_ci(series: pd.Series, window: Union[int, str]) -> Figure: """Fill confidence interval defined by SEM over mean of `window`. Window can be interval or offset, eg, '30s'.""" assert is_datetime_or_timedelta_dtype( series.index ), f"Series index must be datetime but is {type(series.index)}" smooth_path = series.rolling(window).mean() path_deviation = series.rolling(window).std() fig, ax = plt.subplots() plt.plot(smooth_path.index, smooth_path, "b") plt.fill_between( path_deviation.index, (smooth_path - 2 * path_deviation), (smooth_path + 2 * path_deviation), color="b", alpha=0.2, ) plt.gcf().autofmt_xdate() return ax def plot_xy(xy: np.ndarray, *args: Optional, **kwargs: Optional): """Plot trajectory from xy values. Args: xy (np.ndarray) : xy values of dimensions N x 2 *args : Plot args **kwargs : Plot kwargs """ trj = traja.from_xy(xy) trj.traja.plot(*args, **kwargs)
[docs]def plot_actogram( series: pd.Series, dark=(19, 7), ax: matplotlib.axes.Axes = None, **kwargs ): """Plot activity or displacement as an actogram. .. note:: For published example see Eckel-Mahan K, Sassone-Corsi P. Phenotyping Circadian Rhythms in Mice. Curr Protoc Mouse Biol. 2015;5(3):271-281. Published 2015 Sep 1. doi:10.1002/9780470942390.mo140229 """ assert isinstance(series, pd.Series) assert is_datetime_or_timedelta_dtype( series.index ), f"Series must have datetime index but has {type(series.index)}" after_plot_args, _ = _get_after_plot_args(**kwargs) ax = series.plot(ax=ax) ax.set_ylabel( color_dark(series, ax, start=dark[0], end=dark[1]) _process_after_plot_args(**after_plot_args)
def _polar_bar( radii: np.ndarray, theta: np.ndarray, bin_size: int = 2, ax: Optional[matplotlib.axes.Axes] = None, overlap: bool = True, **kwargs: str, ) -> Axes: after_plot_args, kwargs = _get_after_plot_args(**kwargs) title = kwargs.pop("title", None) ax = ax or plt.subplot(111, projection="polar") hist, bin_edges = np.histogram( theta, bins=np.arange(-180, 180 + bin_size, bin_size) ) centers = np.deg2rad(np.ediff1d(bin_edges) // 2 + bin_edges[:-1]) radians = np.deg2rad(theta) width = np.deg2rad(bin_size) angle = radians if overlap else centers height = radii if overlap else hist max_height = max(height) bars =, height, width=width, bottom=0.0, **kwargs) for h, bar in zip(height, bars): bar.set_facecolor( / max_height)) bar.set_alpha(0.5) if isinstance(ax, matplotlib.axes.Axes): ax.set_theta_zero_location("N") ax.set_xticklabels(["0", "45", "90", "135", "180", "-135", "-90", "-45"]) if title: plt.title(title + "\n", y=1.08) plt.tight_layout() _process_after_plot_args(**after_plot_args) return ax
[docs]def polar_bar( trj: TrajaDataFrame, feature: str = "turn_angle", bin_size: int = 2, threshold: float = 0.001, overlap: bool = True, ax: Optional[matplotlib.axes.Axes] = None, **plot_kws: str, ) -> Axes: """Plot polar bar chart. Args: trj (:class:`traja.TrajaDataFrame`): trajectory feature (str): Options: 'turn_angle', 'heading' bin_size (int): width of bins threshold (float): filter for step distance overlap (bool): Overlapping shows all values, if set to false is a histogram Returns: ax (:class:`~matplotlib.collections.PathCollection`): Axes of plot """ # Get displacement displacement = traja.trajectory.calc_displacement(trj) trj["displacement"] = displacement trj = trj.loc[trj.displacement > threshold] if feature == "turn_angle": feature_series = traja.trajectory.calc_turn_angle(trj) trj["turn_angle"] = feature_series trj.turn_angle = trj.turn_angle.shift(-1) elif feature == "heading": feature_series = traja.trajectory.calc_heading(trj) trj[feature] = feature_series trj = trj[pd.notnull(trj[feature])] trj = trj[pd.notnull(trj.displacement)] assert ( len(trj) > 0 ), f"Dataframe is empty after filtering for step distance threshold {threshold}" ax = _polar_bar( trj.displacement, trj[feature], bin_size=bin_size, overlap=overlap, ax=ax, **plot_kws, ) return ax
[docs]def plot_clustermap( displacements: List[pd.Series], rule: Optional[str] = None, nr_steps=None, colors: Optional[List[Union[int, str]]] = None, **kwargs, ): """Plot cluster map / dendrogram of trajectories with DatetimeIndex. Args: displacements: list of pd.Series, outputs of :func:`traja.calc_displacement()` rule: how to resample series, eg '30s' for 30-seconds nr_steps: select first N samples for clustering colors: list of colors (eg, 'b','r') to map to each trajectory kwargs: keyword arguments for :func:`seaborn.clustermap` Returns: cg: a :func:`seaborn.matrix.ClusterGrid` instance .. note:: Requires seaborn to be installed. Install it with 'pip install seaborn'. """ try: import seaborn as sns except ImportError: logging.error("seaborn is not installed. Install it with 'pip install seaborn'") return after_plot_args, _ = _get_after_plot_args(**kwargs) series_lst = [] for disp in displacements: if rule: disp = disp.resample(rule).sum() series_lst.append(disp) df = pd.DataFrame(series_lst) df.columns = range(len(df.columns)) df.reset_index(drop=True, inplace=True) if not nr_steps: nr_steps = df.shape[1] cg = sns.clustermap( df.fillna(0).iloc[:, :nr_steps], xticklabels=False, col_cluster=False, figsize=(16, 6), cmap="Greys", row_colors=colors, **kwargs, ) plt.setp(cg.ax_heatmap.yaxis.get_majorticklabels(), rotation=0) _process_after_plot_args(**after_plot_args) return cg
def _get_markov_edges(Q: pd.DataFrame, greater_than=0.1): """Select edges greater than a threshold of weight.""" edges = {} for col in Q.columns: for idx in Q.index: if greater_than and Q.loc[idx, col] > greater_than: edges[(idx, col)] = Q.loc[idx, col] return edges def plot_transition_graph( data: Union[pd.DataFrame, traja.TrajaDataFrame, np.ndarray], outpath="", interactive=True, ): """Plot transition graph with networkx. Args: data (trajectory or transition_matrix) .. note:: Modified from """ try: import networkx as nx import pydot import graphviz except ImportError as e: raise ImportError(f"{e} - please install it with pip") if ( isinstance(data, (traja.TrajaDataFrame)) or isinstance(data, pd.DataFrame) and "x" in data ): transition_matrix = traja.transitions(data) edges_wts = _get_markov_edges(pd.DataFrame(transition_matrix)) states_ = list(range(transition_matrix.shape[0])) # create graph object G = nx.MultiDiGraph() # nodes correspond to states G.add_nodes_from(states_) # edges represent transition probabilities for k, v in edges_wts.items(): tmp_origin, tmp_destination = k[0], k[1] G.add_edge(tmp_origin, tmp_destination, weight=v.round(4), label=v.round(4)) pos = nx.drawing.nx_pydot.graphviz_layout(G, prog="dot") nx.draw_networkx(G, pos) # create edge labels for jupyter plot but is not necessary edge_labels = {(n1, n2): d["label"] for n1, n2, d in G.edges(data=True)} nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels) if os.exists(outpath):"Overwriting {outpath}") nx.drawing.nx_pydot.write_dot(G, outpath) if interactive: # Plot from graphviz import Source s = Source.from_file(outpath) s.view() def plot_transition_matrix( data: Union[pd.DataFrame, traja.TrajaDataFrame, np.ndarray], interactive=True, **kwargs, ) -> matplotlib.image.AxesImage: """Plot transition matrix. Args: data (trajectory or square transition matrix) interactive (bool): show plot kwargs: kwargs to :func:`traja.grid_coordinates` Returns: axesimage (matplotlib.image.AxesImage) """ if isinstance(data, np.ndarray): if data.shape[0] != data.shape[1]: raise ValueError( f"Ndarray input must be square transition matrix, shape is {data.shape}" ) transition_matrix = data elif isinstance(data, (pd.DataFrame, traja.TrajaDataFrame)): transition_matrix = traja.transitions(data, **kwargs) img = plt.imshow(transition_matrix) if interactive: return img
[docs]def animate(trj: TrajaDataFrame, polar: bool = True, save: bool = False): """Animate trajectory. Args: polar (bool): include polar bar chart with turn angle save (bool): save video to ``trajectory.mp4`` Returns: anim (matplotlib.animation.FuncAnimation): animation """ from matplotlib import animation from matplotlib.animation import FuncAnimation displacement = traja.trajectory.calc_displacement(trj).reset_index(drop=True) # heading = traja.calc_heading(trj) turn_angle = traja.trajectory.calc_turn_angle(trj).reset_index(drop=True) xy = trj[["x", "y"]].reset_index(drop=True) POLAR_STEPS = XY_STEPS = 20 DISPLACEMENT_THRESH = 0.025 bin_size = 2 overlap = True fig = plt.figure(figsize=(8, 6)) ax1 = plt.subplot(211) fig.add_subplot(ax1) if polar: ax2 = plt.subplot(212, polar="projection") ax2.set_theta_zero_location("N") ax2.set_xticklabels(["0", "45", "90", "135", "180", "-135", "-90", "-45"]) fig.add_subplot(ax2) np.zeros(XY_STEPS), np.zeros(XY_STEPS), width=np.zeros(XY_STEPS), bottom=0.0 ) xlim, ylim = traja.trajectory._get_xylim(trj) ax1.set( xlim=xlim, ylim=ylim, ylabel=trj.__dict__.get("spatial_units", "m"), xlabel=trj.__dict__.get("spatial_units", "m"), aspect="equal", ) alphas = np.linspace(0.1, 1, XY_STEPS) rgba_colors = np.zeros((XY_STEPS, 4)) rgba_colors[:, 0] = 1.0 # red rgba_colors[:, 3] = alphas scat = ax1.scatter( range(XY_STEPS), range(XY_STEPS), marker=".", color=rgba_colors[:XY_STEPS] ) def update(frame_number): if frame_number < (XY_STEPS+2): pass else: ind = frame_number % len(xy) if ind < XY_STEPS: scat.set_offsets(xy[:ind]) else: prev_steps = max(ind - XY_STEPS, 0) scat.set_offsets(xy[prev_steps:ind]) displacement_str = ( rf"$\bf{displacement[ind]:.2f}$" if displacement[ind] >= DISPLACEMENT_THRESH else f"{displacement[ind]:.2f}" ) x, y = xy.iloc[ind] ax1.set_title( f"frame {ind} - distance (cm/0.25s): {displacement_str}\n" f"x: {x:.2f}, y: {y:.2f}\n" f"turn_angle: {turn_angle[ind]:.2f}" ) if polar and ind > 1: ax2.clear() start_index = max(ind - POLAR_STEPS, 0) theta = turn_angle[start_index:ind] radii = displacement[start_index:ind] hist, bin_edges = np.histogram( theta, bins=np.arange(-180, 180 + bin_size, bin_size) ) centers = np.deg2rad(np.ediff1d(bin_edges) // 2 + bin_edges[:-1]) radians = np.deg2rad(theta) width = np.deg2rad(bin_size) angle = radians if overlap else centers height = radii if overlap else hist max_height = displacement.max() if overlap else max(hist) bars =, height, width=width, bottom=0.0) for idx, (h, bar) in enumerate(zip(height, bars)): bar.set_facecolor( / max_height)) bar.set_alpha(0.8 * (idx / POLAR_STEPS)) ax2.set_theta_zero_location("N") ax2.set_xticklabels(["0", "45", "90", "135", "180", "-135", "-90", "-45"]) anim = FuncAnimation(fig, update, interval=10, frames=len(xy)) if save: try:"trajectory.mp4", writer=animation.FFMpegWriter(fps=10)) except FileNotFoundError: raise Exception("FFmpeg not installed, please install it.") else: return anim