Skip to content

Commit bc367f4

Browse files
committed
Add sample
1 parent 7581f9f commit bc367f4

File tree

7 files changed

+81
-11
lines changed

7 files changed

+81
-11
lines changed

README.md

+10
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ print(src, tgt)
1818
# source.txt.shuffled target.txt.shuffled
1919
```
2020

21+
You can also simultaneously sample/isolate a certain number of sentences from the dataset (which are then removed from the shuffled result)
22+
23+
```
24+
from fastshuffle import file_shuffle_sample
25+
26+
src, tgt, src_sample, tgt_sample = file_shuffle("source.txt", "target.txt", 5) # Sample 5 sentences
27+
print(src, tgt, src_sample, tgt_sample)
28+
# source.txt.shuffled target.txt.shuffled source.txt.shuffled.sample target.txt.shuffled.sample
29+
```
30+
2131
## Notes
2232

2333
Source and target must have the same number of lines. No validation checks are made.

fastshuffle.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ namespace py = pybind11;
88

99
PYBIND11_MODULE(fastshuffle, m) {
1010
m.def("file_shuffle", &shuffle);
11+
m.def("file_shuffle_sample", &shuffle_sample);
12+
1113

1214
#ifdef VERSION_INFO
1315
m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);

main.cpp

+13-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ int main(int argc, char **argv) {
77
options.add_options()
88
("s,src", "Input source corpus", cxxopts::value<std::string>())
99
("t,tgt", "Input target corpus", cxxopts::value<std::string>())
10+
("sample", "Also sample these many sentences", cxxopts::value<long long>()->default_value("0"))
1011
("h,help", "Print usage")
1112
;
1213
options.parse_positional({ "src", "tgt" });
@@ -30,10 +31,19 @@ int main(int argc, char **argv) {
3031
try {
3132
const auto source = result["src"].as<std::string>();
3233
const auto target = result["tgt"].as<std::string>();
34+
const auto sample = result["sample"].as<long long>();
3335

34-
auto result = shuffle(source, target);
35-
std::cout << "W\t" << std::get<0>(result) << std::endl;
36-
std::cout << "W\t" << std::get<1>(result) << std::endl;
36+
if (sample > 0){
37+
auto result = shuffle_sample(source, target, sample);
38+
std::cout << "W\t" << std::get<0>(result) << std::endl;
39+
std::cout << "W\t" << std::get<1>(result) << std::endl;
40+
std::cout << "W\t" << std::get<2>(result) << std::endl;
41+
std::cout << "W\t" << std::get<3>(result) << std::endl;
42+
}else{
43+
auto result = shuffle(source, target);
44+
std::cout << "W\t" << std::get<0>(result) << std::endl;
45+
std::cout << "W\t" << std::get<1>(result) << std::endl;
46+
}
3747
}catch (std::exception &e) {
3848
std::cerr << "Error: " << e.what() << std::endl;
3949
return EXIT_FAILURE;

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def build_extension(self, ext: CMakeExtension) -> None:
124124

125125
setup(
126126
name="fastshuffle",
127-
version="1.0.0",
127+
version="1.0.1",
128128
author="Piero Toffanin",
129129
author_email="[email protected]",
130130
url="https://github.com/LibreTranslate/FastShuffle",

shuffle.cpp

+39-5
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,20 @@ size_t line_count(mmap_handle_t mht){
7171
}
7272

7373
std::tuple<std::string, std::string> shuffle(const std::string &src, const std::string &tgt){
74+
auto res = shuffle_sample(src, tgt, 0);
75+
return std::make_tuple(std::get<0>(res), std::get<1>(res));
76+
}
77+
78+
std::tuple<std::string, std::string, std::string, std::string> shuffle_sample(const std::string &src, const std::string &tgt, long long sample){
7479
mmap_handle_t smht = mmap_open(src);
7580
mmap_handle_t tmht = mmap_open(tgt);
7681

7782
std::string src_out = src + ".shuffled";
7883
std::string tgt_out = tgt + ".shuffled";
7984

85+
std::string src_sample_out = src_out + ".sample";
86+
std::string tgt_sample_out = tgt_out + ".sample";
87+
8088
// size_t src_count = line_count(smht);
8189

8290
std::vector<line_off_t> offsets;
@@ -122,12 +130,38 @@ std::tuple<std::string, std::string> shuffle(const std::string &src, const std::
122130
std::ofstream src_of(src_out, std::ios::trunc);
123131
std::ofstream tgt_of(tgt_out, std::ios::trunc);
124132

133+
std::ofstream *src_sample_of = nullptr;
134+
std::ofstream *tgt_sample_of = nullptr;
135+
136+
if (sample > 0){
137+
src_sample_of = new std::ofstream(src_sample_out, std::ios::trunc);
138+
tgt_sample_of = new std::ofstream(tgt_sample_out, std::ios::trunc);
139+
}
140+
125141
// std::cout << offsets.size() << std::endl << std::endl;
126142
for (size_t i = 0; i < offsets.size(); i++){
127-
src_of.write(offsets[i].src_start, offsets[i].src_end - offsets[i].src_start);
128-
src_of.write("\n", 1);
129-
tgt_of.write(offsets[i].tgt_start, offsets[i].tgt_end - offsets[i].tgt_start);
130-
tgt_of.write("\n", 1);
143+
if (sample > 0 && i < sample){
144+
src_sample_of->write(offsets[i].src_start, offsets[i].src_end - offsets[i].src_start);
145+
src_sample_of->write("\n", 1);
146+
tgt_sample_of->write(offsets[i].tgt_start, offsets[i].tgt_end - offsets[i].tgt_start);
147+
tgt_sample_of->write("\n", 1);
148+
}else{
149+
src_of.write(offsets[i].src_start, offsets[i].src_end - offsets[i].src_start);
150+
src_of.write("\n", 1);
151+
tgt_of.write(offsets[i].tgt_start, offsets[i].tgt_end - offsets[i].tgt_start);
152+
tgt_of.write("\n", 1);
153+
}
154+
}
155+
156+
if (tgt_sample_of != nullptr){
157+
tgt_sample_of->close();
158+
delete tgt_sample_of;
159+
tgt_sample_of = nullptr;
160+
}
161+
if (src_sample_of != nullptr){
162+
src_sample_of->close();
163+
delete src_sample_of;
164+
src_sample_of = nullptr;
131165
}
132166

133167
src_of.close();
@@ -136,5 +170,5 @@ std::tuple<std::string, std::string> shuffle(const std::string &src, const std::
136170
mmap_close(smht);
137171
mmap_close(tmht);
138172

139-
return std::make_tuple(src_out, tgt_out);
173+
return std::make_tuple(src_out, tgt_out, src_sample_out, tgt_sample_out);
140174
}

shuffle.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,5 @@ typedef struct{
4242
mmap_handle_t mmap_open(const std::string &file);
4343
void mmap_close(mmap_handle_t mht);
4444
size_t line_count(mmap_handle_t mht);
45-
std::tuple<std::string, std::string> shuffle(const std::string &src, const std::string &tgt);
45+
std::tuple<std::string, std::string> shuffle(const std::string &src, const std::string &tgt);
46+
std::tuple<std::string, std::string, std::string, std::string> shuffle_sample(const std::string &src, const std::string &tgt, long long sample);

tests/test_shuffle.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from fastshuffle import file_shuffle
2+
from fastshuffle import file_shuffle, file_shuffle_sample
33
import os
44

55
def test_shuffle():
@@ -15,4 +15,17 @@ def test_shuffle():
1515
assert sum(1 for _ in open(src)) == sum(1 for _ in open(in_src))
1616
assert sum(1 for _ in open(tgt)) == sum(1 for _ in open(in_tgt))
1717

18+
19+
def test_shuffle_sample():
20+
cwd = os.path.dirname(__file__)
21+
22+
in_src = os.path.join(cwd, "data", "src.txt")
23+
in_tgt = os.path.join(cwd, "data", "tgt.txt")
24+
src, tgt, ssrc, stgt = file_shuffle_sample(in_src, in_tgt, 1)
25+
26+
assert os.path.isfile(src)
27+
assert os.path.isfile(tgt)
28+
assert os.path.isfile(ssrc)
29+
assert os.path.isfile(stgt)
30+
1831

0 commit comments

Comments
 (0)