"""Module describing the scattering matrix used in the PRISM algorithm."""from__future__importannotationsimportinspectimportoperatorimportwarningsfromabcimportabstractmethodfromfunctoolsimportpartial,reduceimportdask.arrayasdaimportnumpyasnpfromaseimportAtomsfromdask.graph_manipulationimportwait_onfromabtem.arrayimport_validate_lazy,ArrayObject,ComputableListfromabtem.core.axesimport(OrdinalAxis,AxisMetadata,ScanAxis,UnknownAxis,WaveVectorAxis,)fromabtem.core.backendimportget_array_module,cp,validate_device,copy_to_devicefromabtem.core.chunksimportchunk_ranges,validate_chunks,equal_sized_chunks,Chunksfromabtem.core.compleximportcomplex_exponentialfromabtem.core.energyimportAcceleratorfromabtem.core.ensembleimportEnsemble,_wrap_with_arrayfromabtem.core.gridimportGrid,GridUndefinedErrorfromabtem.core.utilsimport(safe_ceiling_int,expand_dims_to_broadcast,ensure_list,CopyMixin,EqualityMixin,tuple_range,)fromabtem.detectorsimport(BaseDetector,_validate_detectors,WavesDetector,FlexibleAnnularDetector,AnnularDetector,)fromabtem.measurementsimportBaseMeasurementsfromabtem.multisliceimport(allocate_multislice_measurements,multislice_and_detect,)fromabtem.potentials.iamimportBasePotential,_validate_potentialfromabtem.prism.utilsimport(plane_waves,wrapped_crop_2d,minimum_crop,batch_crop_2d,)fromabtem.scanimportBaseScan,_validate_scan,GridScanfromabtem.transferimportCTFfromabtem.wavesimportBaseWaves,_antialias_cutoff_gptsfromabtem.wavesimportWaves,Probedef_extract_measurement(array,index):ifarray.size==0:returnarrayarray=array.item()[index].arrayreturnarraydef_wrap_measurements(measurements):returnmeasurements[0]iflen(measurements)==1elseComputableList(measurements)def_finalize_lazy_measurements(arrays,waves,detectors,extra_ensemble_axes_metadata=None,chunks=None):ifextra_ensemble_axes_metadataisNone:extra_ensemble_axes_metadata=[]measurements=[]fori,detectorinenumerate(detectors):base_shape=detector._out_base_shape(waves)ifisinstance(detector,AnnularDetector):# TODObase_shape=()meta=detector._out_meta(waves)new_axis=tuple(range(len(arrays.shape),len(arrays.shape)+len(base_shape)))ifchunksisNone:chunks=arrays.chunksarray=arrays.map_blocks(_extract_measurement,i,chunks=chunks+tuple((n,)forninbase_shape),new_axis=new_axis,meta=meta,)ensemble_axes_metadata=detector._out_ensemble_axes_metadata(waves)base_axes_metadata=detector._out_base_axes_metadata(waves)axes_metadata=ensemble_axes_metadata+base_axes_metadatametadata=detector._out_metadata(waves)cls=detector._out_type(waves)axes_metadata=extra_ensemble_axes_metadata+axes_metadatameasurement=cls.from_array_and_metadata(array,axes_metadata=axes_metadata,metadata=metadata)ifhasattr(measurement,"reduce_ensemble"):measurement=measurement.reduce_ensemble()measurements.append(measurement)returnmeasurementsdef_round_gpts_to_multiple_of_interpolation(gpts:tuple[int,int],interpolation:tuple[int,int])->tuple[int,int]:returntuple(n+(-n)%fforf,ninzip(interpolation,gpts))# noqa
[docs]classBaseSMatrix(BaseWaves):"""Base class for scattering matrices."""_device:strensemble_axes_metadata:list[AxisMetadata]ensemble_shape:tuple[int,...]_base_dims=3@propertydefdevice(self):"""The device where the S-Matrix is created and reduced."""returnself._device@property@abstractmethoddefinterpolation(self):"""Interpolation factor in the `x` and `y` directions"""pass@property@abstractmethoddefwave_vectors(self)->np.ndarray:"""The wave vectors corresponding to each plane wave."""pass@property@abstractmethoddefsemiangle_cutoff(self)->float:"""The radial cutoff of the plane-wave expansion [mrad]."""pass@property@abstractmethoddefwindow_extent(self):"""The cropping window extent of the waves."""pass@property@abstractmethoddefwindow_gpts(self):"""The number of grid points describing the cropping window of the wave functions."""passdef__len__(self)->int:returnlen(self.wave_vectors)@propertydefbase_axes_metadata(self)->list[AxisMetadata]:wave_axes_metadata=super().base_axes_metadatareturn[WaveVectorAxis(label="q",values=tuple(tuple(value)forvalueinself.wave_vectors),),wave_axes_metadata[0],wave_axes_metadata[1],]
[docs]defdummy_probes(self,scan:BaseScan=None,ctf:CTF=None,plane:str="entrance",**kwargs)->Probe:# TODO""" A probe or an ensemble of probes equivalent reducing the SMatrix at a single position. Parameters ---------- scan : BaseScan ctf : CTF plane : str Returns ------- dummy_probes : Probes """ifctfisNone:ctf=CTF(energy=self.energy,semiangle_cutoff=self.semiangle_cutoff)elifisinstance(ctf,dict):ctf=CTF(energy=self.energy,semiangle_cutoff=self.semiangle_cutoff,**ctf)elifisinstance(ctf,CTF):ctf=ctf.copy()else:raiseValueError()ifplane=="exit":defocus=0.0ifhasattr(self,"potential"):ifself.potentialisnotNone:defocus=self.potential.thicknesselif"accumulated_defocus"inself.metadata:defocus=self.metadata["accumulated_defocus"]ctf.defocus=ctf.defocus-defocusctf.semiangle_cutoff=min(ctf.semiangle_cutoff,self.semiangle_cutoff)default_kwargs={"device":self.device,"metadata":{**self.metadata}}kwargs={**default_kwargs,**kwargs}probes=Probe._from_ctf(extent=self.window_extent,gpts=self.window_gpts,ctf=ctf,energy=self.energy,**kwargs,)ifscanisnotNone:probes._positions=scanreturnprobes
def_validate_interpolation(interpolation:int|tuple[int,int]):ifisinstance(interpolation,int):interpolation=(interpolation,)*2elifnotlen(interpolation)==2:raiseValueError("Interpolation factor must be an integer.")returntuple(interpolation)def_common_kwargs(a,b):a_kwargs=inspect.signature(a).parameters.keys()b_kwargs=inspect.signature(b).parameters.keys()returnset(a_kwargs).intersection(b_kwargs)def_pack_wave_vectors(wave_vectors):returntuple((float(wave_vector[0]),float(wave_vector[1]))forwave_vectorinwave_vectors)def_chunked_axis(s_matrix_array):window_margin=s_matrix_array._window_marginargsort=np.argsort((-s_matrix_array.gpts[0]//window_margin[0],-s_matrix_array.gpts[1]//window_margin[1],))returnint(argsort[0]),int(argsort[1])def_chunks_for_multiple_rechunk_reduce(partitions):chunks_1=()chunk_indices_1=()foriinrange(1,len(partitions)-1,3):chunks_1+=(sum(partitions[i-1:i+2]),)chunk_indices_1+=(i-1,)chunks_1=chunks_1+(sum(partitions[i+2:]),)assertsum(chunks_1)==sum(partitions)chunks_2=(sum(partitions[:1]),)chunk_indices_2=()foriinrange(2,len(partitions)-1,3):chunks_2+=(sum(partitions[i-1:i+2]),)chunk_indices_2+=(i-1,)chunks_2=chunks_2+(sum(partitions[i+2:]),)assertsum(chunks_2)==sum(partitions)chunks_3=(sum(partitions[:2]),)chunk_indices_3=()foriinrange(3,len(partitions)-1,3):chunks_3+=(sum(partitions[i-1:i+2]),)chunk_indices_3+=(i-1,)chunks_3=chunks_3+(sum(partitions[i+2:]),)assertsum(chunks_3)==sum(partitions)assertlen(chunk_indices_1+chunk_indices_2+chunk_indices_3)==(len(partitions)-2)return(chunks_1,chunks_2,chunks_3),(chunk_indices_1,chunk_indices_2,chunk_indices_3,)def_lazy_reduce(array,waves_partial,ensemble_axes_metadata,from_waves_kwargs,scan,ctf,detectors,max_batch_reduction,):args=(array,ensemble_axes_metadata)waves=waves_partial(args).item()s_matrix=SMatrixArray._from_waves(waves,**from_waves_kwargs)measurements=s_matrix._batch_reduce_to_measurements(scan,ctf,detectors,max_batch_reduction)arr=np.zeros((1,)*(len(array.shape)-1),dtype=object)arr.itemset(measurements)returnarrdef_map_blocks(array,scans,block_indices,window_offset=(0,0),**kwargs):ctf_chunks=tuple((n,)forninkwargs["ctf"].ensemble_shape)blocks=()fori,scaninzip(block_indices,scans):block=array.blocks[(slice(None),)*(len(array.shape)-2)+i]new_chunks=array.chunks[:-3]+ctf_chunks+scan.shapekwargs["from_waves_kwargs"]["window_offset"]=(window_offset[0]+sum(array.chunks[-2][:i[0]]),window_offset[1]+sum(array.chunks[-1][:i[1]]),)iflen(scan.shape)==1:drop_axis=(len(array.shape)-3,len(array.shape)-1)eliflen(scan.shape)==2:drop_axis=(len(array.shape)-3,)else:raiseNotImplementedErrorblock=da.map_blocks(_lazy_reduce,block,scan=scan,drop_axis=drop_axis,chunks=new_chunks,**kwargs,meta=np.array((),dtype=np.complex64),)iflen(scan)==0:block=da.zeros((0,)*len(block.shape),dtype=np.complex64,)blocks+=(block,)returnblocksdef_tuple_from_indices(*args):temp_list=[None]*(len(args)//2)forarg1,arg2inzip(args[::2],args[1::2]):temp_list[arg1]=arg2returntuple(temp_list)def_multiple_rechunk_reduce(s_matrix_array,scan,detectors,ctf,max_batch_reduction):assertnp.all(s_matrix_array.periodic)window_margin=s_matrix_array._window_marginchunked_axis,nochunks_axis=_chunked_axis(s_matrix_array)pad_amounts=_tuple_from_indices(chunked_axis,(window_margin[chunked_axis],)*2,nochunks_axis,(0,0))s_matrix_array=s_matrix_array._pad(pad_amounts)chunk_size=window_margin[chunked_axis]size=s_matrix_array.shape[-2:][chunked_axis]-window_margin[chunked_axis]*2num_chunks=-(size//-chunk_size)partitions=_tuple_from_indices(chunked_axis,(chunk_size,)*num_chunks,nochunks_axis,(s_matrix_array.shape[len(s_matrix_array.shape)-2+chunked_axis],),)chunk_extents=tuple(tuple(((cc[0])*d,(cc[1])*d)forccinc)forc,dinzip(chunk_ranges(partitions),s_matrix_array.sampling))scan,scan_chunks=scan._sort_into_extents(chunk_extents)scans=[(indices,scan.item())forindices,_,scaninscan.generate_blocks(scan_chunks)]partitions=(pad_amounts[chunked_axis][0],)+partitions[chunked_axis]partitions=partitions+(s_matrix_array.shape[len(s_matrix_array.shape)-2+chunked_axis]-sum(partitions),)(chunks_1,chunks_2,chunks_3),(scan_indices_1,scan_indices_2,scan_indices_3,)=_chunks_for_multiple_rechunk_reduce(partitions)chunks_1=(s_matrix_array.array.chunks[:-3]+(-1,)+_tuple_from_indices(chunked_axis,chunks_1,nochunks_axis,-1))chunks_2=(s_matrix_array.array.chunks[:-3]+(-1,)+_tuple_from_indices(chunked_axis,chunks_2,nochunks_axis,-1))chunks_3=(s_matrix_array.array.chunks[:-3]+(-1,)+_tuple_from_indices(chunked_axis,chunks_3,nochunks_axis,-1))shape=tuple(len(c)forcinscan_chunks)blocks=np.zeros(shape,dtype=object)kwargs={"waves_partial":s_matrix_array.waves._from_partitioned_args(),"ensemble_axes_metadata":s_matrix_array.waves.ensemble_axes_metadata,"from_waves_kwargs":s_matrix_array._copy_kwargs(exclude=("array","extent")),"ctf":ctf,"detectors":detectors,"max_batch_reduction":max_batch_reduction,}array=s_matrix_array.array.rechunk(chunks_1)window_offset=s_matrix_array.window_offsetblock_indices=[_tuple_from_indices(chunked_axis,i,nochunks_axis,0)foriinrange(len(scan_indices_1))]new_blocks=_map_blocks(array,[scans[i][1]foriinscan_indices_1],block_indices,window_offset=window_offset,**kwargs,)fori,blockinzip(scan_indices_1,new_blocks):blocks.itemset(scans[i][0],block)ifs_matrix_array.ensemble_shape:fp_arrays=[]foriinnp.ndindex(s_matrix_array.ensemble_shape):try:fp_new_blocks=tuple(block[i]forblockinnew_blocks)fp_array=wait_on(array[i],*fp_new_blocks)[0]fp_arrays.append(fp_array)exceptIndexError:fp_arrays.append(array[i])array=da.stack(fp_arrays,axis=0)array=array.rechunk(chunks_2)block_indices=[_tuple_from_indices(chunked_axis,i,nochunks_axis,0)foriinrange(1,len(scan_indices_2)+1)]new_blocks=_map_blocks(array,[scans[i][1]foriinscan_indices_2],block_indices,window_offset=window_offset,**kwargs,)fori,blockinzip(scan_indices_2,new_blocks):blocks.itemset(scans[i][0],block)ifs_matrix_array.ensemble_shape:fp_arrays=[]foriinnp.ndindex(s_matrix_array.ensemble_shape):try:fp_new_blocks=tuple(block[i]forblockinnew_blocks)fp_array=wait_on(array[i],*fp_new_blocks)[0]fp_arrays.append(fp_array)exceptIndexError:fp_arrays.append(array[i])array=da.stack(fp_arrays,axis=0)array=array.rechunk(chunks_3)block_indices=[_tuple_from_indices(chunked_axis,i,nochunks_axis,0)foriinrange(1,len(scan_indices_3)+1)]new_blocks=_map_blocks(array,[scans[i][1]foriinscan_indices_3],block_indices,window_offset=window_offset,**kwargs,)fori,blockinzip(scan_indices_3,new_blocks):blocks.itemset(scans[i][0],block)array=da.block(blocks.tolist())dummy_probes=s_matrix_array.dummy_probes(scan=scan,ctf=ctf)measurements=_finalize_lazy_measurements(array,waves=dummy_probes,detectors=detectors,extra_ensemble_axes_metadata=s_matrix_array.ensemble_axes_metadata,)returnmeasurementsdef_single_rechunk_reduce(s_matrix_array:"SMatrixArray",scan:BaseScan,detectors:list[BaseDetector],ctf:CTF,max_batch_reduction:int,):chunked_axis,nochunks_axis=_chunked_axis(s_matrix_array)num_chunks=(s_matrix_array.gpts[chunked_axis]//s_matrix_array._window_margin[chunked_axis])chunks=equal_sized_chunks(s_matrix_array.shape[-2:][chunked_axis],num_chunks=num_chunks)assertnp.all(np.array(chunks)>s_matrix_array._window_margin[chunked_axis])chunks=(s_matrix_array.array.chunks[:-3]+(-1,)+_tuple_from_indices(chunked_axis,chunks,nochunks_axis,-1))array=s_matrix_array._array.rechunk(chunks)assertall(s_matrix_array.periodic)# chunk_extents = tuple(# tuple(((cc[0]) * d, (cc[1]) * d) for cc in c)# for c, d in zip(chunk_ranges(array.chunks[-2:]), s_matrix_array.sampling)# )chunk_extents_x=tuple(((cc[0])*s_matrix_array.sampling[0],(cc[1])*s_matrix_array.sampling[0])forccinarray.chunks[-2])chunk_extents_y=tuple(((cc[0])*s_matrix_array.sampling[1],(cc[1])*s_matrix_array.sampling[1])forccinarray.chunks[-1])chunk_extents=(chunk_extents_x,chunk_extents_y)scan,scan_chunks=scan._sort_into_extents(chunk_extents)ctf_chunks=tuple((n,)forninctf.ensemble_shape)chunks=array.chunks[:-3]+ctf_chunksshape=tuple(len(c)forc,pinzip(scan_chunks,s_matrix_array.periodic))blocks=np.zeros((1,)*len(array.shape[:-3])+shape,dtype=object)kwargs={"waves_partial":s_matrix_array.waves._from_partitioned_args(),"ensemble_axes_metadata":s_matrix_array.waves.ensemble_axes_metadata,"from_waves_kwargs":s_matrix_array._copy_kwargs(exclude=("array","extent")),"ctf":ctf,"detectors":detectors,"max_batch_reduction":max_batch_reduction,}forindices,_,sub_scaninscan.generate_blocks(scan_chunks):sub_scan=sub_scan.item()iflen(sub_scan)==0:blocks.itemset((0,)*len(array.shape[:-3])+indices,da.zeros((0,)*len(blocks.shape),dtype=np.complex64,),)continueslics=(slice(None),)*(len(array.shape)-2)window_offset=()fori,kinenumerate(indices):iflen(array.chunks[-2:][i])>1:slics+=([k-1,k,(k+1)%len(array.chunks[-2:][i])],)window_offset+=(sum(array.chunks[-2:][i][:k])-array.chunks[-2:][i][k-1],)else:slics+=(slice(None),)window_offset+=(0,)new_block=array.blocks[slics]new_block=new_block.rechunk(array.chunks[:-2]+(-1,-1))new_chunks=chunks+sub_scan.shapekwargs["from_waves_kwargs"]["window_offset"]=tuple(window_offset)iflen(scan.shape)==1:drop_axis=(len(array.shape)-3,len(array.shape)-1)eliflen(scan.shape)==2:drop_axis=(len(array.shape)-3,)else:raiseNotImplementedErrornew_block=da.map_blocks(_lazy_reduce,new_block,scan=sub_scan,drop_axis=drop_axis,chunks=new_chunks,**kwargs,meta=np.array((),dtype=np.complex64),)blocks.itemset((0,)*len(array.shape[:-3])+indices,new_block)array=da.block(blocks.tolist())dummy_probes=s_matrix_array.dummy_probes(scan=scan,ctf=ctf)measurements=_finalize_lazy_measurements(array,waves=dummy_probes,detectors=detectors,extra_ensemble_axes_metadata=s_matrix_array.ensemble_axes_metadata,)returnmeasurementsdef_no_chunks_reduce(s_matrix_array:"SMatrixArray",scan:BaseScan,detectors:list[BaseDetector],ctf:CTF,max_batch_reduction:int,):array=s_matrix_array._array.rechunk(s_matrix_array.array.chunks[:-3]+(-1,-1,-1))kwargs={"waves_partial":s_matrix_array.waves._from_partitioned_args(),"ensemble_axes_metadata":s_matrix_array.waves.ensemble_axes_metadata,"from_waves_kwargs":s_matrix_array._copy_kwargs(exclude=("array","extent")),"ctf":ctf,"detectors":detectors,"max_batch_reduction":max_batch_reduction,}ctf_chunks=tuple((n,)forninctf.ensemble_shape)chunks=array.chunks[:-3]+ctf_chunks+scan.shapeiflen(scan.shape)==1:drop_axis=(len(array.shape)-3,len(array.shape)-1)eliflen(scan.shape)==2:drop_axis=(len(array.shape)-3,)else:raiseNotImplementedErrorarray=da.map_blocks(_lazy_reduce,array,scan=scan,drop_axis=drop_axis,chunks=chunks,**kwargs,meta=np.array((),dtype=np.complex64),)dummy_probes=s_matrix_array.dummy_probes(scan=scan,ctf=ctf)measurements=_finalize_lazy_measurements(array,waves=dummy_probes,detectors=detectors,extra_ensemble_axes_metadata=s_matrix_array.ensemble_axes_metadata,)returnmeasurements
[docs]classSMatrixArray(BaseSMatrix,ArrayObject):""" A scattering matrix defined by a given array of dimension 3, where the first indexes the probe plane waves and the latter two are the `y` and `x` scan directions. Parameters ---------- array : np.ndarray Array defining the scattering matrix. Must be 3D or higher, dimensions before the last three dimensions should represent ensemble dimensions, the next dimension indexes the plane waves and the last two dimensions represent the spatial extent of the plane waves. wave_vectors : np.ndarray Array defining the wave vectors corresponding to each plane wave. Must have shape Nx2, where N is equal to the number of plane waves. semiangle_cutoff : float The radial cutoff of the plane-wave expansion [mrad]. energy : float Electron energy [eV]. sampling : one or two float, optional Lateral sampling of wave functions [1 / Å]. Provide only if potential is not given. Will be ignored if 'gpts' is also provided. extent : one or two float, optional Lateral extent of wave functions [Å]. Provide only if potential is not given. interpolation : one or two int, optional Interpolation factor in the `x` and `y` directions (default is 1, ie. no interpolation). If a single value is provided, assumed to be the same for both directions. window_gpts : tuple of int The number of grid points describing the cropping window of the wave functions. window_offset : tuple of int The number of grid points from the origin the cropping windows of the wave functions is displaced. periodic: tuple of bool Specifies whether the SMatrix should be assumed to be periodic along the x and y-axis. device : str, optional The calculations will be carried out on this device ('cpu' or 'gpu'). Default is 'cpu'. The default is determined by the user configuration. ensemble_axes_metadata : list of AxesMetadata Axis metadata for each ensemble axis. The axis metadata must be compatible with the shape of the array. metadata : dict A dictionary defining wave function metadata. All items will be added to the metadata of measurements derived from the waves. """
[docs]defcopy_to_device(self,device:str)->"SMatrixArray":"""Copy SMatrixArray to specified device."""s_matrix=super().copy_to_device(device)s_matrix._wave_vectors=copy_to_device(self._wave_vectors,device)returns_matrix
@staticmethoddef_packed_wave_vectors(wave_vectors):return_pack_wave_vectors(wave_vectors)@propertydefdevice(self):"""The device on which the SMatrixArray is reduced."""returnself._device@propertydefstorage_device(self):"""The device on which the SMatrixArray is stored."""returnsuper().device@classmethoddef_from_waves(cls,waves:Waves,**kwargs):common_kwargs=_common_kwargs(cls,Waves)kwargs.update({key:getattr(waves,key)forkeyincommon_kwargs})kwargs["ensemble_axes_metadata"]=kwargs["ensemble_axes_metadata"][:-1]returncls(**kwargs)@propertydefwaves(self)->Waves:"""The wave vectors describing each plane wave."""kwargs={key:getattr(self,key)forkeyin_common_kwargs(self.__class__,Waves)}kwargs["ensemble_axes_metadata"]=(kwargs["ensemble_axes_metadata"]+self.base_axes_metadata[:-2])returnWaves(**kwargs)def_copy_with_new_waves(self,waves):keys=set(inspect.signature(self.__class__).parameters.keys())-_common_kwargs(self.__class__,Waves)kwargs={key:getattr(self,key)forkeyinkeys}returnself._from_waves(waves,**kwargs)@propertydefperiodic(self)->tuple[bool,bool]:"""If True the SMatrix is assumed to be periodic along corresponding axis."""returnself._periodic@propertydefmetadata(self)->dict:self._metadata["energy"]=self.energyreturnself._metadata@propertydefensemble_axes_metadata(self)->list[AxisMetadata]:"""Axis metadata for each ensemble axis."""returnself._ensemble_axes_metadata@propertydefensemble_shape(self)->tuple[int,int]:returnself.array.shape[:-3]@propertydefinterpolation(self)->tuple[int,int]:returnself._interpolation
@propertydefsemiangle_cutoff(self)->float:"""The cutoff semiangle of the plane wave expansion."""returnself._semiangle_cutoff@propertydefwave_vectors(self)->np.ndarray:returnself._wave_vectors@propertydefwindow_gpts(self)->tuple[int,int]:returnself._window_gpts@propertydefwindow_extent(self)->tuple[float,float]:return(self.window_gpts[0]*self.sampling[0],self.window_gpts[1]*self.sampling[1],)@propertydefwindow_offset(self)->tuple[float,float]:"""The number of grid points from the origin the cropping windows of the wave functions is displaced."""returnself._window_offset
[docs]defreduce(self,scan:BaseScan=None,ctf:CTF=None,detectors:BaseDetector|list[BaseDetector]=None,max_batch_reduction:int|str="auto",reduction_scheme:str="auto",)->BaseMeasurements|Waves|list[BaseMeasurements|Waves]:""" Scan the probe across the potential and record a measurement for each detector. Parameters ---------- detectors : list of Detector objects The detectors recording the measurements. scan : Scan object Scan defining the positions of the probe wave functions. ctf: CTF object, optional The probe contrast transfer function. Default is None (aperture is set by the planewave cutoff). max_batch_reduction : int or str, optional Number of positions per reduction operation. A large number of positions better utilize thread parallelization, but requires more memory and floating point operations. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". rechunk : two int or str, optional Partitioning of the scan. The scattering matrix will be reduced in similarly partitioned chunks. Should be equal to or greater than the interpolation. """self.accelerator.check_is_defined()ifctfisNone:ctf=CTF(semiangle_cutoff=self.semiangle_cutoff)ctf.grid.match(self.dummy_probes())ctf.accelerator.match(self)ifctf.semiangle_cutoff==np.inf:ctf.semiangle_cutoff=self.semiangle_cutoffifnotisinstance(scan,BaseScan):squeeze=(-3,)else:squeeze=()ifscanisNone:squeeze_scan=Truescan=self.extent[0]/2,self.extent[1]/2scan=_validate_scan(scan,Probe._from_ctf(extent=self.extent,ctf=ctf,energy=self.energy))detectors=_validate_detectors(detectors)max_batch_reduction=self._validate_max_batch_reduction(scan,max_batch_reduction)reduction_scheme=self._validate_reduction_scheme(reduction_scheme)ifself.is_lazy:ifreduction_scheme=="multiple-rechunk":measurements=_multiple_rechunk_reduce(self,scan,detectors,ctf,max_batch_reduction)elifreduction_scheme=="single-rechunk":measurements=_single_rechunk_reduce(self,scan,detectors,ctf,max_batch_reduction)elifreduction_scheme=="no-chunks":measurements=_no_chunks_reduce(self,scan,detectors,ctf,max_batch_reduction)else:raiseValueError()else:measurements=self._batch_reduce_to_measurements(scan,ctf,detectors,max_batch_reduction)measurements=[measurement.squeeze(squeeze)formeasurementinmeasurements]return_wrap_measurements(measurements)
[docs]defscan(self,scan:BaseScan=None,detectors:BaseDetector|list[BaseDetector]=None,ctf:CTF=None,max_batch_reduction:int|str="auto",rechunk:tuple[int,int]|str="auto",):""" Reduce the SMatrix using coefficients calculated by a BaseScan and a CTF, to obtain the exit wave functions at given initial probe positions and aberrations. Parameters ---------- scan : BaseScan Positions of the probe wave functions. If not given, scans across the entire potential at Nyquist sampling. detectors : BaseDetector, list of BaseDetector, optional A detector or a list of detectors defining how the wave functions should be converted to measurements after running the multislice algorithm. See abtem.measurements.detect for a list of implemented detectors. ctf : CTF Contrast transfer function from used for calculating the expansion coefficients in the reduction of the SMatrix. max_batch_reduction : int or str, optional Number of positions per reduction operation. A large number of positions better utilize thread parallelization, but requires more memory and floating point operations. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". rechunk : str or tuple of int, optional Parallel reduction of the SMatrix requires rechunking the Dask array from chunking along the expansion axis to chunking over the spatial axes. If given as a tuple of int of length the SMatrix is rechunked to have those chunks. If 'auto' (default) the chunks are taken to be identical to the interpolation factor. Returns ------- detected_waves : BaseMeasurements or list of BaseMeasurement The detected measurement (if detector(s) given). exit_waves : Waves Wave functions at the exit plane(s) of the potential (if no detector(s) given). """ifscanisNone:scan=GridScan()ifdetectorsisNone:detectors=[FlexibleAnnularDetector()]returnself.reduce(scan=scan,ctf=ctf,detectors=detectors,max_batch_reduction=max_batch_reduction,rechunk=rechunk,)
[docs]classSMatrix(BaseSMatrix,Ensemble,CopyMixin,EqualityMixin):""" The scattering matrix is used for simulating STEM experiments using the PRISM algorithm. Parameters ---------- semiangle_cutoff : float The radial cutoff of the plane-wave expansion [mrad]. energy : float Electron energy [eV]. potential : Atoms or AbstractPotential, optional Atoms or a potential that the scattering matrix represents. If given as atoms, a default potential will be created. If nothing is provided the scattering matrix will represent a vacuum potential, in which case the sampling and extent must be provided. gpts : one or two int, optional Number of grid points describing the scattering matrix. Provide only if potential is not given. sampling : one or two float, optional Lateral sampling of scattering matrix [1 / Å]. Provide only if potential is not given. Will be ignored if 'gpts' is also provided. extent : one or two float, optional Lateral extent of scattering matrix [Å]. Provide only if potential is not given. interpolation : one or two int, optional Interpolation factor in the `x` and `y` directions (default is 1, ie. no interpolation). If a single value is provided, assumed to be the same for both directions. downsample : {'cutoff', 'valid'} or float or bool Controls whether to downsample the scattering matrix after running the multislice algorithm. ``cutoff`` : Downsample to the antialias cutoff scattering angle (default). ``valid`` : Downsample to the largest rectangle that fits inside the circle with a radius defined by the antialias cutoff scattering angle. float : Downsample to a specified maximum scattering angle [mrad]. device : str, optional The calculations will be carried out on this device ('cpu' or 'gpu'). Default is 'cpu'. The default is determined by the user configuration. store_on_host : bool, optional If True, store the scattering matrix in host (cpu) memory so that the necessary memory is transferred as chunks to the device to run calculations (default is False). """
[docs]def__init__(self,semiangle_cutoff:float,energy:float,potential:Atoms|BasePotential=None,gpts:int|tuple[int,int]=None,sampling:float|tuple[float,float]=None,extent:float|tuple[float,float]=None,interpolation:int|tuple[int,int]=1,downsample:bool|str="cutoff",# tilt: Tuple[float, float] = (0.0, 0.0),device:str=None,store_on_host:bool=False,):ifdownsampleisTrue:downsample="cutoff"self._device=validate_device(device)self._grid=Grid(extent=extent,gpts=gpts,sampling=sampling)ifpotentialisNone:try:self.grid.check_is_defined()exceptGridUndefinedError:raiseValueError("Provide a potential or provide 'extent' and 'gpts'.")else:potential=_validate_potential(potential)self.grid.match(potential)self._grid=potential.gridself._potential=potentialself._interpolation=_validate_interpolation(interpolation)self._semiangle_cutoff=semiangle_cutoffself._downsample=downsampleself._accelerator=Accelerator(energy=energy)# self._beam_tilt = BeamTilt(tilt=tilt)self._store_on_host=store_on_hostassertsemiangle_cutoff>0.0ifnotall(n%f==0forf,ninzip(self.interpolation,self.gpts)):warnings.warn("The interpolation factor does not exactly divide 'gpts', normalization may not be exactly preserved.")
@propertydefbase_shape(self)->tuple[int,int,int]:"""Shape of the base axes of the SMatrix."""returnlen(self),self.gpts[0],self.gpts[1]@propertydeftilt(self):"""The small-angle tilt of applied to the Fresnel propagator [mrad]."""return0.0,0.0
[docs]defround_gpts_to_interpolation(self)->SMatrix:""" Round the gpts of the SMatrix to the closest multiple of the interpolation factor. Returns ------- s_matrix_with_rounded_gpts : SMatrix """rounded=_round_gpts_to_multiple_of_interpolation(self.gpts,self.interpolation)ifrounded==self.gpts:returnselfself.gpts=roundedreturnself
@propertydefdownsample(self)->str|bool:"""How to downsample the scattering matrix after running the multislice algorithm."""returnself._downsample@propertydefstore_on_host(self)->bool:"""Store the SMatrix in host memory. The reduction may still be calculated on the device."""returnself._store_on_host@propertydefmetadata(self):return{"energy":self.energy}@propertydefshape(self)->tuple[int,...]:"""Shape of the SMatrix."""returnself.ensemble_shape+(len(self),)+self.gpts@propertydefensemble_shape(self)->tuple[int,...]:"""Shape of the SMatrix ensemble axes."""ifself.potentialisNone:return()else:returnself.potential.ensemble_shape@propertydefensemble_axes_metadata(self):"""Axis metadata for each ensemble axis."""ifself.potentialisNone:return[]else:returnself.potential.ensemble_axes_metadata@propertydefwave_vectors(self)->np.ndarray:self.grid.check_is_defined()self.accelerator.check_is_defined()dummy_probes=self.dummy_probes(device="cpu")aperture=dummy_probes.aperture._evaluate_kernel(dummy_probes)indices=np.where(aperture>0.0)n=np.fft.fftfreq(aperture.shape[0],d=1/aperture.shape[0])[indices[0]]m=np.fft.fftfreq(aperture.shape[1],d=1/aperture.shape[1])[indices[1]]w,h=self.extentkx=n/w*np.float32(self.interpolation[0])ky=m/h*np.float32(self.interpolation[1])xp=get_array_module(self.device)returnxp.asarray([kx,ky]).T@propertydefpotential(self)->BasePotential:"""The potential described by the SMatrix."""returnself._potential@potential.setterdefpotential(self,potential:BasePotential):self._potential=potentialself._grid=potential.grid@propertydefsemiangle_cutoff(self)->float:"""Plane-wave expansion cutoff."""returnself._semiangle_cutoff@semiangle_cutoff.setterdefsemiangle_cutoff(self,value:float):self._semiangle_cutoff=value@propertydefinterpolation(self)->tuple[int,int]:returnself._interpolationdef_wave_vector_chunks(self,max_batch):ifisinstance(max_batch,int):max_batch=max_batch*reduce(operator.mul,self.gpts)chunks=validate_chunks(shape=(len(self),)+self.gpts,chunks=("auto",-1,-1),limit=max_batch,dtype=np.dtype("complex64"),device=self.device,)returnchunks@propertydefdownsampled_gpts(self)->tuple[int,int]:"""The gpts of the SMatrix after downsampling."""ifself.downsample:downsampled_gpts=self._gpts_within_angle(self.downsample)rounded=_round_gpts_to_multiple_of_interpolation(downsampled_gpts,self.interpolation)returnroundedelse:returnself.gpts@propertydefwindow_gpts(self):return(safe_ceiling_int(self.downsampled_gpts[0]/self.interpolation[0]),safe_ceiling_int(self.downsampled_gpts[1]/self.interpolation[1]),)@propertydefwindow_extent(self):sampling=(self.extent[0]/self.downsampled_gpts[0],self.extent[1]/self.downsampled_gpts[1],)return(self.window_gpts[0]*sampling[0],self.window_gpts[1]*sampling[1],)# @staticmethod# def _wrapped_build_s_matrix(*args, s_matrix_partial):# s_matrix = s_matrix_partial(*tuple(arg.item() for arg in args[:-1]))## wave_vector_range = slice(*np.squeeze(args[-1]))# array = s_matrix._build_s_matrix(wave_vector_range).array# return array## def _s_matrix_partial(self):# def s_matrix(*args, potential_partial, **kwargs):# if potential_partial is not None:# potential = potential_partial(*args + (np.array([None], dtype=object),))# else:# potential = None# return SMatrix(potential=potential, **kwargs)## potential_partial = (# self.potential._from_partitioned_args()# if self.potential is not None# else None# )# return partial(# s_matrix,# potential_partial=potential_partial,# **self._copy_kwargs(exclude=("potential",)),# )
[docs]defmultislice(self,potential=None,lazy:bool=None,max_batch:int|str="auto",):""" Parameters ---------- potential lazy : bool, optional If True, create the wave functions lazily, otherwise, calculate instantly. If not given, defaults to the setting in the user configuration file. max_batch : int or str, optional The number of expansion plane waves in each run of the multislice algorithm. Returns ------- """s_matrix=self.__class__(potential=potential,**self._copy_kwargs(exclude=("potential",)))returns_matrix.build(lazy=lazy,max_batch=max_batch)
[docs]defbuild(self,lazy:bool=None,max_batch:int|str="auto",bound:bool=None)->SMatrixArray:""" Build the plane waves of the scattering matrix and propagate them through the potential using the multislice algorithm. Parameters ---------- lazy : bool, optional If True, create the wave functions lazily, otherwise, calculate instantly. If not given, defaults to the setting in the user configuration file. max_batch : int or str, optional The number of expansion plane waves in each run of the multislice algorithm. Returns ------- s_matrix_array : SMatrixArray The built scattering matrix. """lazy=_validate_lazy(lazy)downsampled_gpts=self.downsampled_gptss_matrix_blocks=self.ensemble_blocks(1)xp=get_array_module(self.device)wave_vector_chunks=self._wave_vector_chunks(max_batch)iflazy:wave_vector_blocks=self._wave_vector_blocks(wave_vector_chunks,lazy=False)wave_vector_blocks=np.tile(wave_vector_blocks[None],(len(s_matrix_blocks),1))wave_vector_blocks=da.from_array(wave_vector_blocks,chunks=1)fromdask.graph_manipulationimportbindifboundisnotNone:wave_vector_blocks=bind(wave_vector_blocks,bound)adjust_chunks={1:wave_vector_chunks[0],2:(downsampled_gpts[0],),3:(downsampled_gpts[1],),}symbols=(0,1,2,3)ifself.potentialisNoneornotself.potential.ensemble_shape:symbols=symbols[1:]array=da.blockwise(self._build_s_matrix,symbols,s_matrix_blocks,(0,),wave_vector_blocks[...,None,None],(0,1,2,3),concatenate=True,adjust_chunks=adjust_chunks,meta=xp.array((),dtype=np.complex64),)else:wave_vector_blocks=self._wave_vector_blocks(wave_vector_chunks,lazy=False)ifself.store_on_host:array=np.zeros(self.ensemble_shape+(len(self),)+self.downsampled_gpts,dtype=np.complex64,)else:array=xp.zeros(self.ensemble_shape+(len(self),)+self.downsampled_gpts,dtype=np.complex64,)fori,_,s_matrixinself.generate_blocks(1):s_matrix=s_matrix.item()forstart,stopinwave_vector_blocks:items=(slice(start,stop),)ifself.ensemble_shape:items=i+itemsnew_array=self._build_s_matrix(s_matrix,slice(start,stop))ifself.store_on_host:new_array=xp.asnumpy(new_array)array[items]=new_arraywaves=Waves(array,energy=self.energy,extent=self.extent,ensemble_axes_metadata=self.ensemble_axes_metadata+self.base_axes_metadata[:1],)ifself.downsampled_gpts!=self.gpts:waves.metadata["adjusted_antialias_cutoff_gpts"]=_antialias_cutoff_gpts(self.window_gpts,self.sampling)s_matrix_array=SMatrixArray._from_waves(waves,wave_vectors=self.wave_vectors,interpolation=self.interpolation,semiangle_cutoff=self.semiangle_cutoff,window_gpts=self.window_gpts,device=self.device,)returns_matrix_array
[docs]defscan(self,scan:np.ndarray|BaseScan=None,detectors:BaseDetector|list[BaseDetector]=None,ctf:CTF|dict=None,max_batch_multislice:str|int="auto",max_batch_reduction:str|int="auto",reduction_scheme:str="auto",disable_s_matrix_chunks:bool="auto",lazy:bool=None,)->BaseMeasurements|Waves|list[BaseMeasurements|Waves]:""" Run the multislice algorithm, then reduce the SMatrix using coefficients calculated by a BaseScan and a CTF, to obtain the exit wave functions at given initial probe positions and aberrations. Parameters ---------- scan : BaseScan Positions of the probe wave functions. If not given, scans across the entire potential at Nyquist sampling. detectors : BaseDetector, list of BaseDetector, optional A detector or a list of detectors defining how the wave functions should be converted to measurements after running the multislice algorithm. See abtem.measurements.detect for a list of implemented detectors. ctf : CTF Contrast transfer function from used for calculating the expansion coefficients in the reduction of the SMatrix. max_batch_multislice : int, optional The number of wave functions in each chunk of the Dask array. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". max_batch_reduction : int or str, optional Number of positions per reduction operation. A large number of positions better utilize thread parallelization, but requires more memory and floating point operations. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". reduction_scheme : str or tuple of int, optional Parallel reduction of the SMatrix requires rechunking the Dask array from chunking along the expansion axis to chunking over the spatial axes. If given as a tuple of int of length the SMatrix is rechunked to have those chunks. If 'auto' (default) the chunks are taken to be identical to the interpolation factor. disable_s_matrix_chunks : bool, optional If True, each S-Matrix is kept as a single chunk, thus lowering the communication overhead, but providing fewer opportunities for parallelization. lazy : bool, optional If True, create the measurements lazily, otherwise, calculate instantly. If None, this defaults to the value set in the configuration file. Returns ------- detected_waves : BaseMeasurements or list of BaseMeasurement The detected measurement (if detector(s) given). exit_waves : Waves Wave functions at the exit plane(s) of the potential (if no detector(s) given). """ifscanisNone:scan=GridScan(start=(0,0),end=self.extent,sampling=self.dummy_probes().aperture.nyquist_sampling,)ifdetectorsisNone:detectors=FlexibleAnnularDetector()returnself.reduce(scan=scan,detectors=detectors,max_batch_reduction=max_batch_reduction,max_batch_multislice=max_batch_multislice,ctf=ctf,reduction_scheme=reduction_scheme,disable_s_matrix_chunks=disable_s_matrix_chunks,lazy=lazy,)
[docs]defreduce(self,scan:np.ndarray|BaseScan=None,detectors:BaseDetector|list[BaseDetector]=None,ctf:CTF|dict=None,reduction_scheme:str="auto",max_batch_multislice:str|int="auto",max_batch_reduction:str|int="auto",disable_s_matrix_chunks:bool="auto",lazy:bool=None,)->BaseMeasurements|Waves|list[BaseMeasurements|Waves]:""" Run the multislice algorithm, then reduce the SMatrix using coefficients calculated by a BaseScan and a CTF, to obtain the exit wave functions at given initial probe positions and aberrations. Parameters ---------- scan : BaseScan Positions of the probe wave functions. If not given, scans across the entire potential at Nyquist sampling. detectors : BaseDetector, list of BaseDetector, optional A detector or a list of detectors defining how the wave functions should be converted to measurements after running the multislice algorithm. See abtem.measurements.detect for a list of implemented detectors. ctf : CTF Contrast transfer function from used for calculating the expansion coefficients in the reduction of the SMatrix. max_batch_multislice : int, optional The number of wave functions in each chunk of the Dask array. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". max_batch_reduction : int or str, optional Number of positions per reduction operation. A large number of positions better utilize thread parallelization, but requires more memory and floating point operations. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". reduction_scheme : str, optional Parallel reduction of the SMatrix requires rechunking the Dask array from chunking along the expansion axis to chunking over the spatial axes. If given as a tuple of int of length the SMatrix is rechunked to have those chunks. If 'auto' (default) the chunks are taken to be identical to the interpolation factor. disable_s_matrix_chunks : bool, optional If True, each S-Matrix is kept as a single chunk, thus lowering the communication overhead, but providing fewer opportunities for parallelization. lazy : bool, optional If True, create the measurements lazily, otherwise, calculate instantly. If None, this defaults to the value set in the configuration file. Returns ------- measurements : BaseMeasurements or Waves or list of BaseMeasurements or list of Waves The detected measurement (if detector(s) given). """detectors=_validate_detectors(detectors)ifscanisNone:scan=(self.extent[0]/2,self.extent[1]/2)lazy=_validate_lazy(lazy)ifself.device=="gpu"anddisable_s_matrix_chunks=="auto":disable_s_matrix_chunks=Trueelifdisable_s_matrix_chunks=="auto":disable_s_matrix_chunks=Falseifnotlazy:scan=_validate_scan(scan,self)measurements=self._eager_build_s_matrix_detect(scan,ctf,detectors,squeeze=True)return_wrap_measurements(measurements)ifdisable_s_matrix_chunks:scan=_validate_scan(scan,self)blocks=self.ensemble_blocks(1)chunks=blocks.chunks+scan.shapenew_axis=tuple_range(offset=len(blocks.shape),length=len(scan.shape))drop_axis=()iflen(self.ensemble_shape)==0:drop_axis=(0,)chunks=chunks[1:]new_axis=tuple_range(offset=len(blocks.shape)-1,length=len(scan.shape))arrays=blocks.map_blocks(self._lazy_build_s_matrix_detect,drop_axis=drop_axis,new_axis=new_axis,chunks=chunks,scan=scan,ctf=ctf,detectors=detectors,meta=np.array((),dtype=object),)waves=self.build(lazy=True).dummy_probes(scan=scan)extra_axes_metadata=[]ifself.potentialisnotNone:extra_axes_metadata=self.potential.ensemble_axes_metadatameasurements=_finalize_lazy_measurements(arrays,waves,detectors,extra_axes_metadata)return_wrap_measurements(measurements)s_matrix_array=self.build(max_batch=max_batch_multislice,lazy=lazy)returns_matrix_array.reduce(scan=scan,detectors=detectors,reduction_scheme=reduction_scheme,max_batch_reduction=max_batch_reduction,ctf=ctf,)