Skip to content

Commit 0c1391a

Browse files
feat: support data generation (#2067)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent fbe1874 commit 0c1391a

14 files changed

Lines changed: 296 additions & 17 deletions

File tree

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/forge_api/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ forge_stream.workspace = true
1212
forge_services.workspace = true
1313
forge_repo.workspace = true
1414
forge_infra.workspace = true
15+
futures.workspace = true
1516

1617

1718

@@ -21,6 +22,7 @@ forge_infra.workspace = true
2122

2223

2324
forge_app.workspace = true
25+
serde_json.workspace = true
2426

2527
[dev-dependencies]
2628

crates/forge_api/src/api.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use forge_app::dto::ToolsOverview;
55
use forge_app::{User, UserUsage};
66
use forge_domain::{AgentId, InitAuth, ModelId};
77
use forge_stream::MpscStream;
8+
use futures::stream::BoxStream;
89
use url::Url;
910

1011
use crate::*;
@@ -211,4 +212,9 @@ pub trait API: Sync + Send {
211212
/// credentials. This is a one-time migration that runs only if the
212213
/// credentials file doesn't exist.
213214
async fn migrate_env_credentials(&self) -> Result<Option<forge_domain::MigrationResult>>;
215+
216+
async fn generate_data(
217+
&self,
218+
data_parameters: DataGenerationParameters,
219+
) -> Result<BoxStream<'static, Result<serde_json::Value, anyhow::Error>>>;
214220
}

crates/forge_api/src/forge_api.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@ use anyhow::Result;
66
use forge_app::dto::ToolsOverview;
77
use forge_app::{
88
AgentProviderResolver, AgentRegistry, AppConfigService, AuthService, CommandInfra,
9-
CommandLoaderService, ContextEngineService, ConversationService, EnvironmentInfra,
10-
EnvironmentService, FileDiscoveryService, ForgeApp, GitApp, GrpcInfra, McpConfigManager,
11-
McpService, ProviderAuthService, ProviderService, Services, User, UserUsage, Walker,
9+
CommandLoaderService, ContextEngineService, ConversationService, DataGenerationApp,
10+
EnvironmentInfra, EnvironmentService, FileDiscoveryService, ForgeApp, GitApp, GrpcInfra,
11+
McpConfigManager, McpService, ProviderAuthService, ProviderService, Services, User, UserUsage,
12+
Walker,
1213
};
1314
use forge_domain::{Agent, InitAuth, LoginInfo, *};
1415
use forge_infra::ForgeInfra;
1516
use forge_repo::ForgeRepo;
1617
use forge_services::ForgeServices;
1718
use forge_stream::MpscStream;
19+
use futures::stream::BoxStream;
1820
use url::Url;
1921

2022
use crate::API;
@@ -364,6 +366,14 @@ impl<
364366
Ok(self.services.migrate_env_credentials().await?)
365367
}
366368

369+
async fn generate_data(
370+
&self,
371+
data_parameters: DataGenerationParameters,
372+
) -> Result<BoxStream<'static, Result<serde_json::Value, anyhow::Error>>> {
373+
let app = DataGenerationApp::new(self.services.clone());
374+
app.execute(data_parameters).await
375+
}
376+
367377
async fn get_default_provider(&self) -> Result<Provider<Url>> {
368378
self.services.get_default_provider().await
369379
}

