|
15 | 15 | package dev.cel.extensions; |
16 | 16 |
|
17 | 17 | import com.google.common.collect.ImmutableList; |
18 | | -import com.google.common.collect.ImmutableMap; |
19 | 18 | import com.google.common.collect.ImmutableSet; |
20 | 19 | import com.google.errorprone.annotations.Immutable; |
21 | 20 | import com.google.re2j.Matcher; |
|
25 | 24 | import dev.cel.common.CelFunctionDecl; |
26 | 25 | import dev.cel.common.CelOverloadDecl; |
27 | 26 | import dev.cel.common.types.ListType; |
28 | | -import dev.cel.common.types.MapType; |
29 | 27 | import dev.cel.common.types.OptionalType; |
30 | 28 | import dev.cel.common.types.SimpleType; |
31 | 29 | import dev.cel.compiler.CelCompilerLibrary; |
|
40 | 38 | final class CelRegexExtensions implements CelCompilerLibrary, CelRuntimeLibrary { |
41 | 39 |
|
42 | 40 | private static final String REGEX_REPLACE_FUNCTION = "regex.replace"; |
43 | | - private static final String REGEX_CAPTURE_FUNCTION = "regex.capture"; |
44 | | - private static final String REGEX_CAPTUREALL_FUNCTION = "regex.captureAll"; |
45 | | - private static final String REGEX_CAPTUREALLNAMED_FUNCTION = "regex.captureAllNamed"; |
| 41 | + private static final String REGEX_EXTRACT_FUNCTION = "regex.extract"; |
| 42 | + private static final String REGEX_EXTRACT_ALL_FUNCTION = "regex.extractAll"; |
46 | 43 |
|
47 | 44 | enum Function { |
48 | 45 | REPLACE( |
@@ -83,52 +80,36 @@ enum Function { |
83 | 80 | long count = (long) args[3]; |
84 | 81 | return CelRegexExtensions.replace(target, pattern, replaceStr, count); |
85 | 82 | }))), |
86 | | - CAPTURE( |
| 83 | + EXTRACT( |
87 | 84 | CelFunctionDecl.newFunctionDeclaration( |
88 | | - REGEX_CAPTURE_FUNCTION, |
| 85 | + REGEX_EXTRACT_FUNCTION, |
89 | 86 | CelOverloadDecl.newGlobalOverload( |
90 | | - "regex_capture_string_string", |
| 87 | + "regex_extract_string_string", |
91 | 88 | "Returns the first substring that matches the regex.", |
92 | 89 | OptionalType.create(SimpleType.STRING), |
93 | 90 | SimpleType.STRING, |
94 | 91 | SimpleType.STRING)), |
95 | 92 | ImmutableSet.of( |
96 | 93 | CelFunctionBinding.from( |
97 | | - "regex_capture_string_string", |
| 94 | + "regex_extract_string_string", |
98 | 95 | String.class, |
99 | 96 | String.class, |
100 | | - CelRegexExtensions::captureFirstMatch))), |
101 | | - CAPTUREALL( |
| 97 | + CelRegexExtensions::extract))), |
| 98 | + EXTRACTALL( |
102 | 99 | CelFunctionDecl.newFunctionDeclaration( |
103 | | - REGEX_CAPTUREALL_FUNCTION, |
| 100 | + REGEX_EXTRACT_ALL_FUNCTION, |
104 | 101 | CelOverloadDecl.newGlobalOverload( |
105 | | - "regex_captureAll_string_string", |
106 | | - "Returns an arrat of all substrings that match the regex.", |
| 102 | + "regex_extractAll_string_string", |
| 103 | + "Returns an array of all substrings that match the regex.", |
107 | 104 | ListType.create(SimpleType.STRING), |
108 | 105 | SimpleType.STRING, |
109 | 106 | SimpleType.STRING)), |
110 | 107 | ImmutableSet.of( |
111 | 108 | CelFunctionBinding.from( |
112 | | - "regex_captureAll_string_string", |
| 109 | + "regex_extractAll_string_string", |
113 | 110 | String.class, |
114 | 111 | String.class, |
115 | | - CelRegexExtensions::captureAllMatches))), |
116 | | - CAPTUREALLNAMED( |
117 | | - CelFunctionDecl.newFunctionDeclaration( |
118 | | - REGEX_CAPTUREALLNAMED_FUNCTION, |
119 | | - CelOverloadDecl.newGlobalOverload( |
120 | | - "regex_captureAllNamed_string_string", |
121 | | - "Returns a map of all named captured groups as <named_group_name, captured_string>." |
122 | | - + " Ignores the unnamed capture groups.", |
123 | | - MapType.create(SimpleType.STRING, SimpleType.STRING), |
124 | | - SimpleType.STRING, |
125 | | - SimpleType.STRING)), |
126 | | - ImmutableSet.of( |
127 | | - CelFunctionBinding.from( |
128 | | - "regex_captureAllNamed_string_string", |
129 | | - String.class, |
130 | | - String.class, |
131 | | - CelRegexExtensions::captureAllNamedGroups))); |
| 112 | + CelRegexExtensions::extractAll))); |
132 | 113 |
|
133 | 114 | private final CelFunctionDecl functionDecl; |
134 | 115 | private final ImmutableSet<CelFunctionBinding> functionBindings; |
@@ -200,67 +181,49 @@ private static String replace(String target, String regex, String replaceStr, lo |
200 | 181 | return sb.toString(); |
201 | 182 | } |
202 | 183 |
|
203 | | - private static Optional<String> captureFirstMatch(String target, String regex) { |
| 184 | + private static Optional<String> extract(String target, String regex) { |
204 | 185 | Pattern pattern = compileRegexPattern(regex); |
205 | 186 | Matcher matcher = pattern.matcher(target); |
206 | 187 |
|
207 | | - if (matcher.find()) { |
208 | | - // If there are capture groups, return the first one. |
209 | | - if (matcher.groupCount() > 0) { |
210 | | - return Optional.ofNullable(matcher.group(1)); |
211 | | - } else { |
212 | | - // If there are no capture groups, return the entire match. |
213 | | - return Optional.of(matcher.group(0)); |
214 | | - } |
| 188 | + if (!matcher.find()) { |
| 189 | + return Optional.empty(); |
215 | 190 | } |
216 | 191 |
|
217 | | - return Optional.empty(); |
218 | | - } |
219 | | - |
220 | | - private static ImmutableList<String> captureAllMatches(String target, String regex) { |
221 | | - Pattern pattern = compileRegexPattern(regex); |
222 | | - |
223 | | - Matcher matcher = pattern.matcher(target); |
224 | | - ImmutableList.Builder<String> builder = ImmutableList.builder(); |
225 | | - |
226 | | - while (matcher.find()) { |
227 | | - // If there are capture groups, return all of them. Otherwise, return the entire match. |
228 | | - if (matcher.groupCount() > 0) { |
229 | | - // Add all the capture groups to the result list. |
230 | | - for (int i = 1; i <= matcher.groupCount(); i++) { |
231 | | - String group = matcher.group(i); |
232 | | - if (group != null) { |
233 | | - builder.add(group); |
234 | | - } |
235 | | - } |
236 | | - } else { |
237 | | - builder.add(matcher.group(0)); |
238 | | - } |
| 192 | + int groupCount = matcher.groupCount(); |
| 193 | + if (groupCount > 1) { |
| 194 | + throw new IllegalArgumentException( |
| 195 | + "Regular expression has more than one capturing group: " + regex); |
239 | 196 | } |
240 | 197 |
|
241 | | - return builder.build(); |
| 198 | + String result = (groupCount == 1) ? matcher.group(1) : matcher.group(0); |
| 199 | + |
| 200 | + return Optional.ofNullable(result); |
242 | 201 | } |
243 | 202 |
|
244 | | - private static ImmutableMap<String, String> captureAllNamedGroups(String target, String regex) { |
245 | | - ImmutableMap.Builder<String, String> builder = ImmutableMap.builder(); |
| 203 | + private static ImmutableList<String> extractAll(String target, String regex) { |
246 | 204 | Pattern pattern = compileRegexPattern(regex); |
| 205 | + Matcher matcher = pattern.matcher(target); |
247 | 206 |
|
248 | | - Set<String> groupNames = pattern.namedGroups().keySet(); |
249 | | - if (groupNames.isEmpty()) { |
250 | | - return builder.buildOrThrow(); |
| 207 | + if (matcher.groupCount() > 1) { |
| 208 | + throw new IllegalArgumentException( |
| 209 | + "Regular expression has more than one capturing group: " + regex); |
251 | 210 | } |
252 | 211 |
|
253 | | - Matcher matcher = pattern.matcher(target); |
| 212 | + ImmutableList.Builder<String> builder = ImmutableList.builder(); |
| 213 | + boolean hasOneGroup = matcher.groupCount() == 1; |
254 | 214 |
|
255 | 215 | while (matcher.find()) { |
256 | | - |
257 | | - for (String groupName : groupNames) { |
258 | | - String capturedValue = matcher.group(groupName); |
259 | | - if (capturedValue != null) { |
260 | | - builder.put(groupName, capturedValue); |
| 216 | + if (hasOneGroup) { |
| 217 | + String group = matcher.group(1); |
| 218 | + // Add the captured group's content only if it's not null (e.g. optional group didn't match) |
| 219 | + if (group != null) { |
| 220 | + builder.add(group); |
261 | 221 | } |
| 222 | + } else { // No capturing groups (matcher.groupCount() == 0) |
| 223 | + builder.add(matcher.group(0)); |
262 | 224 | } |
263 | 225 | } |
264 | | - return builder.buildOrThrow(); |
| 226 | + |
| 227 | + return builder.build(); |
265 | 228 | } |
266 | 229 | } |
0 commit comments