Skip to content

Commit

Permalink
Implement simpler logic for edit predictions prompt byte limits
Browse files Browse the repository at this point in the history
Happily this could be done by copy-modifying some of the code from #23814
  • Loading branch information
mgsloan committed Jan 30, 2025
1 parent 36a4732 commit 77cb5cd
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 43 deletions.
9 changes: 6 additions & 3 deletions crates/collab/src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ use util::ResultExt;

pub use token::*;

const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);

/// Output token limit. A copy of this constant is also in `crates/zeta/src/zeta.rs`.
const MAX_OUTPUT_TOKENS: usize = 2048;

pub struct LlmState {
pub config: Config,
pub executor: Executor,
Expand All @@ -52,8 +57,6 @@ pub struct LlmState {
RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
}

const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);

impl LlmState {
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
let database_url = config
Expand Down Expand Up @@ -488,7 +491,7 @@ async fn predict_edits(
fireworks::CompletionRequest {
model: model.to_string(),
prompt: prompt.clone(),
max_tokens: 2048,
max_tokens: MAX_OUTPUT_TOKENS,
temperature: 0.,
prediction: Some(fireworks::Prediction::Content {
content: params.input_excerpt.clone(),
Expand Down
155 changes: 115 additions & 40 deletions crates/zeta/src/zeta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,16 @@ const ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY: &'static str =
/// intentionally low to err on the side of underestimating limits.
const BYTES_PER_TOKEN_GUESS: usize = 3;

/// This is based on the output token limit `max_tokens: 2048` in `crates/collab/src/llm.rs`. Number
/// of output tokens is relevant to the size of the input excerpt because the model is tasked with
/// outputting a modified excerpt. `2/3` is chosen so that there are some output tokens remaining
/// for the model to specify insertions.
const BUFFER_EXCERPT_BYTE_LIMIT: usize = (2048 * 2 / 3) * BYTES_PER_TOKEN_GUESS;
/// Output token limit, used to inform the size of the input. A copy of this constant is also in
/// `crates/collab/src/llm.rs`.
const MAX_OUTPUT_TOKENS: usize = 2048;

/// Total bytes limit for editable region of buffer excerpt.
///
/// The number of output tokens is relevant to the size of the input excerpt because the model is
/// tasked with outputting a modified excerpt. `2/3` is chosen so that there are some output tokens
/// remaining for the model to specify insertions.
const BUFFER_EXCERPT_BYTE_LIMIT: usize = (MAX_OUTPUT_TOKENS * 2 / 3) * BYTES_PER_TOKEN_GUESS;

/// Note that this is not the limit for the overall prompt, just for the inputs to the template
/// instantiated in `crates/collab/src/llm.rs`.
Expand Down Expand Up @@ -345,9 +350,8 @@ impl Zeta {
let snapshot = self.report_changes_for_buffer(&buffer, cx);
let cursor_point = cursor.to_point(&snapshot);
let cursor_offset = cursor_point.to_offset(&snapshot);
let excerpt_range = excerpt_range_for_position(cursor_point, &snapshot);
let events = self.events.clone();
let path = snapshot
let path: Arc<Path> = snapshot
.file()
.map(|f| Arc::from(f.full_path(cx).as_path()))
.unwrap_or_else(|| Arc::from(Path::new("untitled")));
Expand All @@ -360,24 +364,38 @@ impl Zeta {
cx.spawn(|_, cx| async move {
let request_sent_at = Instant::now();

let (input_events, input_excerpt, input_outline) = cx
let (input_events, input_excerpt, excerpt_range, input_outline) = cx
.background_executor()
.spawn({
let snapshot = snapshot.clone();
let excerpt_range = excerpt_range.clone();
let path = path.clone();
async move {
let input_excerpt =
prompt_for_excerpt(&snapshot, &excerpt_range, cursor_offset);

let bytes_remaining = todo!();
let path = path.to_string_lossy();
let (excerpt_range, excerpt_len_guess) = excerpt_range_for_position(
cursor_point,
BUFFER_EXCERPT_BYTE_LIMIT,
&path,
&snapshot,
)?;
let input_excerpt = prompt_for_excerpt(
cursor_offset,
&excerpt_range,
excerpt_len_guess,
&path,
&snapshot,
);

let bytes_remaining = TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.len());
let input_events = prompt_for_events(events.iter(), bytes_remaining);

// Note that input_outline is not currently used in prompt generation and so
// is not counted towards TOTAL_BYTE_LIMIT.
let input_outline = prompt_for_outline(&snapshot);

(input_events, input_excerpt, input_outline)
anyhow::Ok((input_events, input_excerpt, excerpt_range, input_outline))
}
})
.await;
.await?;

log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);

Expand Down Expand Up @@ -999,21 +1017,14 @@ fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
}

fn prompt_for_excerpt(
snapshot: &BufferSnapshot,
excerpt_range: &Range<usize>,
offset: usize,
excerpt_range: &Range<usize>,
len_guess: usize,
path: &str,
snapshot: &BufferSnapshot,
) -> String {
let mut prompt_excerpt = String::new();
writeln!(
prompt_excerpt,
"```{}",
snapshot
.file()
.map_or(Cow::Borrowed("untitled"), |file| file
.path()
.to_string_lossy())
)
.unwrap();
let mut prompt_excerpt = String::with_capacity(len_guess);
writeln!(prompt_excerpt, "```{}", path).unwrap();

if excerpt_range.start == 0 {
writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
Expand Down Expand Up @@ -1050,25 +1061,89 @@ fn prompt_for_excerpt(
}

write!(prompt_excerpt, "\n```").unwrap();
debug_assert!(
prompt_excerpt.len() <= len_guess,
"Excerpt length {} exceeds estimated length {}",
prompt_excerpt.len(),
len_guess
);
prompt_excerpt
}

fn excerpt_range_for_position(point: Point, snapshot: &BufferSnapshot) -> Range<usize> {
const CONTEXT_LINES: u32 = 32;

let mut context_lines_before = CONTEXT_LINES;
let mut context_lines_after = CONTEXT_LINES;
if point.row < CONTEXT_LINES {
context_lines_after += CONTEXT_LINES - point.row;
} else if point.row + CONTEXT_LINES > snapshot.max_point().row {
context_lines_before += (point.row + CONTEXT_LINES) - snapshot.max_point().row;
fn excerpt_range_for_position(
cursor_point: Point,
byte_limit: usize,
path: &str,
snapshot: &BufferSnapshot,
) -> Result<(Range<usize>, usize)> {
let cursor_row = cursor_point.row;
let last_buffer_row = snapshot.max_point().row;

// This is an overestimate because it includes parts of prompt_for_excerpt which are
// conditionally skipped.
let mut len_guess = 0;
len_guess += "```".len() + path.len() + 1;
len_guess += START_OF_FILE_MARKER.len() + 1;
len_guess += EDITABLE_REGION_START_MARKER.len() + 1;
len_guess += CURSOR_MARKER.len();
len_guess += EDITABLE_REGION_END_MARKER.len() + 1;
len_guess += "```".len() + 1;

len_guess += usize::try_from(snapshot.line_len(cursor_row)).unwrap();

if len_guess > byte_limit {
return Err(anyhow!("Current line too long to send to model."));
}

let mut excerpt_start_row = cursor_row;
let mut excerpt_end_row = cursor_row;
let mut no_more_before = cursor_row == 0;
let mut no_more_after = cursor_row >= last_buffer_row;
let mut row_delta = 1;
loop {
if !no_more_before {
let row = cursor_point.row - row_delta;
let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap();
let mut new_len_guess = len_guess + line_len;
if row == 0 {
new_len_guess += START_OF_FILE_MARKER.len() + 1;
}
if new_len_guess <= byte_limit {
len_guess = new_len_guess;
excerpt_start_row = row;
if row == 0 {
no_more_before = true;
}
} else {
no_more_before = true;
}
}
if !no_more_after {
let row = cursor_point.row + row_delta;
let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap();
let new_len_guess = len_guess + line_len;
if new_len_guess <= byte_limit {
len_guess = new_len_guess;
excerpt_end_row = row;
if row >= last_buffer_row {
no_more_after = true;
}
} else {
no_more_after = true;
}
}
if no_more_before && no_more_after {
break;
}
row_delta += 1;
}

let excerpt_start_row = point.row.saturating_sub(context_lines_before);
let excerpt_start = Point::new(excerpt_start_row, 0);
let excerpt_end_row = cmp::min(point.row + context_lines_after, snapshot.max_point().row);
let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot)
Ok((
excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot),
len_guess,
))
}

fn prompt_for_events<'a>(
Expand Down

0 comments on commit 77cb5cd

Please sign in to comment.