crates/forge_app/src/data_gen.rs

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
use std::path::PathBuf;
2+
use std::sync::Arc;
3+
4+
use anyhow::{Context as _, Result};
5+
use forge_domain::{
6+
Context, ContextMessage, DataGenerationParameters, ResultStreamExt, Template, ToolDefinition,
7+
};
8+
use futures::StreamExt;
9+
use futures::stream::{self, BoxStream};
10+
use schemars::schema::RootSchema;
11+
use tracing::{debug, info};
12+
13+
use crate::{
14+
AppConfigService, EnvironmentService, FsReadService, ProviderService, Services, TemplateEngine,
15+
};
16+
17+
pub struct DataGenerationApp<A> {
18+
services: Arc<A>,
19+
}
20+
21+
type JsonSchema = String;
22+
type SystemPrompt = String;
23+
type UserPrompt = String;
24+
type Input = Vec<serde_json::Value>;
25+
26+
impl<A: Services> DataGenerationApp<A> {
27+
pub fn new(services: Arc<A>) -> Self {
28+
Self { services }
29+
}
30+
31+
/// Helper function to read a file from a path, resolving it relative to cwd
32+
/// if necessary
33+
async fn read_file(&self, path: PathBuf) -> Result<String> {
34+
let resolved_path = if path.is_absolute() {
35+
path
36+
} else {
37+
let cwd = self.services.get_environment().cwd;
38+
cwd.join(path)
39+
};
40+
41+
let content = self
42+
.services
43+
.read(resolved_path.display().to_string(), None, None)
44+
.await?
45+
.content
46+
.file_content()
47+
.to_owned();
48+
49+
Ok(content)
50+
}
51+
52+
async fn read_file_opt(&self, path: Option<PathBuf>) -> Result<Option<String>> {
53+
match path {
54+
Some(path) => self.read_file(path).await.map(Some),
55+
None => Ok(None),
56+
}
57+
}
58+
59+
async fn load_parameters(
60+
&self,
61+
params: DataGenerationParameters,
62+
) -> Result<(JsonSchema, Option<SystemPrompt>, Option<UserPrompt>, Input)> {
63+
debug!("Loading data generation parameters");
64+
65+
// Read all files in parallel
66+
let (schema, system_prompt, user_prompt, input) = tokio::join!(
67+
self.read_file(params.schema.clone()),
68+
self.read_file_opt(params.system_prompt),
69+
self.read_file_opt(params.user_prompt),
70+
self.read_file(params.input)
71+
);
72+
73+
let input: Vec<serde_json::Value> = input?
74+
.lines()
75+
.map(|text| {
76+
serde_json::from_str(text).with_context(|| "Could not parse the input file")
77+
})
78+
.collect::<Result<Vec<_>>>()?;
79+
80+
debug!("Loaded {} input items", input.len());
81+
82+
Ok((schema?, system_prompt?, user_prompt?, input))
83+
}
84+
85+
pub async fn execute(
86+
&self,
87+
params: DataGenerationParameters,
88+
) -> Result<BoxStream<'static, Result<serde_json::Value>>> {
89+
let concurrency = params.concurrency;
90+
let (schema, system_prompt, user_prompt, input) = self.load_parameters(params).await?;
91+
92+
info!(
93+
"Starting data generation with {} items (concurrency: {})",
94+
input.len(),
95+
concurrency
96+
);
97+
98+
let provider = self.services.get_default_provider().await?;
99+
let model_id = self.services.get_provider_model(Some(&provider.id)).await?;
100+
debug!("Using provider: {}, model: {}", provider.id, model_id);
101+
let schema: RootSchema =
102+
serde_json::from_str(&schema).with_context(|| "Could not parse the JSON schema")?;
103+
let mut context =
104+
Context::default().add_tool(ToolDefinition::new("output").input_schema(schema));
105+
106+
if let Some(content) = system_prompt {
107+
context = context.add_message(ContextMessage::system(content))
108+
}
109+
110+
let services = self.services.clone();
111+
112+
let json_stream = input.into_iter().map(move |input| {
113+
let provider = provider.clone();
114+
let context = context.clone();
115+
let user_prompt = user_prompt.clone();
116+
let model_id = model_id.clone();
117+
let services = services.clone();
118+
119+
async move {
120+
debug!("Processing data generation request");
121+
122+
let provider = provider.clone();
123+
let mut context = context.clone();
124+
let content = if let Some(ref content) = user_prompt {
125+
TemplateEngine::default().render_template(Template::new(content), &input)?
126+
} else {
127+
serde_json::to_string(&input)?
128+
};
129+
130+
context =
131+
context.add_message(ContextMessage::user(content, Some(model_id.clone())));
132+
133+
let stream = services.chat(&model_id, context, provider.clone()).await?;
134+
let response = stream.into_full(false).await?;
135+
136+
anyhow::Ok((input, response))
137+
}
138+
});
139+
140+
let json_stream = stream::iter(json_stream)
141+
.buffer_unordered(concurrency)
142+
.map(|result| {
143+
result.and_then(|(input, response)| {
144+
response
145+
.tool_calls
146+
.into_iter()
147+
.map(|tool| {
148+
let output = tool.arguments.parse()?;
149+
let mut value = serde_json::Map::new();
150+
value.insert("input".to_string(), input.clone());
151+
value.insert("output".to_string(), output);
152+
Ok(serde_json::Value::from(value))
153+
})
154+
.collect::<Result<Vec<_>>>()
155+
})
156+
})
157+
.flat_map(|data| match data {
158+
Ok(data) => stream::iter(data).map(Ok).boxed(),
159+
Err(err) => stream::iter(Err(err)).boxed(),
160+
})
161+
.boxed();
162+
163+
Ok(json_stream)
164+
}
165+
}

