diff --git "a/examples/Zamba_2_1_2B.ipynb" "b/examples/Zamba_2_1_2B.ipynb"
new file mode 100644--- /dev/null
+++ "b/examples/Zamba_2_1_2B.ipynb"
@@ -0,0 +1,2010 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "f2dbda2ff7ba4b019af0f17975444078": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_31a19787f40749c3bc4f3fdaf14135b1",
+ "IPY_MODEL_00aee0329e044d31bd4c4ebe24097f2c",
+ "IPY_MODEL_a2dc535a5a7e45629d16d0ffdcf485bb"
+ ],
+ "layout": "IPY_MODEL_a52adc24c6244cc9b7b9d1fd1c255515"
+ }
+ },
+ "31a19787f40749c3bc4f3fdaf14135b1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_0ee238c7f9bf439e9102126c9e9c5564",
+ "placeholder": "",
+ "style": "IPY_MODEL_b8427e036de44f1fa305b74437578d70",
+ "value": "config.json: 100%"
+ }
+ },
+ "00aee0329e044d31bd4c4ebe24097f2c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_b560b0df040349fea0a1b05b4834c719",
+ "max": 1337,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_c6a427f52ce94de99ed357d54f2b1725",
+ "value": 1337
+ }
+ },
+ "a2dc535a5a7e45629d16d0ffdcf485bb": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_c3ec07880d864473af81da2726d42dc0",
+ "placeholder": "",
+ "style": "IPY_MODEL_e3f0244893824653bd080fd2d5b50c61",
+ "value": " 1.34k/1.34k [00:00<00:00, 19.8kB/s]"
+ }
+ },
+ "a52adc24c6244cc9b7b9d1fd1c255515": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0ee238c7f9bf439e9102126c9e9c5564": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "b8427e036de44f1fa305b74437578d70": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "b560b0df040349fea0a1b05b4834c719": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "c6a427f52ce94de99ed357d54f2b1725": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "c3ec07880d864473af81da2726d42dc0": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e3f0244893824653bd080fd2d5b50c61": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "a08ec1110d7c4fe3acd58b5d7cf3051d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_7d36442fdd35419f91246fe08cede692",
+ "IPY_MODEL_867f946fe3a0459f965114d9585d4d76",
+ "IPY_MODEL_1e10ded196cd4fa58ccf4e755b476ad1"
+ ],
+ "layout": "IPY_MODEL_245953fc4e04473986803901006aa94e"
+ }
+ },
+ "7d36442fdd35419f91246fe08cede692": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_79326f7e145f47c79d95ea2ea5d49441",
+ "placeholder": "",
+ "style": "IPY_MODEL_4f81a97adef7417a8235efe631f65aff",
+ "value": "config.json: 100%"
+ }
+ },
+ "867f946fe3a0459f965114d9585d4d76": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_c64dbd4e5eaf4b8b96d9a7707606b249",
+ "max": 1337,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_a2b9f172590c4d988f66cef60270d9ee",
+ "value": 1337
+ }
+ },
+ "1e10ded196cd4fa58ccf4e755b476ad1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_86ae8996a0be4ff0b4d5df3ed9b0629e",
+ "placeholder": "",
+ "style": "IPY_MODEL_51be35e151df4b37bd42d9075667cd93",
+ "value": " 1.34k/1.34k [00:00<00:00, 14.2kB/s]"
+ }
+ },
+ "245953fc4e04473986803901006aa94e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "79326f7e145f47c79d95ea2ea5d49441": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "4f81a97adef7417a8235efe631f65aff": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "c64dbd4e5eaf4b8b96d9a7707606b249": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a2b9f172590c4d988f66cef60270d9ee": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "86ae8996a0be4ff0b4d5df3ed9b0629e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "51be35e151df4b37bd42d9075667cd93": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "3e5f50646e04474f92718dc4882cbe4b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_f69224acb6874b9abce791367de857a2",
+ "IPY_MODEL_55e667af53844d1f9f4f55a339bbba87",
+ "IPY_MODEL_0afa674b2cae440b8aa229b557c8dc8f"
+ ],
+ "layout": "IPY_MODEL_0c4d444eb00545ceb6cdccfa77171202"
+ }
+ },
+ "f69224acb6874b9abce791367de857a2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_c5c27aa06cb745dba82328b6b422588e",
+ "placeholder": "",
+ "style": "IPY_MODEL_7347744554da4cd6bdf9b3397c3da1d6",
+ "value": "model.safetensors: 100%"
+ }
+ },
+ "55e667af53844d1f9f4f55a339bbba87": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_178e5bbfd3bf4ad081654f25110d253f",
+ "max": 2430175992,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_25c014e006d5444d93d8990d4646e1b4",
+ "value": 2430175992
+ }
+ },
+ "0afa674b2cae440b8aa229b557c8dc8f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_93cce87d3dfc46fc9ce2e84b4e485bbf",
+ "placeholder": "",
+ "style": "IPY_MODEL_b5835f0bb0764998ac3b587335bc6638",
+ "value": " 2.43G/2.43G [00:20<00:00, 86.1MB/s]"
+ }
+ },
+ "0c4d444eb00545ceb6cdccfa77171202": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "c5c27aa06cb745dba82328b6b422588e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "7347744554da4cd6bdf9b3397c3da1d6": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "178e5bbfd3bf4ad081654f25110d253f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "25c014e006d5444d93d8990d4646e1b4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "93cce87d3dfc46fc9ce2e84b4e485bbf": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "b5835f0bb0764998ac3b587335bc6638": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "3cfa20cc4dd24b499eca1c36a7f65320": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_03f47e6071d945afa9f422cfb21a8e30",
+ "IPY_MODEL_5dc0b8840cf84516a9f7e3a5deaead80",
+ "IPY_MODEL_a736cede537d4124b5da9e72fb4d1ba1"
+ ],
+ "layout": "IPY_MODEL_5e321e040972405eb9312a5f43be4228"
+ }
+ },
+ "03f47e6071d945afa9f422cfb21a8e30": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_a189c93099a24d3796944497c3ae45b4",
+ "placeholder": "",
+ "style": "IPY_MODEL_be36fa86b54c4003ba948f303e97b9ff",
+ "value": "generation_config.json: 100%"
+ }
+ },
+ "5dc0b8840cf84516a9f7e3a5deaead80": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_752ffddbba474ad58534d5471f24085b",
+ "max": 137,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_b1cb53eb6c5942d0bae3171e4eea22d8",
+ "value": 137
+ }
+ },
+ "a736cede537d4124b5da9e72fb4d1ba1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_85234acea70f4cf28c41e07bbb7eb8c2",
+ "placeholder": "",
+ "style": "IPY_MODEL_3bddfd87e42440778a098dfdfaefa263",
+ "value": " 137/137 [00:00<00:00, 6.42kB/s]"
+ }
+ },
+ "5e321e040972405eb9312a5f43be4228": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a189c93099a24d3796944497c3ae45b4": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "be36fa86b54c4003ba948f303e97b9ff": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "752ffddbba474ad58534d5471f24085b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "b1cb53eb6c5942d0bae3171e4eea22d8": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "85234acea70f4cf28c41e07bbb7eb8c2": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "3bddfd87e42440778a098dfdfaefa263": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ }
+ }
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "
Zamba 2-1.2B
"
+ ],
+ "metadata": {
+ "id": "h9DjgpWw2C27"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "This is a mamba based model with some transformer blocks in it. It is on its own fork of the transformers library so there are extra steps in installing its dependencies."
+ ],
+ "metadata": {
+ "id": "-bIoeS_I2MvR"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!git clone https://github.com/Zyphra/transformers_zamba2.git"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "TIdVR89J2fA4",
+ "outputId": "19ac8288-1fd8-4e53-b137-29aa8ecb2cad"
+ },
+ "execution_count": 1,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Cloning into 'transformers_zamba2'...\n",
+ "remote: Enumerating objects: 182990, done.\u001b[K\n",
+ "remote: Counting objects: 100% (200/200), done.\u001b[K\n",
+ "remote: Compressing objects: 100% (105/105), done.\u001b[K\n",
+ "remote: Total 182990 (delta 109), reused 152 (delta 78), pack-reused 182790 (from 1)\u001b[K\n",
+ "Receiving objects: 100% (182990/182990), 206.54 MiB | 9.13 MiB/s, done.\n",
+ "Resolving deltas: 100% (130848/130848), done.\n",
+ "Updating files: 100% (4310/4310), done.\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install -e transformers_zamba2"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "y22A557n2pVF",
+ "outputId": "039b0d88-927e-472b-94ce-27337858ee06"
+ },
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Obtaining file:///content/transformers_zamba2\n",
+ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Checking if build backend supports build_editable ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build editable ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing editable metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (3.16.0)\n",
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (0.24.6)\n",
+ "Requirement already satisfied: numpy<2.0,>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (1.26.4)\n",
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (24.1)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (6.0.2)\n",
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (2024.5.15)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (2.32.3)\n",
+ "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (0.19.1)\n",
+ "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (0.4.5)\n",
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (4.66.5)\n",
+ "Collecting mamba-ssm==2.1.0 (from transformers==4.43.0.dev0)\n",
+ " Downloading mamba_ssm-2.1.0.tar.gz (84 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.0/84.0 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ "Collecting ninja (from transformers==4.43.0.dev0)\n",
+ " Downloading ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl.metadata (5.3 kB)\n",
+ "Requirement already satisfied: einops in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (0.8.0)\n",
+ "Collecting triton (from transformers==4.43.0.dev0)\n",
+ " Downloading triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.3 kB)\n",
+ "Collecting causal-conv1d==1.3.0.post1 (from transformers==4.43.0.dev0)\n",
+ " Downloading causal_conv1d-1.3.0.post1.tar.gz (8.1 kB)\n",
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from causal-conv1d==1.3.0.post1->transformers==4.43.0.dev0) (2.4.0+cu121)\n",
+ "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers==4.43.0.dev0) (2024.6.1)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers==4.43.0.dev0) (4.12.2)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.43.0.dev0) (3.3.2)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.43.0.dev0) (3.8)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.43.0.dev0) (2.0.7)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.43.0.dev0) (2024.8.30)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->causal-conv1d==1.3.0.post1->transformers==4.43.0.dev0) (1.13.2)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->causal-conv1d==1.3.0.post1->transformers==4.43.0.dev0) (3.3)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->causal-conv1d==1.3.0.post1->transformers==4.43.0.dev0) (3.1.4)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->causal-conv1d==1.3.0.post1->transformers==4.43.0.dev0) (2.1.5)\n",
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->causal-conv1d==1.3.0.post1->transformers==4.43.0.dev0) (1.3.0)\n",
+ "Downloading ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl (307 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m307.2/307.2 kB\u001b[0m \u001b[31m18.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.4/209.4 MB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hBuilding wheels for collected packages: transformers, causal-conv1d, mamba-ssm\n",
+ " Building editable for transformers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for transformers: filename=transformers-4.43.0.dev0-0.editable-py3-none-any.whl size=17274 sha256=2671982986cfe13f2155d632e8fa0ea7dfea3fabcb3ca6f1f13b2dd2276e53fb\n",
+ " Stored in directory: /tmp/pip-ephem-wheel-cache-b5wf_y35/wheels/bf/e2/60/f15fb160f61df44c33784d7a3d842b1a6d21217a8d1d7572c6\n",
+ " Building wheel for causal-conv1d (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for causal-conv1d: filename=causal_conv1d-1.3.0.post1-cp310-cp310-linux_x86_64.whl size=103872557 sha256=bd8a9034d54da9480c8bc41af83660796430673d5ea41a00389646105cfcb324\n",
+ " Stored in directory: /root/.cache/pip/wheels/6a/25/f8/ffbc841b608eb39aaa91ff5d1a5b830b7780855dd65796b774\n",
+ " Building wheel for mamba-ssm (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for mamba-ssm: filename=mamba_ssm-2.1.0-cp310-cp310-linux_x86_64.whl size=323986419 sha256=4e6940c8e19fdc6d9a17171546a89ace24984f65917e51494bc6f1ed820097b4\n",
+ " Stored in directory: /root/.cache/pip/wheels/62/1a/a0/88447e865ca478b954b6560317096bd11c79fb03f6312a64bc\n",
+ "Successfully built transformers causal-conv1d mamba-ssm\n",
+ "Installing collected packages: ninja, triton, causal-conv1d, mamba-ssm, transformers\n",
+ " Attempting uninstall: transformers\n",
+ " Found existing installation: transformers 4.44.2\n",
+ " Uninstalling transformers-4.44.2:\n",
+ " Successfully uninstalled transformers-4.44.2\n",
+ "Successfully installed causal-conv1d-1.3.0.post1 mamba-ssm-2.1.0 ninja-1.11.1.1 transformers-4.43.0.dev0 triton-3.0.0\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install accelerate"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "zike0qpHVtVm",
+ "outputId": "f7b08f0a-d79e-4175-f131-7c108e9d7575"
+ },
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.34.2)\n",
+ "Requirement already satisfied: numpy<3.0.0,>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate) (1.26.4)\n",
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (24.1)\n",
+ "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n",
+ "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0.2)\n",
+ "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.4.0+cu121)\n",
+ "Requirement already satisfied: huggingface-hub>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.24.6)\n",
+ "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.4.5)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (3.16.0)\n",
+ "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (2024.6.1)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (2.32.3)\n",
+ "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (4.66.5)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (4.12.2)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.13.2)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.3)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.4)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.5)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.3.2)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.8)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2.0.7)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2024.8.30)\n",
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### NOTE\n",
+ "Restart session after installing the custom transformers fork for these changes to take effect."
+ ],
+ "metadata": {
+ "id": "OXJUYhrC42fd"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
+ "import torch"
+ ],
+ "metadata": {
+ "id": "9gkimtr040n6"
+ },
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+ "USE_MAMBA_KERNELS = True if torch.cuda.is_available() else False"
+ ],
+ "metadata": {
+ "id": "0B3xoCWY4-Qo"
+ },
+ "execution_count": 3,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 501,
+ "referenced_widgets": [
+ "f2dbda2ff7ba4b019af0f17975444078",
+ "31a19787f40749c3bc4f3fdaf14135b1",
+ "00aee0329e044d31bd4c4ebe24097f2c",
+ "a2dc535a5a7e45629d16d0ffdcf485bb",
+ "a52adc24c6244cc9b7b9d1fd1c255515",
+ "0ee238c7f9bf439e9102126c9e9c5564",
+ "b8427e036de44f1fa305b74437578d70",
+ "b560b0df040349fea0a1b05b4834c719",
+ "c6a427f52ce94de99ed357d54f2b1725",
+ "c3ec07880d864473af81da2726d42dc0",
+ "e3f0244893824653bd080fd2d5b50c61",
+ "a08ec1110d7c4fe3acd58b5d7cf3051d",
+ "7d36442fdd35419f91246fe08cede692",
+ "867f946fe3a0459f965114d9585d4d76",
+ "1e10ded196cd4fa58ccf4e755b476ad1",
+ "245953fc4e04473986803901006aa94e",
+ "79326f7e145f47c79d95ea2ea5d49441",
+ "4f81a97adef7417a8235efe631f65aff",
+ "c64dbd4e5eaf4b8b96d9a7707606b249",
+ "a2b9f172590c4d988f66cef60270d9ee",
+ "86ae8996a0be4ff0b4d5df3ed9b0629e",
+ "51be35e151df4b37bd42d9075667cd93",
+ "3e5f50646e04474f92718dc4882cbe4b",
+ "f69224acb6874b9abce791367de857a2",
+ "55e667af53844d1f9f4f55a339bbba87",
+ "0afa674b2cae440b8aa229b557c8dc8f",
+ "0c4d444eb00545ceb6cdccfa77171202",
+ "c5c27aa06cb745dba82328b6b422588e",
+ "7347744554da4cd6bdf9b3397c3da1d6",
+ "178e5bbfd3bf4ad081654f25110d253f",
+ "25c014e006d5444d93d8990d4646e1b4",
+ "93cce87d3dfc46fc9ce2e84b4e485bbf",
+ "b5835f0bb0764998ac3b587335bc6638",
+ "3cfa20cc4dd24b499eca1c36a7f65320",
+ "03f47e6071d945afa9f422cfb21a8e30",
+ "5dc0b8840cf84516a9f7e3a5deaead80",
+ "a736cede537d4124b5da9e72fb4d1ba1",
+ "5e321e040972405eb9312a5f43be4228",
+ "a189c93099a24d3796944497c3ae45b4",
+ "be36fa86b54c4003ba948f303e97b9ff",
+ "752ffddbba474ad58534d5471f24085b",
+ "b1cb53eb6c5942d0bae3171e4eea22d8",
+ "85234acea70f4cf28c41e07bbb7eb8c2",
+ "3bddfd87e42440778a098dfdfaefa263"
+ ]
+ },
+ "id": "kuKpJo2gVUg1",
+ "outputId": "f5c1963f-d60a-4a48-b6ea-d8090814c033"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "config.json: 0%| | 0.00/1.34k [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "f2dbda2ff7ba4b019af0f17975444078"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "config.json: 0%| | 0.00/1.34k [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "a08ec1110d7c4fe3acd58b5d7cf3051d"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "model.safetensors: 0%| | 0.00/2.43G [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "3e5f50646e04474f92718dc4882cbe4b"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "generation_config.json: 0%| | 0.00/137 [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "3cfa20cc4dd24b499eca1c36a7f65320"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "error",
+ "ename": "RuntimeError",
+ "evalue": "Expected x.is_cuda() to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0minput_ids\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_text\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_tensors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"pt\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDEVICE\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_new_tokens\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mctx_factory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 116\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 117\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/content/transformers_zamba2/src/transformers/generation/utils.py\u001b[0m in \u001b[0;36mgenerate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m 1967\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1968\u001b[0m \u001b[0;31m# 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1969\u001b[0;31m result = self._sample(\n\u001b[0m\u001b[1;32m 1970\u001b[0m \u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1971\u001b[0m \u001b[0mlogits_processor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mprepared_logits_processor\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/content/transformers_zamba2/src/transformers/generation/utils.py\u001b[0m in \u001b[0;36m_sample\u001b[0;34m(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)\u001b[0m\n\u001b[1;32m 2911\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2912\u001b[0m \u001b[0;31m# forward pass to get next token\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2913\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mmodel_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2914\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2915\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msynced_gpus\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mthis_peer_finished\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1552\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1553\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1554\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1555\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1560\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1561\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1563\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1564\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/content/transformers_zamba2/src/transformers/models/zamba2/modeling_zamba2.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)\u001b[0m\n\u001b[1;32m 1433\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1434\u001b[0m \u001b[0;31m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1435\u001b[0;31m outputs = self.model(\n\u001b[0m\u001b[1;32m 1436\u001b[0m \u001b[0minput_ids\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1437\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1552\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1553\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1554\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1555\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1560\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1561\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1563\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1564\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/content/transformers_zamba2/src/transformers/models/zamba2/modeling_zamba2.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)\u001b[0m\n\u001b[1;32m 1277\u001b[0m )\n\u001b[1;32m 1278\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1279\u001b[0;31m layer_outputs = next(mamba_layers)(\n\u001b[0m\u001b[1;32m 1280\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1281\u001b[0m \u001b[0mtransformer_hidden_states\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtransformer_hidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1552\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1553\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1554\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1555\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1560\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1561\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1563\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1564\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/content/transformers_zamba2/src/transformers/models/zamba2/modeling_zamba2.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, transformer_hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)\u001b[0m\n\u001b[1;32m 982\u001b[0m )\n\u001b[1;32m 983\u001b[0m \u001b[0mhidden_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minput_layernorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 984\u001b[0;31m hidden_states = self.mamba(\n\u001b[0m\u001b[1;32m 985\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 986\u001b[0m \u001b[0minference_params\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpast_key_value\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1552\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1553\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1554\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1555\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1560\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1561\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1563\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1564\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/content/transformers_zamba2/src/transformers/models/zamba2/mamba2_layer.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, u, from_shared_proj, seqlen, seq_idx, inference_params, attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m 228\u001b[0m ) # (B, L, self.d_ssm + 2 * ngroups * d_state)\n\u001b[1;32m 229\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 230\u001b[0;31m xBC = causal_conv1d_fn(\n\u001b[0m\u001b[1;32m 231\u001b[0m \u001b[0mxBC\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[0mrearrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"d 1 w -> d w\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/causal_conv1d/causal_conv1d_interface.py\u001b[0m in \u001b[0;36mcausal_conv1d_fn\u001b[0;34m(x, weight, bias, seq_idx, initial_states, return_final_states, final_states_out, activation)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mseqlen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \"\"\"\n\u001b[0;32m--> 121\u001b[0;31m return CausalConv1dFn.apply(\n\u001b[0m\u001b[1;32m 122\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m 572\u001b[0m \u001b[0;31m# See NOTE: [functorch vjp and autograd interaction]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 573\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_functorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munwrap_dead_wrappers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 574\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 575\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 576\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_setup_ctx_defined\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/causal_conv1d/causal_conv1d_interface.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(ctx, x, weight, bias, seq_idx, initial_states, return_final_states, final_states_out, activation)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0mfinal_states_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mactivation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mactivation\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m\"silu\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"swish\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 57\u001b[0;31m out = causal_conv1d_cuda.causal_conv1d_fwd(\n\u001b[0m\u001b[1;32m 58\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mseq_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minitial_states\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfinal_states_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mactivation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m )\n",
+ "\u001b[0;31mRuntimeError\u001b[0m: Expected x.is_cuda() to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)"
+ ]
+ }
+ ],
+ "source": [
+ "tokenizer = AutoTokenizer.from_pretrained(\"Zyphra/Zamba2-1.2B\")\n",
+ "model = AutoModelForCausalLM.from_pretrained(\"Zyphra/Zamba2-1.2B\",\n",
+ " device_map=DEVICE,\n",
+ " torch_dtype=torch.bfloat16,\n",
+ " use_mamba_kernels=USE_MAMBA_KERNELS,\n",
+ " # force_download=True,\n",
+ " )\n",
+ "\n",
+ "input_text = \"A funny prompt would be \"\n",
+ "input_ids = tokenizer(input_text, return_tensors=\"pt\").to(DEVICE)\n",
+ "\n",
+ "outputs = model.generate(**input_ids, max_new_tokens=100)\n",
+ "print(tokenizer.decode(outputs[0]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "help(tokenizer(input_text, return_tensors=\"pt\").to)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Jbft3pGh8c6X",
+ "outputId": "e6877b68-a529-4eea-eadd-572035761607"
+ },
+ "execution_count": 11,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Help on method to in module transformers.tokenization_utils_base:\n",
+ "\n",
+ "to(device: Union[str, ForwardRef('torch.device')]) -> 'BatchEncoding' method of transformers.tokenization_utils_base.BatchEncoding instance\n",
+ " Send all values to device by calling `v.to(device)` (PyTorch only).\n",
+ " \n",
+ " Args:\n",
+ " device (`str` or `torch.device`): The device to put the tensors on.\n",
+ " \n",
+ " Returns:\n",
+ " [`BatchEncoding`]: The same instance after modification.\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "help(AutoModelForCausalLM.from_pretrained)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "r5f0R-jI4A1I",
+ "outputId": "8aadb38d-92d6-49dc-d944-0730b562db78"
+ },
+ "execution_count": 5,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Help on method from_pretrained in module transformers.models.auto.auto_factory:\n",
+ "\n",
+ "from_pretrained(*model_args, **kwargs) method of builtins.type instance\n",
+ " Instantiate one of the model classes of the library (with a causal language modeling head) from a pretrained model.\n",
+ " \n",
+ " The model class to instantiate is selected based on the `model_type` property of the config object (either\n",
+ " passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by\n",
+ " falling back to using pattern matching on `pretrained_model_name_or_path`:\n",
+ " \n",
+ " - **bart** -- [`BartForCausalLM`] (BART model)\n",
+ " - **bert** -- [`BertLMHeadModel`] (BERT model)\n",
+ " - **bert-generation** -- [`BertGenerationDecoder`] (Bert Generation model)\n",
+ " - **big_bird** -- [`BigBirdForCausalLM`] (BigBird model)\n",
+ " - **bigbird_pegasus** -- [`BigBirdPegasusForCausalLM`] (BigBird-Pegasus model)\n",
+ " - **biogpt** -- [`BioGptForCausalLM`] (BioGpt model)\n",
+ " - **blenderbot** -- [`BlenderbotForCausalLM`] (Blenderbot model)\n",
+ " - **blenderbot-small** -- [`BlenderbotSmallForCausalLM`] (BlenderbotSmall model)\n",
+ " - **bloom** -- [`BloomForCausalLM`] (BLOOM model)\n",
+ " - **camembert** -- [`CamembertForCausalLM`] (CamemBERT model)\n",
+ " - **code_llama** -- [`LlamaForCausalLM`] (CodeLlama model)\n",
+ " - **codegen** -- [`CodeGenForCausalLM`] (CodeGen model)\n",
+ " - **cohere** -- [`CohereForCausalLM`] (Cohere model)\n",
+ " - **cpmant** -- [`CpmAntForCausalLM`] (CPM-Ant model)\n",
+ " - **ctrl** -- [`CTRLLMHeadModel`] (CTRL model)\n",
+ " - **data2vec-text** -- [`Data2VecTextForCausalLM`] (Data2VecText model)\n",
+ " - **dbrx** -- [`DbrxForCausalLM`] (DBRX model)\n",
+ " - **electra** -- [`ElectraForCausalLM`] (ELECTRA model)\n",
+ " - **ernie** -- [`ErnieForCausalLM`] (ERNIE model)\n",
+ " - **falcon** -- [`FalconForCausalLM`] (Falcon model)\n",
+ " - **fuyu** -- [`FuyuForCausalLM`] (Fuyu model)\n",
+ " - **gemma** -- [`GemmaForCausalLM`] (Gemma model)\n",
+ " - **gemma2** -- [`Gemma2ForCausalLM`] (Gemma2 model)\n",
+ " - **git** -- [`GitForCausalLM`] (GIT model)\n",
+ " - **gpt-sw3** -- [`GPT2LMHeadModel`] (GPT-Sw3 model)\n",
+ " - **gpt2** -- [`GPT2LMHeadModel`] (OpenAI GPT-2 model)\n",
+ " - **gpt_bigcode** -- [`GPTBigCodeForCausalLM`] (GPTBigCode model)\n",
+ " - **gpt_neo** -- [`GPTNeoForCausalLM`] (GPT Neo model)\n",
+ " - **gpt_neox** -- [`GPTNeoXForCausalLM`] (GPT NeoX model)\n",
+ " - **gpt_neox_japanese** -- [`GPTNeoXJapaneseForCausalLM`] (GPT NeoX Japanese model)\n",
+ " - **gptj** -- [`GPTJForCausalLM`] (GPT-J model)\n",
+ " - **jamba** -- [`JambaForCausalLM`] (Jamba model)\n",
+ " - **jetmoe** -- [`JetMoeForCausalLM`] (JetMoe model)\n",
+ " - **llama** -- [`LlamaForCausalLM`] (LLaMA model)\n",
+ " - **mamba** -- [`MambaForCausalLM`] (Mamba model)\n",
+ " - **marian** -- [`MarianForCausalLM`] (Marian model)\n",
+ " - **mbart** -- [`MBartForCausalLM`] (mBART model)\n",
+ " - **mega** -- [`MegaForCausalLM`] (MEGA model)\n",
+ " - **megatron-bert** -- [`MegatronBertForCausalLM`] (Megatron-BERT model)\n",
+ " - **mistral** -- [`MistralForCausalLM`] (Mistral model)\n",
+ " - **mixtral** -- [`MixtralForCausalLM`] (Mixtral model)\n",
+ " - **mpt** -- [`MptForCausalLM`] (MPT model)\n",
+ " - **musicgen** -- [`MusicgenForCausalLM`] (MusicGen model)\n",
+ " - **musicgen_melody** -- [`MusicgenMelodyForCausalLM`] (MusicGen Melody model)\n",
+ " - **mvp** -- [`MvpForCausalLM`] (MVP model)\n",
+ " - **olmo** -- [`OlmoForCausalLM`] (OLMo model)\n",
+ " - **open-llama** -- [`OpenLlamaForCausalLM`] (OpenLlama model)\n",
+ " - **openai-gpt** -- [`OpenAIGPTLMHeadModel`] (OpenAI GPT model)\n",
+ " - **opt** -- [`OPTForCausalLM`] (OPT model)\n",
+ " - **pegasus** -- [`PegasusForCausalLM`] (Pegasus model)\n",
+ " - **persimmon** -- [`PersimmonForCausalLM`] (Persimmon model)\n",
+ " - **phi** -- [`PhiForCausalLM`] (Phi model)\n",
+ " - **phi3** -- [`Phi3ForCausalLM`] (Phi3 model)\n",
+ " - **plbart** -- [`PLBartForCausalLM`] (PLBart model)\n",
+ " - **prophetnet** -- [`ProphetNetForCausalLM`] (ProphetNet model)\n",
+ " - **qdqbert** -- [`QDQBertLMHeadModel`] (QDQBert model)\n",
+ " - **qwen2** -- [`Qwen2ForCausalLM`] (Qwen2 model)\n",
+ " - **qwen2_moe** -- [`Qwen2MoeForCausalLM`] (Qwen2MoE model)\n",
+ " - **recurrent_gemma** -- [`RecurrentGemmaForCausalLM`] (RecurrentGemma model)\n",
+ " - **reformer** -- [`ReformerModelWithLMHead`] (Reformer model)\n",
+ " - **rembert** -- [`RemBertForCausalLM`] (RemBERT model)\n",
+ " - **roberta** -- [`RobertaForCausalLM`] (RoBERTa model)\n",
+ " - **roberta-prelayernorm** -- [`RobertaPreLayerNormForCausalLM`] (RoBERTa-PreLayerNorm model)\n",
+ " - **roc_bert** -- [`RoCBertForCausalLM`] (RoCBert model)\n",
+ " - **roformer** -- [`RoFormerForCausalLM`] (RoFormer model)\n",
+ " - **rwkv** -- [`RwkvForCausalLM`] (RWKV model)\n",
+ " - **speech_to_text_2** -- [`Speech2Text2ForCausalLM`] (Speech2Text2 model)\n",
+ " - **stablelm** -- [`StableLmForCausalLM`] (StableLm model)\n",
+ " - **starcoder2** -- [`Starcoder2ForCausalLM`] (Starcoder2 model)\n",
+ " - **transfo-xl** -- [`TransfoXLLMHeadModel`] (Transformer-XL model)\n",
+ " - **trocr** -- [`TrOCRForCausalLM`] (TrOCR model)\n",
+ " - **whisper** -- [`WhisperForCausalLM`] (Whisper model)\n",
+ " - **xglm** -- [`XGLMForCausalLM`] (XGLM model)\n",
+ " - **xlm** -- [`XLMWithLMHeadModel`] (XLM model)\n",
+ " - **xlm-prophetnet** -- [`XLMProphetNetForCausalLM`] (XLM-ProphetNet model)\n",
+ " - **xlm-roberta** -- [`XLMRobertaForCausalLM`] (XLM-RoBERTa model)\n",
+ " - **xlm-roberta-xl** -- [`XLMRobertaXLForCausalLM`] (XLM-RoBERTa-XL model)\n",
+ " - **xlnet** -- [`XLNetLMHeadModel`] (XLNet model)\n",
+ " - **xmod** -- [`XmodForCausalLM`] (X-MOD model)\n",
+ " - **zamba2** -- [`Zamba2ForCausalLM`] (Zamba2 model)\n",
+ " \n",
+ " The model is set in evaluation mode by default using `model.eval()` (so for instance, dropout modules are\n",
+ " deactivated). To train the model, you should first set it back in training mode with `model.train()`\n",
+ " \n",
+ " Args:\n",
+ " pretrained_model_name_or_path (`str` or `os.PathLike`):\n",
+ " Can be either:\n",
+ " \n",
+ " - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.\n",
+ " - A path to a *directory* containing model weights saved using\n",
+ " [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.\n",
+ " - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In\n",
+ " this case, `from_tf` should be set to `True` and a configuration object should be provided as\n",
+ " `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a\n",
+ " PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.\n",
+ " model_args (additional positional arguments, *optional*):\n",
+ " Will be passed along to the underlying model `__init__()` method.\n",
+ " config ([`PretrainedConfig`], *optional*):\n",
+ " Configuration for the model to use instead of an automatically loaded configuration. Configuration can\n",
+ " be automatically loaded when:\n",
+ " \n",
+ " - The model is a model provided by the library (loaded with the *model id* string of a pretrained\n",
+ " model).\n",
+ " - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the\n",
+ " save directory.\n",
+ " - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a\n",
+ " configuration JSON file named *config.json* is found in the directory.\n",
+ " state_dict (*Dict[str, torch.Tensor]*, *optional*):\n",
+ " A state dictionary to use instead of a state dictionary loaded from saved weights file.\n",
+ " \n",
+ " This option can be used if you want to create a model from a pretrained configuration but load your own\n",
+ " weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and\n",
+ " [`~PreTrainedModel.from_pretrained`] is not a simpler option.\n",
+ " cache_dir (`str` or `os.PathLike`, *optional*):\n",
+ " Path to a directory in which a downloaded pretrained model configuration should be cached if the\n",
+ " standard cache should not be used.\n",
+ " from_tf (`bool`, *optional*, defaults to `False`):\n",
+ " Load the model weights from a TensorFlow checkpoint save file (see docstring of\n",
+ " `pretrained_model_name_or_path` argument).\n",
+ " force_download (`bool`, *optional*, defaults to `False`):\n",
+ " Whether or not to force the (re-)download of the model weights and configuration files, overriding the\n",
+ " cached versions if they exist.\n",
+ " resume_download:\n",
+ " Deprecated and ignored. All downloads are now resumed by default when possible.\n",
+ " Will be removed in v5 of Transformers.\n",
+ " proxies (`Dict[str, str]`, *optional*):\n",
+ " A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',\n",
+ " 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.\n",
+ " output_loading_info(`bool`, *optional*, defaults to `False`):\n",
+ " Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.\n",
+ " local_files_only(`bool`, *optional*, defaults to `False`):\n",
+ " Whether or not to only look at local files (e.g., not try downloading the model).\n",
+ " revision (`str`, *optional*, defaults to `\"main\"`):\n",
+ " The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a\n",
+ " git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any\n",
+ " identifier allowed by git.\n",
+ " trust_remote_code (`bool`, *optional*, defaults to `False`):\n",
+ " Whether or not to allow for custom models defined on the Hub in their own modeling files. This option\n",
+ " should only be set to `True` for repositories you trust and in which you have read the code, as it will\n",
+ " execute code present on the Hub on your local machine.\n",
+ " code_revision (`str`, *optional*, defaults to `\"main\"`):\n",
+ " The specific revision to use for the code on the Hub, if the code leaves in a different repository than\n",
+ " the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based\n",
+ " system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier\n",
+ " allowed by git.\n",
+ " kwargs (additional keyword arguments, *optional*):\n",
+ " Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,\n",
+ " `output_attentions=True`). Behaves differently depending on whether a `config` is provided or\n",
+ " automatically loaded:\n",
+ " \n",
+ " - If a configuration is provided with `config`, `**kwargs` will be directly passed to the\n",
+ " underlying model's `__init__` method (we assume all relevant updates to the configuration have\n",
+ " already been done)\n",
+ " - If a configuration is not provided, `kwargs` will be first passed to the configuration class\n",
+ " initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that\n",
+ " corresponds to a configuration attribute will be used to override said attribute with the\n",
+ " supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute\n",
+ " will be passed to the underlying model's `__init__` function.\n",
+ " \n",
+ " Examples:\n",
+ " \n",
+ " ```python\n",
+ " >>> from transformers import AutoConfig, AutoModelForCausalLM\n",
+ " \n",
+ " >>> # Download model and configuration from huggingface.co and cache.\n",
+ " >>> model = AutoModelForCausalLM.from_pretrained(\"google-bert/bert-base-cased\")\n",
+ " \n",
+ " >>> # Update configuration during loading\n",
+ " >>> model = AutoModelForCausalLM.from_pretrained(\"google-bert/bert-base-cased\", output_attentions=True)\n",
+ " >>> model.config.output_attentions\n",
+ " True\n",
+ " \n",
+ " >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)\n",
+ " >>> config = AutoConfig.from_pretrained(\"./tf_model/bert_tf_model_config.json\")\n",
+ " >>> model = AutoModelForCausalLM.from_pretrained(\n",
+ " ... \"./tf_model/bert_tf_checkpoint.ckpt.index\", from_tf=True, config=config\n",
+ " ... )\n",
+ " ```\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "K62B1nCQ7ul3"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
|