Spaces:
Sleeping
Sleeping
File size: 1,578 Bytes
3860419 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
"""
Tests the collect_learnings function in the cli/collect module.
"""
import pytest
# def test_collect_learnings(monkeypatch):
# monkeypatch.setattr(rudder_analytics, "track", MagicMock())
#
# model = "test_model"
# temperature = 0.5
# steps = [simple_gen]
# dbs = FileRepositories(
# OnDiskRepository("/tmp"),
# OnDiskRepository("/tmp"),
# OnDiskRepository("/tmp"),
# OnDiskRepository("/tmp"),
# OnDiskRepository("/tmp"),
# OnDiskRepository("/tmp"),
# OnDiskRepository("/tmp"),
# )
# dbs.input = {
# "prompt": "test prompt\n with newlines",
# "feedback": "test feedback",
# }
# code = "this is output\n\nit contains code"
# dbs.logs = {steps[0].__name__: json.dumps([{"role": "system", "content": code}])}
# dbs.memory = {"all_output.txt": "test workspace\n" + code}
#
# collect_learnings(model, temperature, steps, dbs)
#
# learnings = extract_learning(
# model, temperature, steps, dbs, steps_file_hash=steps_file_hash()
# )
# assert rudder_analytics.track.call_count == 1
# assert rudder_analytics.track.call_args[1]["event"] == "learning"
# a = {
# k: v
# for k, v in rudder_analytics.track.call_args[1]["properties"].items()
# if k != "timestamp"
# }
# b = {k: v for k, v in learnings.to_dict().items() if k != "timestamp"}
# assert a == b
#
# assert json.dumps(code) in learnings.logs
# assert code in learnings.workspace
if __name__ == "__main__":
pytest.main(["-v"])
|