crates/forge_app/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
mod agent;
22
mod agent_executor;
3-
43
mod agent_provider_resolver;
54
mod app;
65
mod apply_tunable_parameters;
76
mod authenticator;
87
mod changed_files;
98
mod command_generator;
109
mod compact;
10+
mod data_gen;
1111
pub mod dto;
1212
mod error;
1313
mod file_tracking;
@@ -41,6 +41,7 @@ pub use agent::*;
4141
pub use agent_provider_resolver::*;
4242
pub use app::*;
4343
pub use command_generator::*;
44+
pub use data_gen::*;
4445
pub use error::*;
4546
pub use git_app::*;
4647
pub use infra::*;

crates/forge_app/src/services.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ pub struct PolicyDecision {
129129
pub trait ProviderService: Send + Sync {
130130
async fn chat(
131131
&self,
132-
id: &ModelId,
132+
model_id: &ModelId,
133133
context: Context,
134134
provider: Provider<Url>,
135135
) -> ResultStream<ChatCompletionMessage, anyhow::Error>;
@@ -606,11 +606,13 @@ impl<I: Services> ConversationService for I {
606606
impl<I: Services> ProviderService for I {
607607
async fn chat(
608608
&self,
609-
id: &ModelId,
609+
model_id: &ModelId,
610610
context: Context,
611611
provider: Provider<Url>,
612612
) -> ResultStream<ChatCompletionMessage, anyhow::Error> {
613-
self.provider_service().chat(id, context, provider).await
613+
self.provider_service()
614+
.chat(model_id, context, provider)
615+
.await
614616
}
615617

616618
async fn models(&self, provider: Provider<Url>) -> anyhow::Result<Vec<Model>> {
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
use std::path::PathBuf;
2+
3+
use derive_setters::Setters;
4+
use serde::{Deserialize, Serialize};
5+
6+
/// Parameters for data generation operations
7+
///
8+
/// This struct encapsulates the configuration parameters needed for generating
9+
/// data in various contexts. It provides control over the amount of data to
10+
/// generate, formatting options, and other generation-specific settings.
11+
#[derive(Debug, Clone, Serialize, Deserialize, Setters, PartialEq, fake::Dummy)]
12+
#[setters(into, strip_option)]
13+
pub struct DataGenerationParameters {
14+
/// Path to input JSONL file for data generation
15+
pub input: PathBuf,
16+
17+
/// Path to JSON schema file for LLM tool definition
18+
pub schema: PathBuf,
19+
20+
/// Path to Handlebars template file for system prompt
21+
pub system_prompt: Option<PathBuf>,
22+
23+
/// Path to Handlebars template file for user prompt
24+
pub user_prompt: Option<PathBuf>,
25+
26+
/// Maximum number of concurrent LLM requests
27+
pub concurrency: usize,
28+
}

crates/forge_domain/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ mod compact;
99
mod context;
1010
mod conversation;
1111
mod conversation_html;
12+
mod data_gen;
1213
mod env;
1314
mod error;
1415
mod event;
@@ -58,6 +59,7 @@ pub use compact::*;
5859
pub use context::*;
5960
pub use conversation::*;
6061
pub use conversation_html::*;
62+
pub use data_gen::*;
6163
pub use env::*;
6264
pub use error::*;
6365
pub use event::*;

0 commit comments

Comments
 (0)