Source code for fplore.plot

# -*- coding: utf-8 -*-
import numpy as np
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
from matplotlib.collections import PolyCollection, LineCollection
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from matplotlib import colors as mcolors

from .logging import log
from .util import wigner_seitz_neighbours


[docs]def projected_area(xyz, axis): x = xyz[..., 1 - axis] e = xyz[..., 2] # 0: orig, 1: +dx, 2: +dy, 3: -dx if axis == 1: # project along y y1 = e[..., 3] - e[..., 0] y2 = e[..., 2] - e[..., 1] else: # project along x y1 = e[..., 1] - e[..., 0] y2 = e[..., 2] - e[..., 3] dx = x[..., 2] - x[..., 0] eq_idx = (np.sign(y1) == np.sign(y2)) with np.errstate(invalid='ignore'): p1 = eq_idx * (y1 + y2) # non-intersecting quadrilateral p2 = (1 - eq_idx) * (y1 ** 2 + y2 ** 2) / ( y1 - y2) # self-intersecting quadrilateral p2[np.logical_and(np.isnan(p2), eq_idx)] = 0. # fix y1 = y2 ret = np.abs(.5 * dx * (p1 + p2)) return ret
[docs]def make_quadrilaterals(x, y, z): # cyclical iteration: # cyc_squ[1, 0] = # y -> (2nd axis) # o - 0 - 3 - o x (1st axis) # | | | | | # o - 1 - 2 - o v # | | | | # o - o - o - o # | | | | # o - o - o - o p = np.stack([x, y, z]).transpose(1, 2, 0) # 3,n_x,n_y -> n_x,n_y,3 cyc_squ = np.stack( # 0 +dx +dy -dx [p[:-1, :-1], p[1:, :-1], p[1:, 1:], p[:-1, 1:]]).transpose( (1, 2, 0, 3)) # cyc squ: idx 0, 1: M-1 x N-1, coordinates to circular paths # idx 2: 4, quadrilateral points # idx 3: 3, x/y/z log.debug("Generating {} polygons", np.prod(cyc_squ.shape[:2])) return cyc_squ.reshape(-1, 4, 3)
[docs]def project(x, y, z, axis=1, color=(0., 0., 0., 1.)): """ Projects z(x,y) along an axis. Useful for example for showing bulk states in slab calculations. x, y: MxN (meshgrid with 'ij' indexing) z: MxN axis: int (axis along which to project) Returns: polycollection """ assert x.shape == y.shape == z.shape if np.isnan(z).any(): log.warning("NaN values in projection input") r, g, b, a = color quads = make_quadrilaterals(x, y, z) # x x xs = quads[:, 2, 0] - quads[:, 0, 0] # y y ys = quads[:, 2, 1] - quads[:, 0, 1] areas = xs * ys # for a rectilinear grid # e/z proj_areas = projected_area(quads, axis=axis) with np.errstate(divide='ignore'): alphas = 0.1 * areas / proj_areas alphas[alphas > 1] = 1. idx_visible_axes = [True, True, True] idx_visible_axes[axis] = False pc = quads[..., idx_visible_axes] fcs = np.zeros((len(alphas), 4), dtype=np.float32) fcs[:, :3] = r, g, b fcs[:, 3] = a * alphas pc = PolyCollection(pc, facecolors=fcs, rasterized=True) return pc
[docs]class Arrow3D(FancyArrowPatch): def __init__(self, xs, ys, zs, *args, **kwargs): FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) self._verts3d = xs, ys, zs
[docs] def do_3d_projection(self, renderer=None): xs3d, ys3d, zs3d = self._verts3d xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M) self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) if renderer is not None: # matplotlib<3.6 FancyArrowPatch.draw(self, renderer) return min(zs)
draw = do_3d_projection # for matplotlib<3.5
# from http://stackoverflow.com/questions/23840756/
[docs]def orthogonal_proj(zfront, zback): a = (zfront + zback) / (zfront - zback) b = -2 * (zfront * zback) / (zfront - zback) return np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, a, b], [0, 0, -1e-4, zback]])
[docs]def plot_structure(run, ax): raise NotImplementedError
[docs]def plot_wigner_seitz(ax, lattice, **kwargs): ws_cell = lattice.get_wigner_seitz_cell() ax.add_collection3d(Poly3DCollection(ws_cell, **kwargs))
def _plot_bz(ax, lattice, **kwargs): return plot_wigner_seitz(ax, lattice.reciprocal_lattice, **kwargs)
[docs]def plot_bz(run, ax, vectors='primitive', k_points=False, use_symmetry=False, high_symm_points=True, rot=None, offset=(0, 0, 0)): if rot is None: rot = np.eye(3) offset = np.array(offset) @ run.primitive_lattice.reciprocal_lattice.matrix def cc(arg): return mcolors.to_rgba(arg, alpha=0.1) if k_points: if use_symmetry: points = run.band.symm_data['k'] else: points = run.band.data['k'] points = (points + offset) @ rot ax.plot(*points.T, marker='.', ls='', label='sample k-point', ms=1) if vectors in ('primitive', 'conventional'): lattice = run.primitive_lattice if vectors == 'primitive' else run.lattice for vec, label in zip(lattice.reciprocal_lattice.matrix, 'abc'): vec = (vec + offset) @ rot origin = offset @ rot ax.add_artist(Arrow3D(*zip(origin, vec), mutation_scale=20, lw=3, arrowstyle="-|>", color="r")) ax.text(*vec, s=label, color="r") facets = [[(coord + offset) @ rot for coord in facet] for facet in run.brillouin_zone] ax.add_collection3d( Poly3DCollection(facets, facecolors=cc('k'), edgecolors='k')) if high_symm_points: points = run.high_symm_kpoints points = {k: (v + offset) @ rot for k, v in points.items()} ax.plot(*zip(*points.values()), marker='o', ls='', label='high symmetry point', color='k', ms='1') for kpath in run.high_symm_kpaths: path = [points[lbl] for lbl in kpath] ax.plot(*zip(*path), ls='-', color='k', alpha=0.5) for label, coord in points.items(): ax.text(*coord, s='${}$'.format(label), color='k') ax.set_xlabel('$k_x$') ax.set_ylabel('$k_y$') ax.set_zlabel('$k_z$')
[docs]def plot_bz_proj(run, ax, neighbours=False, rot=None, axis=-1, vectors=True, **kwargs): """Projects along given axis (default: last axis) after applying rotation matrix rot""" if rot is None: rot = np.eye(3) visible_axes = [True] * 3 visible_axes[axis] = False lines = [] if not neighbours: neighbours = [] elif neighbours is True: neighbours = wigner_seitz_neighbours(run.primitive_lattice.reciprocal_lattice) else: neighbours = np.array(neighbours) neighbours = neighbours @ run.primitive_lattice.reciprocal_lattice.matrix for facet in run.brillouin_zone: facet = np.stack(facet) facet = np.array(list(zip(facet, np.roll(facet, -1, axis=0)))) lines.extend(facet) for nb in neighbours: lines.extend((facet + nb)) # project facets P = rot[:, visible_axes] lines = [facet @ P for facet in lines] # lines = np.array(lines) # todo: remove duplicate lines if vectors: x, y, z = P ax.arrow(0, 0, *x) ax.text(*x, '100') ax.arrow(0, 0, *y) ax.text(*y, '010') ax.arrow(0, 0, *z) ax.text(*z, '001') lines = LineCollection(lines, label='Brillouin zone', **kwargs) ax.add_collection(lines) ax.set_aspect('equal') ax.autoscale_view(tight=True)