from keras.src import ops
from keras.src.losses.loss import Loss
from keras.src.saving import serialization_lib
from keras.src.backend.config import backend
from .ops import hard_lowpass, circular_hard_lowpass, squeeze_or_expand_to_same_rank
from keras_fft import fft, fft2, ifft, ifft2
from keras import KerasTensor
from functools import partial
[docs]
class FrequencyLossFunctionWrapper1D(Loss):
"""
Base class FrequencyLossFunctionWrapper1D to implement new loss functions with the option to apply a frequency filter to the ground truth.
This frequency filter helps the model to focus on the relevant frequency range without the need to, e.g., remove HF noise in additional preprocessing steps.
There are two lowpass filters to choose from, the `"static"` and the `"adaptive"` lowpass.
The `"static"` lowpass defines a global cut-off frequency at `f_filter`.
The `"adaptive"` lowpass analyzes the ground truth data, extracts the peak frequency, and sets a dynamic cut-off frequency at for each sample.
The parameter `f_filter` becomes a multiplier for the peak frequency, after which the frequency components are suppressed.
The definition of a lowpass {`"static"`, `"adaptive"`} requires a frequency range `f`.
It enables an additional step in the loss calculation, where
(1) the ground truth is transformed via 1-D FFT,
(2) a hard binary lowpass filter is applied to the Fourier spectrum to set all frequencies `f>f_filter` to (0+0j),
(3) the filtered ground truth is transformed back to its initial space.
If `lowpass==None`, the FFT calculation is skipped, and no `f` is required.
This class inherits from keras.losses.Loss and can thus be used directly in keras.Model.compile()
Parameters
----------
fn : callable
Definition of the loss function.
The function has to accept two tensors (`y_true` and `y_pred`) and return a float.
lowpass : str, optional {`None`, `"static"`, `"adaptive"`}
Lowpass filter that is applied to the ground truth in order to suppress the higher frequency range `f>f_filter`.
Defaults to `None`.
f : KerasTensor, optional
Frequency range for the data.
Is required once a lowpass {`"static"`, `"adaptive"`} is used.
Defaults to `None`.
f_filter : float, optional
Threshold for the lowpass filter.
With the static lowpass, the ground truth spectrum is set to 0+j0 for `f>f_filter`.
With the adaptive lowpass, the ground truth spectrum is set to 0+j0 for `f>f_filter*f_p`,
where `f_p` is the peak frequency that is automatically derived from the ground truth spectrum.
Defaults to 6.0.
f_min : float, optional
Cap for the lowest peak frequency for cases when the automatic estimation of the peak frequency fails (estimated `f_p<0` or `f_p` is Nan).
Defaults to 0.0.
p : float, optional
Exponent to weigh the spectrum towards the peak frequency (for the estimation of the peak frequency), c.f.
Mansard & Funke, "On the fitting of parametric models to measured wave spectra" (1988), and
Sobey & Young, "Hurricane Wind Waves---A discrete spectral model" (1986), https://ascelibrary.org/doi/10.1061/%28ASCE%290733-950X%281986%29112%3A3%28370%29.
Defaults to 7.0.
decay_start : int, optional
Epoch from which on the lowpass filter is linearly decreased from 0 to `f_filter`.
Defaults to 0.
**Requires `UseLossLowpassDecay` callback to work, cf. Notes**
decay_epochs : int, optional
Number of epochs over which the lowpass filter is linearly decreased from 0 to `f_filter`.
Defaults to 50.
**Requires `UseLossLowpassDecay` callback to work, cf. Notes**
data_format : str, optional {`"channels_last"`, `"channels_first"`}
The ordering of the dimensions in the inputs:
`"channels_last"` corresponds to inputs with shape `(batch_size, *dims, channels)`,
`"channels_first"` corresponds to inputs with shape `(batch_size, channels, *dims)`.
Defaults to `"channels_last"`.
reduction : str, optional {`"sum_over_batch_size"`, `None`, `"auto"`, `"sum"`}
Type of reduction to apply to the loss.
In almost all cases this should be `"sum_over_batch_size"`.
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
name : str, optional
Name of the loss function. The name is inhereted from class name if `name=None`.
Defaults to `None`.
**kwargs
Additional keyword arguments for `fn`.
Notes
-----
Both the `"adaptive"` and `"static"` lowpass filter can be linearly increased from 0 to `f_filter` over `decay_epoch` epochs, starting at epoch `decay_start`.
For this to work, the training has to be conducted using the `UseLossLowpassDecay` callback, which sets the class variable `self.epoch` to the current training epoch.
See examples of SSP1D and SSP2D for a MWE.
"""
def __init__(
self,
fn,
lowpass=None,
f=None,
f_filter=6.0,
f_min=0.0,
p=7.0,
decay_start=0,
decay_epochs=50,
data_format="channels_last",
reduction="sum_over_batch_size",
name=None,
**kwargs
):
super().__init__(name=name, reduction=reduction, dtype="float32")
self.data_format = data_format.lower()
if self.data_format not in ["channels_first", "channels_last"]:
raise ValueError(f"Unsupported data format {self.data_format}. Please choose from 'channels_first' and 'channel_last'.")
self.lowpass = lowpass
if (self.lowpass not in ['static', 'adaptive']) & (self.lowpass is not None):
raise ValueError(f"Unsupported filter type, choose from {['static', 'adaptive']}")
# For the frequency range calculations (mainly estimation of filter size) only the positive frequency range is required!
if f is not None:
f = ops.convert_to_tensor(f)
else:
f = []
self.f = ops.cast(f, dtype=self.dtype)
self.nx = ops.shape(self.f)[0]
self.f_filter = ops.cast(f_filter, dtype=self.dtype)
self.f_min = ops.cast(f_min, dtype=self.dtype)
self.p = ops.cast(p, dtype=self.dtype)
# decay for loss filter (spectral radius is gradually reduced from self.nx // 2 + 1 (full frequency range) to the spectral radius that corresponds to the calculated lowpass frequency)
self.decay_epochs = decay_epochs
self.decay_start = decay_start
self.epoch = None # has to be initialized via callback (UseLossLowpassDecay callback)
# define callables
self.fn = fn
self._fn_kwargs = kwargs
self.fft: callable = fft
self.ifft: callable = ifft
self.norm: callable = partial(ops.linalg.norm, axis=-1)
self.expand_dims: callable = ops.expand_dims
[docs]
def call(self, y_true, y_pred):
"""
Call method of FrequencyLossFunctionWrapper1D
Parameters
----------
y_true : KerasTensor
Ground truth
y_pred : KerasTensor
Prediction
Returns
-------
loss : KerasTensor
The scalar loss calculated from `y_true` and `y_pred` using `self.fn`.
"""
y_true = ops.convert_to_tensor(y_true) # shape = (batch, x, ch) if data_format == 'channels_last' else (batch, x) or (batch, 1, x)
y_pred = ops.convert_to_tensor(y_pred)
# squeeze along channel axis
y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred, axis=-1 if self.data_format == 'channels_last' else 1) # this adds channel dimension to y_true if there is none
# make channels first for FFT application
y_true = self.transpose_to_channels_first(y_true) # shape = (batch, ch, x) or (batch, x)
y_pred = self.transpose_to_channels_first(y_pred)
if self.lowpass: # i.e. lowpass is not None
"""
apply frequency filter to ground truth data
note: we are working with tuples, since keras has no complex dtype
"""
y_true_real, y_true_imag = self.fft(y_true)
y_true_real, y_true_imag = self.apply_filter(real=y_true_real, imag=y_true_imag)
y_true, _ = self.ifft((y_true_real, y_true_imag))
return self.fn(y_true, y_pred, **self._fn_kwargs)
# === everything that has to do with filtering / Fourier domain ===
[docs]
def estimate_peak_frequency(self, power_spectrum):
"""
Estimate peak frequency from spectrum
Parameters
----------
power_spectrum : KerasTensor
power spectrum of ground truth
Returns
-------
peak_frequency: KerasTensor
a tensor with `shape==(batch_size,)` with peak frequency (capped at `self.f_min`) for each sample
"""
peak_frequency = ops.divide_no_nan(
self.norm(ops.abs(ops.multiply(self.f, power_spectrum))),
self.norm(power_spectrum)
)
return ops.where(
(peak_frequency < self.f_min) | ops.numpy.isnan(peak_frequency),
self.f_min,
peak_frequency
)
[docs]
def get_f_hat(self, real, imag):
"""
Normalize frequency range `self.f` by estimated peak frequency
Parameters
----------
real : KerasTensor
real part of Fourier transform of signal `y`
imag : KerasTensor
imaginary part of Fourier transform of signal `y`
Returns
-------
f_hat : KerasTensor
frequency range normalized by estimated peak frequency
"""
ps = self.magnitude(real=real, imag=imag)
normalized_ps = (ps - ops.min(ps)) / (ops.max(ps) - ops.min(ps))
ps = ops.cast(normalized_ps ** self.p, dtype=real.dtype)
peak_frequency = self.estimate_peak_frequency(power_spectrum=ps)
return ops.divide_no_nan(
ops.expand_dims(self.f, axis=0), # add batch dimension
self.expand_dims(peak_frequency, axis=-1) # add space dimension
)
[docs]
def get_frequency_filter(self, real, imag):
"""
The frequency filter is implemented as a hard binary window, which is multiplied with the Fourier spectrum.
The window requires
(1) the overall length (int), and
(2) the length of the window, which is here given by the index where the filter frequency exceeds the frequency vector
The radius of the filter is calculated based on the frequency grid `f` and `f_filter`.
The closest frequency component to `f_filter` is found by `diff = abs(f - f_filter)`.
The index of the minimum entry in `diff` is used to calculate the radius of the filter.
All calculations are performed on the positive quadrant of `f`.
For the `"adaptive filter"`, the frequency spectrum scaled by the peak frequency is used.
Parameters
----------
real : KerasTensor
real part of Fourier transform of signal `y`
imag : KerasTensor
imaginary part of Fourier transform of signal `y`
Returns
-------
lowpass_filter : KerasTensor
Binary lowpass filter with the same shape as `y`.
"""
nx = self.nx // 2
if self.lowpass == 'adaptive':
# work with the freqency range normalized by the peak frequency
# NOTE this is of shape (b, N)!
f = self.get_f_hat(real=real, imag=imag)
# reduce to positive frequencies
f = f[tuple([slice(None), slice(None, nx)])]
# use only positive half of f
spectral_radius = ops.argmin(ops.abs(f - self.f_filter), axis=-1)
freq_filter = ops.vectorized_map(self.lowpass_fn, spectral_radius)
# apply fft shift to filter
return self.fftshift(freq_filter, axis=-1)
else:
# work with the defaults frequency range
# NOTE this is of shape (N,)
# reduce to positive frequencies
f = self.f[slice(None, nx)]
spectral_radius = ops.argmin(ops.abs(f - self.f_filter))
freq_filter = self.lowpass_fn(spectral_radius=spectral_radius)
# apply fft shift
return self.fftshift(freq_filter)
[docs]
def lowpass_decay(self, spectral_radius):
"""
We want to have a linear decrease in spectral radius over the epochs,
starting from full frequency range to spectral radius.
The linear function is consequently given by
`s = int(spectral_radius + alpha*(s - spectral_radius))`
where `s` is the spectral radius (depending on the training epoch), and
`alpha` is the slope of linear function, i.e., `(epoch - self.decay_start) / self.decay_epochs`
Parameters
----------
spectral_radius : int
Desired radius of the binary window.
Returns
-------
radius : int
Actual radius at `self.epoch`
"""
if self.epoch is None:
# callback to steer decay is not active, just return spectral radius
return spectral_radius
if self.epoch > (self.decay_start + self.decay_epochs):
return spectral_radius
max_radius = self.nx // 2 - 1 # zero indexing, we need 255 --> positive frequencies only
if self.epoch < self.decay_start:
return max_radius
alpha = ops.divide(self.epoch - self.decay_start, self.decay_epochs)
return ops.cast(max_radius + alpha * ops.cast(spectral_radius - max_radius, dtype=alpha.dtype), dtype=spectral_radius.dtype)
[docs]
def apply_filter(self, real, imag):
"""
Apply frequency filter
Parameters
----------
real : KerasTensor
real part of Fourier transform of signal `y`
imag : KerasTensor
imaginary part of Fourier transform of signal `y`
Returns
-------
y_real, y_imag : (KerasTensor, KerasTensor)
Fourier transform of signal `y` with `f>f_filter` set to (0 + 0j)
"""
frequency_filter = self.get_frequency_filter(real, imag)
return ops.multiply(real, frequency_filter), ops.multiply(imag, frequency_filter)
[docs]
def lowpass_fn(self, spectral_radius):
"""
Lowpass function
Parameters
----------
spectral_radius : int
Desired radius of the binary window.
Returns
-------
hard_lowpass : KerasTensor
binary lowpass filter
"""
# adjust spectral radius
spectral_radius = self.lowpass_decay(spectral_radius=spectral_radius)
return hard_lowpass(n=self.nx, spectral_radius=spectral_radius)
[docs]
@staticmethod
def magnitude(real, imag):
"""
Magnitude of signal
Parameters
----------
real : KerasTensor
real part of Fourier transform of signal `y`
imag : KerasTensor
imaginary part of Fourier transform of signal `y`
Returns
-------
magnitude : KerasTensor
magnitude of spectrum
"""
return ops.sqrt(ops.square(real) + ops.square(imag))
# === serialization ===
[docs]
def get_config(self):
"""
Get config of loss function.
Required for model serialization once a model compiled with this loss function is saved with `keras.Model.save()`
and loaded with `keras.models.load_model()`.
Returns
-------
config : dict
Dictionary holding the configuration of the loss function.
"""
base_config: dict = super().get_config()
config = {
"lowpass": self.lowpass,
"f": self.tolist(self.f),
"f_filter": float(self.f_filter),
"f_min": float(self.f_min),
"p": float(self.p),
"decay_epochs": self.decay_epochs,
"decay_start": self.decay_start,
"data_format": str(self.data_format)
}
config.update(serialization_lib.serialize_keras_object(self._fn_kwargs))
return {**base_config, **config}
[docs]
@classmethod
def from_config(cls, config):
"""
Restore loss function instance from config.
Required for model serialization once a model compiled with loss function is saved with `keras.Model.save()`
and loaded with `keras.models.load_model()`.
Parameters
----------
cls : FrequencyLossFunctionWrapper1D
The class itself
config : dict
Dictionary holding the configuration of the loss function obtained from `self.get_config()`
Returns
-------
cls : FrequencyLossFunctionWrapper1D
An instance of FrequencyLossFunctionWrapper1D
Notes
-----
This function is never called by the user but internally by Keras.
So *don't panic*.
"""
f = config.pop("f")
f = ops.array(f)
return cls(f=f, **config)
# === helper routines ===
[docs]
def transpose_to_channels_first(self, inputs):
"""
Transpose input data to data format `"channels_first"`.
The FFT is by default applied along the last dimension of the data.
Therefore, we have to transpose the data from `"channels_last"` (default) to `"channels_first"`
Parameters
----------
inputs : KerasTensor
input tensor to transpose
Returns
-------
transposed_inputs : KerasTensor
input tensor in data format `"channels_first"`
"""
if self.data_format == "channels_first":
return inputs
if not hasattr(self, "transpose_axes"):
# get transpose axes once and cache it
shape = ops.shape(inputs)
if len(shape) == 2: # there is no channel dimension!
return inputs
transpose_axes = list(range(len(shape)))
# move channel_dimension to first position after batch size ('channels_first')
ch_dim = transpose_axes.pop(-1)
transpose_axes.insert(-1, ch_dim) # NOTE this is only for 1d data; -1 inserts at index -2!
self.transpose_axes = transpose_axes
return ops.transpose(inputs, axes=self.transpose_axes)
[docs]
def fftshift(self, x, axis=None):
"""
shifts the FFT spectrum
Parameters
----------
x : KerasTensor
signal
axis : int, optional
axis to shift along.
Defaults to None.
Returns
-------
x_shifted : KerasTensor
fft-shifted version of `x`
"""
return ops.roll(x, shift=self.nx // 2, axis=axis)
[docs]
@staticmethod
def tolist(arr):
"""
Casts a KerasTensor to a list.
Required for serialization of a KerasTensor.
Parameters
----------
arr : KerasTensor
array to be casted to list
Returns
-------
l : list
arr as (nested) list
Notes
-----
Raises `RuntimeError` when `torch` backend is used.
"""
if backend() == "tensorflow":
shape = arr.get_shape()
if len(shape) == 2:
return list([item.numpy().tolist() for item in arr])
return arr.numpy().tolist()
if backend() == "jax":
if arr.ndim == 2:
return list([item.tolist() for item in arr])
return arr.tolist()
raise RuntimeError(f"{backend()} is not supported! SSP currently only works with 'jax' and 'tensorflow' backends.")
[docs]
class FrequencyLossFunctionWrapper2D(FrequencyLossFunctionWrapper1D):
"""
Base class FrequencyLossFunctionWrapper2D to implement new loss functions with the option to apply a frequency filter to the ground truth.
This frequency filter helps the model to focus on the relevant frequency range without the need to, e.g., remove HF noise in additional preprocessing steps.
There are two lowpass filters to choose from, the `"static"` and the `"adaptive"` lowpass.
The `"static"` lowpass defines a global cut-off frequency at `f_filter`.
The `"adaptive"` lowpass analyzes the ground truth data, extracts the peak frequency, and sets a dynamic cut-off frequency at for each sample.
The parameter `f_filter` becomes a multiplier for the peak frequency, after which the frequency components are suppressed.
The definition of a lowpass {`"static"`, `"adaptive"`} requires a frequency range `f`.
It enables an additional step in the loss calculation, where
(1) the ground truth is transformed via 2-D FFT,
(2) a hard binary lowpass filter is applied to the Fourier spectrum to set all frequencies `f>f_filter` to (0+0j),
(3) the filtered ground truth is transformed back to its initial space.
If `lowpass==None`, the FFT calculation is skipped, and no `f` is required.
This class can thus be used directly in keras.Model.compile()
Parameters
----------
fn : callable
Definition of the loss function.
The function has to accept two tensors (`y_true` and `y_pred`) and return a float.
lowpass : str, optional {`None`, `"static"`, `"adaptive"`}
Lowpass filter that is applied to the ground truth in order to suppress the higher frequency range `f>f_filter`.
Defaults to `None`.
f : KerasTensor, optional
Frequency range for the data.
Is required once a lowpass {`"static"`, `"adaptive"`} is used.
A 1-D `f` is automatically casted to a 2-D grid.
Defaults to `None`.
f_filter : float, optional
Threshold for the lowpass filter.
With the static lowpass, the ground truth spectrum is set to 0+j0 for `f>f_filter`.
With the adaptive lowpass, the ground truth spectrum is set to 0+j0 for `f>f_filter*f_p`,
where `f_p` is the peak frequency that is automatically derived from the ground truth spectrum.
Defaults to 6.0.
f_min : float, optional
Cap for the lowest peak frequency for cases when the automatic estimation of the peak frequency fails (estimated `f_p<0` or `f_p` is Nan).
Defaults to 0.0.
p : float, optional
Exponent to weigh the spectrum towards the peak frequency (for the estimation of the peak frequency), c.f.
Mansard & Funke, "On the fitting of parametric models to measured wave spectra" (1988), and
Sobey & Young, "Hurricane Wind Waves---A discrete spectral model" (1986), https://ascelibrary.org/doi/10.1061/%28ASCE%290733-950X%281986%29112%3A3%28370%29.
Defaults to 7.0.
decay_start : int, optional
Epoch from which on the lowpass filter is linearly decreased from 0 to `f_filter`.
Defaults to 0.
**Requires `UseLossLowpassDecay` callback to work, cf. Notes**
decay_epochs : int, optional
Number of epochs over which the lowpass filter is linearly decreased from 0 to `f_filter`.
Defaults to 50.
**Requires `UseLossLowpassDecay` callback to work, cf. Notes**
data_format : str, optional {`"channels_last"`, `"channels_first"`}
The ordering of the dimensions in the inputs:
`"channels_last"` corresponds to inputs with shape `(batch_size, *dims, channels)`,
`"channels_first"` corresponds to inputs with shape `(batch_size, channels, *dims)`.
Defaults to `"channels_last"`.
reduction : str, optional {`"sum_over_batch_size"`, `None`, `"auto"`, `"sum"`}
Type of reduction to apply to the loss.
In almost all cases this should be `"sum_over_batch_size"`.
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
name : str, optional
Name of the loss function. The name is inhereted from class name if `name=None`.
Defaults to `None`.
**kwargs
Additional keyword arguments for `fn`.
Notes
-----
Both the `"adaptive"` and `"static"` lowpass filter can be linearly increased from 0 to `f_filter` over `decay_epoch` epochs, starting at epoch `decay_start`.
For this to work, the training has to be conducted using the `UseLossLowpassDecay` callback, which sets the class variable `self.epoch` to the current training epoch.
See examples of SSP1D and SSP2D for a MWE.
"""
def __init__(
self,
fn,
lowpass=None,
f=None,
f_filter=6.0,
f_min=0.0,
p=7.0,
decay_start=0,
decay_epochs=50,
data_format="channels_last",
reduction="sum_over_batch_size",
name=None,
**kwargs
):
super().__init__(
fn=fn,
lowpass=lowpass,
f=f,
f_filter=f_filter,
f_min=f_min,
p=p,
decay_start=decay_start,
decay_epochs=decay_epochs,
data_format=data_format,
reduction=reduction,
name=name,
**kwargs
)
# overwrite callables for 2D
self.fft: callable = fft2
self.ifft: callable = ifft2
self.norm: callable = partial(ops.linalg.norm, axis=(-2, -1))
self.expand_dims: callable = lambda x, axis: ops.expand_dims(ops.expand_dims(x, axis=axis), axis=axis)
"""
For the frequency range calculations (mainly estimation of filter size) only the positive frequency range is required!
Here, f is a (potentially empty) 1D tensor!
"""
if ops.convert_to_numpy(self.f).ndim == 1:
self.f: KerasTensor = self.magnitude(*ops.meshgrid(self.f, self.f))
self.ny, self.nx = ops.shape(self.f)
# === everything that has to do with filtering / Fourier domain adjusted for 2D ===
[docs]
def get_frequency_filter(self, real, imag):
"""
The frequency filter is implemented as a hard binary window, which is multiplied with the Fourier spectrum.
The window requires
(1) the overall length (int), and
(2) the length of the window, which is here given by the index where the filter frequency exceeds the frequency vector
The radius of the filter is calculated based on the frequency grid `f` and `f_filter`.
The closest frequency component to `f_filter` is found by `diff = abs(f - f_filter)`.
The index of the minimum entry in `diff` is used to calculate the radius of the filter.
All calculations are performed on the positive quadrant of `f`.
For the `"adaptive filter"`, the frequency spectrum scaled by the peak frequency is used.
Parameters
----------
real : KerasTensor
real part of Fourier transform of signal `y`
imag : KerasTensor
imaginary part of Fourier transform of signal `y`
Returns
-------
lowpass_filter : KerasTensor
Binary lowpass filter with the same shape as `y`.
"""
nx = self.nx // 2
ny = self.ny // 2
if self.lowpass == 'adaptive':
# work with the freqency range normalized by the peak frequency
# NOTE this is of shape (b, N, N)!
f = self.get_f_hat(real=real, imag=imag)
# reduce to positive frequencies
f = f[tuple([slice(None), slice(None, ny), slice(None, nx)])]
diff = abs(f - self.f_filter)
# flatten tensor
diff = ops.reshape(diff, newshape=(-1, ny * nx))
min_idx = ops.argmin(diff, axis=-1)
rows = min_idx // ny
cols = min_idx % nx
spectral_radius = ops.cast(ops.round(self.magnitude(rows, cols)), dtype='int32')
freq_filter = ops.vectorized_map(self.lowpass_fn, spectral_radius)
# apply fft shift to filter
return self.fftshift(freq_filter)
else:
# work with the defaults frequency range
# NOTE this is of shape (N,N)
# reduce to positive frequencies
f = self.f[tuple([slice(None, ny), slice(None, nx)])]
min_idx = ops.argmin(ops.abs(f - self.f_filter)) # we are interested in the minimum in x-direction
rows = min_idx // ny
cols = min_idx % nx
spectral_radius = ops.cast(ops.round(self.magnitude(rows, cols)), dtype='int32')
freq_filter = self.lowpass_fn(spectral_radius=spectral_radius)
# apply fft shift
return self.fftshift(freq_filter)
[docs]
def lowpass_decay(self, spectral_radius: int) -> int:
"""
We want to have a linear decrease in spectral radius over the epochs,
starting from full frequency range to spectral radius.
The linear function is consequently given by
s = int(spectral_radius + alpha*(s - spectral_radius))
with
s: spectral radius(epoch)
alpha: slope of linear function, i.e., (epoch - self.decay_start) / self.decay_epochs
Parameters
----------
spectral_radius : int
Desired radius of the binary window.
Returns
-------
radius : int
Actual radius at self.epoch
"""
if self.epoch is None:
# callback to steer decay is not active, just return spectral radius
return spectral_radius
if self.epoch > (self.decay_start + self.decay_epochs):
return spectral_radius
max_radius = ops.cast(ops.round(self.magnitude(self.nx // 2 - 1, self.ny // 2 - 1)), dtype='int32') # zero indexing, we need positive frequencies only
if self.epoch < self.decay_start:
return max_radius
alpha = ops.divide(self.epoch - self.decay_start, self.decay_epochs)
return ops.cast(ops.cast(max_radius, dtype=alpha.dtype) + alpha * ops.cast(spectral_radius - max_radius, dtype=alpha.dtype), dtype=spectral_radius.dtype)
[docs]
def lowpass_fn(self, spectral_radius):
"""
Lowpass function
Parameters
----------
spectral_radius : int
Desired radius of the binary window.
Returns
-------
hard_lowpass : KerasTensor
binary lowpass filter
"""
spectral_radius = self.lowpass_decay(spectral_radius=spectral_radius)
return circular_hard_lowpass(n=self.nx, spectral_radius=spectral_radius)
# === helper routines adjusted for 2D ===
[docs]
def transpose_to_channels_first(self, inputs):
"""
Transpose input data to data format `"channels_first"`.
The FFT is by default applied along the last dimension of the data.
Therefore, we have to transpose the data from `"channels_last"` (default) to `"channels_first"`
Parameters
----------
inputs : KerasTensor
input tensor to transpose
Returns
-------
transposed_inputs : KerasTensor
input tensor in data format `"channels_first"`
"""
if self.data_format == 'channels_first':
return inputs
shape = ops.shape(inputs)
if len(shape) == 3: # there is no channel dimension!
return inputs
transpose_axes = list(range(len(shape)))
# move channel_dimension to first position after batch size ('channels_first')
ch_dim = transpose_axes.pop(-1)
transpose_axes.insert(-2, ch_dim) # NOTE this is only for 2d data
return ops.transpose(inputs, axes=transpose_axes)
[docs]
def fftshift(self, x):
"""
FFT shift
shifts the FFT spectra along the last 2 axes
Parameters
----------
x : KerasTensor
signal
Returns
-------
x_shifted : KerasTensor
fft-shifted version of 'x'
"""
return ops.roll(ops.roll(x, shift=self.ny // 2, axis=-2), shift=self.nx // 2, axis=-1)