Skip to content
Snippets Groups Projects
Commit 20f612f8 authored by Kashif Rasul's avatar Kashif Rasul
Browse files

initial multistep ancestral sampling

parent c2b7567a
Branches main
No related tags found
No related merge requests found
......@@ -113,8 +113,7 @@ class LagLlamaLightningModule(LightningModule):
def forward(self, *args, **kwargs):
"""
Self-speculative decoding for continuous distributions where model predicts
distribution parameters for multiple future steps.
multi-step ancestral sampling, for each time step predict the next n_predictions - 1 steps and then continue for time step n_predictions till prediction_length
"""
past_target = kwargs["past_target"]
past_observed_values = kwargs["past_observed_values"]
......@@ -139,7 +138,8 @@ class LagLlamaLightningModule(LightningModule):
while cur_pos < self.prediction_length:
remaining_len = self.prediction_length - cur_pos
steps_to_predict = min(self.n_predictions - 1, remaining_len)
# Get multi-step predictions from the model
params, loc, scale = self.model(
past_time_feat=repeated_past_time_feat,
......@@ -151,86 +151,161 @@ class LagLlamaLightningModule(LightningModule):
# Sample proposed values from each distribution
proposed_samples = []
for i in range(min(self.n_predictions -1, remaining_len)):
for i in range(steps_to_predict):
# Get distribution for this step using parameters from the i-th prediction head
sliced_params = [p[:, -1:] for p in params[i]]
distr = self.model.distr_output.distribution(sliced_params, loc, scale)
sample = distr.sample()
proposed_samples.append(sample)
# Concatenate sampled steps
proposed_samples = torch.cat(proposed_samples, dim=1)
# verify the proposed samples by passing them to the model in parallel:
proposed_params, proposed_loc, proposed_scale = self.model(
past_time_feat=repeated_past_time_feat,
future_time_feat=repeated_future_time_feat[..., :cur_pos + self.n_predictions - 1, :],
past_target=repeated_past_target,
past_observed_values=repeated_past_observed_values,
future_target=proposed_samples,
)
# get the last "self.n_predictions - 1" parameters from proposed_params[0]
proposed_sliced_params = [p[:, -self.n_predictions + 1:] for p in proposed_params[0]]
proposed_distr = self.model.distr_output.distribution(proposed_sliced_params, proposed_loc, proposed_scale)
proposed_nll = - proposed_distr.log_prob(proposed_samples)
import pdb; pdb.set_trace()
# TODO The rest of the code is not correct yet
# Verify proposals using base model
accepted_samples = []
for i, proposal in enumerate(proposed_samples):
# Add proposal to sequence temporarily
test_target = torch.cat([repeated_past_target,
torch.cat(accepted_samples + [proposal], dim=1)], dim=1)
# Get base model prediction for this position
base_params, base_loc, base_scale = self.model(
past_time_feat=repeated_past_time_feat,
future_time_feat=repeated_future_time_feat[..., :cur_pos + i + 1, :],
past_target=test_target,
past_observed_values=repeated_past_observed_values,
is_test=False
)
# Get distribution from base model (first head/distribution)
base_distr = self.model.distr_output.distribution(base_params[0], base_loc, base_scale)
# Calculate log probability of proposal under base distribution
log_prob = base_distr.log_prob(proposal)
# Accept if log probability is above threshold
if torch.all(log_prob > -5.0): # Threshold can be tuned
accepted_samples.append(proposal)
else:
break
if len(accepted_samples) == 0:
# If no proposals accepted, sample one step from base distribution
distr = self.model.distr_output.distribution(params[0], loc, scale)
sample = distr.sample()
future_samples.append(sample)
repeated_past_target = torch.cat([repeated_past_target, sample], dim=1)
cur_pos += 1
else:
# Add all accepted proposals
future_samples.extend(accepted_samples)
repeated_past_target = torch.cat([repeated_past_target] + accepted_samples, dim=1)
cur_pos += len(accepted_samples)
# Append to future_samples
future_samples.append(proposed_samples)
# Update past_target and past_observed_values
# Assuming the target dimension is at dim=1
repeated_past_target = torch.cat(
[repeated_past_target, proposed_samples], dim=1
)
repeated_past_observed_values = torch.cat(
(repeated_past_observed_values, torch.ones_like(repeated_past_target[:, -len(accepted_samples) or -1:])),
[repeated_past_observed_values, torch.ones_like(proposed_samples)],
dim=1
)
# Update current position
cur_pos += steps_to_predict
self.model.reset_cache()
# Concatenate and reshape samples
concat_future_samples = torch.cat(future_samples, dim=1)
concat_future_samples = torch.cat(future_samples, dim=1)[:, :self.prediction_length]
return concat_future_samples.reshape(
(-1, self.model.num_parallel_samples, self.prediction_length)
+ self.model.distr_output.event_shape
)
# def forward(self, *args, **kwargs):
# """
# Self-speculative decoding for continuous distributions where model predicts
# distribution parameters for multiple future steps.
# """
# past_target = kwargs["past_target"]
# past_observed_values = kwargs["past_observed_values"]
# past_time_feat = kwargs["past_time_feat"]
# future_time_feat = kwargs["future_time_feat"]
# repeated_past_target = past_target.repeat_interleave(
# self.model.num_parallel_samples, 0
# )
# repeated_past_observed_values = past_observed_values.repeat_interleave(
# self.model.num_parallel_samples, 0
# )
# repeated_past_time_feat = past_time_feat.repeat_interleave(
# self.model.num_parallel_samples, 0
# )
# repeated_future_time_feat = future_time_feat.repeat_interleave(
# self.model.num_parallel_samples, 0
# )[:, :self.prediction_length]
# future_samples = []
# cur_pos = 0
# while cur_pos < self.prediction_length:
# remaining_len = self.prediction_length - cur_pos
# # Get multi-step predictions from the model
# params, loc, scale = self.model(
# past_time_feat=repeated_past_time_feat,
# future_time_feat=repeated_future_time_feat[..., :cur_pos + 1, :],
# past_target=repeated_past_target,
# past_observed_values=repeated_past_observed_values,
# is_test=False
# )
# # Sample proposed values from each distribution
# proposed_samples = []
# for i in range(min(self.n_predictions -1, remaining_len)):
# # Get distribution for this step using parameters from the i-th prediction head
# sliced_params = [p[:, -1:] for p in params[i]]
# distr = self.model.distr_output.distribution(sliced_params, loc, scale)
# sample = distr.sample()
# proposed_samples.append(sample)
# proposed_samples = torch.cat(proposed_samples, dim=1)
# # verify the proposed samples by passing them to the model in parallel:
# proposed_params, proposed_loc, proposed_scale = self.model(
# past_time_feat=repeated_past_time_feat,
# future_time_feat=repeated_future_time_feat[..., :cur_pos + self.n_predictions - 1, :],
# past_target=repeated_past_target,
# past_observed_values=repeated_past_observed_values,
# future_target=proposed_samples,
# )
# # get the last "self.n_predictions - 1" parameters from proposed_params[0]
# proposed_sliced_params = [p[:, -self.n_predictions + 1:] for p in proposed_params[0]]
# proposed_distr = self.model.distr_output.distribution(proposed_sliced_params, proposed_loc, proposed_scale)
# proposed_nll = - proposed_distr.log_prob(proposed_samples)
# import pdb; pdb.set_trace()
# # TODO The rest of the code is not correct yet
# # Verify proposals using base model
# accepted_samples = []
# for i, proposal in enumerate(proposed_samples):
# # Add proposal to sequence temporarily
# test_target = torch.cat([repeated_past_target,
# torch.cat(accepted_samples + [proposal], dim=1)], dim=1)
# # Get base model prediction for this position
# base_params, base_loc, base_scale = self.model(
# past_time_feat=repeated_past_time_feat,
# future_time_feat=repeated_future_time_feat[..., :cur_pos + i + 1, :],
# past_target=test_target,
# past_observed_values=repeated_past_observed_values,
# is_test=False
# )
# # Get distribution from base model (first head/distribution)
# base_distr = self.model.distr_output.distribution(base_params[0], base_loc, base_scale)
# # Calculate log probability of proposal under base distribution
# log_prob = base_distr.log_prob(proposal)
# # Accept if log probability is above threshold
# if torch.all(log_prob > -5.0): # Threshold can be tuned
# accepted_samples.append(proposal)
# else:
# break
# if len(accepted_samples) == 0:
# # If no proposals accepted, sample one step from base distribution
# distr = self.model.distr_output.distribution(params[0], loc, scale)
# sample = distr.sample()
# future_samples.append(sample)
# repeated_past_target = torch.cat([repeated_past_target, sample], dim=1)
# cur_pos += 1
# else:
# # Add all accepted proposals
# future_samples.extend(accepted_samples)
# repeated_past_target = torch.cat([repeated_past_target] + accepted_samples, dim=1)
# cur_pos += len(accepted_samples)
# repeated_past_observed_values = torch.cat(
# (repeated_past_observed_values, torch.ones_like(repeated_past_target[:, -len(accepted_samples) or -1:])),
# dim=1
# )
# self.model.reset_cache()
# # Concatenate and reshape samples
# concat_future_samples = torch.cat(future_samples, dim=1)
# return concat_future_samples.reshape(
# (-1, self.model.num_parallel_samples, self.prediction_length)
# + self.model.distr_output.event_shape
# )
# train matryoshka loss
def _compute_loss(self, batch):
past_target = batch["past_target"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment