-
Notifications
You must be signed in to change notification settings - Fork 0
/
jax_models.py
214 lines (185 loc) · 8.4 KB
/
jax_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
from jax.scipy.linalg import inv, det, svd
import jax.numpy as jnp
from jax import random, jit, lax, random
from sklearn.datasets import make_spd_matrix
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import jax
from functools import partial
import pyqg_jax
from pyqg_jax import qg_model, state, parameterizations, steppers
@jit
def rk4_step_lorenz96(x, F, dt):
f = lambda y: (jnp.roll(y, 1) - jnp.roll(y, -2)) * jnp.roll(y, -1) - y + F
k1 = dt * f(x)
k2 = dt * f(x + k1/2)
k3 = dt * f(x + k2/2)
k4 = dt * f(x + k3)
return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)
@jit
def kuramoto_sivashinsky_step(x, dt, E, E2, Q, f1, f2, f3, g):
v = jnp.fft.fft(x)
Nv = g * jnp.fft.fft(jnp.real(jnp.fft.ifft(v))**2)
a = E2 * v + Q * Nv
Na = g * jnp.fft.fft(jnp.real(jnp.fft.ifft(a))**2)
b = E2 * v + Q * Na
Nb = g * jnp.fft.fft(jnp.real(jnp.fft.ifft(b))**2)
c = E2 * a + Q * (2*Nb - Nv)
Nc = g * jnp.fft.fft(jnp.real(jnp.fft.ifft(c))**2)
v_next = E * v + Nv * f1 + 2 * (Na + Nb) * f2 + Nc * f3
x_next = jnp.real(jnp.fft.ifft(v_next))
return x_next
@jit
def pyqg_step(x, init_state, param_model, stepped_model):
# Reshape x back to the model's q shape
q_shape = init_state.state.model_state.q.shape
q = x.reshape(q_shape)
# Create a new model state with q using state.update
base_state = init_state.state.model_state.update(q=q)
# Wrap the new model state through the parameterization
wrapped_in_param = param_model.initialize_param_state(base_state)
# Initialize the stepper state with the wrapped parameterized state
integrator_state = stepped_model.initialize_stepper_state(wrapped_in_param)
# Perform one step
next_state = stepped_model.step_model(integrator_state)
# Extract q from the next state and flatten it
q_next = next_state.state.model_state.q
x_next = q_next.reshape(-1)
return x_next
# models as classes
class BaseModel:
def __init__(self, dt=0.01):
self.dt = dt
def step(self, x):
raise NotImplementedError("The step method must be implemented by subclasses.")
class Lorenz63(BaseModel):
def __init__(self, dt=0.01, sigma=10.0, rho=28.0, beta=8.0/3):
super().__init__(dt)
self.sigma = sigma
self.rho = rho
self.beta = beta
def step(self, x):
x_dot = self.sigma * (x[1] - x[0])
y_dot = x[0] * (self.rho - x[2]) - x[1]
z_dot = x[0] * x[1] - self.beta * x[2]
return x + self.dt * jnp.array([x_dot, y_dot, z_dot])
class Lorenz96(BaseModel):
def __init__(self, dt=0.01, F=8.0):
super().__init__(dt)
self.F = F
def step(self, x):
return rk4_step_lorenz96(x, self.F, self.dt)
class KuramotoSivashinsky(BaseModel):
def __init__(self, dt=0.25, s=128, l=22, M=16):
super().__init__(dt)
self.s, self.l, self.M = s, l, M # discretization points, domain length, exponential time differencing points (modes)
self.k, self.E, self.E2, self.Q, self.f1, self.f2, self.f3, self.g = self.precompute_constants()
def precompute_constants(self):
k = (2 * jnp.pi / self.l) * jnp.concatenate([jnp.arange(0, self.s//2), jnp.array([0]), jnp.arange(-self.s//2+1, 0)])
L = k**2 - k**4
E = jnp.exp(self.dt*L)
E2 = jnp.exp(self.dt*L/2)
r = jnp.exp(1j * jnp.pi * (jnp.arange(1, self.M+1)-.5) / self.M)
LR = self.dt * jnp.tile(L, (self.M, 1)).T + jnp.tile(r, (self.s, 1))
Q = self.dt * jnp.real(jnp.mean((jnp.exp(LR/2)-1)/LR, axis=1))
f1 = self.dt * jnp.real(jnp.mean((-4-LR+jnp.exp(LR)*(4-3*LR+LR**2))/LR**3, axis=1))
f2 = self.dt * jnp.real(jnp.mean((2+LR+jnp.exp(LR)*(-2+LR))/LR**3, axis=1))
f3 = self.dt * jnp.real(jnp.mean((-4-3*LR-LR**2+jnp.exp(LR)*(4-LR))/LR**3, axis=1))
g = -0.5j * k
return k, E, E2, Q, f1, f2, f3, g
def step(self, x):
return kuramoto_sivashinsky_step(x, self.dt, self.E, self.E2, self.Q, self.f1, self.f2, self.f3, self.g)
class PyQGModel(BaseModel):
def __init__(self, dt=14400.0, nx=64, ny=64):
super().__init__(dt)
self.nx = nx
self.ny = ny
self.dt = dt
# self.precision = precision
self.base_model = qg_model.QGModel(
nx=nx,
ny=ny,
precision = pyqg_jax.state.Precision.DOUBLE,
)
self.param_model = parameterizations.smagorinsky.apply_parameterization(
self.base_model, constant=0.08,
)
self.stepper = steppers.AB3Stepper(dt=dt)
self.stepped_model = steppers.SteppedModel(
self.param_model, self.stepper
)
self.init_state = self.stepped_model.create_initial_state(jax.random.key(0))
def step(self, x):
return pyqg_step(x, self.init_state, self.param_model, self.stepped_model)
@jit
def step_function(carry, input):
key, x, observation_interval, H, Q, R, model_step, counter = carry
n = len(x)
key, subkey = random.split(key)
x_j = model_step(x)
# Add process noise Q only at observation times using a conditional operation
def update_observation():
x_noise = x_j + random.multivariate_normal(key, jnp.zeros(n), Q)
obs_state = jnp.dot(H, x_noise)
# Adjust noise dimension to the number of observed states
obs_noise = random.multivariate_normal(subkey, jnp.zeros(H.shape[0]), R)
return x_noise, obs_state + obs_noise
def no_update():
# Return a vector of NaNs matching the number of observed states
return x_j, jnp.nan * jnp.ones(H.shape[0])
# Conditional update based on the observation interval
x_j, obs = lax.cond(counter % observation_interval == 0,
update_observation,
no_update)
counter += 1
carry = (key, x_j, observation_interval, H, Q, R, model_step, counter)
output = (x_j, obs)
return carry, output
@partial(jit, static_argnums=(1, 2, 7))
def generate_true_states(key, num_steps, n, x0, H, Q, R, model_step, observation_interval):
initial_carry = (key, x0, observation_interval, H, Q, R, model_step, 1)
_, (xs, observations) = lax.scan(step_function, initial_carry, None, length=num_steps-1)
key, subkey = random.split(key)
initial_observation = H @ x0 + random.multivariate_normal(subkey, jnp.zeros(H.shape[0]), R)
xs = jnp.vstack([x0[jnp.newaxis, :], xs])
observations = jnp.vstack([initial_observation[jnp.newaxis, :], observations])
return observations, xs
def visualize_observations(observations):
observation_values = observations.T # Transpose for plotting
cmap = LinearSegmentedColormap.from_list('CustomColormap', [(0, 'blue'), (0.5, 'white'), (1, 'red')])
plt.figure(figsize=(12, 6))
plt.imshow(observation_values, cmap=cmap, aspect='auto', extent=[0, observations.shape[0], 0, observations.shape[1]])
plt.colorbar(label='Observation Value')
plt.xlabel('Time Step')
plt.ylabel('State/Variable Number')
plt.title('Observations Over Time')
plt.show()
def plot_ensemble_mean_and_variance(states, observations, state_index, observation_interval, title_suffix=''):
time_steps = jnp.arange(states.shape[0])
state_mean = jnp.mean(states[:, :, state_index], axis=1)
state_std = jnp.std(states[:, :, state_index], axis=1)
plt.figure(figsize=(12, 8))
plt.plot(time_steps, state_mean, label='State Mean', color='orange')
plt.fill_between(time_steps,
state_mean - 1.96 * state_std,
state_mean + 1.96 * state_std,
color='orange', alpha=0.3, label='95% Confidence Interval')
observed_time_steps = jnp.arange(0, len(observations), observation_interval)
observed_values = observations[observed_time_steps, state_index]
plt.scatter(observed_time_steps, observed_values, label='Observation', color='red', marker='x')
#plt.title(f'State {state_index+1} Ensemble Mean and Variance {title_suffix}')
plt.xlabel('Time Step')
#plt.ylabel(f'State {state_index+1} Value')
plt.legend()
plt.show()
@partial(jit, static_argnums=(0))
def generate_localization_matrix(n, localization_radius):
"""
Generate a localization matrix with given radius
"""
i = jnp.arange(n)[:,None]
j = jnp.arange(n)
min_modulo_distance = jnp.minimum(jnp.abs(i - j), n - jnp.abs(i - j))
r = min_modulo_distance / localization_radius
localization_matrix = jnp.exp(-(r**2))
return localization_matrix